瀏覽代碼

feat: 增强OCR工具类型检测功能,支持条件检查和嵌套字段路径

zhch158_admin 1 月之前
父節點
當前提交
4d655f7e2d
共有 1 個文件被更改,包括 141 次插入19 次删除
  1. 141 19
      ocr_validator_utils.py

+ 141 - 19
ocr_validator_utils.py

@@ -413,32 +413,154 @@ def detect_mineru_structure(data: Union[List, Dict]) -> bool:
     return False
 
 def detect_ocr_tool_type(data: Union[List, Dict], config: Dict) -> str:
-    """自动检测OCR工具类型 - 增强版"""
+    """
+    自动检测OCR工具类型 - 增强版
+    
+    Args:
+        data: OCR数据(可能是列表或字典)
+        config: 配置字典
+    
+    Returns:
+        工具类型字符串
+    """
     if not config['ocr']['auto_detection']['enabled']:
-        return 'dots_ocr'  # 默认类型
+        return 'mineru'  # 默认类型
     
     rules = config['ocr']['auto_detection']['rules']
     
-    for rule in rules:
-        # 检查字段存在性
-        if 'field_exists' in rule:
-            field_name = rule['field_exists']
-            if isinstance(data, dict) and field_name in data:
-                return rule['tool_type']
-            elif isinstance(data, list) and data and isinstance(data[0], dict) and field_name in data[0]:
-                # 如果是list,检查第一个元素
-                return rule['tool_type']
+    # 按优先级排序
+    sorted_rules = sorted(rules, key=lambda x: x.get('priority', 999))
+    
+    for rule in sorted_rules:
+        tool_type = rule['tool_type']
+        conditions = rule.get('conditions', [])
         
-        # 检查是否为数组
-        if 'json_is_array' in rule:
-            if rule['json_is_array'] and isinstance(data, list):
-                # 进一步区分是dots_ocr还是mineru
-                if not detect_mineru_structure(data):
-                    return rule['tool_type']
-    
-    # 默认返回dots_ocr
+        # 检查所有条件是否满足
+        if _check_all_conditions(data, conditions):
+            return tool_type
+    
+    # 如果所有规则都不匹配,返回默认类型
     return 'dots_ocr'
 
+
+def _check_all_conditions(data: Union[List, Dict], conditions: List[Dict]) -> bool:
+    """
+    检查所有条件是否满足
+    
+    Args:
+        data: 数据
+        conditions: 条件列表
+    
+    Returns:
+        是否所有条件都满足
+    """
+    for condition in conditions:
+        condition_type = condition.get('type', '')
+        
+        if condition_type == 'field_exists':
+            # 检查字段存在
+            field = condition.get('field', '')
+            if not _check_field_exists(data, field):
+                return False
+        
+        elif condition_type == 'field_not_exists':
+            # 检查字段不存在
+            field = condition.get('field', '')
+            if _check_field_exists(data, field):
+                return False
+        
+        elif condition_type == 'json_structure':
+            # 检查JSON结构类型
+            expected_structure = condition.get('structure', '')
+            if expected_structure == 'array' and not isinstance(data, list):
+                return False
+            elif expected_structure == 'object' and not isinstance(data, dict):
+                return False
+        
+        elif condition_type == 'field_value':
+            # 检查字段值
+            field = condition.get('field', '')
+            expected_value = condition.get('value')
+            actual_value = _get_field_value(data, field)
+            if actual_value != expected_value:
+                return False
+        
+        elif condition_type == 'field_contains':
+            # 检查字段包含某个值
+            field = condition.get('field', '')
+            expected_values = condition.get('values', [])
+            actual_value = _get_field_value(data, field)
+            if actual_value not in expected_values:
+                return False
+    
+    return True
+
+
+def _check_field_exists(data: Union[List, Dict], field_path: str) -> bool:
+    """
+    检查字段是否存在(支持嵌套路径)
+    
+    Args:
+        data: 数据
+        field_path: 字段路径(支持点分隔,如 "doc_preprocessor_res.angle")
+    
+    Returns:
+        字段是否存在
+    """
+    if not field_path:
+        return False
+    
+    # 处理数组情况:检查第一个元素
+    if isinstance(data, list):
+        if not data:
+            return False
+        data = data[0]
+    
+    # 处理嵌套字段路径
+    fields = field_path.split('.')
+    current = data
+    
+    for field in fields:
+        if isinstance(current, dict) and field in current:
+            current = current[field]
+        else:
+            return False
+    
+    return True
+
+
+def _get_field_value(data: Union[List, Dict], field_path: str):
+    """
+    获取字段值(支持嵌套路径)
+    
+    Args:
+        data: 数据
+        field_path: 字段路径
+    
+    Returns:
+        字段值,如果不存在返回 None
+    """
+    if not field_path:
+        return None
+    
+    # 处理数组情况:检查第一个元素
+    if isinstance(data, list):
+        if not data:
+            return None
+        data = data[0]
+    
+    # 处理嵌套字段路径
+    fields = field_path.split('.')
+    current = data
+    
+    for field in fields:
+        if isinstance(current, dict) and field in current:
+            current = current[field]
+        else:
+            return None
+    
+    return current
+
 def normalize_ocr_data(raw_data: Union[List, Dict], config: Dict) -> List[Dict]:
     """标准化OCR数据 - 支持多种工具"""
     tool_type = detect_ocr_tool_type(raw_data, config)