|
|
@@ -413,32 +413,154 @@ def detect_mineru_structure(data: Union[List, Dict]) -> bool:
|
|
|
return False
|
|
|
|
|
|
def detect_ocr_tool_type(data: Union[List, Dict], config: Dict) -> str:
|
|
|
- """自动检测OCR工具类型 - 增强版"""
|
|
|
+ """
|
|
|
+ 自动检测OCR工具类型 - 增强版
|
|
|
+
|
|
|
+ Args:
|
|
|
+ data: OCR数据(可能是列表或字典)
|
|
|
+ config: 配置字典
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 工具类型字符串
|
|
|
+ """
|
|
|
if not config['ocr']['auto_detection']['enabled']:
|
|
|
- return 'dots_ocr' # 默认类型
|
|
|
+ return 'mineru' # 默认类型
|
|
|
|
|
|
rules = config['ocr']['auto_detection']['rules']
|
|
|
|
|
|
- for rule in rules:
|
|
|
- # 检查字段存在性
|
|
|
- if 'field_exists' in rule:
|
|
|
- field_name = rule['field_exists']
|
|
|
- if isinstance(data, dict) and field_name in data:
|
|
|
- return rule['tool_type']
|
|
|
- elif isinstance(data, list) and data and isinstance(data[0], dict) and field_name in data[0]:
|
|
|
- # 如果是list,检查第一个元素
|
|
|
- return rule['tool_type']
|
|
|
+ # 按优先级排序
|
|
|
+ sorted_rules = sorted(rules, key=lambda x: x.get('priority', 999))
|
|
|
+
|
|
|
+ for rule in sorted_rules:
|
|
|
+ tool_type = rule['tool_type']
|
|
|
+ conditions = rule.get('conditions', [])
|
|
|
|
|
|
- # 检查是否为数组
|
|
|
- if 'json_is_array' in rule:
|
|
|
- if rule['json_is_array'] and isinstance(data, list):
|
|
|
- # 进一步区分是dots_ocr还是mineru
|
|
|
- if not detect_mineru_structure(data):
|
|
|
- return rule['tool_type']
|
|
|
-
|
|
|
- # 默认返回dots_ocr
|
|
|
+ # 检查所有条件是否满足
|
|
|
+ if _check_all_conditions(data, conditions):
|
|
|
+ return tool_type
|
|
|
+
|
|
|
+ # 如果所有规则都不匹配,返回默认类型
|
|
|
return 'dots_ocr'
|
|
|
|
|
|
+
|
|
|
+def _check_all_conditions(data: Union[List, Dict], conditions: List[Dict]) -> bool:
|
|
|
+ """
|
|
|
+ 检查所有条件是否满足
|
|
|
+
|
|
|
+ Args:
|
|
|
+ data: 数据
|
|
|
+ conditions: 条件列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 是否所有条件都满足
|
|
|
+ """
|
|
|
+ for condition in conditions:
|
|
|
+ condition_type = condition.get('type', '')
|
|
|
+
|
|
|
+ if condition_type == 'field_exists':
|
|
|
+ # 检查字段存在
|
|
|
+ field = condition.get('field', '')
|
|
|
+ if not _check_field_exists(data, field):
|
|
|
+ return False
|
|
|
+
|
|
|
+ elif condition_type == 'field_not_exists':
|
|
|
+ # 检查字段不存在
|
|
|
+ field = condition.get('field', '')
|
|
|
+ if _check_field_exists(data, field):
|
|
|
+ return False
|
|
|
+
|
|
|
+ elif condition_type == 'json_structure':
|
|
|
+ # 检查JSON结构类型
|
|
|
+ expected_structure = condition.get('structure', '')
|
|
|
+ if expected_structure == 'array' and not isinstance(data, list):
|
|
|
+ return False
|
|
|
+ elif expected_structure == 'object' and not isinstance(data, dict):
|
|
|
+ return False
|
|
|
+
|
|
|
+ elif condition_type == 'field_value':
|
|
|
+ # 检查字段值
|
|
|
+ field = condition.get('field', '')
|
|
|
+ expected_value = condition.get('value')
|
|
|
+ actual_value = _get_field_value(data, field)
|
|
|
+ if actual_value != expected_value:
|
|
|
+ return False
|
|
|
+
|
|
|
+ elif condition_type == 'field_contains':
|
|
|
+ # 检查字段包含某个值
|
|
|
+ field = condition.get('field', '')
|
|
|
+ expected_values = condition.get('values', [])
|
|
|
+ actual_value = _get_field_value(data, field)
|
|
|
+ if actual_value not in expected_values:
|
|
|
+ return False
|
|
|
+
|
|
|
+ return True
|
|
|
+
|
|
|
+
|
|
|
+def _check_field_exists(data: Union[List, Dict], field_path: str) -> bool:
|
|
|
+ """
|
|
|
+ 检查字段是否存在(支持嵌套路径)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ data: 数据
|
|
|
+ field_path: 字段路径(支持点分隔,如 "doc_preprocessor_res.angle")
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 字段是否存在
|
|
|
+ """
|
|
|
+ if not field_path:
|
|
|
+ return False
|
|
|
+
|
|
|
+ # 处理数组情况:检查第一个元素
|
|
|
+ if isinstance(data, list):
|
|
|
+ if not data:
|
|
|
+ return False
|
|
|
+ data = data[0]
|
|
|
+
|
|
|
+ # 处理嵌套字段路径
|
|
|
+ fields = field_path.split('.')
|
|
|
+ current = data
|
|
|
+
|
|
|
+ for field in fields:
|
|
|
+ if isinstance(current, dict) and field in current:
|
|
|
+ current = current[field]
|
|
|
+ else:
|
|
|
+ return False
|
|
|
+
|
|
|
+ return True
|
|
|
+
|
|
|
+
|
|
|
+def _get_field_value(data: Union[List, Dict], field_path: str):
|
|
|
+ """
|
|
|
+ 获取字段值(支持嵌套路径)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ data: 数据
|
|
|
+ field_path: 字段路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 字段值,如果不存在返回 None
|
|
|
+ """
|
|
|
+ if not field_path:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 处理数组情况:检查第一个元素
|
|
|
+ if isinstance(data, list):
|
|
|
+ if not data:
|
|
|
+ return None
|
|
|
+ data = data[0]
|
|
|
+
|
|
|
+ # 处理嵌套字段路径
|
|
|
+ fields = field_path.split('.')
|
|
|
+ current = data
|
|
|
+
|
|
|
+ for field in fields:
|
|
|
+ if isinstance(current, dict) and field in current:
|
|
|
+ current = current[field]
|
|
|
+ else:
|
|
|
+ return None
|
|
|
+
|
|
|
+ return current
|
|
|
+
|
|
|
def normalize_ocr_data(raw_data: Union[List, Dict], config: Dict) -> List[Dict]:
|
|
|
"""标准化OCR数据 - 支持多种工具"""
|
|
|
tool_type = detect_ocr_tool_type(raw_data, config)
|