Sfoglia il codice sorgente

优化ppstructv3数据解析逻辑,新增table_recognition_v2数据解析功能,支持嵌套字段提取

zhch158_admin 1 mese fa
parent
commit
e3ebf68f19
1 ha cambiato i file con 74 aggiunte e 35 eliminazioni
  1. 74 35
      ocr_validator_utils.py

+ 74 - 35
ocr_validator_utils.py

@@ -333,12 +333,7 @@ def parse_ppstructv3_data(data: Dict, config: Dict) -> List[Dict]:
     tool_config = config['ocr']['tools']['ppstructv3']
     parsed_data = []
     
-    # 获取解析结果列表
-    parsing_results_field = tool_config['parsing_results_field']
-    if parsing_results_field not in data:
-        return parsed_data
-    
-    parsing_results = data[parsing_results_field]
+    parsing_results = data.get(tool_config['parsing_results_field'], [])
     if not isinstance(parsing_results, list):
         return parsed_data
     
@@ -346,55 +341,87 @@ def parse_ppstructv3_data(data: Dict, config: Dict) -> List[Dict]:
         if not isinstance(item, dict):
             continue
             
-        # 提取字段
         text = item.get(tool_config['text_field'], '')
         bbox = item.get(tool_config['bbox_field'], [])
         category = item.get(tool_config['category_field'], 'text')
-        confidence = item.get(tool_config.get('confidence_field', 'confidence'), 
-                            config['ocr']['default_confidence'])
+        confidence = item.get(
+            tool_config.get('confidence_field', 'confidence'),
+            config['ocr']['default_confidence']
+        )
         
         if text and bbox and len(bbox) >= 4:
             parsed_data.append({
                 'text': str(text).strip(),
-                'bbox': bbox[:4],  # 确保只取前4个坐标
+                'bbox': bbox[:4],
                 'category': category,
                 'confidence': confidence,
                 'source_tool': 'ppstructv3'
             })
     
-    # 如果有OCR文本识别结果,也添加进来
-    if 'overall_ocr_res' in data:
-        ocr_res = data['overall_ocr_res']
-        if isinstance(ocr_res, dict) and 'rec_texts' in ocr_res and 'rec_boxes' in ocr_res:
-            texts = ocr_res['rec_texts']
-            boxes = ocr_res['rec_boxes']
-            scores = ocr_res.get('rec_scores', [])
-            
-            for i, (text, box) in enumerate(zip(texts, boxes)):
-                if text and len(box) >= 4:
-                    confidence = scores[i] if i < len(scores) else config['ocr']['default_confidence']
-                    parsed_data.append({
-                        'text': str(text).strip(),
-                        'bbox': box[:4],
-                        'category': 'OCR_Text',
-                        'confidence': confidence,
-                        'source_tool': 'ppstructv3_ocr'
-                    })
+    rec_texts = get_nested_value(data, tool_config.get('rec_texts_field', ''))
+    rec_boxes = get_nested_value(data, tool_config.get('rec_boxes_field', ''))
+    if isinstance(rec_texts, list) and isinstance(rec_boxes, list):
+        for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)):
+            if text and isinstance(box, list) and len(box) >= 4:
+                parsed_data.append({
+                    'text': str(text).strip(),
+                    'bbox': box[:4],
+                    'category': 'OCR_Text',
+                    'source_tool': 'ppstructv3_ocr'
+                })
     
     return parsed_data
 
 
+def parse_table_recognition_v2_data(data: Dict, config: Dict) -> List[Dict]:
+    tool_config = config['ocr']['tools']['table_recognition_v2']
+    parsed_data = []
+    tables = data.get(tool_config['parsing_results_field'], [])
+    if not isinstance(tables, list):
+        return parsed_data
+
+    for item in tables:
+        if not isinstance(item, dict):
+            continue
+
+        html_text = item.get(tool_config['text_field'], '')
+        bbox = item.get(tool_config['bbox_field'], [])
+        if bbox and len(bbox) >= 4:
+            bbox = bbox[:4]
+        else:
+            bbox = [0, 0, 0, 0]
+
+        parsed_data.append({
+            'text': str(html_text).strip(),
+            'bbox': bbox,
+            'category': item.get(tool_config.get('category_field', ''), 'table'),
+            'confidence': item.get(tool_config.get('confidence_field', ''), config['ocr']['default_confidence']),
+            'source_tool': 'table_recognition_v2',
+        })
+
+    rec_texts = get_nested_value(data, tool_config.get('rec_texts_field', ''))
+    rec_boxes = get_nested_value(data, tool_config.get('rec_boxes_field', ''))
+    if isinstance(rec_texts, list) and isinstance(rec_boxes, list):
+        for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)):
+            if text and isinstance(box, list) and len(box) >= 4:
+                parsed_data.append({
+                    'text': str(text).strip(),
+                    'bbox': box[:4],
+                    'category': 'OCR_Text',
+                    'source_tool': 'ppstructv3_ocr'
+                })
+    
+    return parsed_data
+
 def normalize_ocr_data(raw_data: Union[List, Dict], config: Dict) -> List[Dict]:
-    """统一不同OCR工具的数据格式"""
-    # 自动检测OCR工具类型
     tool_type = detect_ocr_tool_type(raw_data, config)
-    
     if tool_type == 'dots_ocr':
         return parse_dots_ocr_data(raw_data, config)
-    elif tool_type == 'ppstructv3':
+    if tool_type == 'ppstructv3':
         return parse_ppstructv3_data(raw_data, config)
-    else:
-        raise ValueError(f"不支持的OCR工具类型: {tool_type}")
+    if tool_type == 'table_recognition_v2':
+        return parse_table_recognition_v2_data(raw_data, config)
+    raise ValueError(f"不支持的OCR工具类型: {tool_type}")
 
 
 def get_rotation_angle_from_ppstructv3(data: Dict) -> float:
@@ -1096,4 +1123,16 @@ def get_data_source_display_name(source_config: Dict) -> str:
     }
     
     tool_display = tool_name_map.get(tool, tool)
-    return f"{name} ({tool_display})"
+    return f"{name} ({tool_display})"
+
+def get_nested_value(data: Dict, path: str, default=None):
+    if not path:
+        return default
+    keys = path.split('.')
+    value = data
+    for key in keys:
+        if isinstance(value, dict) and key in value:
+            value = value[key]
+        else:
+            return default
+    return value