file_utils.py 9.1 KB


  1. import tempfile
  2. from pathlib import Path
  3. from typing import List, Tuple
  4. import json
  5. from .doc_utils import load_images_from_pdf
  6. import traceback
  7. def split_files(file_list: List[str], num_splits: int) -> List[List[str]]:
  8. """
  9. 将文件列表分割成指定数量的子列表
  10. Args:
  11. file_list: 文件路径列表
  12. num_splits: 分割数量
  13. Returns:
  14. 分割后的文件列表
  15. """
  16. if num_splits <= 0:
  17. return [file_list]
  18. chunk_size = len(file_list) // num_splits
  19. remainder = len(file_list) % num_splits
  20. chunks = []
  21. start = 0
  22. for i in range(num_splits):
  23. # 前remainder个chunk多分配一个文件
  24. current_chunk_size = chunk_size + (1 if i < remainder else 0)
  25. if current_chunk_size > 0:
  26. chunks.append(file_list[start:start + current_chunk_size])
  27. start += current_chunk_size
  28. return [chunk for chunk in chunks if chunk] # 过滤空列表
  29. def create_temp_file_list(file_chunk: List[str]) -> str:
  30. """
  31. 创建临时文件列表文件
  32. Args:
  33. file_chunk: 文件路径列表
  34. Returns:
  35. 临时文件路径
  36. """
  37. with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
  38. for file_path in file_chunk:
  39. f.write(f"{file_path}\n")
  40. return f.name
  41. def get_image_files_from_dir(input_dir: Path, pattern: str = "*", max_files: int = None) -> List[str]:
  42. """
  43. 从目录获取图像文件列表
  44. Args:
  45. input_dir: 输入目录
  46. max_files: 最大文件数量限制
  47. Returns:
  48. 图像文件路径列表
  49. """
  50. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
  51. image_files = []
  52. for ext in image_extensions:
  53. image_files.extend(list(input_dir.glob(f"{pattern}{ext}")))
  54. image_files.extend(list(input_dir.glob(f"{pattern}{ext.upper()}")))
  55. # 去重并排序
  56. image_files = sorted(list(set(str(f) for f in image_files)))
  57. # 限制文件数量
  58. if max_files:
  59. image_files = image_files[:max_files]
  60. return image_files
  61. def get_image_files_from_list(file_list_path: str) -> List[str]:
  62. """
  63. 从文件列表获取图像文件列表
  64. Args:
  65. file_list_path: 文件列表路径
  66. Returns:
  67. 图像文件路径列表
  68. """
  69. print(f"📄 Reading file list from: {file_list_path}")
  70. with open(file_list_path, 'r', encoding='utf-8') as f:
  71. image_files = [line.strip() for line in f if line.strip()]
  72. # 验证文件存在性
  73. valid_files = []
  74. missing_files = []
  75. for file_path in image_files:
  76. if Path(file_path).exists():
  77. valid_files.append(file_path)
  78. else:
  79. missing_files.append(file_path)
  80. if missing_files:
  81. print(f"⚠️ Warning: {len(missing_files)} files not found:")
  82. for missing_file in missing_files[:5]: # 只显示前5个
  83. print(f" - {missing_file}")
  84. if len(missing_files) > 5:
  85. print(f" ... and {len(missing_files) - 5} more")
  86. print(f"✅ Found {len(valid_files)} valid files out of {len(image_files)} in list")
  87. return valid_files
  88. def get_image_files_from_csv(csv_file: str, status_filter: str = "fail") -> List[str]:
  89. """
  90. 从CSV文件获取图像文件列表
  91. Args:
  92. csv_file: CSV文件路径
  93. status_filter: 状态过滤器
  94. Returns:
  95. 图像文件路径列表
  96. """
  97. print(f"📄 Reading image files from CSV: {csv_file}")
  98. # 读取CSV文件, 表头:image_path,status
  99. image_files = []
  100. with open(csv_file, 'r', encoding='utf-8') as f:
  101. for line in f:
  102. # 需要去掉表头, 按“,”分割,读取文件名,状态
  103. image_file, status = line.strip().split(",")
  104. if status.lower() == status_filter.lower():
  105. image_files.append(image_file)
  106. return image_files
  107. def collect_pid_files(pid_output_file: str) -> List[Tuple[str, str]]:
  108. """
  109. 从进程输出文件中收集文件
  110. Args:
  111. pid_output_file: 进程输出文件路径
  112. Returns:
  113. 文件列表(文件路径,处理结果)
  114. """
  115. """
  116. 单进程结果统计文件格式
  117. "results": [
  118. {
  119. "image_path": "docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.jpg",
  120. "processing_time": 2.0265579223632812e-06,
  121. "success": true,
  122. "device": "gpu:3",
  123. "output_json": "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Results_Scheduler/process_3/docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.json",
  124. "output_md": "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Results_Scheduler/process_3/docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.md"
  125. },
  126. ...
  127. """
  128. if not Path(pid_output_file).exists():
  129. print(f"⚠️ Warning: PID output file not found: {pid_output_file}")
  130. return []
  131. with open(pid_output_file, 'r', encoding='utf-8') as f:
  132. data = json.load(f)
  133. if not isinstance(data, dict) or "results" not in data:
  134. print(f"⚠️ Warning: Invalid PID output file format: {pid_output_file}")
  135. return []
  136. # 返回文件路径和处理状态, 如果“success”: True, 则状态为“success”, 否则为“fail”
  137. file_list = []
  138. for file_result in data.get("results", []):
  139. image_path = file_result.get("image_path", "")
  140. status = "success" if file_result.get("success", False) else "fail"
  141. file_list.append((image_path, status))
  142. return file_list
  143. def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
  144. """
  145. 将PDF转换为图像文件
  146. Args:
  147. pdf_file: PDF文件路径
  148. output_dir: 输出目录
  149. dpi: 图像分辨率
  150. Returns:
  151. 生成的图像文件路径列表
  152. """
  153. pdf_path = Path(pdf_file)
  154. if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
  155. print(f"❌ Invalid PDF file: {pdf_path}")
  156. return []
  157. # 如果没有指定输出目录,使用PDF同名目录
  158. if output_dir is None:
  159. output_path = pdf_path.parent / f"{pdf_path.stem}"
  160. else:
  161. output_path = Path(output_dir) / f"{pdf_path.stem}"
  162. output_path = output_path.resolve()
  163. output_path.mkdir(parents=True, exist_ok=True)
  164. try:
  165. # 使用doc_utils中的函数加载PDF图像
  166. images = load_images_from_pdf(str(pdf_path), dpi=dpi)
  167. image_paths = []
  168. for i, image in enumerate(images):
  169. # 生成图像文件名
  170. image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
  171. image_path = output_path / image_filename
  172. # 保存图像
  173. image.save(str(image_path))
  174. image_paths.append(str(image_path))
  175. print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
  176. return image_paths
  177. except Exception as e:
  178. print(f"❌ Error converting PDF {pdf_path}: {e}")
  179. traceback.print_exc()
  180. return []
  181. def get_input_files(args) -> List[str]:
  182. """
  183. 获取输入文件列表,统一处理PDF和图像文件
  184. Args:
  185. args: 命令行参数
  186. Returns:
  187. 处理后的图像文件路径列表
  188. """
  189. input_files = []
  190. # 获取原始输入文件
  191. if args.input_csv:
  192. raw_files = get_image_files_from_csv(args.input_csv, "fail")
  193. elif args.input_file_list:
  194. raw_files = get_image_files_from_list(args.input_file_list)
  195. elif args.input_file:
  196. raw_files = [Path(args.input_file).resolve()]
  197. else:
  198. input_dir = Path(args.input_dir).resolve()
  199. if not input_dir.exists():
  200. print(f"❌ Input directory does not exist: {input_dir}")
  201. return []
  202. # 获取所有支持的文件(图像和PDF)
  203. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
  204. pdf_extensions = ['.pdf']
  205. raw_files = []
  206. for ext in image_extensions + pdf_extensions:
  207. raw_files.extend(list(input_dir.glob(f"*{ext}")))
  208. raw_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
  209. raw_files = [str(f) for f in raw_files]
  210. # 分别处理PDF和图像文件
  211. pdf_count = 0
  212. image_count = 0
  213. for file_path in raw_files:
  214. file_path = Path(file_path)
  215. if file_path.suffix.lower() == '.pdf':
  216. # 转换PDF为图像
  217. print(f"📄 Processing PDF: {file_path.name}")
  218. pdf_images = convert_pdf_to_images(
  219. str(file_path),
  220. args.output_dir,
  221. dpi=args.pdf_dpi
  222. )
  223. input_files.extend(pdf_images)
  224. pdf_count += 1
  225. else:
  226. # 直接添加图像文件
  227. if file_path.exists():
  228. input_files.append(str(file_path))
  229. image_count += 1
  230. print(f"📊 Input summary:")
  231. print(f" PDF files processed: {pdf_count}")
  232. print(f" Image files found: {image_count}")
  233. print(f" Total image files to process: {len(input_files)}")
  234. return sorted(list(set(str(f) for f in input_files)))