import sys from pathlib import Path from typing import Dict, Any, List, Union, Optional import numpy as np import cv2 from PIL import Image from loguru import logger # 添加MinerU路径 mineru_path = Path(__file__).parents[4] / "mineru" if str(mineru_path) not in sys.path: sys.path.insert(0, str(mineru_path)) from .base import BasePreprocessor, BaseLayoutDetector, BaseVLRecognizer, BaseOCRRecognizer # 导入MinerU组件 try: from mineru.backend.pipeline.model_init import AtomModelSingleton from mineru.backend.vlm.vlm_analyze import ModelSingleton as VLMModelSingleton from mineru.backend.pipeline.model_list import AtomicModel from mineru.utils.config_reader import get_device MINERU_AVAILABLE = True except ImportError as e: print(f"Warning: MinerU components not available: {e}") MINERU_AVAILABLE = False class MinerUPreprocessor(BasePreprocessor): """MinerU预处理器适配器""" def __init__(self, config: Dict[str, Any]): super().__init__(config) if not MINERU_AVAILABLE: raise ImportError("MinerU components not available") self.atom_model_manager = AtomModelSingleton() self.orientation_classifier = None def initialize(self): """初始化预处理组件""" # 初始化方向分类器 if self.config.get('orientation_classifier', {}).get('enabled', True): try: self.orientation_classifier = self.atom_model_manager.get_atom_model( atom_model_name=AtomicModel.ImgOrientationCls, ) print("✅ Orientation classifier initialized") except Exception as e: print(f"⚠️ Failed to initialize orientation classifier: {e}") def cleanup(self): """清理资源""" pass def process(self, image: Union[np.ndarray, Image.Image]) -> tuple[np.ndarray, int]: """图像预处理""" # 转换为numpy数组 if isinstance(image, Image.Image): image = np.array(image) rotate_angle = 0 processed_image = image # 方向校正 if self.orientation_classifier is not None: try: rotate_angle = int(self.orientation_classifier.predict(image)) processed_image = self._apply_rotation(processed_image, rotate_angle) logger.info(f"📐 Applied rotation: {rotate_angle}") except Exception as e: logger.error(f"⚠️ Orientation classification failed: {e}") return processed_image, rotate_angle class MinerULayoutDetector(BaseLayoutDetector): """MinerU版式检测适配器""" def __init__(self, config: Dict[str, Any]): super().__init__(config) if not MINERU_AVAILABLE: raise ImportError("MinerU components not available") self.atom_model_manager = AtomModelSingleton() self.layout_model = None def initialize(self): """初始化版式检测模型""" try: # 获取模型配置 model_name = self.config.get('model_name', AtomicModel.Layout) model_dir = self.config.get('model_dir') device = self.config.get('device', 'cpu') # 初始化版式检测模型 if model_dir: # 使用自定义模型路径 self.layout_model = self.atom_model_manager.get_atom_model( atom_model_name=AtomicModel.Layout, doclayout_yolo_weights=model_dir, device=device ) else: import os from mineru.utils.enum_class import ModelPath from mineru.utils.models_download_utils import auto_download_and_get_model_root_path self.layout_model = self.atom_model_manager.get_atom_model( atom_model_name=AtomicModel.Layout, doclayout_yolo_weights=os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo), device=device ) print(f"✅ Layout detector initialized: {model_name}") except Exception as e: print(f"❌ Failed to initialize layout detector: {e}") raise def cleanup(self): """清理资源""" pass def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]: """版式检测""" if self.layout_model is None: raise RuntimeError("Layout model not initialized") # 转换为PIL图像 if isinstance(image, np.ndarray): image = Image.fromarray(image) # 进行检测 try: layout_results = self.layout_model.predict([image]) # 转换结果格式 formatted_results = [] for result in layout_results: # 第一页结果 # 提取坐标信息 poly = result.get('poly', [0, 0, 0, 0, 0, 0, 0, 0]) if len(poly) >= 8: # 转换8点坐标为4点坐标 [x1,y1,x2,y2] bbox = [poly[0], poly[1], poly[4], poly[5]] else: bbox = poly[:4] if len(poly) >= 4 else [0, 0, 0, 0] formatted_results.append({ 'category': self._map_category_id(result.get('category_id', 1)), 'bbox': bbox, 'confidence': result.get('score', 0.0), 'raw': result }) return formatted_results except Exception as e: print(f"❌ Layout detection failed: {e}") return [] class MinerUVLRecognizer(BaseVLRecognizer): """MinerU VL识别适配器""" def __init__(self, config: Dict[str, Any]): super().__init__(config) if not MINERU_AVAILABLE: raise ImportError("MinerU components not available") self.vlm_model = None # 🔧 添加图片尺寸限制配置 self.max_image_size = config.get('max_image_size', 1568) # VLM 模型的最大尺寸 self.resize_mode = config.get('resize_mode', 'max') # 'max' or 'fixed' def initialize(self): """初始化VL模型""" try: backend = self.config.get('backend', 'http-client') server_url = self.config.get('server_url') model_params = self.config.get('model_params', {}) # 初始化VLM模型 self.vlm_model = VLMModelSingleton().get_model( backend=backend, model_path=None, server_url=server_url, **model_params ) print(f"✅ VL recognizer initialized: {backend}") except Exception as e: print(f"❌ Failed to initialize VL recognizer: {e}") raise def cleanup(self): """清理资源""" pass def _preprocess_image(self, image: Union[np.ndarray, Image.Image]) -> Image.Image: """ 预处理图片,控制尺寸避免序列长度超限 Args: image: 输入图片 Returns: 处理后的PIL图片 """ # 转换为PIL图像 if isinstance(image, np.ndarray): image = Image.fromarray(image) # 获取原始尺寸 orig_w, orig_h = image.size # 计算缩放比例 if self.resize_mode == 'max': # 保持宽高比,最长边不超过 max_image_size max_dim = max(orig_w, orig_h) if max_dim > self.max_image_size: scale = self.max_image_size / max_dim new_w = int(orig_w * scale) new_h = int(orig_h * scale) logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {new_w}x{new_h} (scale={scale:.3f})") image = image.resize((new_w, new_h), Image.Resampling.LANCZOS) elif self.resize_mode == 'fixed': # 固定尺寸(可能改变宽高比) if orig_w != self.max_image_size or orig_h != self.max_image_size: logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {self.max_image_size}x{self.max_image_size}") image = image.resize((self.max_image_size, self.max_image_size), Image.Resampling.LANCZOS) return image def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]: """表格识别""" if self.vlm_model is None: raise RuntimeError("VL model not initialized") try: # 🔧 预处理图片 image = self._preprocess_image(image) # 直接调用 content_extract,指定类型为 table table_content = self.vlm_model.content_extract( image=image, type="table" ) if not table_content: return {'html': '', 'markdown': '', 'cells': []} # 解析表格内容(假设返回的是HTML格式) return { 'html': table_content, 'markdown': self._html_to_markdown(table_content), 'cells': self._extract_cells_from_html(table_content) if kwargs.get('return_cells_coordinate', False) else [] } except Exception as e: logger.error(f"❌ Table recognition failed: {e}") return {'html': '', 'markdown': '', 'cells': []} def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]: """识别公式""" if self.vlm_model is None: raise RuntimeError("VL model not initialized") try: # 🔧 预处理图片 image = self._preprocess_image(image) # 直接调用 content_extract,指定类型为 equation formula_content = self.vlm_model.content_extract( image=image, type="equation" ) if not formula_content: return {'latex': '', 'confidence': 0.0, 'raw': {}} # 清理LaTeX格式 latex = self._clean_latex(formula_content) return { 'latex': latex, 'confidence': 0.9 if latex else 0.0, 'raw': {'raw_output': formula_content} } except Exception as e: logger.error(f"❌ Formula recognition failed: {e}") return {'latex': '', 'confidence': 0.0, 'raw': {}} def recognize_text(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]: """识别文本区域""" if self.vlm_model is None: raise RuntimeError("VL model not initialized") try: # 🔧 预处理图片 image = self._preprocess_image(image) # 直接调用 content_extract,指定类型为 text text_content = self.vlm_model.content_extract( image=image, type="text" ) return { 'text': text_content or '', 'confidence': 0.9 if text_content else 0.0 } except Exception as e: print(f"❌ Text recognition failed: {e}") return {'text': '', 'confidence': 0.0} def batch_recognize_table( self, images: List[Union[np.ndarray, Image.Image]], **kwargs ) -> List[Dict[str, Any]]: """批量表格识别""" if self.vlm_model is None: raise RuntimeError("VL model not initialized") try: # 🔧 批量预处理图片 pil_images = [self._preprocess_image(img) for img in images] # 批量调用 batch_content_extract table_contents = self.vlm_model.batch_content_extract( images=pil_images, types="table" ) # 格式化结果 results = [] for content in table_contents: if content: results.append({ 'html': content, 'markdown': self._html_to_markdown(content), 'cells': self._extract_cells_from_html(content) if kwargs.get('return_cells_coordinate', False) else [] }) else: results.append({'html': '', 'markdown': '', 'cells': []}) return results except Exception as e: logger.error(f"❌ Batch table recognition failed: {e}") return [{'html': '', 'markdown': '', 'cells': []} for _ in images] def batch_recognize_formula( self, images: List[Union[np.ndarray, Image.Image]], **kwargs ) -> List[Dict[str, Any]]: """批量公式识别""" if self.vlm_model is None: raise RuntimeError("VL model not initialized") # 转换为PIL图像列表 pil_images = [] for img in images: if isinstance(img, np.ndarray): pil_images.append(Image.fromarray(img)) else: pil_images.append(img) try: # 批量调用 batch_content_extract,指定类型为 equation formula_contents = self.vlm_model.batch_content_extract( images=pil_images, types="equation" ) # 格式化结果 results = [] for content in formula_contents: latex = self._clean_latex(content) if content else '' results.append({ 'latex': latex, 'confidence': 0.9 if latex else 0.0, 'raw': {'raw_output': content} }) return results except Exception as e: print(f"❌ Batch formula recognition failed: {e}") return [{'latex': '', 'confidence': 0.0, 'raw': {}} for _ in images] def _clean_latex(self, raw_latex: str) -> str: """清理LaTeX格式""" if not raw_latex: return '' # 移除外层的 $$ 或 $ latex = raw_latex.strip() if latex.startswith('$$') and latex.endswith('$$'): latex = latex[2:-2].strip() elif latex.startswith('$') and latex.endswith('$'): latex = latex[1:-1].strip() return latex def _html_to_markdown(self, html: str) -> str: """将HTML表格转换为Markdown格式""" if not html: return '' return html # try: # # 简单的HTML到Markdown转换 # # 实际应用中可以使用 markdownify 库 # import re # # 移除HTML标签,保留内容 # markdown = re.sub(r']*>', '\n', html) # markdown = re.sub(r'', '', markdown) # markdown = re.sub(r']*>', '| ', markdown) # markdown = re.sub(r'', ' ', markdown) # markdown = re.sub(r'<[^>]+>', '', markdown) # return markdown.strip() # except Exception as e: # print(f"⚠️ HTML to Markdown conversion failed: {e}") # return html def _extract_cells_from_html(self, html: str) -> List[Dict[str, Any]]: """从HTML中提取单元格信息(简化版本)""" if not html: return [] try: # 这里只是示例,实际需要解析HTML DOM # 可以使用 BeautifulSoup 等库 cells = [] # TODO: 实现HTML解析逻辑 return cells except Exception as e: print(f"⚠️ Cell extraction failed: {e}") return [] class MinerUOCRRecognizer(BaseOCRRecognizer): """MinerU OCR识别适配器""" def __init__(self, config: Dict[str, Any]): super().__init__(config) if not MINERU_AVAILABLE: raise ImportError("MinerU components not available") self.atom_model_manager = AtomModelSingleton() self.ocr_model = None def initialize(self): """初始化OCR模型""" try: # 初始化OCR模型 self.ocr_model = self.atom_model_manager.get_atom_model( atom_model_name=AtomicModel.OCR, det_db_box_thresh=self.config.get('det_threshold', 0.3), lang=self.config.get('language', 'ch'), det_db_unclip_ratio=self.config.get('unclip_ratio', 1.8), ) print(f"✅ OCR recognizer initialized: {self.config.get('language', 'ch')}") except Exception as e: print(f"❌ Failed to initialize OCR recognizer: {e}") raise def cleanup(self): """清理资源""" pass def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]: """文本识别""" if self.ocr_model is None: raise RuntimeError("OCR model not initialized") # 转换为BGR格式 if isinstance(image, Image.Image): image = np.array(image) bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) try: # OCR识别 ocr_results = self.ocr_model.ocr(bgr_image, rec=True) # 格式化结果 formatted_results = [] if ocr_results and ocr_results[0]: for item in ocr_results[0]: if len(item) >= 2 and len(item[1]) >= 2: formatted_results.append({ 'bbox': item[0], # 坐标 'text': item[1][0], # 识别文本 'confidence': item[1][1] # 置信度 }) return formatted_results except Exception as e: print(f"❌ OCR recognition failed: {e}") return [] # 导出适配器类 __all__ = [ 'MinerUPreprocessor', 'MinerULayoutDetector', 'MinerUVLRecognizer', 'MinerUOCRRecognizer' ]