Эх сурвалжийг харах

feat: 新增增强适配器支持,优化表格识别流程并添加相关参数

zhch158_admin 1 сар өмнө
parent
commit
68aff2d1ca

+ 93 - 70
zhch/table_recognition_v2_single_process.py

@@ -25,6 +25,9 @@ from ppstructurev3_utils import (
 )
 from utils import normalize_markdown_table, get_input_files
 
+# 🎯 新增:导入适配器
+from adapters.table_recognition_adapter import apply_table_recognition_adapter, restore_original_function
+
 def html_table_to_markdown(html: str) -> str:
     """
     将简单HTML表格转换为Markdown表格。
@@ -94,7 +97,8 @@ def process_images_with_table_pipeline(
     pipeline_cfg: str = "./my_config/table_recognition_v2.yaml",
     device: str = "gpu:0",
     output_dir: str = "./output",
-    normalize_numbers: bool = True
+    normalize_numbers: bool = True,
+    use_enhanced_adapter: bool = True  # 🎯 新增参数
 ) -> List[Dict[str, Any]]:
     """
     运行 table_recognition_v2 管线,输出 JSON、可视化图,且将每个表格HTML转为Markdown保存。
@@ -102,6 +106,15 @@ def process_images_with_table_pipeline(
     output_path = Path(output_dir).resolve()
     output_path.mkdir(parents=True, exist_ok=True)
 
+    # 🎯 应用适配器
+    adapter_applied = False
+    if use_enhanced_adapter:
+        adapter_applied = apply_table_recognition_adapter()
+        if adapter_applied:
+            print("🎯 Enhanced table recognition adapter activated")
+        else:
+            print("⚠️  Failed to apply adapter, using original implementation")
+
     print(f"Initializing pipeline '{pipeline_cfg}' on device '{device}'...")
     try:
         os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
@@ -109,82 +122,89 @@ def process_images_with_table_pipeline(
         print(f"Pipeline initialized successfully on {device}")
     except Exception as e:
         print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
-        traceback.print_exc()
+        if adapter_applied:
+            restore_original_function()
         return []
 
-    results_all: List[Dict[str, Any]] = []
-    total = len(image_paths)
-    print(f"Processing {total} images with table_recognition_v2")
-    print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
-
-    with tqdm(total=total, desc="Processing images", unit="img",
-              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
-        for img_path in image_paths:
-            start = time.time()
-            try:
-                outputs = pipeline.predict(
-                    img_path,
-                    use_doc_orientation_classify=True,
-                    use_doc_unwarping=False,
-                    use_layout_detection=True,
-                    use_ocr_results_with_table_cells=True,
-                    use_table_orientation_classify=True,
-                    use_wired_table_cells_trans_to_html=True,
-                    # 新增:关闭单元格内拆分,整格识别以保留折行文本, 
-                    # 修改paddlex/inference/pipelines/table_recognition/pipeline_v2.py
-                    # get_table_recognition_res传入参数self.cells_split_ocr=False,保证单元格内换行不被拆分
-                    use_table_cells_split_ocr=False,
-                )
-                cost = time.time() - start
-
-                # 一般每张图片只返回一个结果
-                for idx, res in enumerate(outputs):
-                    if idx > 0:
-                        raise ValueError("Multiple results found for a single image")
-
-                    input_path = Path(res["input_path"])
-                    base_name = input_path.stem
-
-                    res.save_all(save_path=output_path.as_posix())  # 保存所有结果到指定路径
-                    # 保存结构化JSON
-                    json_res = res.json.get("res", res.json)
-
-                    saved_json, md_tables, changes_count = save_json_tables(json_res, str(output_path), base_name, normalize_numbers=normalize_numbers)
+    try:
+        results_all: List[Dict[str, Any]] = []
+        total = len(image_paths)
+        print(f"Processing {total} images with table_recognition_v2")
+        print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
+        print(f"🎯 增强适配器: {'启用' if adapter_applied else '禁用'}")
+
+        with tqdm(total=total, desc="Processing images", unit="img",
+                  bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
+            for img_path in image_paths:
+                start = time.time()
+                try:
+                    outputs = pipeline.predict(
+                        img_path,
+                        use_doc_orientation_classify=True,
+                        use_doc_unwarping=False,
+                        use_layout_detection=True,
+                        use_ocr_results_with_table_cells=True,
+                        use_table_orientation_classify=True,
+                        use_wired_table_cells_trans_to_html=True,
+                        # 🎯 注意:适配器模式下不需要这个参数
+                        # use_table_cells_split_ocr=False,
+                    )
+                    cost = time.time() - start
+
+                    # 一般每张图片只返回一个结果
+                    for idx, res in enumerate(outputs):
+                        if idx > 0:
+                            raise ValueError("Multiple results found for a single image")
+
+                        input_path = Path(res["input_path"])
+                        base_name = input_path.stem
+
+                        res.save_all(save_path=output_path.as_posix())  # 保存所有结果到指定路径
+                        # 保存结构化JSON
+                        json_res = res.json.get("res", res.json)
+
+                        saved_json, md_tables, changes_count = save_json_tables(json_res, str(output_path), base_name, normalize_numbers=normalize_numbers)
     
-                    saved_md = save_markdown_tables(md_tables, str(output_path), base_name,
+                        saved_md = save_markdown_tables(md_tables, str(output_path), base_name,
                                                      normalize_numbers=normalize_numbers)
 
+                        results_all.append({
+                            "image_path": str(input_path),
+                            "success": True,
+                            "time_sec": cost,
+                            "device": device,
+                            "json_path": saved_json,
+                            "markdown_path": saved_md,
+                            "tables_detected": len(md_tables),
+                            "is_pdf_page": "_page_" in input_path.name,
+                            "normalize_numbers": normalize_numbers,
+                            "changes_applied": changes_count > 0,
+                            "character_changes_count": changes_count,
+                        })
+
+                    pbar.update(1)
+                    ok = sum(1 for r in results_all if r.get("success"))
+                    pbar.set_postfix(time=f"{cost:.2f}s", ok=ok)
+
+                except Exception as e:
+                    traceback.print_exc()
                     results_all.append({
-                        "image_path": str(input_path),
-                        "success": True,
-                        "time_sec": cost,
+                        "image_path": str(img_path),
+                        "success": False,
+                        "time_sec": 0,
                         "device": device,
-                        "json_path": saved_json,
-                        "markdown_path": saved_md,
-                        "tables_detected": len(md_tables),
-                        "is_pdf_page": "_page_" in input_path.name,
-                        "normalize_numbers": normalize_numbers,
-                        "changes_applied": changes_count > 0,
-                        "character_changes_count": changes_count,
+                        "error": str(e)
                     })
+                    pbar.update(1)
+                    pbar.set_postfix_str("error")
 
-                pbar.update(1)
-                ok = sum(1 for r in results_all if r.get("success"))
-                pbar.set_postfix(time=f"{cost:.2f}s", ok=ok)
-
-            except Exception as e:
-                traceback.print_exc()
-                results_all.append({
-                    "image_path": str(img_path),
-                    "success": False,
-                    "time_sec": 0,
-                    "device": device,
-                    "error": str(e)
-                })
-                pbar.update(1)
-                pbar.set_postfix_str("error")
-
-    return results_all
+        return results_all
+        
+    finally:
+        # 🎯 清理:恢复原始函数
+        if adapter_applied:
+            restore_original_function()
+            print("🔄 Original function restored")
 
 def main():
     parser = argparse.ArgumentParser(description="table_recognition_v2 单管线运行(输出Markdown表格)")
@@ -202,9 +222,11 @@ def main():
     parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化(仅对Markdown内容生效)")
     parser.add_argument("--test_mode", action="store_true", help="仅处理前20个文件")
     parser.add_argument("--collect_results", type=str, help="将处理状态收集到指定CSV")
-
+    parser.add_argument("--no-adapter", action="store_true", help="禁用增强适配器")
+    
     args = parser.parse_args()
     normalize_numbers = not args.no_normalize
+    use_enhanced_adapter = not args.no_adapter
 
     # 复用 ppstructurev3_utils 的输入收集逻辑
     input_files = get_input_files(args)
@@ -222,7 +244,8 @@ def main():
         args.pipeline,
         args.device,
         args.output_dir,
-        normalize_numbers=normalize_numbers
+        normalize_numbers=normalize_numbers,
+        use_enhanced_adapter=use_enhanced_adapter  # 🎯 传递参数
     )
     total_time = time.time() - start