| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- """
- DotsOCR vLLM 处理器
- 基于 DotsOCR 的文档处理类
- """
- import os
- import shutil
- import time
- import tempfile
- import uuid
- import traceback
- from pathlib import Path
- from typing import List, Dict, Any
- from PIL import Image
- from loguru import logger
- # 导入 dots.ocr 相关模块
- from dots_ocr.parser import DotsOCRParser
- from dots_ocr.utils import dict_promptmode_to_prompt
- from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
- # 导入 ocr_utils
- import sys
- ocr_platform_root = Path(__file__).parents[2]
- if str(ocr_platform_root) not in sys.path:
- sys.path.insert(0, str(ocr_platform_root))
- from ocr_utils import normalize_markdown_table, normalize_json_table
- class DotsOCRProcessor:
- """DotsOCR 处理器"""
-
- def __init__(self,
- ip: str = "127.0.0.1",
- port: int = 8101,
- model_name: str = "DotsOCR",
- prompt_mode: str = "prompt_layout_all_en",
- dpi: int = 200,
- min_pixels: int = MIN_PIXELS,
- max_pixels: int = MAX_PIXELS,
- normalize_numbers: bool = False,
- debug: bool = False):
- """
- 初始化处理器
-
- Args:
- ip: vLLM 服务器 IP
- port: vLLM 服务器端口
- model_name: 模型名称
- prompt_mode: 提示模式
- dpi: PDF 处理 DPI
- min_pixels: 最小像素数
- max_pixels: 最大像素数
- normalize_numbers: 是否标准化数字
- debug: 是否启用调试模式
- """
- self.ip = ip
- self.port = port
- self.model_name = model_name
- self.prompt_mode = prompt_mode
- self.dpi = dpi
- self.min_pixels = min_pixels
- self.max_pixels = max_pixels
- self.normalize_numbers = normalize_numbers
- self.debug = debug
- # 初始化解析器
- self.parser = DotsOCRParser(
- ip=ip,
- port=port,
- dpi=dpi,
- min_pixels=min_pixels,
- max_pixels=max_pixels,
- model_name=model_name
- )
-
- logger.info(f"DotsOCR Parser 初始化完成:")
- logger.info(f" - 服务器: {ip}:{port}")
- logger.info(f" - 模型: {model_name}")
- logger.info(f" - 提示模式: {prompt_mode}")
- logger.info(f" - 像素范围: {min_pixels} - {max_pixels}")
- logger.info(f" - 数字标准化: {normalize_numbers}")
- logger.info(f" - 调试模式: {debug}")
-
- def create_temp_session_dir(self) -> tuple:
- """创建临时会话目录"""
- session_id = uuid.uuid4().hex[:8]
- temp_dir = os.path.join(tempfile.gettempdir(), f"dotsocr_batch_{session_id}")
- os.makedirs(temp_dir, exist_ok=True)
- return temp_dir, session_id
-
- def save_results_to_output_dir(self, result: Dict, image_name: str, output_dir: str) -> Dict[str, str]:
- """
- 将处理结果保存到输出目录
-
- Args:
- result: 解析结果
- image_name: 图片文件名(不含扩展名)
- output_dir: 输出目录
-
- Returns:
- dict: 保存的文件路径
- """
- saved_files = {}
-
- try:
- # 1. 保存 Markdown 文件
- output_md_path = os.path.join(output_dir, f"{image_name}.md")
- md_content = ""
-
- # 优先使用无页眉页脚的版本(符合 OmniDocBench 评测要求)
- if 'md_content_nohf_path' in result and os.path.exists(result['md_content_nohf_path']):
- with open(result['md_content_nohf_path'], 'r', encoding='utf-8') as f:
- md_content = f.read()
- elif 'md_content_path' in result and os.path.exists(result['md_content_path']):
- with open(result['md_content_path'], 'r', encoding='utf-8') as f:
- md_content = f.read()
- else:
- md_content = "# 解析失败\n\n未能提取到有效的文档内容。"
-
- # 如果启用数字标准化,处理 Markdown 内容
- original_text = md_content
- if self.normalize_numbers:
- md_content = normalize_markdown_table(md_content)
-
- # 统计标准化的变化
- changes_count = len([1 for o, n in zip(original_text, md_content) if o != n])
- if changes_count > 0:
- saved_files['md_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
- else:
- saved_files['md_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
-
- with open(output_md_path, 'w', encoding='utf-8') as f:
- f.write(md_content)
- saved_files['md'] = output_md_path
-
- # 如果启用了标准化,也保存原始版本用于对比
- if self.normalize_numbers and original_text != md_content:
- original_markdown_path = Path(output_dir) / f"{Path(image_name).stem}_original.md"
- with open(original_markdown_path, 'w', encoding='utf-8') as f:
- f.write(original_text)
-
- # 2. 保存 JSON 文件
- output_json_path = os.path.join(output_dir, f"{image_name}.json")
-
- if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
- with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
- json_content = f.read()
- else:
- json_content = '{"error": "未能提取到有效的布局信息"}'
-
- # 对json中的表格内容进行数字标准化
- original_json_text = json_content
- if self.normalize_numbers:
- json_content = normalize_json_table(json_content)
-
- # 统计标准化的变化
- changes_count = len([1 for o, n in zip(original_json_text, json_content) if o != n])
- if changes_count > 0:
- saved_files['json_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
- else:
- saved_files['json_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
-
- with open(output_json_path, 'w', encoding='utf-8') as f:
- f.write(json_content)
- saved_files['json'] = output_json_path
-
- # 如果启用了标准化,也保存原始版本用于对比
- if self.normalize_numbers and original_json_text != json_content:
- original_json_path = Path(output_dir) / f"{Path(image_name).stem}_original.json"
- with open(original_json_path, 'w', encoding='utf-8') as f:
- f.write(original_json_text)
-
- # 3. 保存带标注的布局图片
- output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
-
- if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
- # 直接复制布局图片
- shutil.copy2(result['layout_image_path'], output_layout_image_path)
- saved_files['layout_image'] = output_layout_image_path
- else:
- # 如果没有布局图片,使用原始图片作为占位符
- try:
- original_image = Image.open(result.get('original_image_path', ''))
- original_image.save(output_layout_image_path, 'JPEG', quality=95)
- saved_files['layout_image'] = output_layout_image_path
- except Exception as e:
- logger.warning(f"Failed to save layout image: {e}")
- saved_files['layout_image'] = None
-
- except Exception as e:
- logger.error(f"Error saving results for {image_name}: {e}")
- if self.debug:
- traceback.print_exc()
-
- return saved_files
-
- def process_single_image(self, image_path: str, output_dir: str) -> Dict[str, Any]:
- """
- 处理单张图片
-
- Args:
- image_path: 图片路径
- output_dir: 输出目录
-
- Returns:
- dict: 处理结果,包含 success 字段(基于输出文件存在性判断)
- """
- start_time = time.time()
- image_path_obj = Path(image_path)
- image_name = image_path_obj.stem
-
- # 判断是否为PDF页面(根据文件名模式)
- is_pdf_page = "_page_" in image_path_obj.name
-
- # 根据输入类型生成预期的输出文件名
- expected_md_path = Path(output_dir) / f"{image_name}.md"
- expected_json_path = Path(output_dir) / f"{image_name}.json"
-
- result_info = {
- "image_path": image_path,
- "processing_time": 0,
- "success": False,
- "device": f"{self.ip}:{self.port}",
- "error": None,
- "output_files": {},
- "is_pdf_page": is_pdf_page
- }
-
- try:
- # 检查输出文件是否已存在(成功判断标准:.md 和 .json 文件都存在)
- if expected_md_path.exists() and expected_json_path.exists():
- result_info.update({
- "success": True,
- "processing_time": 0,
- "output_files": {
- "md": str(expected_md_path),
- "json": str(expected_json_path)
- },
- "skipped": True
- })
- logger.info(f"✅ 文件已存在,跳过处理: {image_name}")
- return result_info
-
- # 创建临时会话目录
- temp_dir, session_id = self.create_temp_session_dir()
-
- try:
- # 读取图片
- image = Image.open(image_path)
-
- # 使用 DotsOCRParser 处理图片
- filename = f"dotsocr_{session_id}"
- results = self.parser.parse_image(
- input_path=image,
- filename=filename,
- prompt_mode=self.prompt_mode,
- save_dir=temp_dir,
- fitz_preprocess=True # 对图片使用 fitz 预处理
- )
-
- # 解析结果
- if not results:
- raise Exception("未返回解析结果")
-
- result = results[0] # parse_image 返回单个结果的列表
-
- # 保存所有结果文件到输出目录
- saved_files = self.save_results_to_output_dir(result, image_name, output_dir)
-
- # 处理完成后,再次检查输出文件是否存在(成功判断标准)
- if expected_md_path.exists() and expected_json_path.exists():
- result_info.update({
- "success": True,
- "output_files": saved_files
- })
- logger.info(f"✅ 处理成功: {image_name}")
- else:
- # 文件不存在,标记为失败
- missing_files = []
- if not expected_md_path.exists():
- missing_files.append("md")
- if not expected_json_path.exists():
- missing_files.append("json")
- result_info["error"] = f"输出文件不存在: {', '.join(missing_files)}"
- result_info["success"] = False
- logger.error(f"❌ 处理失败: {image_name} - {result_info['error']}")
-
- finally:
- # 清理临时目录
- if os.path.exists(temp_dir):
- shutil.rmtree(temp_dir, ignore_errors=True)
-
- except Exception as e:
- result_info["error"] = str(e)
- result_info["success"] = False
- logger.error(f"Error processing {image_name}: {e}")
- if self.debug:
- traceback.print_exc()
-
- finally:
- result_info["processing_time"] = time.time() - start_time
-
- return result_info
|