Преглед изворни кода

feat: 添加表格分类器支持,优化表格识别路径选择

zhch158_admin пре 1 дан
родитељ
комит
630cf15a2d
1 измењених фајлова са 41 додато и 7 уклоњено
  1. 41 7
      ocr_tools/universal_doc_parser/core/pipeline_manager_v2.py

+ 41 - 7
ocr_tools/universal_doc_parser/core/pipeline_manager_v2.py

@@ -139,7 +139,17 @@ class EnhancedDocPipeline:
                 self.config['ocr_recognition']
             )
 
-            # 5. 有线表格识别器(可选)
+            # 5. 表格分类器(可选)
+            self.table_classifier = None
+            table_cls_config = self.config.get('table_classification', {})
+            if table_cls_config.get('enabled', False):
+                try:
+                    self.table_classifier = ModelFactory.create_table_classifier(table_cls_config)
+                    logger.info("✅ Table classifier initialized")
+                except Exception as e:
+                    logger.warning(f"⚠️ Table classifier init failed: {e}")
+
+            # 6. 有线表格识别器(可选)
             self.table_config = self.config.get('table_recognition_wired', {})
             self.wired_table_recognizer = None
             if self.table_config.get('use_wired_unet', False):
@@ -178,6 +188,7 @@ class EnhancedDocPipeline:
             vl_recognizer=self.vl_recognizer,
             table_cell_matcher=table_cell_matcher,
             wired_table_recognizer=getattr(self, 'wired_table_recognizer', None),
+            table_classifier=getattr(self, 'table_classifier', None),
         )
     
     # ==================== 主处理流程 ====================
@@ -784,12 +795,32 @@ class EnhancedDocPipeline:
             try:
                 spans = get_matched_spans_for_item(item)
                 
-                # 🔑 关键:根据配置选择表格识别路径
+                # 🔑 关键:智能选择表格识别路径(支持自动分类)
                 use_wired_unet = self.table_config.get('use_wired_unet', False)
+                use_table_classification = self.config.get('table_classification', {}).get('enabled', False)
+                
+                # Step 1: 如果启用了自动分类,先对表格进行分类
+                table_type = None
+                if use_table_classification and self.table_classifier:
+                    bbox = item.get('bbox', [])
+                    table_img = CoordinateUtils.crop_region(detection_image, bbox)
+                    cls_result = self.table_classifier.classify(table_img)
+                    table_type = cls_result.get('table_type', 'wireless')
+                    confidence = cls_result.get('confidence', 0.0)
+                    logger.info(f"📊 Table {idx} classified as '{table_type}' (conf: {confidence:.3f})")
+                
+                # Step 2: 根据分类结果或配置选择识别器
+                should_use_wired = False
+                if use_table_classification:
+                    # 自动分类模式:根据分类结果决定
+                    should_use_wired = (table_type == 'wired' and self.wired_table_recognizer)
+                # else:
+                #     # 手动配置模式:根据配置决定
+                #     should_use_wired = (use_wired_unet and self.wired_table_recognizer)
                 
-                if use_wired_unet and self.wired_table_recognizer:
+                if should_use_wired:
                     # 有线表格路径:UNet 识别
-                    logger.info(f"🔷 Using wired UNet table recognition (configured)")
+                    logger.info(f"🔷 Table {idx}: Using wired UNet recognition")
                     element = self.element_processors.process_table_element_wired(
                         detection_image, item, scale, pre_matched_spans=spans, pdf_type=pdf_type,
                         output_dir=output_dir, basename=f"{basename}_{idx}"
@@ -797,10 +828,13 @@ class EnhancedDocPipeline:
                     
                     # 如果有线识别失败(返回空 HTML),fallback 到 VLM
                     if not element['content'].get('html') and not element['content'].get('cells'):
-                        raise ValueError(f"Wired UNet table recognition failed, element: {item}")
+                        logger.warning(f"⚠️ Wired recognition failed for table {idx}, fallback to VLM")
+                        element = self.element_processors.process_table_element_vlm(
+                            detection_image, item, scale, pre_matched_spans=spans
+                        )
                 else:
                     # VLM 无线表格路径(默认)
-                    logger.info(f"🔷 Using VLM table recognition (configured)")
+                    logger.info(f"🔷 Table {idx}: Using VLM recognition")
                     element = self.element_processors.process_table_element_vlm(
                         detection_image, item, scale, pre_matched_spans=spans
                     )
@@ -867,7 +901,7 @@ class EnhancedDocPipeline:
                 self.preprocessor.cleanup()
             if hasattr(self, 'layout_detector'):
                 self.layout_detector.cleanup()
-            if hasattr(self, 'vl_recognizer'):
+            if hasattr(self, 'vl_recognizer') and self.vl_recognizer is not None:
                 self.vl_recognizer.cleanup()
             if hasattr(self, 'ocr_recognizer'):
                 self.ocr_recognizer.cleanup()