فهرست منبع

feat: 新增增强适配器支持,优化图像处理流程并添加相关参数

zhch158_admin 1 ماه پیش
والد
کامیت
e986069da0
1فایلهای تغییر یافته به همراه142 افزوده شده و 109 حذف شده
  1. 142 109
      zhch/ppstructurev3_single_process.py

+ 142 - 109
zhch/ppstructurev3_single_process.py

@@ -34,11 +34,16 @@ from ppstructurev3_utils import (
     save_markdown_content
 )
 
+# 🎯 新增:导入适配器
+from adapters import apply_table_recognition_adapter, restore_original_function
+
+
 def process_images_unified(image_paths: List[str],
                          pipeline_name: str = "PP-StructureV3",
                          device: str = "gpu:0",
                          output_dir: str = "./output",
-                         normalize_numbers: bool = True) -> List[Dict[str, Any]]:
+                         normalize_numbers: bool = True,
+                         use_enhanced_adapter: bool = True) -> List[Dict[str, Any]]:  # 🎯 新增参数
     """
     统一的图像处理函数,支持数字标准化
     """
@@ -46,6 +51,15 @@ def process_images_unified(image_paths: List[str],
     output_path = Path(output_dir)
     output_path.mkdir(parents=True, exist_ok=True)
     
+    # 🎯 应用适配器
+    adapter_applied = False
+    if use_enhanced_adapter:
+        adapter_applied = apply_table_recognition_adapter()
+        if adapter_applied:
+            print("🎯 Enhanced table recognition adapter activated")
+        else:
+            print("⚠️  Failed to apply adapter, using original implementation")
+    
     print(f"Initializing pipeline '{pipeline_name}' on device '{device}'...")
     
     try:
@@ -59,122 +73,138 @@ def process_images_unified(image_paths: List[str],
     except Exception as e:
         print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
         traceback.print_exc()
+        if adapter_applied:
+            restore_original_function()
         return []
     
-    all_results = []
-    total_images = len(image_paths)
-    
-    print(f"Processing {total_images} images one by one")
-    print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
-    
-    # 使用tqdm显示进度
-    with tqdm(total=total_images, desc="Processing images", unit="img", 
-              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
+    try:
+        all_results = []
+        total_images = len(image_paths)
+        
+        print(f"Processing {total_images} images one by one")
+        print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
+        print(f"🎯 增强适配器: {'启用' if adapter_applied else '禁用'}")
         
-        # 逐个处理图像
-        for img_path in image_paths:
-            start_time = time.time()
+        # 使用tqdm显示进度
+        with tqdm(total=total_images, desc="Processing images", unit="img", 
+                  bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
             
-            try:
-                # 使用pipeline预测单个图像
-                results = pipeline.predict(
-                    img_path,
-                    use_doc_orientation_classify=True,
-                    use_doc_unwarping=False,
-                    use_seal_recognition=True,
-                    use_table_recognition=True,
-                    use_formula_recognition=False,
-                    use_chart_recognition=True,
-                )
+            # 逐个处理图像
+            for img_path in image_paths:
+                start_time = time.time()
                 
-                processing_time = time.time() - start_time
-                
-                # 处理结果
-                for idx, result in enumerate(results):
-                    if idx > 0:
-                        raise ValueError("Multiple results found for a single image")
-                    try:
-                        input_path = Path(result["input_path"])
-                        
-                        # 生成输出文件名
-                        if result.get("page_index") is not None:
-                            output_filename = f"{input_path.stem}_{result['page_index']}"
-                        else:
-                            output_filename = f"{input_path.stem}"
-                        
-                        # 转换并保存标准JSON格式
-                        json_content = result.json['res']
-                        json_output_path, converted_json = convert_pruned_result_to_json(
-                            json_content, 
-                            str(input_path), 
-                            output_dir,
-                            output_filename,
-                            normalize_numbers=normalize_numbers
-                        )
+                try:
+                    # 使用pipeline预测单个图像
+                    results = pipeline.predict(
+                        img_path,
+                        use_doc_orientation_classify=False,
+                        use_doc_unwarping=False,
+                        use_layout_detection=True,
+                        use_seal_recognition=True,
+                        use_table_recognition=True,
+                        use_formula_recognition=False,
+                        use_chart_recognition=True,
+                        use_ocr_results_with_table_cells=True,
+                        use_table_orientation_classify=True,
+                        use_wired_table_cells_trans_to_html=True,
+                        use_wireless_table_cells_trans_to_html=True,
+                    )
+                    
+                    processing_time = time.time() - start_time
+                    
+                    # 处理结果
+                    for idx, result in enumerate(results):
+                        if idx > 0:
+                            raise ValueError("Multiple results found for a single image")
+                        try:
+                            input_path = Path(result["input_path"])
+                            
+                            # 生成输出文件名
+                            if result.get("page_index") is not None:
+                                output_filename = f"{input_path.stem}_{result['page_index']}"
+                            else:
+                                output_filename = f"{input_path.stem}"
+                            
+                            # 转换并保存标准JSON格式
+                            json_content = result.json['res']
+                            json_output_path, converted_json = convert_pruned_result_to_json(
+                                json_content, 
+                                str(input_path), 
+                                output_dir,
+                                output_filename,
+                                normalize_numbers=normalize_numbers
+                            )
 
-                        # 保存输出图像
-                        img_content = result.img
-                        saved_images = save_output_images(img_content, str(output_dir), output_filename) 
+                            # 保存输出图像
+                            img_content = result.img
+                            saved_images = save_output_images(img_content, str(output_dir), output_filename) 
 
-                        # 保存Markdown内容
-                        markdown_content = result.markdown
-                        md_output_path = save_markdown_content(
-                            markdown_content, 
-                            output_dir, 
-                            output_filename,
-                            normalize_numbers=normalize_numbers,
-                            key_text='markdown_texts',
-                            key_images='markdown_images'
-                        )
-                        
-                        # 记录处理结果
-                        all_results.append({
-                            "image_path": str(input_path),
-                            "processing_time": processing_time,
-                            "success": True,
-                            "device": device,
-                            "output_json": json_output_path,
-                            "output_md": md_output_path,
-                            "is_pdf_page": "_page_" in input_path.name,  # 标记是否为PDF页面
-                            "processing_info": converted_json.get('processing_info', {})
-                        })                        
+                            # 保存Markdown内容
+                            markdown_content = result.markdown
+                            md_output_path = save_markdown_content(
+                                markdown_content, 
+                                output_dir, 
+                                output_filename,
+                                normalize_numbers=normalize_numbers,
+                                key_text='markdown_texts',
+                                key_images='markdown_images'
+                            )
+                            
+                            # 记录处理结果
+                            all_results.append({
+                                "image_path": str(input_path),
+                                "processing_time": processing_time,
+                                "success": True,
+                                "device": device,
+                                "output_json": json_output_path,
+                                "output_md": md_output_path,
+                                "is_pdf_page": "_page_" in input_path.name,
+                                "processing_info": converted_json.get('processing_info', {})
+                            })                        
 
-                    except Exception as e:
-                        print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
-                        traceback.print_exc()
-                        all_results.append({
-                            "image_path": str(img_path),
-                            "processing_time": 0,
-                            "success": False,
-                            "device": device,
-                            "error": str(e)
-                        })
-                
-                # 更新进度条
-                success_count = sum(1 for r in all_results if r.get('success', False))
-                
-                pbar.update(1)
-                pbar.set_postfix({
-                    'time': f"{processing_time:.2f}s",
-                    'success': f"{success_count}/{len(all_results)}",
-                    'rate': f"{success_count/len(all_results)*100:.1f}%"
-                })
-                
-            except Exception as e:
-                print(f"Error processing {Path(img_path).name}: {e}", file=sys.stderr)
-                traceback.print_exc()
-                
-                # 添加错误结果
-                all_results.append({
-                    "image_path": str(img_path),
-                    "processing_time": 0,
-                    "success": False,
-                    "device": device,
-                    "error": str(e)
-                })
-                pbar.update(1)
+                        except Exception as e:
+                            print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
+                            traceback.print_exc()
+                            all_results.append({
+                                "image_path": str(img_path),
+                                "processing_time": 0,
+                                "success": False,
+                                "device": device,
+                                "error": str(e)
+                            })
+                    
+                    # 更新进度条
+                    success_count = sum(1 for r in all_results if r.get('success', False))
+                    
+                    pbar.update(1)
+                    pbar.set_postfix({
+                        'time': f"{processing_time:.2f}s",
+                        'success': f"{success_count}/{len(all_results)}",
+                        'rate': f"{success_count/len(all_results)*100:.1f}%"
+                    })
+                    
+                except Exception as e:
+                    print(f"Error processing {Path(img_path).name}: {e}", file=sys.stderr)
+                    traceback.print_exc()
+                    
+                    # 添加错误结果
+                    all_results.append({
+                        "image_path": str(img_path),
+                        "processing_time": 0,
+                        "success": False,
+                        "device": device,
+                        "error": str(e)
+                    })
+                    pbar.update(1)
+        
+        return all_results
     
-    return all_results
+    finally:
+        # 🎯 清理:恢复原始函数
+        if adapter_applied:
+            restore_original_function()
+            print("🔄 Original function restored")
+
 
 def main():
     """主函数"""
@@ -194,10 +224,12 @@ def main():
     parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化")
     parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)")
     parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
+    parser.add_argument("--no-adapter", action="store_true", help="禁用增强适配器")  # 🎯 新增参数
 
     args = parser.parse_args()
     
     normalize_numbers = not args.no_normalize
+    use_enhanced_adapter = not args.no_adapter  # 🎯 新增
     
     try:
         # 获取并预处理输入文件
@@ -221,7 +253,8 @@ def main():
             args.pipeline,
             args.device,
             args.output_dir,
-            normalize_numbers=normalize_numbers
+            normalize_numbers=normalize_numbers,
+            use_enhanced_adapter=use_enhanced_adapter  # 🎯 传递参数
         )
         total_time = time.time() - start_time