Răsfoiți Sursa

feat: 更新process_images_unified函数,支持多种pipeline和可选参数

zhch158_admin 1 lună în urmă
părinte
comite
4f6565c516
1 a modificat fișierele cu 41 adăugiri și 20 ștergeri
  1. 41 20
      zhch/ppstructurev3_single_process.py

+ 41 - 20
zhch/ppstructurev3_single_process.py

@@ -43,9 +43,10 @@ def process_images_unified(image_paths: List[str],
                          device: str = "gpu:0",
                          output_dir: str = "./output",
                          normalize_numbers: bool = True,
-                         use_enhanced_adapter: bool = True) -> List[Dict[str, Any]]:  # 🎯 新增参数
+                         use_enhanced_adapter: bool = True,
+                         **kwargs) -> List[Dict[str, Any]]:  # 🎯 新增 **kwargs
     """
-    统一的图像处理函数,支持数字标准化
+    统一的图像处理函数,支持数字标准化和多种 pipeline
     """
     # 创建输出目录
     output_path = Path(output_dir)
@@ -85,6 +86,9 @@ def process_images_unified(image_paths: List[str],
         print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
         print(f"🎯 增强适配器: {'启用' if adapter_applied else '禁用'}")
         
+        # 🎯 检测 pipeline 类型
+        is_paddleocr_vl = 'PaddleOCR-VL'.lower() in str(pipeline_name).lower()
+        
         # 使用tqdm显示进度
         with tqdm(total=total_images, desc="Processing images", unit="img", 
                   bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
@@ -94,21 +98,34 @@ def process_images_unified(image_paths: List[str],
                 start_time = time.time()
                 
                 try:
-                    # 使用pipeline预测单个图像
-                    results = pipeline.predict(
-                        img_path,
-                        use_doc_orientation_classify=True,
-                        use_doc_unwarping=False,
-                        use_layout_detection=True,
-                        use_seal_recognition=True,
-                        use_table_recognition=True,
-                        use_formula_recognition=False,
-                        use_chart_recognition=True,
-                        use_ocr_results_with_table_cells=True,
-                        use_table_orientation_classify=False,
-                        use_wired_table_cells_trans_to_html=True,
-                        use_wireless_table_cells_trans_to_html=True,
-                    )
+                    # 🎯 根据 pipeline 类型使用不同的参数
+                    if is_paddleocr_vl:
+                        # PaddleOCR-VL 使用驼峰命名
+                        predict_kwargs = {
+                            'input': img_path,
+                            'useLayoutDetection': kwargs.get('use_layout_detection', False),
+                            'useDocOrientationClassify': kwargs.get('use_doc_orientation', False),
+                            'useDocUnwarping': kwargs.get('use_doc_unwarping', False),
+                        }
+                    else:
+                        # PP-StructureV3 使用下划线命名
+                        predict_kwargs = {
+                            'img_path': img_path,
+                            'use_doc_orientation_classify': kwargs.get('use_doc_orientation', True),
+                            'use_doc_unwarping': kwargs.get('use_doc_unwarping', False),
+                            'use_layout_detection': kwargs.get('use_layout_detection', True),
+                            'use_seal_recognition': kwargs.get('use_seal_recognition', True),
+                            'use_table_recognition': kwargs.get('use_table_recognition', True),
+                            'use_formula_recognition': kwargs.get('use_formula_recognition', False),
+                            'use_chart_recognition': kwargs.get('use_chart_recognition', True),
+                            'use_ocr_results_with_table_cells': kwargs.get('use_ocr_results_with_table_cells', True),
+                            'use_table_orientation_classify': kwargs.get('use_table_orientation_classify', False),
+                            'use_wired_table_cells_trans_to_html': kwargs.get('use_wired_table_cells_trans_to_html', True),
+                            'use_wireless_table_cells_trans_to_html': kwargs.get('use_wireless_table_cells_trans_to_html', True),
+                        }
+                    
+                    # 使用pipeline预测
+                    results = pipeline.predict(**predict_kwargs)
                     
                     processing_time = time.time() - start_time
                     
@@ -209,7 +226,7 @@ def process_images_unified(image_paths: List[str],
 
 def main():
     """主函数"""
-    parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Unified PDF/Image Processor")
+    parser = argparse.ArgumentParser(description="PaddleX Unified PDF/Image Processor")
     
     # 参数定义
     input_group = parser.add_mutually_exclusive_group(required=True)
@@ -230,7 +247,10 @@ def main():
     args = parser.parse_args()
     
     normalize_numbers = not args.no_normalize
-    use_enhanced_adapter = not args.no_adapter  # 🎯 新增
+    use_enhanced_adapter = not args.no_adapter
+    
+    # 🎯 构建 predict 参数
+    predict_kwargs = {}
     
     try:
         # 获取并预处理输入文件
@@ -255,7 +275,8 @@ def main():
             args.device,
             args.output_dir,
             normalize_numbers=normalize_numbers,
-            use_enhanced_adapter=use_enhanced_adapter  # 🎯 传递参数
+            use_enhanced_adapter=use_enhanced_adapter,
+            **predict_kwargs  # 🎯 传递所有参数
         )
         total_time = time.time() - start_time