processor.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. """
  2. PaddleX 统一处理器
  3. 支持多种 pipeline(PaddleOCR-VL 和 PP-StructureV3)的文档处理类
  4. """
  5. import os
  6. import time
  7. import traceback
  8. import warnings
  9. from pathlib import Path
  10. from typing import List, Dict, Any
  11. from loguru import logger
  12. # 抑制特定警告
  13. warnings.filterwarnings("ignore", message="To copy construct from a tensor")
  14. warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
  15. warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
  16. from paddlex import create_pipeline
  17. # 导入工具函数
  18. import sys
  19. paddle_common_root = Path(__file__).parent
  20. if str(paddle_common_root) not in sys.path:
  21. sys.path.insert(0, str(paddle_common_root))
  22. from .utils import (
  23. convert_pruned_result_to_json,
  24. save_output_images,
  25. save_markdown_content
  26. )
  27. # 导入适配器
  28. from .adapters import (
  29. apply_table_recognition_adapter,
  30. restore_original_function,
  31. apply_enhanced_doc_preprocessor,
  32. restore_paddlex_doc_preprocessor
  33. )
  34. class PaddleXProcessor:
  35. """PaddleX 统一处理器,支持多种 pipeline"""
  36. def __init__(self,
  37. pipeline_name: str = "PP-StructureV3",
  38. device: str = "gpu:0",
  39. normalize_numbers: bool = True,
  40. use_enhanced_adapter: bool = True,
  41. log_level: str = "INFO",
  42. **kwargs):
  43. """
  44. 初始化处理器
  45. Args:
  46. pipeline_name: Pipeline 名称或配置文件路径
  47. device: 设备字符串(如 'gpu:0', 'cpu')
  48. normalize_numbers: 是否标准化数字
  49. use_enhanced_adapter: 是否使用增强适配器
  50. log_level: 日志级别(DEBUG, INFO, WARNING, ERROR),当为 DEBUG 时会打印详细错误信息
  51. **kwargs: 其他预测参数
  52. """
  53. self.pipeline_name = pipeline_name
  54. self.device = device
  55. self.normalize_numbers = normalize_numbers
  56. self.use_enhanced_adapter = use_enhanced_adapter
  57. self.log_level = log_level
  58. self.predict_kwargs = kwargs
  59. # 检测 pipeline 类型
  60. self.is_paddleocr_vl = 'PaddleOCR-VL'.lower() in str(pipeline_name).lower()
  61. # 应用适配器
  62. self.adapter_applied = False
  63. if use_enhanced_adapter:
  64. self.adapter_applied = apply_table_recognition_adapter() and apply_enhanced_doc_preprocessor()
  65. if self.adapter_applied:
  66. logger.info("🎯 Enhanced table recognition adapter activated and document preprocessor applied")
  67. else:
  68. logger.warning("⚠️ Failed to apply adapter, using original implementation")
  69. # 初始化 pipeline
  70. self.pipeline = None
  71. self._initialize_pipeline()
  72. logger.info(f"PaddleX Processor 初始化完成:")
  73. logger.info(f" - Pipeline: {pipeline_name}")
  74. logger.info(f" - 设备: {device}")
  75. logger.info(f" - Pipeline 类型: {'PaddleOCR-VL' if self.is_paddleocr_vl else 'PP-StructureV3'}")
  76. logger.info(f" - 数字标准化: {normalize_numbers}")
  77. logger.info(f" - 增强适配器: {use_enhanced_adapter}")
  78. logger.info(f" - 日志级别: {log_level}")
  79. def _initialize_pipeline(self):
  80. """初始化 pipeline"""
  81. try:
  82. # 设置环境变量以减少警告
  83. os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
  84. logger.info(f"Initializing pipeline '{self.pipeline_name}' on device '{self.device}'...")
  85. self.pipeline = create_pipeline(self.pipeline_name, device=self.device)
  86. logger.info(f"Pipeline initialized successfully on {self.device}")
  87. except Exception as e:
  88. logger.error(f"Failed to initialize pipeline: {e}")
  89. if self.log_level == "DEBUG":
  90. traceback.print_exc()
  91. if self.adapter_applied:
  92. restore_original_function()
  93. restore_paddlex_doc_preprocessor()
  94. raise
  95. def _get_predict_kwargs(self) -> Dict[str, Any]:
  96. """根据 pipeline 类型获取预测参数"""
  97. if self.is_paddleocr_vl:
  98. # PaddleOCR-VL 使用驼峰命名
  99. return {
  100. 'use_layout_detection': self.predict_kwargs.get('use_layout_detection', True),
  101. 'use_doc_orientation_classify': self.predict_kwargs.get('use_doc_orientation', True),
  102. 'use_doc_unwarping': self.predict_kwargs.get('use_doc_unwarping', False),
  103. }
  104. else:
  105. # PP-StructureV3 使用下划线命名
  106. return {
  107. 'use_doc_orientation_classify': self.predict_kwargs.get('use_doc_orientation', True),
  108. 'use_doc_unwarping': self.predict_kwargs.get('use_doc_unwarping', False),
  109. 'use_layout_detection': self.predict_kwargs.get('use_layout_detection', True),
  110. 'use_seal_recognition': self.predict_kwargs.get('use_seal_recognition', True),
  111. 'use_table_recognition': self.predict_kwargs.get('use_table_recognition', True),
  112. 'use_formula_recognition': self.predict_kwargs.get('use_formula_recognition', False),
  113. 'use_chart_recognition': self.predict_kwargs.get('use_chart_recognition', True),
  114. 'use_ocr_results_with_table_cells': self.predict_kwargs.get('use_ocr_results_with_table_cells', True),
  115. 'use_table_orientation_classify': self.predict_kwargs.get('use_table_orientation_classify', False),
  116. 'use_wired_table_cells_trans_to_html': self.predict_kwargs.get('use_wired_table_cells_trans_to_html', True),
  117. 'use_wireless_table_cells_trans_to_html': self.predict_kwargs.get('use_wireless_table_cells_trans_to_html', True),
  118. }
  119. def process_single_image(self, image_path: str, output_dir: str) -> Dict[str, Any]:
  120. """
  121. 处理单张图片
  122. Args:
  123. image_path: 图片路径
  124. output_dir: 输出目录
  125. Returns:
  126. dict: 处理结果,包含 success 字段(基于输出文件存在性判断)
  127. """
  128. start_time = time.time()
  129. image_path_obj = Path(image_path)
  130. image_name = image_path_obj.stem
  131. # 判断是否为PDF页面(根据文件名模式)
  132. is_pdf_page = "_page_" in image_path_obj.name
  133. result_info = {
  134. "image_path": image_path,
  135. "processing_time": 0,
  136. "success": False,
  137. "device": self.device,
  138. "error": None,
  139. "output_files": {},
  140. "is_pdf_page": is_pdf_page,
  141. "processing_info": {}
  142. }
  143. try:
  144. if self.pipeline is None:
  145. raise Exception("Pipeline not initialized")
  146. # 准备预测参数
  147. predict_kwargs = self._get_predict_kwargs()
  148. predict_kwargs['input'] = image_path
  149. # 使用 pipeline 预测
  150. results = self.pipeline.predict(**predict_kwargs)
  151. # 处理结果(应该只有一个结果)
  152. # 使用迭代方式处理生成器,与原始实现保持一致
  153. result = None
  154. for idx, res in enumerate(results):
  155. if idx > 0:
  156. raise ValueError("Multiple results found for a single image")
  157. result = res
  158. break # 只处理第一个结果
  159. if result is None:
  160. raise Exception("No results returned from pipeline")
  161. input_path = Path(result["input_path"])
  162. # 生成输出文件名
  163. # 使用输入文件名(PaddleX 的 result["input_path"] 可能包含页面信息)
  164. output_filename = input_path.stem
  165. # 转换并保存标准JSON格式
  166. json_content = result.json['res']
  167. json_output_path, converted_json = convert_pruned_result_to_json(
  168. json_content,
  169. str(input_path),
  170. output_dir,
  171. output_filename,
  172. normalize_numbers=self.normalize_numbers
  173. )
  174. # 保存输出图像
  175. img_content = result.img
  176. saved_images = save_output_images(img_content, str(output_dir), output_filename)
  177. # 保存Markdown内容
  178. markdown_content = result.markdown
  179. md_output_path = save_markdown_content(
  180. markdown_content,
  181. output_dir,
  182. output_filename,
  183. normalize_numbers=self.normalize_numbers,
  184. key_text='markdown_texts',
  185. key_images='markdown_images',
  186. json_data=converted_json
  187. )
  188. # 根据实际保存的文件路径判断成功(成功判断标准:.md 和 .json 文件都存在)
  189. # 使用实际保存的文件路径
  190. actual_md_path = Path(md_output_path) if md_output_path else Path(output_dir) / f"{output_filename}.md"
  191. actual_json_path = Path(json_output_path) if json_output_path else Path(output_dir) / f"{output_filename}.json"
  192. if actual_md_path.exists() and actual_json_path.exists():
  193. result_info.update({
  194. "success": True,
  195. "output_files": {
  196. "md": str(actual_md_path),
  197. "json": str(actual_json_path),
  198. **saved_images
  199. },
  200. "processing_info": converted_json.get('processing_info', {})
  201. })
  202. logger.info(f"✅ 处理成功: {image_name}")
  203. else:
  204. # 文件不存在,标记为失败
  205. missing_files = []
  206. if not actual_md_path.exists():
  207. missing_files.append("md")
  208. if not actual_json_path.exists():
  209. missing_files.append("json")
  210. result_info["error"] = f"输出文件不存在: {', '.join(missing_files)}"
  211. result_info["success"] = False
  212. logger.error(f"❌ 处理失败: {image_name} - {result_info['error']}")
  213. except Exception as e:
  214. result_info["error"] = str(e)
  215. result_info["success"] = False
  216. logger.error(f"Error processing {image_name}: {e}")
  217. if self.log_level == "DEBUG":
  218. traceback.print_exc()
  219. finally:
  220. result_info["processing_time"] = time.time() - start_time
  221. return result_info
  222. def __del__(self):
  223. """清理资源"""
  224. if self.adapter_applied:
  225. try:
  226. restore_original_function()
  227. restore_paddlex_doc_preprocessor()
  228. logger.info("🔄 Original function restored")
  229. except Exception as e:
  230. logger.warning(f"Failed to restore original function: {e}")