ppstructurev3_scheduler.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. import json
  2. import time
  3. import os
  4. import argparse
  5. import sys
  6. import subprocess
  7. import tempfile
  8. from pathlib import Path
  9. from typing import List, Dict, Any, Tuple
  10. from concurrent.futures import ProcessPoolExecutor, as_completed
  11. import threading
  12. from queue import Queue
  13. from tqdm import tqdm
  14. def split_files(file_list: List[str], num_splits: int) -> List[List[str]]:
  15. """
  16. 将文件列表分割成指定数量的子列表
  17. Args:
  18. file_list: 文件路径列表
  19. num_splits: 分割数量
  20. Returns:
  21. 分割后的文件列表
  22. """
  23. if num_splits <= 0:
  24. return [file_list]
  25. chunk_size = len(file_list) // num_splits
  26. remainder = len(file_list) % num_splits
  27. chunks = []
  28. start = 0
  29. for i in range(num_splits):
  30. # 前remainder个chunk多分配一个文件
  31. current_chunk_size = chunk_size + (1 if i < remainder else 0)
  32. if current_chunk_size > 0:
  33. chunks.append(file_list[start:start + current_chunk_size])
  34. start += current_chunk_size
  35. return [chunk for chunk in chunks if chunk] # 过滤空列表
  36. def create_temp_file_list(file_chunk: List[str]) -> str:
  37. """
  38. 创建临时文件列表文件
  39. Args:
  40. file_chunk: 文件路径列表
  41. Returns:
  42. 临时文件路径
  43. """
  44. with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
  45. for file_path in file_chunk:
  46. f.write(f"{file_path}\n")
  47. return f.name
  48. def get_image_files_from_dir(input_dir: Path, max_files: int = None) -> List[str]:
  49. """
  50. 从目录获取图像文件列表
  51. Args:
  52. input_dir: 输入目录
  53. max_files: 最大文件数量限制
  54. Returns:
  55. 图像文件路径列表
  56. """
  57. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
  58. image_files = []
  59. for ext in image_extensions:
  60. image_files.extend(list(input_dir.glob(f"*{ext}")))
  61. image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
  62. # 去重并排序
  63. image_files = sorted(list(set(str(f) for f in image_files)))
  64. # 限制文件数量
  65. if max_files:
  66. image_files = image_files[:max_files]
  67. return image_files
  68. def get_image_files_from_list(file_list_path: str) -> List[str]:
  69. """
  70. 从文件列表获取图像文件列表
  71. Args:
  72. file_list_path: 文件列表路径
  73. Returns:
  74. 图像文件路径列表
  75. """
  76. print(f"📄 Reading file list from: {file_list_path}")
  77. with open(file_list_path, 'r', encoding='utf-8') as f:
  78. image_files = [line.strip() for line in f if line.strip()]
  79. # 验证文件存在性
  80. valid_files = []
  81. missing_files = []
  82. for file_path in image_files:
  83. if Path(file_path).exists():
  84. valid_files.append(file_path)
  85. else:
  86. missing_files.append(file_path)
  87. if missing_files:
  88. print(f"⚠️ Warning: {len(missing_files)} files not found:")
  89. for missing_file in missing_files[:5]: # 只显示前5个
  90. print(f" - {missing_file}")
  91. if len(missing_files) > 5:
  92. print(f" ... and {len(missing_files) - 5} more")
  93. print(f"✅ Found {len(valid_files)} valid files out of {len(image_files)} in list")
  94. return valid_files
  95. def get_image_files_from_csv(csv_file: str, status_filter: str = "fail") -> List[str]:
  96. """
  97. 从CSV文件获取图像文件列表
  98. Args:
  99. csv_file: CSV文件路径
  100. status_filter: 状态过滤器
  101. Returns:
  102. 图像文件路径列表
  103. """
  104. print(f"📄 Reading image files from CSV: {csv_file}")
  105. # 读取CSV文件, 表头:image_path,status
  106. image_files = []
  107. with open(csv_file, 'r', encoding='utf-8') as f:
  108. for line in f:
  109. # 需要去掉表头, 按“,”分割,读取文件名,状态
  110. image_file, status = line.strip().split(",")
  111. if status.lower() == status_filter.lower():
  112. image_files.append(image_file)
  113. return image_files
  114. def collect_pid_files(pid_output_file: str) -> List[Tuple[str, str]]:
  115. """
  116. 从进程输出文件中收集文件
  117. Args:
  118. pid_output_file: 进程输出文件路径
  119. Returns:
  120. 文件列表(文件路径,处理结果)
  121. """
  122. """
  123. 单进程结果统计文件格式
  124. "results": [
  125. {
  126. "image_path": "docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.jpg",
  127. "processing_time": 2.0265579223632812e-06,
  128. "success": true,
  129. "device": "gpu:3",
  130. "output_json": "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Results_Scheduler/process_3/docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.json",
  131. "output_md": "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Results_Scheduler/process_3/docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.md"
  132. },
  133. ...
  134. """
  135. if not Path(pid_output_file).exists():
  136. print(f"⚠️ Warning: PID output file not found: {pid_output_file}")
  137. return []
  138. with open(pid_output_file, 'r', encoding='utf-8') as f:
  139. data = json.load(f)
  140. if not isinstance(data, dict) or "results" not in data:
  141. print(f"⚠️ Warning: Invalid PID output file format: {pid_output_file}")
  142. return []
  143. # 返回文件路径和处理状态, 如果“success”: True, 则状态为“success”, 否则为“fail”
  144. file_list = []
  145. for file_result in data.get("results", []):
  146. image_path = file_result.get("image_path", "")
  147. status = "success" if file_result.get("success", False) else "fail"
  148. file_list.append((image_path, status))
  149. return file_list
  150. def collect_processed_files(results: List[Dict[str, Any]]) -> List[Tuple[str, str]]:
  151. """
  152. 从处理结果中收集文件
  153. Args:
  154. results: 处理结果列表
  155. Returns:
  156. 文件列表(文件路径,处理结果),
  157. """
  158. processed_files = []
  159. for result in results:
  160. """
  161. 根据output_dir+process_id找到每个进程的结果文件
  162. {
  163. "process_id": 1,
  164. "success": true,
  165. "processing_time": 42.744526386260986,
  166. "file_count": 5,
  167. "device": "gpu:1",
  168. "output_dir": "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Results_Scheduler/process_1",
  169. ...
  170. }
  171. """
  172. pid_output_file = Path(result.get("output_dir", "")) / f"process_{result['process_id']}.json"
  173. if not pid_output_file.exists():
  174. print(f"⚠️ Warning: Output file not found for process {result['process_id']}: {pid_output_file}")
  175. if not result.get("success", False):
  176. # 整个进程失败的情况
  177. process_failed_files = result.get("failed_files", [])
  178. processed_files.extend([(f, "fail") for f in process_failed_files if f])
  179. else:
  180. pid_files = collect_pid_files(str(pid_output_file))
  181. processed_files.extend(pid_files)
  182. return processed_files
  183. def run_single_process(args: Tuple[List[str], Dict[str, Any], int]) -> Dict[str, Any]:
  184. """
  185. 运行单个ppstructurev3_single_process.py进程
  186. Args:
  187. args: (file_chunk, config, process_id)
  188. Returns:
  189. 处理结果
  190. """
  191. file_chunk, config, process_id = args
  192. if not file_chunk:
  193. return {"process_id": process_id, "success": False, "error": "Empty file chunk"}
  194. # 创建临时文件列表
  195. temp_file_list = create_temp_file_list(file_chunk)
  196. try:
  197. # 创建进程专用的输出目录
  198. process_output_dir = Path(config["output_dir"]) / f"process_{process_id}"
  199. process_output_dir.mkdir(parents=True, exist_ok=True)
  200. # 构建命令行参数
  201. cmd = [
  202. sys.executable,
  203. config["single_process_script"],
  204. "--input_file_list", temp_file_list, # 需要修改single_process脚本支持文件列表
  205. "--output_dir", str(process_output_dir),
  206. "--pipeline", config["pipeline"],
  207. "--device", config["device"],
  208. "--batch_size", str(config["batch_size"]),
  209. ]
  210. # 添加可选参数
  211. if config.get("test_mode", False):
  212. cmd.append("--test_mode")
  213. print(f"Process {process_id} starting with {len(file_chunk)} files on device {config['device']}")
  214. # 执行子进程
  215. start_time = time.time()
  216. result = subprocess.run(
  217. cmd,
  218. capture_output=True,
  219. text=True,
  220. timeout=config.get("timeout", 3600) # 1小时超时
  221. )
  222. processing_time = time.time() - start_time
  223. if result.returncode == 0:
  224. print(f"Process {process_id} completed successfully in {processing_time:.2f}s")
  225. # 读取结果文件
  226. result_files = list(process_output_dir.glob("*.json"))
  227. return {
  228. "process_id": process_id,
  229. "success": True,
  230. "processing_time": processing_time,
  231. "file_count": len(file_chunk),
  232. "device": config["device"],
  233. "output_dir": str(process_output_dir),
  234. "result_files": [str(f) for f in result_files],
  235. "stdout": result.stdout,
  236. "stderr": result.stderr
  237. }
  238. else:
  239. print(f"Process {process_id} failed with return code {result.returncode}")
  240. return {
  241. "process_id": process_id,
  242. "success": False,
  243. "error": f"Process failed with return code {result.returncode}",
  244. "processing_time": processing_time,
  245. "file_count": len(file_chunk),
  246. "device": config["device"],
  247. "output_dir": str(process_output_dir),
  248. "failed_files": [str(f) for f in file_chunk],
  249. "stdout": result.stdout,
  250. "stderr": result.stderr
  251. }
  252. except subprocess.TimeoutExpired:
  253. print(f"Process {process_id} timed out")
  254. return {
  255. "process_id": process_id,
  256. "success": False,
  257. "error": "Process timeout",
  258. "device": config["device"],
  259. "output_dir": str(process_output_dir),
  260. "failed_files": [str(f) for f in file_chunk]
  261. }
  262. except Exception as e:
  263. print(f"Process {process_id} error: {e}")
  264. return {
  265. "process_id": process_id,
  266. "success": False,
  267. "error": str(e),
  268. "device": config["device"],
  269. "output_dir": str(process_output_dir),
  270. "failed_files": [str(f) for f in file_chunk]
  271. }
  272. finally:
  273. # 清理临时文件
  274. try:
  275. os.unlink(temp_file_list)
  276. except:
  277. pass
  278. def monitor_progress(total_files: int, completed_queue: Queue):
  279. """
  280. 监控处理进度
  281. """
  282. with tqdm(total=total_files, desc="Total Progress", unit="files") as pbar:
  283. completed_count = 0
  284. while completed_count < total_files:
  285. try:
  286. batch_count = completed_queue.get(timeout=1)
  287. completed_count += batch_count
  288. pbar.update(batch_count)
  289. except:
  290. continue
  291. def main():
  292. """主函数"""
  293. parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Multi-Process Scheduler")
  294. # 输入输出参数
  295. # 输入输出参数
  296. input_group = parser.add_mutually_exclusive_group(required=True)
  297. input_group.add_argument("--input_dir", type=str, help="Input directory")
  298. input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
  299. input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
  300. parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
  301. parser.add_argument("--single_process_script", type=str,
  302. default="./ppstructurev3_single_process.py",
  303. help="Path to single process script")
  304. # 并行参数
  305. parser.add_argument("--num_processes", type=int, default=4, help="Number of parallel processes")
  306. parser.add_argument("--devices", type=str, default="gpu:0,gpu:1,gpu:2,gpu:3",
  307. help="Device list (comma separated)")
  308. # Pipeline参数
  309. parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
  310. parser.add_argument("--batch_size", type=int, default=4, help="Batch size per process")
  311. parser.add_argument("--timeout", type=int, default=3600, help="Process timeout in seconds")
  312. # 其他参数
  313. parser.add_argument("--test_mode", action="store_true", help="Test mode")
  314. parser.add_argument("--max_files", type=int, default=None, help="Maximum files to process")
  315. args = parser.parse_args()
  316. try:
  317. # 获取图像文件列表
  318. if args.input_csv:
  319. # 从CSV文件读取
  320. image_files = get_image_files_from_csv(args.input_csv, "fail")
  321. print(f"📊 Loaded {len(image_files)} files from CSV with status filter: fail")
  322. elif args.input_file_list:
  323. # 从文件列表读取
  324. image_files = get_image_files_from_list(args.input_file_list)
  325. else:
  326. # 从目录读取
  327. input_dir = Path(args.input_dir).resolve()
  328. print(f"📁 Input dir: {input_dir}")
  329. if not input_dir.exists():
  330. print(f"❌ Input directory does not exist: {input_dir}")
  331. return 1
  332. image_files = get_image_files_from_dir(input_dir, args.max_files)
  333. output_dir = Path(args.output_dir).resolve()
  334. print(f"Input dir: {input_dir}")
  335. print(f"Output dir: {output_dir}")
  336. # 限制文件数量
  337. if args.max_files:
  338. image_files = image_files[:args.max_files]
  339. if args.test_mode:
  340. image_files = image_files[:20]
  341. print(f"Test mode: processing only {len(image_files)} images")
  342. print(f"Found {len(image_files)} image files")
  343. # 解析设备列表
  344. devices = [d.strip() for d in args.devices.split(',')]
  345. if len(devices) < args.num_processes:
  346. # 如果设备数少于进程数,循环使用设备
  347. devices = devices * ((args.num_processes // len(devices)) + 1)
  348. devices = devices[:args.num_processes]
  349. print(f"Using {args.num_processes} processes with devices: {devices}")
  350. # 分割文件列表
  351. file_chunks = split_files(image_files, args.num_processes)
  352. print(f"Split into {len(file_chunks)} chunks: {[len(chunk) for chunk in file_chunks]}")
  353. # 创建输出目录
  354. output_dir.mkdir(parents=True, exist_ok=True)
  355. # 准备进程参数
  356. process_configs = []
  357. for i, (chunk, device) in enumerate(zip(file_chunks, devices)):
  358. config = {
  359. "single_process_script": str(Path(args.single_process_script).resolve()),
  360. "output_dir": str(output_dir),
  361. "pipeline": args.pipeline,
  362. "device": device,
  363. "batch_size": args.batch_size,
  364. "timeout": args.timeout,
  365. "test_mode": args.test_mode
  366. }
  367. process_configs.append((chunk, config, i))
  368. # 启动进度监控
  369. completed_queue = Queue()
  370. progress_thread = threading.Thread(
  371. target=monitor_progress,
  372. args=(len(image_files), completed_queue)
  373. )
  374. progress_thread.daemon = True
  375. progress_thread.start()
  376. # 执行并行处理
  377. start_time = time.time()
  378. results = []
  379. with ProcessPoolExecutor(max_workers=args.num_processes) as executor:
  380. # 提交所有任务
  381. future_to_process = {
  382. executor.submit(run_single_process, config): i
  383. for i, config in enumerate(process_configs)
  384. }
  385. # 收集结果
  386. for future in as_completed(future_to_process):
  387. process_id = future_to_process[future]
  388. try:
  389. result = future.result()
  390. results.append(result)
  391. # 更新进度
  392. if result.get("success", False):
  393. completed_queue.put(result.get("file_count", 0))
  394. print(f"Process {process_id} finished: {result.get('success', False)}")
  395. except Exception as e:
  396. print(f"Process {process_id} generated an exception: {e}")
  397. results.append({
  398. "process_id": process_id,
  399. "success": False,
  400. "error": str(e)
  401. })
  402. total_time = time.time() - start_time
  403. # 统计结果
  404. successful_processes = sum(1 for r in results if r.get('success', False))
  405. total_processed_files = sum(r.get('file_count', 0) for r in results if r.get('success', False))
  406. print(f"\n" + "="*60)
  407. print(f"🎉 Parallel processing completed!")
  408. print(f"📊 Statistics:")
  409. print(f" Total processes: {len(results)}")
  410. print(f" Successful processes: {successful_processes}")
  411. print(f" Total files processed: {total_processed_files}/{len(image_files)}")
  412. print(f" Success rate: {total_processed_files/len(image_files)*100:.2f}%")
  413. print(f"⏱️ Performance:")
  414. print(f" Total time: {total_time:.2f} seconds")
  415. print(f" Throughput: {total_processed_files/total_time:.2f} files/second")
  416. print(f" Avg time per file: {total_time/total_processed_files:.2f} seconds")
  417. # 保存调度结果
  418. scheduler_stats = {
  419. "total_files": len(image_files),
  420. "total_processes": len(results),
  421. "successful_processes": successful_processes,
  422. "total_processed_files": total_processed_files,
  423. "success_rate": total_processed_files / len(image_files) if len(image_files) > 0 else 0,
  424. "total_time": total_time,
  425. "throughput": total_processed_files / total_time if total_time > 0 else 0,
  426. "avg_time_per_file": total_time / total_processed_files if total_processed_files > 0 else 0,
  427. "num_processes": args.num_processes,
  428. "devices": devices,
  429. "batch_size": args.batch_size,
  430. "pipeline": args.pipeline,
  431. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
  432. }
  433. final_results = {
  434. "scheduler_stats": scheduler_stats,
  435. "process_results": results
  436. }
  437. # 保存结果
  438. output_file = output_dir / f"scheduler_results_{args.num_processes}procs.json"
  439. with open(output_file, 'w', encoding='utf-8') as f:
  440. json.dump(final_results, f, ensure_ascii=False, indent=2)
  441. print(f"💾 Scheduler results saved to: {output_file}")
  442. # 收集文件处理结果
  443. processed_files = []
  444. processed_files = collect_processed_files(results)
  445. output_file_processed = output_dir / f"processed_files_{args.num_processes}procs_{time.strftime('%Y%m%d_%H%M%S')}.csv"
  446. with open(output_file_processed, 'w', encoding='utf-8') as f:
  447. f.write("image_path,status\n")
  448. for file_path, status in processed_files:
  449. f.write(f"{file_path},{status}\n")
  450. print(f"💾 Processed files saved to: {output_file_processed}")
  451. return 0 if successful_processes == len(results) else 1
  452. except Exception as e:
  453. print(f"❌ Scheduler failed: {e}")
  454. import traceback
  455. traceback.print_exc()
  456. return 1
  457. if __name__ == "__main__":
  458. print(f"🚀 启动多进程调度程序..., 约定各进程统计文件名为: process_{{process_id}}.json")
  459. if len(sys.argv) == 1:
  460. # 默认配置
  461. default_config = {
  462. "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
  463. "output_dir": "./OmniDocBench_Results_Scheduler",
  464. "num_processes": 4,
  465. "devices": "gpu:0,gpu:1,gpu:2,gpu:3",
  466. "batch_size": 2,
  467. }
  468. default_config = {
  469. "input_csv": "./OmniDocBench_Results_Scheduler/processed_files_4procs_20250814_182650.csv",
  470. "output_dir": "./OmniDocBench_Results_Scheduler",
  471. "num_processes": 4,
  472. "devices": "gpu:0,gpu:1,gpu:2,gpu:3",
  473. "batch_size": 2,
  474. }
  475. sys.argv = [sys.argv[0]]
  476. for key, value in default_config.items():
  477. sys.argv.extend([f"--{key}", str(value)])
  478. sys.argv.append("--test_mode")
  479. sys.exit(main())