| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- from abc import ABC, abstractmethod
- from typing import Dict, Any, List, Union
- import numpy as np
- from PIL import Image
- class BaseAdapter(ABC):
- """基础适配器接口"""
-
- def __init__(self, config: Dict[str, Any]):
- self.config = config
-
- @abstractmethod
- def initialize(self):
- """初始化模型"""
- pass
-
- @abstractmethod
- def cleanup(self):
- """清理资源"""
- pass
- class BasePreprocessor(BaseAdapter):
- """预处理器基类"""
-
- @abstractmethod
- def process(self, image: Union[np.ndarray, Image.Image]) -> tuple[np.ndarray, int]:
- """
- 处理图像
- 返回处理后的图像和旋转角度
- """
- pass
-
- def _apply_rotation(self, image: np.ndarray, rotation_label: int) -> np.ndarray:
- """应用旋转"""
- import cv2
- if rotation_label == 1: # 90度
- return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
- elif rotation_label == 2: # 180度
- return cv2.rotate(image, cv2.ROTATE_180)
- elif rotation_label == 3: # 270度
- return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
- return image
- class BaseLayoutDetector(BaseAdapter):
- """版式检测器基类"""
-
- @abstractmethod
- def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
- """检测版式"""
- pass
-
- def _map_category_id(self, category_id: int) -> str:
- """映射类别ID到字符串"""
- category_map = {
- 0: 'title',
- 1: 'text',
- 2: 'abandon',
- 3: 'image_body',
- 4: 'image_caption',
- 5: 'table_body',
- 6: 'table_caption',
- 7: 'table_footnote',
- 8: 'interline_equation',
- 9: 'interline_equation_number',
- 13: 'inline_equation',
- 14: 'interline_equation_yolo',
- 15: 'ocr_text',
- 16: 'low_score_text',
- 101: 'image_footnote'
- }
- return category_map.get(category_id, f'unknown_{category_id}')
- class BaseVLRecognizer(BaseAdapter):
- """VL识别器基类"""
-
- @abstractmethod
- def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
- """识别表格"""
- pass
-
- @abstractmethod
- def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
- """识别公式"""
- pass
- class BaseOCRRecognizer(BaseAdapter):
- """OCR识别器基类"""
-
- @abstractmethod
- def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
- """识别文本"""
- pass
|