Ver código fonte

feat: 更新单进程图像处理功能,支持从CSV文件读取图像路径,优化输入参数处理

zhch158_admin 3 meses atrás
pai
commit
f2b89a7ce6
1 arquivos alterados com 27 adições e 39 exclusões
  1. 27 39
      zhch/ppstructurev3_single_process.py

+ 27 - 39
zhch/ppstructurev3_single_process.py

@@ -1,3 +1,4 @@
+"""单进程运行稳定"""
 import json
 import time
 import os
@@ -22,6 +23,11 @@ from tqdm import tqdm
 from dotenv import load_dotenv
 load_dotenv(override=True)
 
+from utils import (
+    get_image_files_from_dir,
+    get_image_files_from_list,
+    get_image_files_from_csv,
+)
 
 def process_images_single_process(image_paths: List[str],
                                 pipeline_name: str = "PP-StructureV3",
@@ -168,12 +174,15 @@ def main():
     parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Single Process Processing")
     
     # 参数定义
-    parser.add_argument("--input_dir", type=str, default="../../OmniDocBench/OpenDataLab___OmniDocBench/images", help="Input directory")
-    parser.add_argument("--input_file_list", type=str, default=None, help="Input file list (one file per line)")  # 新增
-    parser.add_argument("--output_dir", type=str, default="./OmniDocBench_Results_Single", help="Output directory")
+    input_group = parser.add_mutually_exclusive_group(required=True)
+    input_group.add_argument("--input_dir", type=str, help="Input directory")
+    input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
+    input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
+
+    parser.add_argument("--output_dir", type=str, help="Output directory")
     parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
     parser.add_argument("--device", type=str, default="gpu:0", help="Device string (e.g., 'gpu:0', 'cpu')")
-    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
+    parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
     parser.add_argument("--input_pattern", type=str, default="*", help="Input file pattern")
     parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 images)")
     
@@ -181,46 +190,25 @@ def main():
     
     try:
         # 获取图像文件列表
-        if args.input_file_list:
+        if args.input_csv:
+            # 从CSV文件读取
+            image_files = get_image_files_from_csv(args.input_csv, "fail")
+            print(f"📊 Loaded {len(image_files)} files from CSV with status filter: fail")
+        elif args.input_file_list:
             # 从文件列表读取
-            print(f"Reading file list from: {args.input_file_list}")
-            with open(args.input_file_list, 'r', encoding='utf-8') as f:
-                image_files = [line.strip() for line in f if line.strip()]
-            
-            # 验证文件存在
-            valid_files = []
-            for file_path in image_files:
-                if Path(file_path).exists():
-                    valid_files.append(file_path)
-                else:
-                    print(f"Warning: File not found: {file_path}")
-            image_files = valid_files
-            
+            image_files = get_image_files_from_list(args.input_file_list)
         else:
-            # 从目录读取(原有逻辑)
+            # 从目录读取
             input_dir = Path(args.input_dir).resolve()
-            output_dir = Path(args.output_dir).resolve()
-            
-            print(f"Input dir: {input_dir}")
+            print(f"📁 Input dir: {input_dir}")
             
             if not input_dir.exists():
-                print(f"Input directory does not exist: {input_dir}")
+                print(f"❌ Input directory does not exist: {input_dir}")
                 return 1
-            
-            # 查找图像文件
-            image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
-            image_files = []
-            for ext in image_extensions:
-                image_files.extend(list(input_dir.glob(f"*{ext}")))
-                image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
-            
-            if not image_files:
-                print(f"No image files found in {input_dir}")
-                return 1
-            
-            # 去重并排序
-            image_files = sorted(list(set(str(f) for f in image_files)))
-        
+
+            print(f"Input dir: {input_dir}")
+            image_files = get_image_files_from_dir(input_dir, args.max_files)
+
         output_dir = Path(args.output_dir).resolve()
         print(f"Output dir: {output_dir}")
         print(f"Found {len(image_files)} image files")
@@ -320,6 +308,6 @@ if __name__ == "__main__":
             sys.argv.extend([f"--{key}", str(value)])
         
         # 测试模式
-        # sys.argv.append("--test_mode")
+        sys.argv.append("--test_mode")
     
     sys.exit(main())