mineru_adapter.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. import sys
  2. from pathlib import Path
  3. from typing import Dict, Any, List, Union, Optional
  4. import numpy as np
  5. import cv2
  6. from PIL import Image
  7. from loguru import logger
  8. # 添加MinerU路径
  9. mineru_path = Path(__file__).parents[4] / "mineru"
  10. if str(mineru_path) not in sys.path:
  11. sys.path.insert(0, str(mineru_path))
  12. from .base import BasePreprocessor, BaseLayoutDetector, BaseVLRecognizer, BaseOCRRecognizer
  13. # 导入MinerU组件
  14. try:
  15. from mineru.backend.pipeline.model_init import AtomModelSingleton
  16. from mineru.backend.vlm.vlm_analyze import ModelSingleton as VLMModelSingleton
  17. from mineru.backend.pipeline.model_list import AtomicModel
  18. from mineru.utils.config_reader import get_device
  19. MINERU_AVAILABLE = True
  20. except ImportError as e:
  21. print(f"Warning: MinerU components not available: {e}")
  22. MINERU_AVAILABLE = False
  23. class MinerUPreprocessor(BasePreprocessor):
  24. """MinerU预处理器适配器"""
  25. def __init__(self, config: Dict[str, Any]):
  26. super().__init__(config)
  27. if not MINERU_AVAILABLE:
  28. raise ImportError("MinerU components not available")
  29. self.atom_model_manager = AtomModelSingleton()
  30. self.orientation_classifier = None
  31. def initialize(self):
  32. """初始化预处理组件"""
  33. # 初始化方向分类器
  34. if self.config.get('orientation_classifier', {}).get('enabled', True):
  35. try:
  36. self.orientation_classifier = self.atom_model_manager.get_atom_model(
  37. atom_model_name=AtomicModel.ImgOrientationCls,
  38. )
  39. print("✅ Orientation classifier initialized")
  40. except Exception as e:
  41. print(f"⚠️ Failed to initialize orientation classifier: {e}")
  42. def cleanup(self):
  43. """清理资源"""
  44. pass
  45. def process(self, image: Union[np.ndarray, Image.Image]) -> tuple[np.ndarray, int]:
  46. """图像预处理"""
  47. # 转换为numpy数组
  48. if isinstance(image, Image.Image):
  49. image = np.array(image)
  50. rotate_map = {0: 0, 1: 90, 2: 180, 3: 270}
  51. rotate_label = 0
  52. processed_image = image
  53. # 方向校正
  54. if self.orientation_classifier is not None:
  55. try:
  56. rotate_label = self.orientation_classifier.predict(image)
  57. processed_image = self._apply_rotation(processed_image, rotate_label)
  58. logger.info(f"📐 Applied rotation: {rotate_label}")
  59. except Exception as e:
  60. logger.error(f"⚠️ Orientation classification failed: {e}")
  61. return processed_image, rotate_map.get(rotate_label, 0)
  62. class MinerULayoutDetector(BaseLayoutDetector):
  63. """MinerU版式检测适配器"""
  64. def __init__(self, config: Dict[str, Any]):
  65. super().__init__(config)
  66. if not MINERU_AVAILABLE:
  67. raise ImportError("MinerU components not available")
  68. self.atom_model_manager = AtomModelSingleton()
  69. self.layout_model = None
  70. def initialize(self):
  71. """初始化版式检测模型"""
  72. try:
  73. # 获取模型配置
  74. model_name = self.config.get('model_name', 'RT-DETR-H_layout_17cls')
  75. model_dir = self.config.get('model_dir')
  76. device = self.config.get('device', 'cpu')
  77. # 初始化版式检测模型
  78. if model_dir:
  79. # 使用自定义模型路径
  80. self.layout_model = self.atom_model_manager.get_atom_model(
  81. atom_model_name=AtomicModel.Layout,
  82. doclayout_yolo_weights=model_dir,
  83. device=device
  84. )
  85. else:
  86. # 使用默认模型
  87. self.layout_model = self.atom_model_manager.get_atom_model(
  88. atom_model_name=AtomicModel.Layout,
  89. device=device
  90. )
  91. print(f"✅ Layout detector initialized: {model_name}")
  92. except Exception as e:
  93. print(f"❌ Failed to initialize layout detector: {e}")
  94. raise
  95. def cleanup(self):
  96. """清理资源"""
  97. pass
  98. def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
  99. """版式检测"""
  100. if self.layout_model is None:
  101. raise RuntimeError("Layout model not initialized")
  102. # 转换为PIL图像
  103. if isinstance(image, np.ndarray):
  104. image = Image.fromarray(image)
  105. # 进行检测
  106. try:
  107. layout_results = self.layout_model.predict([image])
  108. # 转换结果格式
  109. formatted_results = []
  110. for result in layout_results[0]: # 第一页结果
  111. # 提取坐标信息
  112. poly = result.get('poly', [0, 0, 0, 0, 0, 0, 0, 0])
  113. if len(poly) >= 8:
  114. # 转换8点坐标为4点坐标 [x1,y1,x2,y2]
  115. bbox = [poly[0], poly[1], poly[4], poly[5]]
  116. else:
  117. bbox = poly[:4] if len(poly) >= 4 else [0, 0, 0, 0]
  118. formatted_results.append({
  119. 'category': self._map_category_id(result.get('category_id', 1)),
  120. 'bbox': bbox,
  121. 'confidence': result.get('score', 0.0),
  122. 'raw': result
  123. })
  124. return formatted_results
  125. except Exception as e:
  126. print(f"❌ Layout detection failed: {e}")
  127. return []
  128. class MinerUVLRecognizer(BaseVLRecognizer):
  129. """MinerU VL识别适配器"""
  130. def __init__(self, config: Dict[str, Any]):
  131. super().__init__(config)
  132. if not MINERU_AVAILABLE:
  133. raise ImportError("MinerU components not available")
  134. self.vlm_model = None
  135. # 🔧 添加图片尺寸限制配置
  136. self.max_image_size = config.get('max_image_size', 1568) # VLM 模型的最大尺寸
  137. self.resize_mode = config.get('resize_mode', 'max') # 'max' or 'fixed'
  138. def initialize(self):
  139. """初始化VL模型"""
  140. try:
  141. backend = self.config.get('backend', 'http-client')
  142. server_url = self.config.get('server_url')
  143. model_params = self.config.get('model_params', {})
  144. # 初始化VLM模型
  145. self.vlm_model = VLMModelSingleton().get_model(
  146. backend=backend,
  147. model_path=None,
  148. server_url=server_url,
  149. **model_params
  150. )
  151. print(f"✅ VL recognizer initialized: {backend}")
  152. except Exception as e:
  153. print(f"❌ Failed to initialize VL recognizer: {e}")
  154. raise
  155. def cleanup(self):
  156. """清理资源"""
  157. pass
  158. def _preprocess_image(self, image: Union[np.ndarray, Image.Image]) -> Image.Image:
  159. """
  160. 预处理图片,控制尺寸避免序列长度超限
  161. Args:
  162. image: 输入图片
  163. Returns:
  164. 处理后的PIL图片
  165. """
  166. # 转换为PIL图像
  167. if isinstance(image, np.ndarray):
  168. image = Image.fromarray(image)
  169. # 获取原始尺寸
  170. orig_w, orig_h = image.size
  171. # 计算缩放比例
  172. if self.resize_mode == 'max':
  173. # 保持宽高比,最长边不超过 max_image_size
  174. max_dim = max(orig_w, orig_h)
  175. if max_dim > self.max_image_size:
  176. scale = self.max_image_size / max_dim
  177. new_w = int(orig_w * scale)
  178. new_h = int(orig_h * scale)
  179. logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {new_w}x{new_h} (scale={scale:.3f})")
  180. image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
  181. elif self.resize_mode == 'fixed':
  182. # 固定尺寸(可能改变宽高比)
  183. if orig_w != self.max_image_size or orig_h != self.max_image_size:
  184. logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {self.max_image_size}x{self.max_image_size}")
  185. image = image.resize((self.max_image_size, self.max_image_size), Image.Resampling.LANCZOS)
  186. return image
  187. def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
  188. """表格识别"""
  189. if self.vlm_model is None:
  190. raise RuntimeError("VL model not initialized")
  191. try:
  192. # 🔧 预处理图片
  193. image = self._preprocess_image(image)
  194. # 直接调用 content_extract,指定类型为 table
  195. table_content = self.vlm_model.content_extract(
  196. image=image,
  197. type="table"
  198. )
  199. if not table_content:
  200. return {'html': '', 'markdown': '', 'cells': []}
  201. # 解析表格内容(假设返回的是HTML格式)
  202. return {
  203. 'html': table_content,
  204. 'markdown': self._html_to_markdown(table_content),
  205. 'cells': self._extract_cells_from_html(table_content) if kwargs.get('return_cells_coordinate', False) else []
  206. }
  207. except Exception as e:
  208. logger.error(f"❌ Table recognition failed: {e}")
  209. return {'html': '', 'markdown': '', 'cells': []}
  210. def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
  211. """识别公式"""
  212. if self.vlm_model is None:
  213. raise RuntimeError("VL model not initialized")
  214. try:
  215. # 🔧 预处理图片
  216. image = self._preprocess_image(image)
  217. # 直接调用 content_extract,指定类型为 equation
  218. formula_content = self.vlm_model.content_extract(
  219. image=image,
  220. type="equation"
  221. )
  222. if not formula_content:
  223. return {'latex': '', 'confidence': 0.0, 'raw': {}}
  224. # 清理LaTeX格式
  225. latex = self._clean_latex(formula_content)
  226. return {
  227. 'latex': latex,
  228. 'confidence': 0.9 if latex else 0.0,
  229. 'raw': {'raw_output': formula_content}
  230. }
  231. except Exception as e:
  232. logger.error(f"❌ Formula recognition failed: {e}")
  233. return {'latex': '', 'confidence': 0.0, 'raw': {}}
  234. def recognize_text(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
  235. """识别文本区域"""
  236. if self.vlm_model is None:
  237. raise RuntimeError("VL model not initialized")
  238. try:
  239. # 🔧 预处理图片
  240. image = self._preprocess_image(image)
  241. # 直接调用 content_extract,指定类型为 text
  242. text_content = self.vlm_model.content_extract(
  243. image=image,
  244. type="text"
  245. )
  246. return {
  247. 'text': text_content or '',
  248. 'confidence': 0.9 if text_content else 0.0
  249. }
  250. except Exception as e:
  251. print(f"❌ Text recognition failed: {e}")
  252. return {'text': '', 'confidence': 0.0}
  253. def batch_recognize_table(
  254. self,
  255. images: List[Union[np.ndarray, Image.Image]],
  256. **kwargs
  257. ) -> List[Dict[str, Any]]:
  258. """批量表格识别"""
  259. if self.vlm_model is None:
  260. raise RuntimeError("VL model not initialized")
  261. try:
  262. # 🔧 批量预处理图片
  263. pil_images = [self._preprocess_image(img) for img in images]
  264. # 批量调用 batch_content_extract
  265. table_contents = self.vlm_model.batch_content_extract(
  266. images=pil_images,
  267. types="table"
  268. )
  269. # 格式化结果
  270. results = []
  271. for content in table_contents:
  272. if content:
  273. results.append({
  274. 'html': content,
  275. 'markdown': self._html_to_markdown(content),
  276. 'cells': self._extract_cells_from_html(content) if kwargs.get('return_cells_coordinate', False) else []
  277. })
  278. else:
  279. results.append({'html': '', 'markdown': '', 'cells': []})
  280. return results
  281. except Exception as e:
  282. logger.error(f"❌ Batch table recognition failed: {e}")
  283. return [{'html': '', 'markdown': '', 'cells': []} for _ in images]
  284. def batch_recognize_formula(
  285. self,
  286. images: List[Union[np.ndarray, Image.Image]],
  287. **kwargs
  288. ) -> List[Dict[str, Any]]:
  289. """批量公式识别"""
  290. if self.vlm_model is None:
  291. raise RuntimeError("VL model not initialized")
  292. # 转换为PIL图像列表
  293. pil_images = []
  294. for img in images:
  295. if isinstance(img, np.ndarray):
  296. pil_images.append(Image.fromarray(img))
  297. else:
  298. pil_images.append(img)
  299. try:
  300. # 批量调用 batch_content_extract,指定类型为 equation
  301. formula_contents = self.vlm_model.batch_content_extract(
  302. images=pil_images,
  303. types="equation"
  304. )
  305. # 格式化结果
  306. results = []
  307. for content in formula_contents:
  308. latex = self._clean_latex(content) if content else ''
  309. results.append({
  310. 'latex': latex,
  311. 'confidence': 0.9 if latex else 0.0,
  312. 'raw': {'raw_output': content}
  313. })
  314. return results
  315. except Exception as e:
  316. print(f"❌ Batch formula recognition failed: {e}")
  317. return [{'latex': '', 'confidence': 0.0, 'raw': {}} for _ in images]
  318. def _clean_latex(self, raw_latex: str) -> str:
  319. """清理LaTeX格式"""
  320. if not raw_latex:
  321. return ''
  322. # 移除外层的 $$ 或 $
  323. latex = raw_latex.strip()
  324. if latex.startswith('$$') and latex.endswith('$$'):
  325. latex = latex[2:-2].strip()
  326. elif latex.startswith('$') and latex.endswith('$'):
  327. latex = latex[1:-1].strip()
  328. return latex
  329. def _html_to_markdown(self, html: str) -> str:
  330. """将HTML表格转换为Markdown格式"""
  331. if not html:
  332. return ''
  333. return html
  334. # try:
  335. # # 简单的HTML到Markdown转换
  336. # # 实际应用中可以使用 markdownify 库
  337. # import re
  338. # # 移除HTML标签,保留内容
  339. # markdown = re.sub(r'<tr[^>]*>', '\n', html)
  340. # markdown = re.sub(r'</tr>', '', markdown)
  341. # markdown = re.sub(r'<t[dh][^>]*>', '| ', markdown)
  342. # markdown = re.sub(r'</t[dh]>', ' ', markdown)
  343. # markdown = re.sub(r'<[^>]+>', '', markdown)
  344. # return markdown.strip()
  345. # except Exception as e:
  346. # print(f"⚠️ HTML to Markdown conversion failed: {e}")
  347. # return html
  348. def _extract_cells_from_html(self, html: str) -> List[Dict[str, Any]]:
  349. """从HTML中提取单元格信息(简化版本)"""
  350. if not html:
  351. return []
  352. try:
  353. # 这里只是示例,实际需要解析HTML DOM
  354. # 可以使用 BeautifulSoup 等库
  355. cells = []
  356. # TODO: 实现HTML解析逻辑
  357. return cells
  358. except Exception as e:
  359. print(f"⚠️ Cell extraction failed: {e}")
  360. return []
  361. class MinerUOCRRecognizer(BaseOCRRecognizer):
  362. """MinerU OCR识别适配器"""
  363. def __init__(self, config: Dict[str, Any]):
  364. super().__init__(config)
  365. if not MINERU_AVAILABLE:
  366. raise ImportError("MinerU components not available")
  367. self.atom_model_manager = AtomModelSingleton()
  368. self.ocr_model = None
  369. def initialize(self):
  370. """初始化OCR模型"""
  371. try:
  372. # 初始化OCR模型
  373. self.ocr_model = self.atom_model_manager.get_atom_model(
  374. atom_model_name=AtomicModel.OCR,
  375. det_db_box_thresh=self.config.get('det_threshold', 0.3),
  376. lang=self.config.get('language', 'ch'),
  377. det_db_unclip_ratio=self.config.get('unclip_ratio', 1.8),
  378. )
  379. print(f"✅ OCR recognizer initialized: {self.config.get('language', 'ch')}")
  380. except Exception as e:
  381. print(f"❌ Failed to initialize OCR recognizer: {e}")
  382. raise
  383. def cleanup(self):
  384. """清理资源"""
  385. pass
  386. def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
  387. """文本识别"""
  388. if self.ocr_model is None:
  389. raise RuntimeError("OCR model not initialized")
  390. # 转换为BGR格式
  391. if isinstance(image, Image.Image):
  392. image = np.array(image)
  393. bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  394. try:
  395. # OCR识别
  396. ocr_results = self.ocr_model.ocr(bgr_image, rec=True)
  397. # 格式化结果
  398. formatted_results = []
  399. if ocr_results and ocr_results[0]:
  400. for item in ocr_results[0]:
  401. if len(item) >= 2 and len(item[1]) >= 2:
  402. formatted_results.append({
  403. 'bbox': item[0], # 坐标
  404. 'text': item[1][0], # 识别文本
  405. 'confidence': item[1][1] # 置信度
  406. })
  407. return formatted_results
  408. except Exception as e:
  409. print(f"❌ OCR recognition failed: {e}")
  410. return []
  411. # 导出适配器类
  412. __all__ = [
  413. 'MinerUPreprocessor',
  414. 'MinerULayoutDetector',
  415. 'MinerUVLRecognizer',
  416. 'MinerUOCRRecognizer'
  417. ]