|
|
@@ -139,7 +139,17 @@ class EnhancedDocPipeline:
|
|
|
self.config['ocr_recognition']
|
|
|
)
|
|
|
|
|
|
- # 5. 有线表格识别器(可选)
|
|
|
+ # 5. 表格分类器(可选)
|
|
|
+ self.table_classifier = None
|
|
|
+ table_cls_config = self.config.get('table_classification', {})
|
|
|
+ if table_cls_config.get('enabled', False):
|
|
|
+ try:
|
|
|
+ self.table_classifier = ModelFactory.create_table_classifier(table_cls_config)
|
|
|
+ logger.info("✅ Table classifier initialized")
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"⚠️ Table classifier init failed: {e}")
|
|
|
+
|
|
|
+ # 6. 有线表格识别器(可选)
|
|
|
self.table_config = self.config.get('table_recognition_wired', {})
|
|
|
self.wired_table_recognizer = None
|
|
|
if self.table_config.get('use_wired_unet', False):
|
|
|
@@ -178,6 +188,7 @@ class EnhancedDocPipeline:
|
|
|
vl_recognizer=self.vl_recognizer,
|
|
|
table_cell_matcher=table_cell_matcher,
|
|
|
wired_table_recognizer=getattr(self, 'wired_table_recognizer', None),
|
|
|
+ table_classifier=getattr(self, 'table_classifier', None),
|
|
|
)
|
|
|
|
|
|
# ==================== 主处理流程 ====================
|
|
|
@@ -784,12 +795,32 @@ class EnhancedDocPipeline:
|
|
|
try:
|
|
|
spans = get_matched_spans_for_item(item)
|
|
|
|
|
|
- # 🔑 关键:根据配置选择表格识别路径
|
|
|
+ # 🔑 关键:智能选择表格识别路径(支持自动分类)
|
|
|
use_wired_unet = self.table_config.get('use_wired_unet', False)
|
|
|
+ use_table_classification = self.config.get('table_classification', {}).get('enabled', False)
|
|
|
+
|
|
|
+ # Step 1: 如果启用了自动分类,先对表格进行分类
|
|
|
+ table_type = None
|
|
|
+ if use_table_classification and self.table_classifier:
|
|
|
+ bbox = item.get('bbox', [])
|
|
|
+ table_img = CoordinateUtils.crop_region(detection_image, bbox)
|
|
|
+ cls_result = self.table_classifier.classify(table_img)
|
|
|
+ table_type = cls_result.get('table_type', 'wireless')
|
|
|
+ confidence = cls_result.get('confidence', 0.0)
|
|
|
+ logger.info(f"📊 Table {idx} classified as '{table_type}' (conf: {confidence:.3f})")
|
|
|
+
|
|
|
+ # Step 2: 根据分类结果或配置选择识别器
|
|
|
+ should_use_wired = False
|
|
|
+ if use_table_classification:
|
|
|
+ # 自动分类模式:根据分类结果决定
|
|
|
+ should_use_wired = (table_type == 'wired' and self.wired_table_recognizer)
|
|
|
+ # else:
|
|
|
+ # # 手动配置模式:根据配置决定
|
|
|
+ # should_use_wired = (use_wired_unet and self.wired_table_recognizer)
|
|
|
|
|
|
- if use_wired_unet and self.wired_table_recognizer:
|
|
|
+ if should_use_wired:
|
|
|
# 有线表格路径:UNet 识别
|
|
|
- logger.info(f"🔷 Using wired UNet table recognition (configured)")
|
|
|
+ logger.info(f"🔷 Table {idx}: Using wired UNet recognition")
|
|
|
element = self.element_processors.process_table_element_wired(
|
|
|
detection_image, item, scale, pre_matched_spans=spans, pdf_type=pdf_type,
|
|
|
output_dir=output_dir, basename=f"{basename}_{idx}"
|
|
|
@@ -797,10 +828,13 @@ class EnhancedDocPipeline:
|
|
|
|
|
|
# 如果有线识别失败(返回空 HTML),fallback 到 VLM
|
|
|
if not element['content'].get('html') and not element['content'].get('cells'):
|
|
|
- raise ValueError(f"Wired UNet table recognition failed, element: {item}")
|
|
|
+ logger.warning(f"⚠️ Wired recognition failed for table {idx}, fallback to VLM")
|
|
|
+ element = self.element_processors.process_table_element_vlm(
|
|
|
+ detection_image, item, scale, pre_matched_spans=spans
|
|
|
+ )
|
|
|
else:
|
|
|
# VLM 无线表格路径(默认)
|
|
|
- logger.info(f"🔷 Using VLM table recognition (configured)")
|
|
|
+ logger.info(f"🔷 Table {idx}: Using VLM recognition")
|
|
|
element = self.element_processors.process_table_element_vlm(
|
|
|
detection_image, item, scale, pre_matched_spans=spans
|
|
|
)
|
|
|
@@ -867,7 +901,7 @@ class EnhancedDocPipeline:
|
|
|
self.preprocessor.cleanup()
|
|
|
if hasattr(self, 'layout_detector'):
|
|
|
self.layout_detector.cleanup()
|
|
|
- if hasattr(self, 'vl_recognizer'):
|
|
|
+ if hasattr(self, 'vl_recognizer') and self.vl_recognizer is not None:
|
|
|
self.vl_recognizer.cleanup()
|
|
|
if hasattr(self, 'ocr_recognizer'):
|
|
|
self.ocr_recognizer.cleanup()
|