base.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, Any, List, Union
  3. import numpy as np
  4. from PIL import Image
  5. class BaseAdapter(ABC):
  6. """基础适配器接口"""
  7. def __init__(self, config: Dict[str, Any]):
  8. self.config = config
  9. @abstractmethod
  10. def initialize(self):
  11. """初始化模型"""
  12. pass
  13. @abstractmethod
  14. def cleanup(self):
  15. """清理资源"""
  16. pass
  17. class BasePreprocessor(BaseAdapter):
  18. """预处理器基类"""
  19. @abstractmethod
  20. def process(self, image: Union[np.ndarray, Image.Image]) -> tuple[np.ndarray, int]:
  21. """
  22. 处理图像
  23. 返回处理后的图像和旋转角度
  24. """
  25. pass
  26. def _apply_rotation(self, image: np.ndarray, rotation_label: int) -> np.ndarray:
  27. """应用旋转"""
  28. import cv2
  29. if rotation_label == 1: # 90度
  30. return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
  31. elif rotation_label == 2: # 180度
  32. return cv2.rotate(image, cv2.ROTATE_180)
  33. elif rotation_label == 3: # 270度
  34. return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  35. return image
  36. class BaseLayoutDetector(BaseAdapter):
  37. """版式检测器基类"""
  38. @abstractmethod
  39. def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
  40. """检测版式"""
  41. pass
  42. def _map_category_id(self, category_id: int) -> str:
  43. """映射类别ID到字符串"""
  44. category_map = {
  45. 0: 'title',
  46. 1: 'text',
  47. 2: 'abandon',
  48. 3: 'image_body',
  49. 4: 'image_caption',
  50. 5: 'table_body',
  51. 6: 'table_caption',
  52. 7: 'table_footnote',
  53. 8: 'interline_equation',
  54. 9: 'interline_equation_number',
  55. 13: 'inline_equation',
  56. 14: 'interline_equation_yolo',
  57. 15: 'ocr_text',
  58. 16: 'low_score_text',
  59. 101: 'image_footnote'
  60. }
  61. return category_map.get(category_id, f'unknown_{category_id}')
  62. class BaseVLRecognizer(BaseAdapter):
  63. """VL识别器基类"""
  64. @abstractmethod
  65. def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
  66. """识别表格"""
  67. pass
  68. @abstractmethod
  69. def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
  70. """识别公式"""
  71. pass
  72. class BaseOCRRecognizer(BaseAdapter):
  73. """OCR识别器基类"""
  74. @abstractmethod
  75. def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
  76. """识别文本"""
  77. pass