Quellcode durchsuchen

feat: 添加从目录和文件列表获取图像文件的功能,支持文件存在性验证

zhch158_admin vor 3 Monaten
Ursprung
Commit
402f7a8a4e
1 geänderte Dateien mit 165 neuen und 23 gelöschten Zeilen
  1. 165 23
      zhch/ppstructurev3_scheduler.py

+ 165 - 23
zhch/ppstructurev3_scheduler.py

@@ -56,6 +56,149 @@ def create_temp_file_list(file_chunk: List[str]) -> str:
             f.write(f"{file_path}\n")
         return f.name
 
+def get_image_files_from_dir(input_dir: Path, max_files: int = None) -> List[str]:
+    """
+    从目录获取图像文件列表
+    
+    Args:
+        input_dir: 输入目录
+        max_files: 最大文件数量限制
+        
+    Returns:
+        图像文件路径列表
+    """
+    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
+    image_files = []
+    
+    for ext in image_extensions:
+        image_files.extend(list(input_dir.glob(f"*{ext}")))
+        image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
+    
+    # 去重并排序
+    image_files = sorted(list(set(str(f) for f in image_files)))
+    
+    # 限制文件数量
+    if max_files:
+        image_files = image_files[:max_files]
+    
+    return image_files
+
+def get_image_files_from_list(file_list_path: str) -> List[str]:
+    """
+    从文件列表获取图像文件列表
+    
+    Args:
+        file_list_path: 文件列表路径
+        
+    Returns:
+        图像文件路径列表
+    """
+    print(f"📄 Reading file list from: {file_list_path}")
+    
+    with open(file_list_path, 'r', encoding='utf-8') as f:
+        image_files = [line.strip() for line in f if line.strip()]
+    
+    # 验证文件存在性
+    valid_files = []
+    missing_files = []
+    
+    for file_path in image_files:
+        if Path(file_path).exists():
+            valid_files.append(file_path)
+        else:
+            missing_files.append(file_path)
+    
+    if missing_files:
+        print(f"⚠️ Warning: {len(missing_files)} files not found:")
+        for missing_file in missing_files[:5]:  # 只显示前5个
+            print(f"  - {missing_file}")
+        if len(missing_files) > 5:
+            print(f"  ... and {len(missing_files) - 5} more")
+    
+    print(f"✅ Found {len(valid_files)} valid files out of {len(image_files)} in list")
+    return valid_files
+
+
+def collect_pid_files(pid_output_file: str) -> List[Tuple[str, str]]:
+    """
+    从进程输出文件中收集文件
+
+    Args:
+        pid_output_file: 进程输出文件路径
+
+    Returns:
+        文件列表(文件路径,处理结果)
+    """
+
+    """
+    单进程结果统计文件格式
+    "results": [
+    {
+      "image_path": "docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.jpg",
+      "processing_time": 2.0265579223632812e-06,
+      "success": true,
+      "device": "gpu:3",
+      "output_json": "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Results_Scheduler/process_3/docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.json",
+      "output_md": "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Results_Scheduler/process_3/docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.md"
+    },
+    ...
+    """
+    if not Path(pid_output_file).exists():
+        print(f"⚠️ Warning: PID output file not found: {pid_output_file}")
+        return []
+
+    with open(pid_output_file, 'r', encoding='utf-8') as f:
+        data = json.load(f)
+
+    if not isinstance(data, dict) or "results" not in data:
+        print(f"⚠️ Warning: Invalid PID output file format: {pid_output_file}")
+        return []
+    # 返回文件路径和处理状态, 如果“success”: True, 则状态为“success”, 否则为“fail”
+    file_list = []
+    for file_result in data.get("results", []):
+        image_path = file_result.get("image_path", "")
+        status = "success" if file_result.get("success", False) else "fail"
+        file_list.append((image_path, status))
+    return file_list
+
+def collect_processed_files(results: List[Dict[str, Any]]) -> List[Tuple[str, str]]:
+    """
+    从处理结果中收集文件
+    
+    Args:
+        results: 处理结果列表
+        
+    Returns:
+        文件列表(文件路径,处理结果),
+    """
+    processed_files = []
+
+    for result in results:
+        """
+        根据output_dir+process_id找到每个进程的结果文件
+        {
+        "process_id": 1,
+        "success": true,
+        "processing_time": 42.744526386260986,
+        "file_count": 5,
+        "device": "gpu:1",
+        "output_dir": "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Results_Scheduler/process_1",
+        ...
+        }
+        """
+        pid_output_file = Path(result["output_dir"]) / f"process_{result['process_id']}" / f"process_{result['process_id']}.json"
+        if not pid_output_file.exists():
+            print(f"⚠️ Warning: Output file not found for process {result['process_id']}: {pid_output_file}")
+            if not result.get("success", False):
+                # 整个进程失败的情况
+                process_failed_files = result.get("failed_files", [])
+                processed_files.extend([(f, "fail") for f in process_failed_files if f])
+
+        pid_files = collect_pid_files(str(pid_output_file))
+        processed_files.extend(pid_files)
+    
+    return processed_files
+
 def run_single_process(args: Tuple[List[str], Dict[str, Any], int]) -> Dict[str, Any]:
     """
     运行单个ppstructurev3_single_process.py进程
@@ -174,7 +317,11 @@ def main():
     parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Multi-Process Scheduler")
     
     # 输入输出参数
-    parser.add_argument("--input_dir", type=str, required=True, help="Input directory")
+    # 输入输出参数
+    input_group = parser.add_mutually_exclusive_group(required=True)
+    input_group.add_argument("--input_dir", type=str, help="Input directory")
+    input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
+
     parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
     parser.add_argument("--single_process_script", type=str, 
                        default="./ppstructurev3_single_process.py", 
@@ -198,30 +345,25 @@ def main():
     
     try:
         # 获取图像文件列表
-        input_dir = Path(args.input_dir).resolve()
+        if args.input_file_list:
+            # 从文件列表读取
+            image_files = get_image_files_from_list(args.input_file_list)
+        else:
+            # 从目录读取
+            input_dir = Path(args.input_dir).resolve()
+            print(f"📁 Input dir: {input_dir}")
+            
+            if not input_dir.exists():
+                print(f"❌ Input directory does not exist: {input_dir}")
+                return 1
+
+            image_files = get_image_files_from_dir(input_dir, args.max_files)
+
         output_dir = Path(args.output_dir).resolve()
-        
+
         print(f"Input dir: {input_dir}")
         print(f"Output dir: {output_dir}")
         
-        if not input_dir.exists():
-            print(f"Input directory does not exist: {input_dir}")
-            return 1
-        
-        # 查找图像文件
-        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
-        image_files = []
-        for ext in image_extensions:
-            image_files.extend(list(input_dir.glob(f"*{ext}")))
-            image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
-        
-        if not image_files:
-            print(f"No image files found in {input_dir}")
-            return 1
-        
-        # 去重并排序
-        image_files = sorted(list(set(str(f) for f in image_files)))
-        
         # 限制文件数量
         if args.max_files:
             image_files = image_files[:args.max_files]
@@ -359,8 +501,8 @@ def main():
         return 1
 
 if __name__ == "__main__":
-    print(f"🚀 启动多进程调度程序...")
-    
+    print(f"🚀 启动多进程调度程序..., 约定各进程统计文件名为: process_{{process_id}}.json")
+
     if len(sys.argv) == 1:
         # 默认配置
         default_config = {