processor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. """
  2. DotsOCR vLLM 处理器
  3. 基于 DotsOCR 的文档处理类
  4. """
  5. import os
  6. import shutil
  7. import time
  8. import tempfile
  9. import uuid
  10. import traceback
  11. from pathlib import Path
  12. from typing import List, Dict, Any
  13. from PIL import Image
  14. from loguru import logger
  15. # 导入 dots.ocr 相关模块
  16. from dots_ocr.parser import DotsOCRParser
  17. from dots_ocr.utils import dict_promptmode_to_prompt
  18. from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
  19. # 导入 ocr_utils
  20. import sys
  21. ocr_platform_root = Path(__file__).parents[2]
  22. if str(ocr_platform_root) not in sys.path:
  23. sys.path.insert(0, str(ocr_platform_root))
  24. from ocr_utils import normalize_markdown_table, normalize_json_table
  25. class DotsOCRProcessor:
  26. """DotsOCR 处理器"""
  27. def __init__(self,
  28. ip: str = "127.0.0.1",
  29. port: int = 8101,
  30. model_name: str = "DotsOCR",
  31. prompt_mode: str = "prompt_layout_all_en",
  32. dpi: int = 200,
  33. min_pixels: int = MIN_PIXELS,
  34. max_pixels: int = MAX_PIXELS,
  35. normalize_numbers: bool = False,
  36. debug: bool = False):
  37. """
  38. 初始化处理器
  39. Args:
  40. ip: vLLM 服务器 IP
  41. port: vLLM 服务器端口
  42. model_name: 模型名称
  43. prompt_mode: 提示模式
  44. dpi: PDF 处理 DPI
  45. min_pixels: 最小像素数
  46. max_pixels: 最大像素数
  47. normalize_numbers: 是否标准化数字
  48. debug: 是否启用调试模式
  49. """
  50. self.ip = ip
  51. self.port = port
  52. self.model_name = model_name
  53. self.prompt_mode = prompt_mode
  54. self.dpi = dpi
  55. self.min_pixels = min_pixels
  56. self.max_pixels = max_pixels
  57. self.normalize_numbers = normalize_numbers
  58. self.debug = debug
  59. # 初始化解析器
  60. self.parser = DotsOCRParser(
  61. ip=ip,
  62. port=port,
  63. dpi=dpi,
  64. min_pixels=min_pixels,
  65. max_pixels=max_pixels,
  66. model_name=model_name
  67. )
  68. logger.info(f"DotsOCR Parser 初始化完成:")
  69. logger.info(f" - 服务器: {ip}:{port}")
  70. logger.info(f" - 模型: {model_name}")
  71. logger.info(f" - 提示模式: {prompt_mode}")
  72. logger.info(f" - 像素范围: {min_pixels} - {max_pixels}")
  73. logger.info(f" - 数字标准化: {normalize_numbers}")
  74. logger.info(f" - 调试模式: {debug}")
  75. def create_temp_session_dir(self) -> tuple:
  76. """创建临时会话目录"""
  77. session_id = uuid.uuid4().hex[:8]
  78. temp_dir = os.path.join(tempfile.gettempdir(), f"dotsocr_batch_{session_id}")
  79. os.makedirs(temp_dir, exist_ok=True)
  80. return temp_dir, session_id
  81. def save_results_to_output_dir(self, result: Dict, image_name: str, output_dir: str) -> Dict[str, str]:
  82. """
  83. 将处理结果保存到输出目录
  84. Args:
  85. result: 解析结果
  86. image_name: 图片文件名(不含扩展名)
  87. output_dir: 输出目录
  88. Returns:
  89. dict: 保存的文件路径
  90. """
  91. saved_files = {}
  92. try:
  93. # 1. 保存 Markdown 文件
  94. output_md_path = os.path.join(output_dir, f"{image_name}.md")
  95. md_content = ""
  96. # 优先使用无页眉页脚的版本(符合 OmniDocBench 评测要求)
  97. if 'md_content_nohf_path' in result and os.path.exists(result['md_content_nohf_path']):
  98. with open(result['md_content_nohf_path'], 'r', encoding='utf-8') as f:
  99. md_content = f.read()
  100. elif 'md_content_path' in result and os.path.exists(result['md_content_path']):
  101. with open(result['md_content_path'], 'r', encoding='utf-8') as f:
  102. md_content = f.read()
  103. else:
  104. md_content = "# 解析失败\n\n未能提取到有效的文档内容。"
  105. # 如果启用数字标准化,处理 Markdown 内容
  106. original_text = md_content
  107. if self.normalize_numbers:
  108. md_content = normalize_markdown_table(md_content)
  109. # 统计标准化的变化
  110. changes_count = len([1 for o, n in zip(original_text, md_content) if o != n])
  111. if changes_count > 0:
  112. saved_files['md_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
  113. else:
  114. saved_files['md_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
  115. with open(output_md_path, 'w', encoding='utf-8') as f:
  116. f.write(md_content)
  117. saved_files['md'] = output_md_path
  118. # 如果启用了标准化,也保存原始版本用于对比
  119. if self.normalize_numbers and original_text != md_content:
  120. original_markdown_path = Path(output_dir) / f"{Path(image_name).stem}_original.md"
  121. with open(original_markdown_path, 'w', encoding='utf-8') as f:
  122. f.write(original_text)
  123. # 2. 保存 JSON 文件
  124. output_json_path = os.path.join(output_dir, f"{image_name}.json")
  125. if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
  126. with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
  127. json_content = f.read()
  128. else:
  129. json_content = '{"error": "未能提取到有效的布局信息"}'
  130. # 对json中的表格内容进行数字标准化
  131. original_json_text = json_content
  132. if self.normalize_numbers:
  133. json_content = normalize_json_table(json_content)
  134. # 统计标准化的变化
  135. changes_count = len([1 for o, n in zip(original_json_text, json_content) if o != n])
  136. if changes_count > 0:
  137. saved_files['json_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
  138. else:
  139. saved_files['json_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
  140. with open(output_json_path, 'w', encoding='utf-8') as f:
  141. f.write(json_content)
  142. saved_files['json'] = output_json_path
  143. # 如果启用了标准化,也保存原始版本用于对比
  144. if self.normalize_numbers and original_json_text != json_content:
  145. original_json_path = Path(output_dir) / f"{Path(image_name).stem}_original.json"
  146. with open(original_json_path, 'w', encoding='utf-8') as f:
  147. f.write(original_json_text)
  148. # 3. 保存带标注的布局图片
  149. output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
  150. if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
  151. # 直接复制布局图片
  152. shutil.copy2(result['layout_image_path'], output_layout_image_path)
  153. saved_files['layout_image'] = output_layout_image_path
  154. else:
  155. # 如果没有布局图片,使用原始图片作为占位符
  156. try:
  157. original_image = Image.open(result.get('original_image_path', ''))
  158. original_image.save(output_layout_image_path, 'JPEG', quality=95)
  159. saved_files['layout_image'] = output_layout_image_path
  160. except Exception as e:
  161. logger.warning(f"Failed to save layout image: {e}")
  162. saved_files['layout_image'] = None
  163. except Exception as e:
  164. logger.error(f"Error saving results for {image_name}: {e}")
  165. if self.debug:
  166. traceback.print_exc()
  167. return saved_files
  168. def process_single_image(self, image_path: str, output_dir: str) -> Dict[str, Any]:
  169. """
  170. 处理单张图片
  171. Args:
  172. image_path: 图片路径
  173. output_dir: 输出目录
  174. Returns:
  175. dict: 处理结果,包含 success 字段(基于输出文件存在性判断)
  176. """
  177. start_time = time.time()
  178. image_path_obj = Path(image_path)
  179. image_name = image_path_obj.stem
  180. # 判断是否为PDF页面(根据文件名模式)
  181. is_pdf_page = "_page_" in image_path_obj.name
  182. # 根据输入类型生成预期的输出文件名
  183. expected_md_path = Path(output_dir) / f"{image_name}.md"
  184. expected_json_path = Path(output_dir) / f"{image_name}.json"
  185. result_info = {
  186. "image_path": image_path,
  187. "processing_time": 0,
  188. "success": False,
  189. "device": f"{self.ip}:{self.port}",
  190. "error": None,
  191. "output_files": {},
  192. "is_pdf_page": is_pdf_page
  193. }
  194. try:
  195. # 检查输出文件是否已存在(成功判断标准:.md 和 .json 文件都存在)
  196. if expected_md_path.exists() and expected_json_path.exists():
  197. result_info.update({
  198. "success": True,
  199. "processing_time": 0,
  200. "output_files": {
  201. "md": str(expected_md_path),
  202. "json": str(expected_json_path)
  203. },
  204. "skipped": True
  205. })
  206. logger.info(f"✅ 文件已存在,跳过处理: {image_name}")
  207. return result_info
  208. # 创建临时会话目录
  209. temp_dir, session_id = self.create_temp_session_dir()
  210. try:
  211. # 读取图片
  212. image = Image.open(image_path)
  213. # 使用 DotsOCRParser 处理图片
  214. filename = f"dotsocr_{session_id}"
  215. results = self.parser.parse_image(
  216. input_path=image,
  217. filename=filename,
  218. prompt_mode=self.prompt_mode,
  219. save_dir=temp_dir,
  220. fitz_preprocess=True # 对图片使用 fitz 预处理
  221. )
  222. # 解析结果
  223. if not results:
  224. raise Exception("未返回解析结果")
  225. result = results[0] # parse_image 返回单个结果的列表
  226. # 保存所有结果文件到输出目录
  227. saved_files = self.save_results_to_output_dir(result, image_name, output_dir)
  228. # 处理完成后,再次检查输出文件是否存在(成功判断标准)
  229. if expected_md_path.exists() and expected_json_path.exists():
  230. result_info.update({
  231. "success": True,
  232. "output_files": saved_files
  233. })
  234. logger.info(f"✅ 处理成功: {image_name}")
  235. else:
  236. # 文件不存在,标记为失败
  237. missing_files = []
  238. if not expected_md_path.exists():
  239. missing_files.append("md")
  240. if not expected_json_path.exists():
  241. missing_files.append("json")
  242. result_info["error"] = f"输出文件不存在: {', '.join(missing_files)}"
  243. result_info["success"] = False
  244. logger.error(f"❌ 处理失败: {image_name} - {result_info['error']}")
  245. finally:
  246. # 清理临时目录
  247. if os.path.exists(temp_dir):
  248. shutil.rmtree(temp_dir, ignore_errors=True)
  249. except Exception as e:
  250. result_info["error"] = str(e)
  251. result_info["success"] = False
  252. logger.error(f"Error processing {image_name}: {e}")
  253. if self.debug:
  254. traceback.print_exc()
  255. finally:
  256. result_info["processing_time"] = time.time() - start_time
  257. return result_info