| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512 |
- import sys
- from pathlib import Path
- from typing import Dict, Any, List, Union, Optional
- import numpy as np
- import cv2
- from PIL import Image
- from loguru import logger
- # 添加MinerU路径
- mineru_path = Path(__file__).parents[4] / "mineru"
- if str(mineru_path) not in sys.path:
- sys.path.insert(0, str(mineru_path))
- from .base import BasePreprocessor, BaseLayoutDetector, BaseVLRecognizer, BaseOCRRecognizer
- # 导入MinerU组件
- try:
- from mineru.backend.pipeline.model_init import AtomModelSingleton
- from mineru.backend.vlm.vlm_analyze import ModelSingleton as VLMModelSingleton
- from mineru.backend.pipeline.model_list import AtomicModel
- from mineru.utils.config_reader import get_device
- MINERU_AVAILABLE = True
- except ImportError as e:
- print(f"Warning: MinerU components not available: {e}")
- MINERU_AVAILABLE = False
- class MinerUPreprocessor(BasePreprocessor):
- """MinerU预处理器适配器"""
-
- def __init__(self, config: Dict[str, Any]):
- super().__init__(config)
- if not MINERU_AVAILABLE:
- raise ImportError("MinerU components not available")
-
- self.atom_model_manager = AtomModelSingleton()
- self.orientation_classifier = None
-
- def initialize(self):
- """初始化预处理组件"""
- # 初始化方向分类器
- if self.config.get('orientation_classifier', {}).get('enabled', True):
- try:
- self.orientation_classifier = self.atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.ImgOrientationCls,
- )
- print("✅ Orientation classifier initialized")
- except Exception as e:
- print(f"⚠️ Failed to initialize orientation classifier: {e}")
-
- def cleanup(self):
- """清理资源"""
- pass
- def process(self, image: Union[np.ndarray, Image.Image]) -> tuple[np.ndarray, int]:
- """图像预处理"""
- # 转换为numpy数组
- if isinstance(image, Image.Image):
- image = np.array(image)
- rotate_map = {0: 0, 1: 90, 2: 180, 3: 270}
- rotate_label = 0
- processed_image = image
-
- # 方向校正
- if self.orientation_classifier is not None:
- try:
- rotate_label = self.orientation_classifier.predict(image)
- processed_image = self._apply_rotation(processed_image, rotate_label)
- logger.info(f"📐 Applied rotation: {rotate_label}")
- except Exception as e:
- logger.error(f"⚠️ Orientation classification failed: {e}")
- return processed_image, rotate_map.get(rotate_label, 0)
- class MinerULayoutDetector(BaseLayoutDetector):
- """MinerU版式检测适配器"""
-
- def __init__(self, config: Dict[str, Any]):
- super().__init__(config)
- if not MINERU_AVAILABLE:
- raise ImportError("MinerU components not available")
-
- self.atom_model_manager = AtomModelSingleton()
- self.layout_model = None
-
- def initialize(self):
- """初始化版式检测模型"""
- try:
- # 获取模型配置
- model_name = self.config.get('model_name', 'RT-DETR-H_layout_17cls')
- model_dir = self.config.get('model_dir')
- device = self.config.get('device', 'cpu')
-
- # 初始化版式检测模型
- if model_dir:
- # 使用自定义模型路径
- self.layout_model = self.atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.Layout,
- doclayout_yolo_weights=model_dir,
- device=device
- )
- else:
- # 使用默认模型
- self.layout_model = self.atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.Layout,
- device=device
- )
- print(f"✅ Layout detector initialized: {model_name}")
-
- except Exception as e:
- print(f"❌ Failed to initialize layout detector: {e}")
- raise
-
- def cleanup(self):
- """清理资源"""
- pass
-
- def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
- """版式检测"""
- if self.layout_model is None:
- raise RuntimeError("Layout model not initialized")
-
- # 转换为PIL图像
- if isinstance(image, np.ndarray):
- image = Image.fromarray(image)
-
- # 进行检测
- try:
- layout_results = self.layout_model.predict([image])
-
- # 转换结果格式
- formatted_results = []
- for result in layout_results[0]: # 第一页结果
- # 提取坐标信息
- poly = result.get('poly', [0, 0, 0, 0, 0, 0, 0, 0])
- if len(poly) >= 8:
- # 转换8点坐标为4点坐标 [x1,y1,x2,y2]
- bbox = [poly[0], poly[1], poly[4], poly[5]]
- else:
- bbox = poly[:4] if len(poly) >= 4 else [0, 0, 0, 0]
-
- formatted_results.append({
- 'category': self._map_category_id(result.get('category_id', 1)),
- 'bbox': bbox,
- 'confidence': result.get('score', 0.0),
- 'raw': result
- })
-
- return formatted_results
-
- except Exception as e:
- print(f"❌ Layout detection failed: {e}")
- return []
- class MinerUVLRecognizer(BaseVLRecognizer):
- """MinerU VL识别适配器"""
-
- def __init__(self, config: Dict[str, Any]):
- super().__init__(config)
- if not MINERU_AVAILABLE:
- raise ImportError("MinerU components not available")
-
- self.vlm_model = None
- # 🔧 添加图片尺寸限制配置
- self.max_image_size = config.get('max_image_size', 1568) # VLM 模型的最大尺寸
- self.resize_mode = config.get('resize_mode', 'max') # 'max' or 'fixed'
-
- def initialize(self):
- """初始化VL模型"""
- try:
- backend = self.config.get('backend', 'http-client')
- server_url = self.config.get('server_url')
- model_params = self.config.get('model_params', {})
-
- # 初始化VLM模型
- self.vlm_model = VLMModelSingleton().get_model(
- backend=backend,
- model_path=None,
- server_url=server_url,
- **model_params
- )
- print(f"✅ VL recognizer initialized: {backend}")
-
- except Exception as e:
- print(f"❌ Failed to initialize VL recognizer: {e}")
- raise
-
- def cleanup(self):
- """清理资源"""
- pass
-
- def _preprocess_image(self, image: Union[np.ndarray, Image.Image]) -> Image.Image:
- """
- 预处理图片,控制尺寸避免序列长度超限
-
- Args:
- image: 输入图片
-
- Returns:
- 处理后的PIL图片
- """
- # 转换为PIL图像
- if isinstance(image, np.ndarray):
- image = Image.fromarray(image)
-
- # 获取原始尺寸
- orig_w, orig_h = image.size
-
- # 计算缩放比例
- if self.resize_mode == 'max':
- # 保持宽高比,最长边不超过 max_image_size
- max_dim = max(orig_w, orig_h)
- if max_dim > self.max_image_size:
- scale = self.max_image_size / max_dim
- new_w = int(orig_w * scale)
- new_h = int(orig_h * scale)
-
- logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {new_w}x{new_h} (scale={scale:.3f})")
- image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
-
- elif self.resize_mode == 'fixed':
- # 固定尺寸(可能改变宽高比)
- if orig_w != self.max_image_size or orig_h != self.max_image_size:
- logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {self.max_image_size}x{self.max_image_size}")
- image = image.resize((self.max_image_size, self.max_image_size), Image.Resampling.LANCZOS)
-
- return image
- def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
- """表格识别"""
- if self.vlm_model is None:
- raise RuntimeError("VL model not initialized")
-
- try:
- # 🔧 预处理图片
- image = self._preprocess_image(image)
-
- # 直接调用 content_extract,指定类型为 table
- table_content = self.vlm_model.content_extract(
- image=image,
- type="table"
- )
-
- if not table_content:
- return {'html': '', 'markdown': '', 'cells': []}
-
- # 解析表格内容(假设返回的是HTML格式)
- return {
- 'html': table_content,
- 'markdown': self._html_to_markdown(table_content),
- 'cells': self._extract_cells_from_html(table_content) if kwargs.get('return_cells_coordinate', False) else []
- }
-
- except Exception as e:
- logger.error(f"❌ Table recognition failed: {e}")
- return {'html': '', 'markdown': '', 'cells': []}
-
- def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
- """识别公式"""
- if self.vlm_model is None:
- raise RuntimeError("VL model not initialized")
-
- try:
- # 🔧 预处理图片
- image = self._preprocess_image(image)
-
- # 直接调用 content_extract,指定类型为 equation
- formula_content = self.vlm_model.content_extract(
- image=image,
- type="equation"
- )
-
- if not formula_content:
- return {'latex': '', 'confidence': 0.0, 'raw': {}}
-
- # 清理LaTeX格式
- latex = self._clean_latex(formula_content)
-
- return {
- 'latex': latex,
- 'confidence': 0.9 if latex else 0.0,
- 'raw': {'raw_output': formula_content}
- }
-
- except Exception as e:
- logger.error(f"❌ Formula recognition failed: {e}")
- return {'latex': '', 'confidence': 0.0, 'raw': {}}
-
- def recognize_text(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
- """识别文本区域"""
- if self.vlm_model is None:
- raise RuntimeError("VL model not initialized")
-
- try:
- # 🔧 预处理图片
- image = self._preprocess_image(image)
- # 直接调用 content_extract,指定类型为 text
- text_content = self.vlm_model.content_extract(
- image=image,
- type="text"
- )
-
- return {
- 'text': text_content or '',
- 'confidence': 0.9 if text_content else 0.0
- }
-
- except Exception as e:
- print(f"❌ Text recognition failed: {e}")
- return {'text': '', 'confidence': 0.0}
-
- def batch_recognize_table(
- self,
- images: List[Union[np.ndarray, Image.Image]],
- **kwargs
- ) -> List[Dict[str, Any]]:
- """批量表格识别"""
- if self.vlm_model is None:
- raise RuntimeError("VL model not initialized")
-
- try:
- # 🔧 批量预处理图片
- pil_images = [self._preprocess_image(img) for img in images]
-
- # 批量调用 batch_content_extract
- table_contents = self.vlm_model.batch_content_extract(
- images=pil_images,
- types="table"
- )
-
- # 格式化结果
- results = []
- for content in table_contents:
- if content:
- results.append({
- 'html': content,
- 'markdown': self._html_to_markdown(content),
- 'cells': self._extract_cells_from_html(content) if kwargs.get('return_cells_coordinate', False) else []
- })
- else:
- results.append({'html': '', 'markdown': '', 'cells': []})
-
- return results
-
- except Exception as e:
- logger.error(f"❌ Batch table recognition failed: {e}")
- return [{'html': '', 'markdown': '', 'cells': []} for _ in images]
-
- def batch_recognize_formula(
- self,
- images: List[Union[np.ndarray, Image.Image]],
- **kwargs
- ) -> List[Dict[str, Any]]:
- """批量公式识别"""
- if self.vlm_model is None:
- raise RuntimeError("VL model not initialized")
-
- # 转换为PIL图像列表
- pil_images = []
- for img in images:
- if isinstance(img, np.ndarray):
- pil_images.append(Image.fromarray(img))
- else:
- pil_images.append(img)
-
- try:
- # 批量调用 batch_content_extract,指定类型为 equation
- formula_contents = self.vlm_model.batch_content_extract(
- images=pil_images,
- types="equation"
- )
-
- # 格式化结果
- results = []
- for content in formula_contents:
- latex = self._clean_latex(content) if content else ''
- results.append({
- 'latex': latex,
- 'confidence': 0.9 if latex else 0.0,
- 'raw': {'raw_output': content}
- })
-
- return results
-
- except Exception as e:
- print(f"❌ Batch formula recognition failed: {e}")
- return [{'latex': '', 'confidence': 0.0, 'raw': {}} for _ in images]
-
- def _clean_latex(self, raw_latex: str) -> str:
- """清理LaTeX格式"""
- if not raw_latex:
- return ''
-
- # 移除外层的 $$ 或 $
- latex = raw_latex.strip()
- if latex.startswith('$$') and latex.endswith('$$'):
- latex = latex[2:-2].strip()
- elif latex.startswith('$') and latex.endswith('$'):
- latex = latex[1:-1].strip()
-
- return latex
-
- def _html_to_markdown(self, html: str) -> str:
- """将HTML表格转换为Markdown格式"""
- if not html:
- return ''
-
- return html
- # try:
- # # 简单的HTML到Markdown转换
- # # 实际应用中可以使用 markdownify 库
- # import re
-
- # # 移除HTML标签,保留内容
- # markdown = re.sub(r'<tr[^>]*>', '\n', html)
- # markdown = re.sub(r'</tr>', '', markdown)
- # markdown = re.sub(r'<t[dh][^>]*>', '| ', markdown)
- # markdown = re.sub(r'</t[dh]>', ' ', markdown)
- # markdown = re.sub(r'<[^>]+>', '', markdown)
-
- # return markdown.strip()
-
- # except Exception as e:
- # print(f"⚠️ HTML to Markdown conversion failed: {e}")
- # return html
-
- def _extract_cells_from_html(self, html: str) -> List[Dict[str, Any]]:
- """从HTML中提取单元格信息(简化版本)"""
- if not html:
- return []
-
- try:
- # 这里只是示例,实际需要解析HTML DOM
- # 可以使用 BeautifulSoup 等库
- cells = []
- # TODO: 实现HTML解析逻辑
- return cells
-
- except Exception as e:
- print(f"⚠️ Cell extraction failed: {e}")
- return []
- class MinerUOCRRecognizer(BaseOCRRecognizer):
- """MinerU OCR识别适配器"""
-
- def __init__(self, config: Dict[str, Any]):
- super().__init__(config)
- if not MINERU_AVAILABLE:
- raise ImportError("MinerU components not available")
-
- self.atom_model_manager = AtomModelSingleton()
- self.ocr_model = None
-
- def initialize(self):
- """初始化OCR模型"""
- try:
- # 初始化OCR模型
- self.ocr_model = self.atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.OCR,
- det_db_box_thresh=self.config.get('det_threshold', 0.3),
- lang=self.config.get('language', 'ch'),
- det_db_unclip_ratio=self.config.get('unclip_ratio', 1.8),
- )
- print(f"✅ OCR recognizer initialized: {self.config.get('language', 'ch')}")
-
- except Exception as e:
- print(f"❌ Failed to initialize OCR recognizer: {e}")
- raise
-
- def cleanup(self):
- """清理资源"""
- pass
-
- def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
- """文本识别"""
- if self.ocr_model is None:
- raise RuntimeError("OCR model not initialized")
-
- # 转换为BGR格式
- if isinstance(image, Image.Image):
- image = np.array(image)
- bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
-
- try:
- # OCR识别
- ocr_results = self.ocr_model.ocr(bgr_image, rec=True)
-
- # 格式化结果
- formatted_results = []
- if ocr_results and ocr_results[0]:
- for item in ocr_results[0]:
- if len(item) >= 2 and len(item[1]) >= 2:
- formatted_results.append({
- 'bbox': item[0], # 坐标
- 'text': item[1][0], # 识别文本
- 'confidence': item[1][1] # 置信度
- })
-
- return formatted_results
-
- except Exception as e:
- print(f"❌ OCR recognition failed: {e}")
- return []
- # 导出适配器类
- __all__ = [
- 'MinerUPreprocessor',
- 'MinerULayoutDetector',
- 'MinerUVLRecognizer',
- 'MinerUOCRRecognizer'
- ]
|