base.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, Any, List, Union, Optional, Tuple
  3. import numpy as np
  4. from PIL import Image
  5. from loguru import logger
  6. from pathlib import Path
  7. import cv2
  8. import json
  9. class BaseAdapter(ABC):
  10. """基础适配器接口"""
  11. def __init__(self, config: Dict[str, Any]):
  12. self.config = config
  13. @abstractmethod
  14. def initialize(self):
  15. """初始化模型"""
  16. pass
  17. @abstractmethod
  18. def cleanup(self):
  19. """清理资源"""
  20. pass
  21. class BasePreprocessor(BaseAdapter):
  22. """预处理器基类"""
  23. def remove_watermark(self, image: Union[np.ndarray, Image.Image]) -> np.ndarray:
  24. """页级水印去除(默认无操作,子类可覆盖)。"""
  25. if isinstance(image, Image.Image):
  26. return np.array(image)
  27. return image
  28. @abstractmethod
  29. def process(
  30. self,
  31. image: Union[np.ndarray, Image.Image],
  32. skip_watermark: bool = False,
  33. ) -> tuple[np.ndarray, int]:
  34. """
  35. 处理图像
  36. 返回处理后的图像和旋转角度
  37. """
  38. pass
  39. def _apply_rotation(self, image: np.ndarray, rotation_angle: int) -> np.ndarray:
  40. """应用旋转"""
  41. import cv2
  42. if rotation_angle == 90: # 90度
  43. return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  44. elif rotation_angle == 180: # 180度
  45. return cv2.rotate(image, cv2.ROTATE_180)
  46. elif rotation_angle == 270: # 270度
  47. return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
  48. return image
  49. class BaseLayoutDetector(BaseAdapter):
  50. """版式检测器基类"""
  51. def __init__(self, config: Dict[str, Any]):
  52. """初始化版式检测器
  53. Args:
  54. config: 配置字典
  55. """
  56. super().__init__(config)
  57. # 初始化 debug 相关属性(支持从配置或运行时设置)
  58. self.debug_mode = None # 将在 detect() 方法中从配置读取
  59. self.output_dir = None # 将在 detect() 方法中从配置读取
  60. self.page_name = None # 将在 detect() 方法中从配置读取
  61. def detect(
  62. self,
  63. image: Union[np.ndarray, Image.Image],
  64. ocr_spans: Optional[List[Dict[str, Any]]] = None
  65. ) -> List[Dict[str, Any]]:
  66. """
  67. 检测版式(模板方法,自动执行后处理)
  68. 此方法会:
  69. 1. 调用子类实现的 _detect_raw() 进行原始检测
  70. 2. 自动执行后处理(去除重叠框、文本转表格等)
  71. Args:
  72. image: 输入图像
  73. ocr_spans: OCR结果(可选,某些detector可能需要)
  74. Returns:
  75. 后处理后的布局检测结果
  76. """
  77. # 调用子类实现的原始检测方法
  78. layout_results = self._detect_raw(image, ocr_spans)
  79. # Debug 模式:打印和可视化后处理前的检测结果
  80. # 优先从实例属性读取(如果存在),否则从配置读取
  81. # 支持两种配置方式:debug_mode 或 debug_options.enabled
  82. debug_mode = getattr(self, 'debug_mode', None)
  83. if debug_mode is None:
  84. if hasattr(self, 'config'):
  85. # 优先从 debug_mode 读取
  86. debug_mode = self.config.get('debug_mode', False)
  87. # 如果没有 debug_mode,尝试从 debug_options.enabled 读取
  88. if not debug_mode:
  89. debug_options = self.config.get('debug_options', {})
  90. if isinstance(debug_options, dict):
  91. debug_mode = debug_options.get('enabled', False)
  92. else:
  93. debug_mode = False
  94. if debug_mode:
  95. logger.debug(f"🔍 Layout detection raw results (before post-processing): {len(layout_results)} elements")
  96. # logger.debug(f"Raw layout_results: {layout_results}")
  97. # 可视化 layout 结果
  98. output_dir = getattr(self, 'output_dir', None)
  99. if output_dir is None:
  100. if hasattr(self, 'config'):
  101. # 优先从 output_dir 读取
  102. output_dir = self.config.get('output_dir', None)
  103. # 如果没有 output_dir,尝试从 debug_options.output_dir 读取
  104. if output_dir is None:
  105. debug_options = self.config.get('debug_options', {})
  106. if isinstance(debug_options, dict):
  107. output_dir = debug_options.get('output_dir', None)
  108. else:
  109. output_dir = None
  110. page_name = getattr(self, 'page_name', None)
  111. if page_name is None:
  112. if hasattr(self, 'config'):
  113. # 优先从 page_name 读取
  114. page_name = self.config.get('page_name', None)
  115. # 如果没有 page_name,尝试从 debug_options.prefix 读取
  116. if page_name is None:
  117. debug_options = self.config.get('debug_options', {})
  118. if isinstance(debug_options, dict):
  119. prefix = debug_options.get('prefix', '')
  120. page_name = prefix if prefix else 'layout_detection'
  121. if page_name is None:
  122. page_name = 'layout_detection'
  123. else:
  124. page_name = 'layout_detection'
  125. if output_dir:
  126. self._visualize_layout_results(image, layout_results, output_dir, page_name, suffix='raw')
  127. # 自动执行后处理
  128. if layout_results:
  129. layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
  130. layout_results = self.post_process(layout_results, image, layout_config)
  131. return layout_results
  132. @abstractmethod
  133. def _detect_raw(
  134. self,
  135. image: Union[np.ndarray, Image.Image],
  136. ocr_spans: Optional[List[Dict[str, Any]]] = None
  137. ) -> List[Dict[str, Any]]:
  138. """
  139. 原始检测方法(子类必须实现)
  140. Args:
  141. image: 输入图像
  142. ocr_spans: OCR结果(可选)
  143. Returns:
  144. 原始检测结果(未后处理)
  145. """
  146. pass
  147. def post_process(
  148. self,
  149. layout_results: List[Dict[str, Any]],
  150. image: Union[np.ndarray, Image.Image],
  151. config: Optional[Dict[str, Any]] = None
  152. ) -> List[Dict[str, Any]]:
  153. """
  154. 后处理布局检测结果
  155. 默认实现包括:
  156. 1. 去除重叠框
  157. 2. 将大面积文本块转换为表格(如果配置启用)
  158. 子类可以重写此方法以自定义后处理逻辑
  159. Args:
  160. layout_results: 原始检测结果
  161. image: 输入图像
  162. config: 后处理配置(可选),如果为None则使用self.config中的post_process配置
  163. Returns:
  164. 后处理后的布局结果
  165. """
  166. if not layout_results:
  167. return layout_results
  168. # 获取配置
  169. if config is None:
  170. config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
  171. # 导入 CoordinateUtils(适配器可以访问)
  172. try:
  173. from ocr_utils.coordinate_utils import CoordinateUtils
  174. except ImportError:
  175. try:
  176. from ocr_utils import CoordinateUtils
  177. except ImportError:
  178. # 如果无法导入,返回原始结果
  179. return layout_results
  180. # 1. 去除重叠框
  181. layout_results_removed_overlapping = self._remove_overlapping_boxes(layout_results, CoordinateUtils)
  182. # 2. 将大面积文本块转换为表格(如果配置启用)
  183. layout_config = config if config is not None else {}
  184. if layout_config.get('convert_large_text_to_table', False):
  185. # 获取图像尺寸
  186. if isinstance(image, Image.Image):
  187. h, w = image.size[1], image.size[0]
  188. else:
  189. h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
  190. layout_results_converted_large_text = self._convert_large_text_to_table(
  191. layout_results_removed_overlapping,
  192. (h, w),
  193. min_area_ratio=layout_config.get('min_text_area_ratio', 0.25),
  194. min_width_ratio=layout_config.get('min_text_width_ratio', 0.4),
  195. min_height_ratio=layout_config.get('min_text_height_ratio', 0.3)
  196. )
  197. return layout_results_converted_large_text
  198. else:
  199. return layout_results_removed_overlapping
  200. def _convert_large_text_to_table(
  201. self,
  202. layout_results: List[Dict[str, Any]],
  203. image_shape: Tuple[int, int],
  204. min_area_ratio: float = 0.25,
  205. min_width_ratio: float = 0.4,
  206. min_height_ratio: float = 0.3
  207. ) -> List[Dict[str, Any]]:
  208. """
  209. 将大面积的文本块转换为表格
  210. 判断规则:
  211. 1. 面积占比:占页面面积超过 min_area_ratio(默认25%)
  212. 2. 尺寸比例:宽度和高度都超过一定比例(避免细长条)
  213. 3. 不与其他表格重叠:如果已有表格,不转换
  214. """
  215. if not layout_results:
  216. return layout_results
  217. img_height, img_width = image_shape
  218. img_area = img_height * img_width
  219. if img_area == 0:
  220. return layout_results
  221. # 检查是否已有表格
  222. has_table = any(
  223. item.get('category', '').lower() in ['table', 'table_body']
  224. for item in layout_results
  225. )
  226. # 如果已有表格,不进行转换(避免误判)
  227. if has_table:
  228. return layout_results
  229. # 复制列表避免修改原数据
  230. results = [item.copy() for item in layout_results]
  231. converted_count = 0
  232. for item in results:
  233. category = item.get('category', '').lower()
  234. # 只处理文本类型的元素
  235. if category not in ['text', 'ocr_text']:
  236. continue
  237. bbox = item.get('bbox', [0, 0, 0, 0])
  238. if len(bbox) < 4:
  239. continue
  240. x1, y1, x2, y2 = bbox[:4]
  241. width = x2 - x1
  242. height = y2 - y1
  243. area = width * height
  244. # 计算占比
  245. area_ratio = area / img_area if img_area > 0 else 0
  246. width_ratio = width / img_width if img_width > 0 else 0
  247. height_ratio = height / img_height if img_height > 0 else 0
  248. # 判断是否满足转换条件
  249. if (area_ratio >= min_area_ratio and
  250. width_ratio >= min_width_ratio and
  251. height_ratio >= min_height_ratio):
  252. # 转换为表格
  253. item['category'] = 'table'
  254. item['original_category'] = category # 保留原始类别
  255. converted_count += 1
  256. return results
  257. def _map_category_id(self, category_id: int) -> str:
  258. """映射类别ID到字符串"""
  259. category_map = {
  260. 0: 'title',
  261. 1: 'text',
  262. 2: 'abandon',
  263. 3: 'image_body',
  264. 4: 'image_caption',
  265. 5: 'table_body',
  266. 6: 'table_caption',
  267. 7: 'table_footnote',
  268. 8: 'interline_equation',
  269. 9: 'interline_equation_number',
  270. 13: 'inline_equation',
  271. 14: 'interline_equation_yolo',
  272. 15: 'ocr_text',
  273. 16: 'low_score_text',
  274. 101: 'image_footnote'
  275. }
  276. return category_map.get(category_id, f'unknown_{category_id}')
  277. def _visualize_layout_results(
  278. self,
  279. image: Union[np.ndarray, Image.Image],
  280. layout_results: List[Dict[str, Any]],
  281. output_dir: str,
  282. page_name: str,
  283. suffix: str = 'raw'
  284. ) -> None:
  285. """
  286. 可视化 layout 检测结果
  287. Args:
  288. image: 输入图像
  289. layout_results: 布局检测结果
  290. output_dir: 输出目录
  291. page_name: 页面名称
  292. suffix: 文件名后缀(如 'raw', 'postprocessed')
  293. """
  294. if not layout_results:
  295. return
  296. try:
  297. # 转换为 numpy 数组
  298. if isinstance(image, Image.Image):
  299. vis_image = np.array(image)
  300. if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
  301. # PIL RGB -> OpenCV BGR
  302. vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
  303. else:
  304. vis_image = image.copy()
  305. if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
  306. # 如果是 RGB,转换为 BGR
  307. vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
  308. # 定义类别颜色映射 (BGR格式)
  309. category_colors = {
  310. 'table_body': (0, 0, 255), # 红色
  311. 'table_caption': (0, 0, 200), # 暗红色
  312. 'table_footnote': (0, 0, 150), # 更暗的红色
  313. 'text': (255, 0, 0), # 蓝色
  314. 'title': (0, 255, 255), # 黄色
  315. 'header': (255, 0, 255), # 紫色
  316. 'footer': (0, 165, 255), # 橙色
  317. 'image_body': (0, 255, 0), # 绿色
  318. 'image_caption': (0, 200, 0), # 暗绿色
  319. 'image_footnote': (0, 150, 0), # 更暗的绿色
  320. 'abandon': (128, 128, 128), # 灰色
  321. }
  322. # 绘制检测框
  323. for result in layout_results:
  324. bbox = result.get('bbox', [])
  325. if not bbox or len(bbox) < 4:
  326. continue
  327. category = result.get('category', 'unknown')
  328. color = category_colors.get(category, (128, 128, 128)) # 默认灰色
  329. thickness = 2
  330. x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
  331. cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, thickness)
  332. # 添加类别标签
  333. label = f"{category}"
  334. confidence = result.get('confidence', result.get('score', 0))
  335. if confidence:
  336. label += f":{confidence:.2f}"
  337. # 计算文本大小
  338. font = cv2.FONT_HERSHEY_SIMPLEX
  339. font_scale = 0.4
  340. text_thickness = 1
  341. (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, text_thickness)
  342. # 在框的上方绘制文本背景
  343. text_y = max(y1 - baseline - 1, text_height + baseline)
  344. cv2.rectangle(vis_image, (x1, text_y - text_height - baseline - 2),
  345. (x1 + text_width, text_y), color, -1)
  346. # 绘制文本
  347. cv2.putText(vis_image, label, (x1, text_y - baseline - 1),
  348. font, font_scale, (255, 255, 255), text_thickness)
  349. # 保存图像
  350. debug_dir = Path(output_dir) / "debug_comparison" / "layout_detection"
  351. debug_dir.mkdir(parents=True, exist_ok=True)
  352. output_path = debug_dir / f"{page_name}_layout_{suffix}.jpg"
  353. cv2.imwrite(str(output_path), vis_image)
  354. logger.info(f"📊 Saved layout detection image ({suffix}): {output_path}")
  355. # 保存 JSON 数据
  356. json_data = {
  357. 'page_name': page_name,
  358. 'suffix': suffix,
  359. 'count': len(layout_results),
  360. 'results': [
  361. {
  362. 'category': r.get('category'),
  363. 'bbox': r.get('bbox'),
  364. 'confidence': r.get('confidence', r.get('score', 0.0))
  365. }
  366. for r in layout_results
  367. ]
  368. }
  369. json_path = debug_dir / f"{page_name}_layout_{suffix}.json"
  370. with open(json_path, 'w', encoding='utf-8') as f:
  371. json.dump(json_data, f, ensure_ascii=False, indent=2)
  372. logger.info(f"📊 Saved layout detection JSON ({suffix}): {json_path}")
  373. except Exception as e:
  374. logger.warning(f"⚠️ Failed to visualize layout results: {e}")
  375. def _remove_overlapping_boxes(
  376. self,
  377. layout_results: List[Dict[str, Any]],
  378. coordinate_utils: Any,
  379. iou_threshold: float = 0.8,
  380. overlap_ratio_threshold: float = 0.8
  381. ) -> List[Dict[str, Any]]:
  382. """
  383. 改进版重叠框处理算法(基于优先级和决策规则的清晰算法)
  384. 策略:
  385. 1. 定义类别优先级(abandon < text/image < table_body)
  386. 2. 使用统一的决策规则
  387. 3. 按综合评分排序处理,优先保留大的聚合框
  388. Args:
  389. layout_results: 布局检测结果
  390. coordinate_utils: 坐标工具类
  391. iou_threshold: IoU阈值(默认0.8)
  392. overlap_ratio_threshold: 重叠比例阈值(默认0.8)
  393. Returns:
  394. 去重后的布局结果
  395. """
  396. if not layout_results or len(layout_results) <= 1:
  397. return layout_results
  398. # 常量定义
  399. CATEGORY_PRIORITY = {
  400. 'abandon': 0,
  401. 'text': 1,
  402. 'image_body': 1,
  403. 'title': 2,
  404. 'footer': 2,
  405. 'header': 2,
  406. 'table_body': 3,
  407. }
  408. AGGREGATE_LABELS = {'key-value region', 'form'}
  409. MAX_AREA = 4000000.0 # 用于面积归一化
  410. AREA_WEIGHT = 0.5
  411. CONFIDENCE_WEIGHT = 0.5
  412. AGGREGATE_BONUS = 0.1
  413. AREA_RATIO_THRESHOLD = 3.0 # 大框面积需大于小框的倍数
  414. def get_bbox_area(bbox: List[float]) -> float:
  415. """计算bbox面积"""
  416. if len(bbox) < 4:
  417. return 0.0
  418. return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
  419. def is_aggregate_type(box: Dict[str, Any]) -> bool:
  420. """检查是否是聚合类型"""
  421. original_label = box.get('raw', {}).get('original_label', '').lower()
  422. return original_label in AGGREGATE_LABELS
  423. def is_bbox_inside(inner: List[float], outer: List[float]) -> bool:
  424. """检查inner是否完全包含在outer内"""
  425. if len(inner) < 4 or len(outer) < 4:
  426. return False
  427. return (inner[0] >= outer[0] and inner[1] >= outer[1] and
  428. inner[2] <= outer[2] and inner[3] <= outer[3])
  429. def calculate_composite_score(box: Dict[str, Any], area: float) -> float:
  430. """计算text类型的综合评分(面积+置信度)"""
  431. if box.get('category') != 'text':
  432. return box.get('confidence', box.get('score', 0))
  433. normalized_area = min(area / MAX_AREA, 1.0)
  434. area_score = (normalized_area ** 0.5) * AREA_WEIGHT
  435. confidence_score = box.get('confidence', box.get('score', 0)) * CONFIDENCE_WEIGHT
  436. bonus = AGGREGATE_BONUS if is_aggregate_type(box) else 0.0
  437. return area_score + confidence_score + bonus
  438. def should_keep_box1(box1: Dict[str, Any], box2: Dict[str, Any],
  439. iou: float, overlap_ratio: float,
  440. contained_1_in_2: bool, contained_2_in_1: bool) -> bool:
  441. """判断是否应该保留box1"""
  442. # 提取基本信息
  443. cat1, cat2 = box1.get('category', ''), box2.get('category', '')
  444. score1 = box1.get('confidence', box1.get('score', 0))
  445. score2 = box2.get('confidence', box2.get('score', 0))
  446. bbox1, bbox2 = box1.get('bbox', [0, 0, 0, 0]), box2.get('bbox', [0, 0, 0, 0])
  447. area1, area2 = get_bbox_area(bbox1), get_bbox_area(bbox2)
  448. is_agg1, is_agg2 = is_aggregate_type(box1), is_aggregate_type(box2)
  449. # 规则1: 类别优先级
  450. priority1 = CATEGORY_PRIORITY.get(cat1, 1)
  451. priority2 = CATEGORY_PRIORITY.get(cat2, 1)
  452. if priority1 != priority2:
  453. return priority1 > priority2
  454. # 规则2: 包含关系 + 聚合类型优先
  455. if contained_2_in_1 and is_agg1 and not is_agg2:
  456. return True
  457. if contained_1_in_2 and is_agg2 and not is_agg1:
  458. return False
  459. # 规则3: 包含关系 + 面积比例
  460. if contained_2_in_1 and area1 > area2 * AREA_RATIO_THRESHOLD:
  461. return True
  462. if contained_1_in_2 and area2 > area1 * AREA_RATIO_THRESHOLD:
  463. return False
  464. # 规则4: text类型使用综合评分
  465. if cat1 == 'text' or cat2 == 'text':
  466. comp_score1 = calculate_composite_score(box1, area1)
  467. comp_score2 = calculate_composite_score(box2, area2)
  468. if abs(comp_score1 - comp_score2) > 0.05:
  469. return comp_score1 > comp_score2
  470. # 规则5: 置信度比较
  471. if abs(score1 - score2) > 0.1:
  472. return score1 > score2
  473. # 规则6: 面积比较
  474. return area1 >= area2
  475. # 主处理逻辑
  476. results = [item.copy() for item in layout_results]
  477. need_remove = set()
  478. # 按综合评分排序(高分优先)
  479. def get_sort_key(i: int) -> float:
  480. item = results[i]
  481. if item.get('category') == 'text':
  482. return -calculate_composite_score(item, get_bbox_area(item.get('bbox', [])))
  483. return -item.get('confidence', item.get('score', 0))
  484. sorted_indices = sorted(range(len(results)), key=get_sort_key)
  485. # 比较每对框
  486. for idx_i, i in enumerate(sorted_indices):
  487. if i in need_remove:
  488. continue
  489. for idx_j, j in enumerate(sorted_indices):
  490. if j == i or j in need_remove or idx_j >= idx_i:
  491. continue
  492. bbox1, bbox2 = results[i].get('bbox', []), results[j].get('bbox', [])
  493. if len(bbox1) < 4 or len(bbox2) < 4:
  494. continue
  495. # 计算重叠指标
  496. iou = coordinate_utils.calculate_iou(bbox1, bbox2)
  497. overlap_ratio = coordinate_utils.calculate_overlap_ratio(bbox1, bbox2)
  498. contained_1_in_2 = is_bbox_inside(bbox1, bbox2)
  499. contained_2_in_1 = is_bbox_inside(bbox2, bbox1)
  500. # 检查是否有显著重叠
  501. if not (iou > iou_threshold or overlap_ratio > overlap_ratio_threshold or
  502. contained_1_in_2 or contained_2_in_1):
  503. continue
  504. # 应用决策规则
  505. if should_keep_box1(results[i], results[j], iou, overlap_ratio,
  506. contained_1_in_2, contained_2_in_1):
  507. need_remove.add(j)
  508. else:
  509. need_remove.add(i)
  510. break
  511. return [results[i] for i in range(len(results)) if i not in need_remove]
  512. class BaseVLRecognizer(BaseAdapter):
  513. """VL识别器基类"""
  514. @abstractmethod
  515. def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
  516. """识别表格"""
  517. pass
  518. @abstractmethod
  519. def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
  520. """识别公式"""
  521. pass
  522. class BaseOCRRecognizer(BaseAdapter):
  523. """OCR识别器基类"""
  524. @abstractmethod
  525. def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
  526. """识别文本"""
  527. pass
  528. @abstractmethod
  529. def detect_text_boxes(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
  530. """
  531. 只检测文本框(不识别文字内容)
  532. 子类必须实现此方法。建议使用只运行检测模型的方式(不运行识别模型)以优化性能。
  533. 如果无法优化,至少实现一个调用 recognize_text() 的版本作为兜底。
  534. Returns:
  535. 文本框列表,每项包含 'bbox', 'poly',可能包含 'confidence'
  536. """
  537. pass