base.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, Any, List, Union, Optional, Tuple
  3. import numpy as np
  4. from PIL import Image
  5. from loguru import logger
  6. from pathlib import Path
  7. import cv2
  8. import json
  9. class BaseAdapter(ABC):
  10. """基础适配器接口"""
  11. def __init__(self, config: Dict[str, Any]):
  12. self.config = config
  13. @abstractmethod
  14. def initialize(self):
  15. """初始化模型"""
  16. pass
  17. @abstractmethod
  18. def cleanup(self):
  19. """清理资源"""
  20. pass
  21. class BasePreprocessor(BaseAdapter):
  22. """预处理器基类"""
  23. def __init__(self, config: Dict[str, Any]):
  24. super().__init__(config)
  25. # 运行时由 pipeline 按页注入(与 layout_detection 一致)
  26. self.debug_mode: Optional[bool] = None
  27. self.output_dir: Optional[str] = None
  28. self.page_name: Optional[str] = None
  29. def _watermark_debug_options(self) -> Dict[str, Any]:
  30. wm_cfg = self.config.get('watermark_removal', {})
  31. opts = wm_cfg.get('debug_options', {})
  32. return opts if isinstance(opts, dict) else {}
  33. def _is_watermark_debug_enabled(self) -> bool:
  34. debug_mode = getattr(self, 'debug_mode', None)
  35. if debug_mode is not None:
  36. return bool(debug_mode)
  37. return bool(self._watermark_debug_options().get('enabled', False))
  38. def _resolve_watermark_debug_paths(self) -> Tuple[Optional[str], str]:
  39. output_dir = getattr(self, 'output_dir', None)
  40. if output_dir is None:
  41. output_dir = self._watermark_debug_options().get('output_dir')
  42. page_name = getattr(self, 'page_name', None)
  43. if not page_name:
  44. page_name = self._watermark_debug_options().get('prefix') or 'watermark'
  45. prefix = self._watermark_debug_options().get('prefix', '')
  46. if prefix and page_name and not str(page_name).startswith(str(prefix)):
  47. page_name = f"{prefix}_{page_name}"
  48. return output_dir, str(page_name)
  49. def _save_watermark_debug_images(
  50. self,
  51. before: np.ndarray,
  52. after: np.ndarray,
  53. threshold: int,
  54. morph_close_kernel: int,
  55. contrast_cfg: Optional[Dict[str, Any]] = None,
  56. ) -> None:
  57. """保存水印调试图(委托 ocr_utils.watermark_utils)。"""
  58. from ocr_utils.watermark_utils import save_watermark_removal_debug
  59. output_dir, page_name = self._resolve_watermark_debug_paths()
  60. if not output_dir:
  61. return
  62. opts = self._watermark_debug_options()
  63. params: Dict[str, Any] = {
  64. "threshold": threshold,
  65. "morph_close_kernel": morph_close_kernel,
  66. }
  67. if contrast_cfg:
  68. params["contrast_enhancement"] = contrast_cfg
  69. try:
  70. save_watermark_removal_debug(
  71. before,
  72. after,
  73. output_dir,
  74. page_name,
  75. processing_params=params,
  76. image_format=opts.get("image_format") or "png",
  77. save_compare=opts.get("save_compare", True),
  78. subdir=opts.get("subdir", "watermark_removal"),
  79. )
  80. except Exception as e:
  81. logger.warning(f"Watermark debug save failed: {e}")
  82. def remove_watermark(self, image: Union[np.ndarray, Image.Image]) -> np.ndarray:
  83. """页级水印去除(默认无操作,子类可覆盖)。"""
  84. if isinstance(image, Image.Image):
  85. return np.array(image)
  86. return image
  87. def _preprocess_order(self) -> str:
  88. """预处理步骤顺序:orient_first(默认)| watermark_first。"""
  89. order = str(self.config.get('order', 'orient_first')).strip().lower()
  90. if order not in ('orient_first', 'watermark_first'):
  91. logger.warning(
  92. f"Unknown preprocessor.order={order!r}, fallback to orient_first"
  93. )
  94. return 'orient_first'
  95. return order
  96. def correct_orientation(
  97. self,
  98. image: Union[np.ndarray, Image.Image],
  99. *,
  100. pdf_rotate_angle: Optional[int] = None,
  101. use_orientation_classifier: bool = True,
  102. ) -> tuple[np.ndarray, int]:
  103. """
  104. 仅方向校正,不去水印。用于表格裁剪等页级已预处理场景。
  105. Args:
  106. pdf_rotate_angle: 文字 PDF 页级旋转(逆时针角度,与 pipeline 一致)
  107. use_orientation_classifier: 是否使用方向分类器(扫描件为 True)
  108. """
  109. if isinstance(image, Image.Image):
  110. image = np.array(image)
  111. if pdf_rotate_angle:
  112. pil_rotated = Image.fromarray(image).rotate(pdf_rotate_angle, expand=True)
  113. return np.array(pil_rotated), int(pdf_rotate_angle)
  114. return image, 0
  115. def prepare_detection_image(
  116. self,
  117. image: Union[np.ndarray, Image.Image],
  118. *,
  119. pdf_rotate_angle: Optional[int] = None,
  120. use_orientation_classifier: bool = True,
  121. ) -> tuple[np.ndarray, int]:
  122. """
  123. 页级完整预处理:按 preprocessor.order 执行方向校正与水印去除。
  124. Returns:
  125. (detection_image, rotate_angle)
  126. """
  127. if isinstance(image, Image.Image):
  128. image = np.array(image)
  129. order = self._preprocess_order()
  130. def _orient(img: np.ndarray) -> tuple[np.ndarray, int]:
  131. return self.correct_orientation(
  132. img,
  133. pdf_rotate_angle=pdf_rotate_angle,
  134. use_orientation_classifier=use_orientation_classifier,
  135. )
  136. if order == 'watermark_first':
  137. cleaned = self.remove_watermark(image)
  138. return _orient(cleaned)
  139. oriented, rotate_angle = _orient(image)
  140. return self.remove_watermark(oriented), rotate_angle
  141. def process(
  142. self,
  143. image: Union[np.ndarray, Image.Image],
  144. skip_watermark: bool = False,
  145. ) -> tuple[np.ndarray, int]:
  146. """
  147. 裁剪块:仅方向校正(skip_watermark=True)。
  148. 页级请使用 prepare_detection_image()。
  149. """
  150. if skip_watermark:
  151. return self.correct_orientation(image, use_orientation_classifier=True)
  152. return self.prepare_detection_image(image, use_orientation_classifier=True)
  153. def _apply_rotation(self, image: np.ndarray, rotation_angle: int) -> np.ndarray:
  154. """应用旋转"""
  155. import cv2
  156. if rotation_angle == 90: # 90度
  157. return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  158. elif rotation_angle == 180: # 180度
  159. return cv2.rotate(image, cv2.ROTATE_180)
  160. elif rotation_angle == 270: # 270度
  161. return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
  162. return image
  163. class BaseLayoutDetector(BaseAdapter):
  164. """版式检测器基类"""
  165. def __init__(self, config: Dict[str, Any]):
  166. """初始化版式检测器
  167. Args:
  168. config: 配置字典
  169. """
  170. super().__init__(config)
  171. # 初始化 debug 相关属性(支持从配置或运行时设置)
  172. self.debug_mode = None # 将在 detect() 方法中从配置读取
  173. self.output_dir = None # 将在 detect() 方法中从配置读取
  174. self.page_name = None # 将在 detect() 方法中从配置读取
  175. def detect(
  176. self,
  177. image: Union[np.ndarray, Image.Image],
  178. ocr_spans: Optional[List[Dict[str, Any]]] = None
  179. ) -> List[Dict[str, Any]]:
  180. """
  181. 检测版式(模板方法,自动执行后处理)
  182. 此方法会:
  183. 1. 调用子类实现的 _detect_raw() 进行原始检测
  184. 2. 自动执行后处理(去除重叠框、文本转表格等)
  185. Args:
  186. image: 输入图像
  187. ocr_spans: OCR结果(可选,某些detector可能需要)
  188. Returns:
  189. 后处理后的布局检测结果
  190. """
  191. # 调用子类实现的原始检测方法
  192. layout_results = self._detect_raw(image, ocr_spans)
  193. debug_mode = self._is_layout_debug_enabled()
  194. output_dir, page_name = self._resolve_layout_debug_paths()
  195. dbg_opts = self._layout_debug_options()
  196. if debug_mode:
  197. logger.debug(
  198. f"Layout detection raw results (before post-processing): "
  199. f"{len(layout_results)} elements"
  200. )
  201. if output_dir and dbg_opts.get('save_raw', True):
  202. self._visualize_layout_results(
  203. image, layout_results, output_dir, page_name, suffix='raw'
  204. )
  205. # 自动执行后处理
  206. if layout_results:
  207. layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
  208. layout_results = self.post_process(layout_results, image, layout_config)
  209. if debug_mode and output_dir and dbg_opts.get('save_post_processed', True):
  210. self._visualize_layout_results(
  211. image, layout_results, output_dir, page_name, suffix='post'
  212. )
  213. return layout_results
  214. @abstractmethod
  215. def _detect_raw(
  216. self,
  217. image: Union[np.ndarray, Image.Image],
  218. ocr_spans: Optional[List[Dict[str, Any]]] = None
  219. ) -> List[Dict[str, Any]]:
  220. """
  221. 原始检测方法(子类必须实现)
  222. Args:
  223. image: 输入图像
  224. ocr_spans: OCR结果(可选)
  225. Returns:
  226. 原始检测结果(未后处理)
  227. """
  228. pass
  229. def post_process(
  230. self,
  231. layout_results: List[Dict[str, Any]],
  232. image: Union[np.ndarray, Image.Image],
  233. config: Optional[Dict[str, Any]] = None
  234. ) -> List[Dict[str, Any]]:
  235. """
  236. 后处理布局检测结果
  237. 默认实现包括:
  238. 1. 去除重叠框
  239. 2. 将大面积文本块转换为表格(如果配置启用)
  240. 子类可以重写此方法以自定义后处理逻辑
  241. Args:
  242. layout_results: 原始检测结果
  243. image: 输入图像
  244. config: 后处理配置(可选),如果为None则使用self.config中的post_process配置
  245. Returns:
  246. 后处理后的布局结果
  247. """
  248. if not layout_results:
  249. return layout_results
  250. # 获取配置
  251. if config is None:
  252. config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
  253. # 导入 CoordinateUtils(适配器可以访问)
  254. try:
  255. from ocr_utils.coordinate_utils import CoordinateUtils
  256. except ImportError:
  257. try:
  258. from ocr_utils import CoordinateUtils
  259. except ImportError:
  260. # 如果无法导入,返回原始结果
  261. return layout_results
  262. # 1. 去除重叠框
  263. layout_results_removed_overlapping = self._remove_overlapping_boxes(layout_results, CoordinateUtils)
  264. # 2. 将大面积文本块转换为表格(如果配置启用)
  265. layout_config = config if config is not None else {}
  266. if layout_config.get('convert_large_text_to_table', False):
  267. # 获取图像尺寸
  268. if isinstance(image, Image.Image):
  269. h, w = image.size[1], image.size[0]
  270. else:
  271. h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
  272. layout_results_converted_large_text = self._convert_large_text_to_table(
  273. layout_results_removed_overlapping,
  274. (h, w),
  275. min_area_ratio=layout_config.get('min_text_area_ratio', 0.25),
  276. min_width_ratio=layout_config.get('min_text_width_ratio', 0.4),
  277. min_height_ratio=layout_config.get('min_text_height_ratio', 0.3)
  278. )
  279. return layout_results_converted_large_text
  280. else:
  281. return layout_results_removed_overlapping
  282. def _convert_large_text_to_table(
  283. self,
  284. layout_results: List[Dict[str, Any]],
  285. image_shape: Tuple[int, int],
  286. min_area_ratio: float = 0.25,
  287. min_width_ratio: float = 0.4,
  288. min_height_ratio: float = 0.3
  289. ) -> List[Dict[str, Any]]:
  290. """
  291. 将大面积的文本块转换为表格
  292. 判断规则:
  293. 1. 面积占比:占页面面积超过 min_area_ratio(默认25%)
  294. 2. 尺寸比例:宽度和高度都超过一定比例(避免细长条)
  295. 3. 不与其他表格重叠:如果已有表格,不转换
  296. """
  297. if not layout_results:
  298. return layout_results
  299. img_height, img_width = image_shape
  300. img_area = img_height * img_width
  301. if img_area == 0:
  302. return layout_results
  303. # 检查是否已有表格
  304. has_table = any(
  305. item.get('category', '').lower() in ['table', 'table_body']
  306. for item in layout_results
  307. )
  308. # 如果已有表格,不进行转换(避免误判)
  309. if has_table:
  310. return layout_results
  311. # 复制列表避免修改原数据
  312. results = [item.copy() for item in layout_results]
  313. converted_count = 0
  314. for item in results:
  315. category = item.get('category', '').lower()
  316. # 只处理文本类型的元素
  317. if category not in ['text', 'ocr_text']:
  318. continue
  319. bbox = item.get('bbox', [0, 0, 0, 0])
  320. if len(bbox) < 4:
  321. continue
  322. x1, y1, x2, y2 = bbox[:4]
  323. width = x2 - x1
  324. height = y2 - y1
  325. area = width * height
  326. # 计算占比
  327. area_ratio = area / img_area if img_area > 0 else 0
  328. width_ratio = width / img_width if img_width > 0 else 0
  329. height_ratio = height / img_height if img_height > 0 else 0
  330. # 判断是否满足转换条件
  331. if (area_ratio >= min_area_ratio and
  332. width_ratio >= min_width_ratio and
  333. height_ratio >= min_height_ratio):
  334. # 转换为表格
  335. item['category'] = 'table'
  336. item['original_category'] = category # 保留原始类别
  337. converted_count += 1
  338. return results
  339. def _map_category_id(self, category_id: int) -> str:
  340. """映射类别ID到字符串"""
  341. category_map = {
  342. 0: 'title',
  343. 1: 'text',
  344. 2: 'abandon',
  345. 3: 'image_body',
  346. 4: 'image_caption',
  347. 5: 'table_body',
  348. 6: 'table_caption',
  349. 7: 'table_footnote',
  350. 8: 'interline_equation',
  351. 9: 'interline_equation_number',
  352. 13: 'inline_equation',
  353. 14: 'interline_equation_yolo',
  354. 15: 'ocr_text',
  355. 16: 'low_score_text',
  356. 101: 'image_footnote'
  357. }
  358. return category_map.get(category_id, f'unknown_{category_id}')
  359. def _layout_debug_options(self) -> Dict[str, Any]:
  360. opts = self.config.get('debug_options', {})
  361. return opts if isinstance(opts, dict) else {}
  362. def _is_layout_debug_enabled(self) -> bool:
  363. debug_mode = getattr(self, 'debug_mode', None)
  364. if debug_mode is not None:
  365. return bool(debug_mode)
  366. if self.config.get('debug_mode', False):
  367. return True
  368. return bool(self._layout_debug_options().get('enabled', False))
  369. def _resolve_layout_debug_paths(self) -> Tuple[Optional[str], str]:
  370. output_dir = getattr(self, 'output_dir', None)
  371. if output_dir is None:
  372. output_dir = self.config.get('output_dir')
  373. if output_dir is None:
  374. output_dir = self._layout_debug_options().get('output_dir')
  375. if output_dir is not None:
  376. output_dir = str(output_dir)
  377. page_name = getattr(self, 'page_name', None)
  378. if page_name is None:
  379. page_name = self.config.get('page_name')
  380. if not page_name:
  381. prefix = self._layout_debug_options().get('prefix', '')
  382. page_name = prefix if prefix else 'layout_detection'
  383. return output_dir, str(page_name)
  384. def _visualize_layout_results(
  385. self,
  386. image: Union[np.ndarray, Image.Image],
  387. layout_results: List[Dict[str, Any]],
  388. output_dir: str,
  389. page_name: str,
  390. suffix: str = 'raw',
  391. ) -> None:
  392. """保存 layout 模块 debug(底图为 inference / detection 输入)。"""
  393. if not layout_results:
  394. return
  395. from ocr_utils.module_debug_viz import save_layout_debug
  396. opts = self._layout_debug_options()
  397. save_layout_debug(
  398. image,
  399. layout_results,
  400. output_dir,
  401. page_name,
  402. suffix=suffix,
  403. subdir=opts.get('subdir', 'layout_detection'),
  404. image_format=opts.get('image_format', 'png'),
  405. save_json=bool(opts.get('save_json', True)),
  406. )
  407. def _remove_overlapping_boxes(
  408. self,
  409. layout_results: List[Dict[str, Any]],
  410. coordinate_utils: Any,
  411. iou_threshold: float = 0.8,
  412. overlap_ratio_threshold: float = 0.8
  413. ) -> List[Dict[str, Any]]:
  414. """
  415. 改进版重叠框处理算法(基于优先级和决策规则的清晰算法)
  416. 策略:
  417. 1. 定义类别优先级(abandon < text/image < table_body)
  418. 2. 使用统一的决策规则
  419. 3. 按综合评分排序处理,优先保留大的聚合框
  420. Args:
  421. layout_results: 布局检测结果
  422. coordinate_utils: 坐标工具类
  423. iou_threshold: IoU阈值(默认0.8)
  424. overlap_ratio_threshold: 重叠比例阈值(默认0.8)
  425. Returns:
  426. 去重后的布局结果
  427. """
  428. if not layout_results or len(layout_results) <= 1:
  429. return layout_results
  430. # 常量定义
  431. CATEGORY_PRIORITY = {
  432. 'abandon': 0,
  433. 'text': 1,
  434. 'image_body': 1,
  435. 'title': 2,
  436. 'footer': 2,
  437. 'header': 2,
  438. 'table_body': 3,
  439. }
  440. AGGREGATE_LABELS = {'key-value region', 'form'}
  441. MAX_AREA = 4000000.0 # 用于面积归一化
  442. AREA_WEIGHT = 0.5
  443. CONFIDENCE_WEIGHT = 0.5
  444. AGGREGATE_BONUS = 0.1
  445. AREA_RATIO_THRESHOLD = 3.0 # 大框面积需大于小框的倍数
  446. def get_bbox_area(bbox: List[float]) -> float:
  447. """计算bbox面积"""
  448. if len(bbox) < 4:
  449. return 0.0
  450. return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
  451. def is_aggregate_type(box: Dict[str, Any]) -> bool:
  452. """检查是否是聚合类型"""
  453. original_label = box.get('raw', {}).get('original_label', '').lower()
  454. return original_label in AGGREGATE_LABELS
  455. def is_bbox_inside(inner: List[float], outer: List[float]) -> bool:
  456. """检查inner是否完全包含在outer内"""
  457. if len(inner) < 4 or len(outer) < 4:
  458. return False
  459. return (inner[0] >= outer[0] and inner[1] >= outer[1] and
  460. inner[2] <= outer[2] and inner[3] <= outer[3])
  461. def calculate_composite_score(box: Dict[str, Any], area: float) -> float:
  462. """计算text类型的综合评分(面积+置信度)"""
  463. if box.get('category') != 'text':
  464. return box.get('confidence', box.get('score', 0))
  465. normalized_area = min(area / MAX_AREA, 1.0)
  466. area_score = (normalized_area ** 0.5) * AREA_WEIGHT
  467. confidence_score = box.get('confidence', box.get('score', 0)) * CONFIDENCE_WEIGHT
  468. bonus = AGGREGATE_BONUS if is_aggregate_type(box) else 0.0
  469. return area_score + confidence_score + bonus
  470. def should_keep_box1(box1: Dict[str, Any], box2: Dict[str, Any],
  471. iou: float, overlap_ratio: float,
  472. contained_1_in_2: bool, contained_2_in_1: bool) -> bool:
  473. """判断是否应该保留box1"""
  474. # 提取基本信息
  475. cat1, cat2 = box1.get('category', ''), box2.get('category', '')
  476. score1 = box1.get('confidence', box1.get('score', 0))
  477. score2 = box2.get('confidence', box2.get('score', 0))
  478. bbox1, bbox2 = box1.get('bbox', [0, 0, 0, 0]), box2.get('bbox', [0, 0, 0, 0])
  479. area1, area2 = get_bbox_area(bbox1), get_bbox_area(bbox2)
  480. is_agg1, is_agg2 = is_aggregate_type(box1), is_aggregate_type(box2)
  481. # 规则1: 类别优先级
  482. priority1 = CATEGORY_PRIORITY.get(cat1, 1)
  483. priority2 = CATEGORY_PRIORITY.get(cat2, 1)
  484. if priority1 != priority2:
  485. return priority1 > priority2
  486. # 规则2: 包含关系 + 聚合类型优先
  487. if contained_2_in_1 and is_agg1 and not is_agg2:
  488. return True
  489. if contained_1_in_2 and is_agg2 and not is_agg1:
  490. return False
  491. # 规则3: 包含关系 + 面积比例
  492. if contained_2_in_1 and area1 > area2 * AREA_RATIO_THRESHOLD:
  493. return True
  494. if contained_1_in_2 and area2 > area1 * AREA_RATIO_THRESHOLD:
  495. return False
  496. # 规则4: text类型使用综合评分
  497. if cat1 == 'text' or cat2 == 'text':
  498. comp_score1 = calculate_composite_score(box1, area1)
  499. comp_score2 = calculate_composite_score(box2, area2)
  500. if abs(comp_score1 - comp_score2) > 0.05:
  501. return comp_score1 > comp_score2
  502. # 规则5: 置信度比较
  503. if abs(score1 - score2) > 0.1:
  504. return score1 > score2
  505. # 规则6: 面积比较
  506. return area1 >= area2
  507. # 主处理逻辑
  508. results = [item.copy() for item in layout_results]
  509. need_remove = set()
  510. # 按综合评分排序(高分优先)
  511. def get_sort_key(i: int) -> float:
  512. item = results[i]
  513. if item.get('category') == 'text':
  514. return -calculate_composite_score(item, get_bbox_area(item.get('bbox', [])))
  515. return -item.get('confidence', item.get('score', 0))
  516. sorted_indices = sorted(range(len(results)), key=get_sort_key)
  517. # 比较每对框
  518. for idx_i, i in enumerate(sorted_indices):
  519. if i in need_remove:
  520. continue
  521. for idx_j, j in enumerate(sorted_indices):
  522. if j == i or j in need_remove or idx_j >= idx_i:
  523. continue
  524. bbox1, bbox2 = results[i].get('bbox', []), results[j].get('bbox', [])
  525. if len(bbox1) < 4 or len(bbox2) < 4:
  526. continue
  527. # 计算重叠指标
  528. iou = coordinate_utils.calculate_iou(bbox1, bbox2)
  529. overlap_ratio = coordinate_utils.calculate_overlap_ratio(bbox1, bbox2)
  530. contained_1_in_2 = is_bbox_inside(bbox1, bbox2)
  531. contained_2_in_1 = is_bbox_inside(bbox2, bbox1)
  532. # 检查是否有显著重叠
  533. if not (iou > iou_threshold or overlap_ratio > overlap_ratio_threshold or
  534. contained_1_in_2 or contained_2_in_1):
  535. continue
  536. # 应用决策规则
  537. if should_keep_box1(results[i], results[j], iou, overlap_ratio,
  538. contained_1_in_2, contained_2_in_1):
  539. need_remove.add(j)
  540. else:
  541. need_remove.add(i)
  542. break
  543. return [results[i] for i in range(len(results)) if i not in need_remove]
  544. class BaseVLRecognizer(BaseAdapter):
  545. """VL识别器基类"""
  546. @abstractmethod
  547. def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
  548. """识别表格"""
  549. pass
  550. @abstractmethod
  551. def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
  552. """识别公式"""
  553. pass
  554. class BaseOCRRecognizer(BaseAdapter):
  555. """OCR识别器基类"""
  556. @abstractmethod
  557. def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
  558. """识别文本"""
  559. pass
  560. @abstractmethod
  561. def detect_text_boxes(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
  562. """
  563. 只检测文本框(不识别文字内容)
  564. 子类必须实现此方法。建议使用只运行检测模型的方式(不运行识别模型)以优化性能。
  565. 如果无法优化,至少实现一个调用 recognize_text() 的版本作为兜底。
  566. Returns:
  567. 文本框列表,每项包含 'bbox', 'poly',可能包含 'confidence'
  568. """
  569. pass