Browse Source

feat(zhch): 添加从文件列表读取图像文件的功能,支持批量处理

zhch158_admin 3 months ago
parent
commit
082d3f7f65
1 changed files with 41 additions and 21 deletions
  1. 41 21
      zhch/ppstructurev3_single_process.py

+ 41 - 21
zhch/ppstructurev3_single_process.py

@@ -169,6 +169,7 @@ def main():
     
     # 参数定义
     parser.add_argument("--input_dir", type=str, default="../../OmniDocBench/OpenDataLab___OmniDocBench/images", help="Input directory")
+    parser.add_argument("--input_file_list", type=str, default=None, help="Input file list (one file per line)")  # 新增
     parser.add_argument("--output_dir", type=str, default="./OmniDocBench_Results_Single", help="Output directory")
     parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
     parser.add_argument("--device", type=str, default="gpu:0", help="Device string (e.g., 'gpu:0', 'cpu')")
@@ -180,29 +181,48 @@ def main():
     
     try:
         # 获取图像文件列表
-        input_dir = Path(args.input_dir).resolve()
-        output_dir = Path(args.output_dir).resolve()
+        if args.input_file_list:
+            # 从文件列表读取
+            print(f"Reading file list from: {args.input_file_list}")
+            with open(args.input_file_list, 'r', encoding='utf-8') as f:
+                image_files = [line.strip() for line in f if line.strip()]
+            
+            # 验证文件存在
+            valid_files = []
+            for file_path in image_files:
+                if Path(file_path).exists():
+                    valid_files.append(file_path)
+                else:
+                    print(f"Warning: File not found: {file_path}")
+            image_files = valid_files
+            
+        else:
+            # 从目录读取(原有逻辑)
+            input_dir = Path(args.input_dir).resolve()
+            output_dir = Path(args.output_dir).resolve()
+            
+            print(f"Input dir: {input_dir}")
+            
+            if not input_dir.exists():
+                print(f"Input directory does not exist: {input_dir}")
+                return 1
+            
+            # 查找图像文件
+            image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
+            image_files = []
+            for ext in image_extensions:
+                image_files.extend(list(input_dir.glob(f"*{ext}")))
+                image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
+            
+            if not image_files:
+                print(f"No image files found in {input_dir}")
+                return 1
+            
+            # 去重并排序
+            image_files = sorted(list(set(str(f) for f in image_files)))
         
-        print(f"Input dir: {input_dir}")
+        output_dir = Path(args.output_dir).resolve()
         print(f"Output dir: {output_dir}")
-        
-        if not input_dir.exists():
-            print(f"Input directory does not exist: {input_dir}")
-            return 1
-        
-        # 查找图像文件
-        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
-        image_files = []
-        for ext in image_extensions:
-            image_files.extend(list(input_dir.glob(f"*{ext}")))
-            image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
-        
-        if not image_files:
-            print(f"No image files found in {input_dir}")
-            return 1
-        
-        # 去重并排序
-        image_files = sorted(list(set(str(f) for f in image_files)))
         print(f"Found {len(image_files)} image files")
         
         if args.test_mode: