paddle_vl_adapter.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import sys
  2. from pathlib import Path
  3. from typing import Dict, Any, List, Union, Optional
  4. import numpy as np
  5. from PIL import Image
  6. from loguru import logger
  7. # 导入基类
  8. from .mineru_adapter import MinerUVLRecognizer
  9. # 导入 mineru-vl-utils 的客户端
  10. try:
  11. from mineru_vl_utils import MinerUClient
  12. MINERU_VL_UTILS_AVAILABLE = True
  13. except ImportError as e:
  14. logger.warning(f"mineru-vl-utils not available: {e}")
  15. MINERU_VL_UTILS_AVAILABLE = False
  16. class PaddleVLRecognizer(MinerUVLRecognizer):
  17. """
  18. PaddleOCR-VL识别适配器,继承自MinerUVLRecognizer
  19. 主要差异:
  20. 1. 强制使用 PaddleOCR-VL-0.9B 模型
  21. 2. 确保使用 vllm-server 后端
  22. 3. 复用所有MinerU的预处理/后处理逻辑
  23. """
  24. def __init__(self, config: Dict[str, Any]):
  25. # 🔧 强制设置 PaddleOCR-VL 模型名称
  26. config['model_name'] = 'PaddleOCR-VL-0.9B'
  27. # 🔧 确保使用正确的后端配置
  28. if config.get('backend') not in ['http-client']:
  29. logger.error(
  30. f"Backend '{config.get('backend')}' may not be optimal for PaddleOCR-VL. "
  31. f"must: 'http-client'"
  32. )
  33. # 调用父类初始化
  34. super().__init__(config)
  35. def initialize(self):
  36. """初始化VL模型 - 使用MinerU的客户端"""
  37. if not MINERU_VL_UTILS_AVAILABLE:
  38. raise ImportError("mineru-vl-utils is required for PaddleVLRecognizer")
  39. try:
  40. backend = self.config.get('backend', 'http-client')
  41. server_url = self.config.get('server_url')
  42. model_params = self.config.get('model_params', {})
  43. # 🔧 提取 MinerUClient 所需的参数
  44. # 从 model_params 中获取,如果没有则使用默认值
  45. max_concurrency = model_params.get('max_concurrency', 100)
  46. http_timeout = model_params.get('http_timeout', 600)
  47. # 🔧 PaddleOCR-VL 特定的提示词(可选)
  48. prompts = model_params.get('prompts', {
  49. "table": "\nTable Recognition:",
  50. "equation": "\nFormula Recognition:",
  51. "[default]": "\nText Recognition:",
  52. "[layout]": "\nLayout Detection:",
  53. })
  54. # 🔧 初始化 MinerUClient
  55. logger.info(f"Initializing PaddleOCR-VL with backend: {backend}")
  56. logger.info(f"Server URL: {server_url}")
  57. logger.info(f"Max concurrency: {max_concurrency}")
  58. # 根据后端类型调整参数
  59. if backend == 'http-client':
  60. # HTTP客户端模式
  61. self.vlm_model = MinerUClient(
  62. backend=backend,
  63. model_name=self.config['model_name'],
  64. server_url=server_url,
  65. prompts=prompts,
  66. max_concurrency=max_concurrency,
  67. http_timeout=http_timeout,
  68. use_tqdm=False, # 可根据需要调整
  69. )
  70. else:
  71. raise ValueError(f"Unsupported backend for PaddleOCR-VL: {backend}")
  72. logger.success(f"✅ PaddleOCR-VL recognizer initialized: {backend}")
  73. except Exception as e:
  74. logger.error(f"❌ Failed to initialize PaddleOCR-VL recognizer: {e}")
  75. raise
  76. # 以下方法都继承自 MinerUVLRecognizer,无需重写:
  77. # - cleanup()
  78. # - _preprocess_image()
  79. # - recognize_table()
  80. # - recognize_formula()
  81. # - recognize_text()
  82. # - batch_recognize_table()
  83. # - batch_recognize_formula()
  84. # - _clean_latex()
  85. # - _html_to_markdown()
  86. # - _extract_cells_from_html()
  87. # 导出适配器类
  88. __all__ = ['PaddleVLRecognizer']