浏览代码

fix: 移除批处理大小参数,优化单进程图像处理逻辑

zhch158_admin 2 月之前
父节点
当前提交
67bead1448
共有 1 个文件被更改,包括 35 次插入53 次删除
  1. 35 53
      zhch/ppstructurev3_single_process.py

+ 35 - 53
zhch/ppstructurev3_single_process.py

@@ -33,7 +33,6 @@ from utils import (
 def process_images_single_process(image_paths: List[str],
                                 pipeline_name: str = "PP-StructureV3",
                                 device: str = "gpu:0",
-                                batch_size: int = 1,
                                 output_dir: str = "./output") -> List[Dict[str, Any]]:
     """
     单进程版本的图像处理函数
@@ -42,7 +41,6 @@ def process_images_single_process(image_paths: List[str],
         image_paths: 图像路径列表
         pipeline_name: Pipeline名称
         device: 设备字符串,如"gpu:0"或"cpu"
-        batch_size: 批处理大小
         output_dir: 输出目录
         
     Returns:
@@ -70,33 +68,31 @@ def process_images_single_process(image_paths: List[str],
     all_results = []
     total_images = len(image_paths)
     
-    print(f"Processing {total_images} images with batch size {batch_size}")
+    print(f"Processing {total_images} images one by one")
     
-    # 使用tqdm显示进度,添加更多统计信息
+    # 使用tqdm显示进度
     with tqdm(total=total_images, desc="Processing images", unit="img", 
               bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
         
-        # 按批次处理图像
-        for i in range(0, total_images, batch_size):
-            batch = image_paths[i:i + batch_size]
-            batch_start_time = time.time()
+        # 逐个处理图像
+        for img_path in image_paths:
+            start_time = time.time()
             
             try:
-                # 使用pipeline预测
+                # 使用pipeline预测单个图像
                 results = pipeline.predict(
-                    batch,
+                    img_path,  # 传入单个文件路径
                     use_doc_orientation_classify=True,
                     use_doc_unwarping=False,
                     use_seal_recognition=True,
-                    use_chart_recognition=True,
                     use_table_recognition=True,
-                    use_formula_recognition=True,
+                    use_formula_recognition=False,  # 暂时关闭公式识别以避免错误
+                    use_chart_recognition=True,
                 )
                 
-                batch_processing_time = time.time() - batch_start_time
-                batch_results = []
+                processing_time = time.time() - start_time
                 
-                # 处理每个结果
+                # 处理结果
                 for result in results:
                     try:
                         input_path = Path(result["input_path"])
@@ -115,9 +111,9 @@ def process_images_single_process(image_paths: List[str],
                         result.save_to_markdown(md_output_path)
                         
                         # 记录处理结果
-                        batch_results.append({
+                        all_results.append({
                             "image_path": str(input_path),
-                            "processing_time": batch_processing_time / len(batch),  # 平均时间
+                            "processing_time": processing_time,
                             "success": True,
                             "device": device,
                             "output_json": json_output_path,
@@ -127,45 +123,37 @@ def process_images_single_process(image_paths: List[str],
                     except Exception as e:
                         print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
                         traceback.print_exc()
-                        batch_results.append({
-                            "image_path": str(input_path),
+                        all_results.append({
+                            "image_path": str(img_path),
                             "processing_time": 0,
                             "success": False,
                             "device": device,
                             "error": str(e)
                         })
                 
-                all_results.extend(batch_results)
-                
                 # 更新进度条
-                success_count = sum(1 for r in batch_results if r.get('success', False))
-                total_success = sum(1 for r in all_results if r.get('success', False))
-                avg_time = batch_processing_time / len(batch)
+                success_count = sum(1 for r in all_results if r.get('success', False))
                 
-                pbar.update(len(batch))
+                pbar.update(1)
                 pbar.set_postfix({
-                    'batch_time': f"{batch_processing_time:.2f}s",
-                    'avg_time': f"{avg_time:.2f}s/img",
-                    'success': f"{total_success}/{len(all_results)}",
-                    'rate': f"{total_success/len(all_results)*100:.1f}%"
+                    'time': f"{processing_time:.2f}s",
+                    'success': f"{success_count}/{len(all_results)}",
+                    'rate': f"{success_count/len(all_results)*100:.1f}%"
                 })
                 
             except Exception as e:
-                print(f"Error processing batch {[Path(p).name for p in batch]}: {e}", file=sys.stderr)
+                print(f"Error processing {Path(img_path).name}: {e}", file=sys.stderr)
                 traceback.print_exc()
                 
-                # 为批次中的所有图像添加错误结果
-                error_results = []
-                for img_path in batch:
-                    error_results.append({
-                        "image_path": str(img_path),
-                        "processing_time": 0,
-                        "success": False,
-                        "device": device,
-                        "error": str(e)
-                    })
-                all_results.extend(error_results)
-                pbar.update(len(batch))
+                # 添加错误结果
+                all_results.append({
+                    "image_path": str(img_path),
+                    "processing_time": 0,
+                    "success": False,
+                    "device": device,
+                    "error": str(e)
+                })
+                pbar.update(1)
     
     return all_results
 
@@ -183,7 +171,6 @@ def main():
     parser.add_argument("--output_dir", type=str, help="Output directory")
     parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
     parser.add_argument("--device", type=str, default="gpu:0", help="Device string (e.g., 'gpu:0', 'cpu')")
-    parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
     parser.add_argument("--input_pattern", type=str, default="*", help="Input file pattern")
     parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 images)")
     parser.add_argument("--collect_results",type=str, help="收集处理结果到指定CSV文件")
@@ -220,15 +207,13 @@ def main():
             print(f"Test mode: processing only {len(image_files)} images")
         
         print(f"Using device: {args.device}")
-        print(f"Batch size: {args.batch_size}")
         
-        # 开始处理
+        # 开始处理(删除了 batch_size 参数)
         start_time = time.time()
         results = process_images_single_process(
             image_files,
             args.pipeline,
             args.device,
-            args.batch_size,
             str(output_dir)
         )
         total_time = time.time() - start_time
@@ -251,7 +236,7 @@ def main():
             print(f"  Throughput: {len(image_files) / total_time:.2f} images/second")
             print(f"  Avg time per image: {total_time / len(image_files):.2f} seconds")
         
-        # 保存结果统计
+        # 保存结果统计(删除了 batch_size 统计)
         stats = {
             "total_files": len(image_files),
             "success_count": success_count,
@@ -260,7 +245,6 @@ def main():
             "total_time": total_time,
             "throughput": len(image_files) / total_time if total_time > 0 else 0,
             "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
-            "batch_size": args.batch_size,
             "device": args.device,
             "pipeline": args.pipeline,
             "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
@@ -304,22 +288,20 @@ if __name__ == "__main__":
         # 如果没有命令行参数,使用默认配置运行
         print("ℹ️  No command line arguments provided. Running with default configuration...")
         
-        # 默认配置
+        # 默认配置(删除了 batch_size)
         default_config = {
             "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
             "output_dir": "./OmniDocBench_PPStructureV3_Results",
-            "pipeline": "PP-StructureV3",
+            "pipeline": "./my_config/PP-StructureV3.yaml",
             "device": "gpu:0",
-            "batch_size": 2,
             "collect_results": "./OmniDocBench_PPStructureV3_Results/processed_files.csv",
         }
         
         # default_config = {
         #     "input_csv": "./OmniDocBench_PPStructureV3_Results/processed_files.csv",
         #     "output_dir": "./OmniDocBench_PPStructureV3_Results",
-        #     "pipeline": "PP-StructureV3",
+        #     "pipeline": "./my_config/PP-StructureV3.yaml",
         #     "device": "gpu:0",
-        #     "batch_size": 2,
         #     "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
         # }
         # 构造参数