base.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, Any, List, Union, Optional, Tuple
  3. import numpy as np
  4. from PIL import Image
  5. from loguru import logger
  6. from pathlib import Path
  7. import cv2
  8. import json
  9. class BaseAdapter(ABC):
  10. """基础适配器接口"""
  11. def __init__(self, config: Dict[str, Any]):
  12. self.config = config
  13. @abstractmethod
  14. def initialize(self):
  15. """初始化模型"""
  16. pass
  17. @abstractmethod
  18. def cleanup(self):
  19. """清理资源"""
  20. pass
  21. class BasePreprocessor(BaseAdapter):
  22. """预处理器基类"""
  23. def __init__(self, config: Dict[str, Any]):
  24. super().__init__(config)
  25. # 运行时由 pipeline 按页注入(与 layout_detection 一致)
  26. self.debug_mode: Optional[bool] = None
  27. self.output_dir: Optional[str] = None
  28. self.page_name: Optional[str] = None
  29. def _watermark_debug_options(self) -> Dict[str, Any]:
  30. wm_cfg = self.config.get('watermark_removal', {})
  31. opts = wm_cfg.get('debug_options', {})
  32. return opts if isinstance(opts, dict) else {}
  33. def _is_watermark_debug_enabled(self) -> bool:
  34. debug_mode = getattr(self, 'debug_mode', None)
  35. if debug_mode is not None:
  36. return bool(debug_mode)
  37. return bool(self._watermark_debug_options().get('enabled', False))
  38. def _resolve_watermark_debug_paths(self) -> Tuple[Optional[str], str]:
  39. output_dir = getattr(self, 'output_dir', None)
  40. if output_dir is None:
  41. output_dir = self._watermark_debug_options().get('output_dir')
  42. page_name = getattr(self, 'page_name', None)
  43. if not page_name:
  44. page_name = self._watermark_debug_options().get('prefix') or 'watermark'
  45. prefix = self._watermark_debug_options().get('prefix', '')
  46. if prefix and page_name and not str(page_name).startswith(str(prefix)):
  47. page_name = f"{prefix}_{page_name}"
  48. return output_dir, str(page_name)
  49. def _save_watermark_debug_images(
  50. self,
  51. before: np.ndarray,
  52. after: np.ndarray,
  53. threshold: int,
  54. morph_close_kernel: int,
  55. contrast_cfg: Optional[Dict[str, Any]] = None,
  56. ) -> None:
  57. """保存水印调试图(委托 ocr_utils.watermark_utils)。"""
  58. from ocr_utils.watermark_utils import save_watermark_removal_debug
  59. output_dir, page_name = self._resolve_watermark_debug_paths()
  60. if not output_dir:
  61. return
  62. opts = self._watermark_debug_options()
  63. params: Dict[str, Any] = {
  64. "threshold": threshold,
  65. "morph_close_kernel": morph_close_kernel,
  66. }
  67. if contrast_cfg:
  68. params["contrast_enhancement"] = contrast_cfg
  69. try:
  70. save_watermark_removal_debug(
  71. before,
  72. after,
  73. output_dir,
  74. page_name,
  75. processing_params=params,
  76. image_format=opts.get("image_format") or "png",
  77. save_compare=opts.get("save_compare", True),
  78. )
  79. except Exception as e:
  80. logger.warning(f"Watermark debug save failed: {e}")
  81. def remove_watermark(self, image: Union[np.ndarray, Image.Image]) -> np.ndarray:
  82. """页级水印去除(默认无操作,子类可覆盖)。"""
  83. if isinstance(image, Image.Image):
  84. return np.array(image)
  85. return image
  86. def _preprocess_order(self) -> str:
  87. """预处理步骤顺序:orient_first(默认)| watermark_first。"""
  88. order = str(self.config.get('order', 'orient_first')).strip().lower()
  89. if order not in ('orient_first', 'watermark_first'):
  90. logger.warning(
  91. f"Unknown preprocessor.order={order!r}, fallback to orient_first"
  92. )
  93. return 'orient_first'
  94. return order
  95. def correct_orientation(
  96. self,
  97. image: Union[np.ndarray, Image.Image],
  98. *,
  99. pdf_rotate_angle: Optional[int] = None,
  100. use_orientation_classifier: bool = True,
  101. ) -> tuple[np.ndarray, int]:
  102. """
  103. 仅方向校正,不去水印。用于表格裁剪等页级已预处理场景。
  104. Args:
  105. pdf_rotate_angle: 文字 PDF 页级旋转(逆时针角度,与 pipeline 一致)
  106. use_orientation_classifier: 是否使用方向分类器(扫描件为 True)
  107. """
  108. if isinstance(image, Image.Image):
  109. image = np.array(image)
  110. if pdf_rotate_angle:
  111. pil_rotated = Image.fromarray(image).rotate(pdf_rotate_angle, expand=True)
  112. return np.array(pil_rotated), int(pdf_rotate_angle)
  113. return image, 0
  114. def prepare_detection_image(
  115. self,
  116. image: Union[np.ndarray, Image.Image],
  117. *,
  118. pdf_rotate_angle: Optional[int] = None,
  119. use_orientation_classifier: bool = True,
  120. ) -> tuple[np.ndarray, int]:
  121. """
  122. 页级完整预处理:按 preprocessor.order 执行方向校正与水印去除。
  123. Returns:
  124. (detection_image, rotate_angle)
  125. """
  126. if isinstance(image, Image.Image):
  127. image = np.array(image)
  128. order = self._preprocess_order()
  129. def _orient(img: np.ndarray) -> tuple[np.ndarray, int]:
  130. return self.correct_orientation(
  131. img,
  132. pdf_rotate_angle=pdf_rotate_angle,
  133. use_orientation_classifier=use_orientation_classifier,
  134. )
  135. if order == 'watermark_first':
  136. cleaned = self.remove_watermark(image)
  137. return _orient(cleaned)
  138. oriented, rotate_angle = _orient(image)
  139. return self.remove_watermark(oriented), rotate_angle
  140. def process(
  141. self,
  142. image: Union[np.ndarray, Image.Image],
  143. skip_watermark: bool = False,
  144. ) -> tuple[np.ndarray, int]:
  145. """
  146. 裁剪块:仅方向校正(skip_watermark=True)。
  147. 页级请使用 prepare_detection_image()。
  148. """
  149. if skip_watermark:
  150. return self.correct_orientation(image, use_orientation_classifier=True)
  151. return self.prepare_detection_image(image, use_orientation_classifier=True)
  152. def _apply_rotation(self, image: np.ndarray, rotation_angle: int) -> np.ndarray:
  153. """应用旋转"""
  154. import cv2
  155. if rotation_angle == 90: # 90度
  156. return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  157. elif rotation_angle == 180: # 180度
  158. return cv2.rotate(image, cv2.ROTATE_180)
  159. elif rotation_angle == 270: # 270度
  160. return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
  161. return image
  162. class BaseLayoutDetector(BaseAdapter):
  163. """版式检测器基类"""
  164. def __init__(self, config: Dict[str, Any]):
  165. """初始化版式检测器
  166. Args:
  167. config: 配置字典
  168. """
  169. super().__init__(config)
  170. # 初始化 debug 相关属性(支持从配置或运行时设置)
  171. self.debug_mode = None # 将在 detect() 方法中从配置读取
  172. self.output_dir = None # 将在 detect() 方法中从配置读取
  173. self.page_name = None # 将在 detect() 方法中从配置读取
  174. def detect(
  175. self,
  176. image: Union[np.ndarray, Image.Image],
  177. ocr_spans: Optional[List[Dict[str, Any]]] = None
  178. ) -> List[Dict[str, Any]]:
  179. """
  180. 检测版式(模板方法,自动执行后处理)
  181. 此方法会:
  182. 1. 调用子类实现的 _detect_raw() 进行原始检测
  183. 2. 自动执行后处理(去除重叠框、文本转表格等)
  184. Args:
  185. image: 输入图像
  186. ocr_spans: OCR结果(可选,某些detector可能需要)
  187. Returns:
  188. 后处理后的布局检测结果
  189. """
  190. # 调用子类实现的原始检测方法
  191. layout_results = self._detect_raw(image, ocr_spans)
  192. debug_mode = self._is_layout_debug_enabled()
  193. output_dir, page_name = self._resolve_layout_debug_paths()
  194. dbg_opts = self._layout_debug_options()
  195. if debug_mode:
  196. logger.debug(
  197. f"Layout detection raw results (before post-processing): "
  198. f"{len(layout_results)} elements"
  199. )
  200. if output_dir and dbg_opts.get('save_raw', True):
  201. self._visualize_layout_results(
  202. image, layout_results, output_dir, page_name, suffix='raw'
  203. )
  204. # 自动执行后处理
  205. if layout_results:
  206. layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
  207. layout_results = self.post_process(layout_results, image, layout_config)
  208. if debug_mode and output_dir and dbg_opts.get('save_post_processed', True):
  209. self._visualize_layout_results(
  210. image, layout_results, output_dir, page_name, suffix='post'
  211. )
  212. return layout_results
  213. @abstractmethod
  214. def _detect_raw(
  215. self,
  216. image: Union[np.ndarray, Image.Image],
  217. ocr_spans: Optional[List[Dict[str, Any]]] = None
  218. ) -> List[Dict[str, Any]]:
  219. """
  220. 原始检测方法(子类必须实现)
  221. Args:
  222. image: 输入图像
  223. ocr_spans: OCR结果(可选)
  224. Returns:
  225. 原始检测结果(未后处理)
  226. """
  227. pass
  228. def post_process(
  229. self,
  230. layout_results: List[Dict[str, Any]],
  231. image: Union[np.ndarray, Image.Image],
  232. config: Optional[Dict[str, Any]] = None
  233. ) -> List[Dict[str, Any]]:
  234. """
  235. 后处理布局检测结果
  236. 默认实现包括:
  237. 1. 去除重叠框
  238. 2. 将大面积文本块转换为表格(如果配置启用)
  239. 子类可以重写此方法以自定义后处理逻辑
  240. Args:
  241. layout_results: 原始检测结果
  242. image: 输入图像
  243. config: 后处理配置(可选),如果为None则使用self.config中的post_process配置
  244. Returns:
  245. 后处理后的布局结果
  246. """
  247. if not layout_results:
  248. return layout_results
  249. # 获取配置
  250. if config is None:
  251. config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
  252. # 导入 CoordinateUtils(适配器可以访问)
  253. try:
  254. from ocr_utils.coordinate_utils import CoordinateUtils
  255. except ImportError:
  256. try:
  257. from ocr_utils import CoordinateUtils
  258. except ImportError:
  259. # 如果无法导入,返回原始结果
  260. return layout_results
  261. # 1. 去除重叠框
  262. layout_results_removed_overlapping = self._remove_overlapping_boxes(layout_results, CoordinateUtils)
  263. # 2. 将大面积文本块转换为表格(如果配置启用)
  264. layout_config = config if config is not None else {}
  265. if layout_config.get('convert_large_text_to_table', False):
  266. # 获取图像尺寸
  267. if isinstance(image, Image.Image):
  268. h, w = image.size[1], image.size[0]
  269. else:
  270. h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
  271. layout_results_converted_large_text = self._convert_large_text_to_table(
  272. layout_results_removed_overlapping,
  273. (h, w),
  274. min_area_ratio=layout_config.get('min_text_area_ratio', 0.25),
  275. min_width_ratio=layout_config.get('min_text_width_ratio', 0.4),
  276. min_height_ratio=layout_config.get('min_text_height_ratio', 0.3)
  277. )
  278. return layout_results_converted_large_text
  279. else:
  280. return layout_results_removed_overlapping
  281. def _convert_large_text_to_table(
  282. self,
  283. layout_results: List[Dict[str, Any]],
  284. image_shape: Tuple[int, int],
  285. min_area_ratio: float = 0.25,
  286. min_width_ratio: float = 0.4,
  287. min_height_ratio: float = 0.3
  288. ) -> List[Dict[str, Any]]:
  289. """
  290. 将大面积的文本块转换为表格
  291. 判断规则:
  292. 1. 面积占比:占页面面积超过 min_area_ratio(默认25%)
  293. 2. 尺寸比例:宽度和高度都超过一定比例(避免细长条)
  294. 3. 不与其他表格重叠:如果已有表格,不转换
  295. """
  296. if not layout_results:
  297. return layout_results
  298. img_height, img_width = image_shape
  299. img_area = img_height * img_width
  300. if img_area == 0:
  301. return layout_results
  302. # 检查是否已有表格
  303. has_table = any(
  304. item.get('category', '').lower() in ['table', 'table_body']
  305. for item in layout_results
  306. )
  307. # 如果已有表格,不进行转换(避免误判)
  308. if has_table:
  309. return layout_results
  310. # 复制列表避免修改原数据
  311. results = [item.copy() for item in layout_results]
  312. converted_count = 0
  313. for item in results:
  314. category = item.get('category', '').lower()
  315. # 只处理文本类型的元素
  316. if category not in ['text', 'ocr_text']:
  317. continue
  318. bbox = item.get('bbox', [0, 0, 0, 0])
  319. if len(bbox) < 4:
  320. continue
  321. x1, y1, x2, y2 = bbox[:4]
  322. width = x2 - x1
  323. height = y2 - y1
  324. area = width * height
  325. # 计算占比
  326. area_ratio = area / img_area if img_area > 0 else 0
  327. width_ratio = width / img_width if img_width > 0 else 0
  328. height_ratio = height / img_height if img_height > 0 else 0
  329. # 判断是否满足转换条件
  330. if (area_ratio >= min_area_ratio and
  331. width_ratio >= min_width_ratio and
  332. height_ratio >= min_height_ratio):
  333. # 转换为表格
  334. item['category'] = 'table'
  335. item['original_category'] = category # 保留原始类别
  336. converted_count += 1
  337. return results
  338. def _map_category_id(self, category_id: int) -> str:
  339. """映射类别ID到字符串"""
  340. category_map = {
  341. 0: 'title',
  342. 1: 'text',
  343. 2: 'abandon',
  344. 3: 'image_body',
  345. 4: 'image_caption',
  346. 5: 'table_body',
  347. 6: 'table_caption',
  348. 7: 'table_footnote',
  349. 8: 'interline_equation',
  350. 9: 'interline_equation_number',
  351. 13: 'inline_equation',
  352. 14: 'interline_equation_yolo',
  353. 15: 'ocr_text',
  354. 16: 'low_score_text',
  355. 101: 'image_footnote'
  356. }
  357. return category_map.get(category_id, f'unknown_{category_id}')
  358. def _layout_debug_options(self) -> Dict[str, Any]:
  359. opts = self.config.get('debug_options', {})
  360. return opts if isinstance(opts, dict) else {}
  361. def _is_layout_debug_enabled(self) -> bool:
  362. debug_mode = getattr(self, 'debug_mode', None)
  363. if debug_mode is not None:
  364. return bool(debug_mode)
  365. if self.config.get('debug_mode', False):
  366. return True
  367. return bool(self._layout_debug_options().get('enabled', False))
  368. def _resolve_layout_debug_paths(self) -> Tuple[Optional[str], str]:
  369. output_dir = getattr(self, 'output_dir', None)
  370. if output_dir is None:
  371. output_dir = self.config.get('output_dir')
  372. if output_dir is None:
  373. output_dir = self._layout_debug_options().get('output_dir')
  374. page_name = getattr(self, 'page_name', None)
  375. if page_name is None:
  376. page_name = self.config.get('page_name')
  377. if not page_name:
  378. prefix = self._layout_debug_options().get('prefix', '')
  379. page_name = prefix if prefix else 'layout_detection'
  380. return output_dir, str(page_name)
  381. def _visualize_layout_results(
  382. self,
  383. image: Union[np.ndarray, Image.Image],
  384. layout_results: List[Dict[str, Any]],
  385. output_dir: str,
  386. page_name: str,
  387. suffix: str = 'raw',
  388. ) -> None:
  389. """保存 layout 模块 debug(底图为 inference / detection 输入)。"""
  390. if not layout_results:
  391. return
  392. from ocr_utils.module_debug_viz import save_layout_debug
  393. opts = self._layout_debug_options()
  394. save_layout_debug(
  395. image,
  396. layout_results,
  397. output_dir,
  398. page_name,
  399. suffix=suffix,
  400. subdir=opts.get('subdir', 'layout_detection'),
  401. image_format=opts.get('image_format', 'jpg'),
  402. save_json=bool(opts.get('save_json', True)),
  403. )
  404. def _remove_overlapping_boxes(
  405. self,
  406. layout_results: List[Dict[str, Any]],
  407. coordinate_utils: Any,
  408. iou_threshold: float = 0.8,
  409. overlap_ratio_threshold: float = 0.8
  410. ) -> List[Dict[str, Any]]:
  411. """
  412. 改进版重叠框处理算法(基于优先级和决策规则的清晰算法)
  413. 策略:
  414. 1. 定义类别优先级(abandon < text/image < table_body)
  415. 2. 使用统一的决策规则
  416. 3. 按综合评分排序处理,优先保留大的聚合框
  417. Args:
  418. layout_results: 布局检测结果
  419. coordinate_utils: 坐标工具类
  420. iou_threshold: IoU阈值(默认0.8)
  421. overlap_ratio_threshold: 重叠比例阈值(默认0.8)
  422. Returns:
  423. 去重后的布局结果
  424. """
  425. if not layout_results or len(layout_results) <= 1:
  426. return layout_results
  427. # 常量定义
  428. CATEGORY_PRIORITY = {
  429. 'abandon': 0,
  430. 'text': 1,
  431. 'image_body': 1,
  432. 'title': 2,
  433. 'footer': 2,
  434. 'header': 2,
  435. 'table_body': 3,
  436. }
  437. AGGREGATE_LABELS = {'key-value region', 'form'}
  438. MAX_AREA = 4000000.0 # 用于面积归一化
  439. AREA_WEIGHT = 0.5
  440. CONFIDENCE_WEIGHT = 0.5
  441. AGGREGATE_BONUS = 0.1
  442. AREA_RATIO_THRESHOLD = 3.0 # 大框面积需大于小框的倍数
  443. def get_bbox_area(bbox: List[float]) -> float:
  444. """计算bbox面积"""
  445. if len(bbox) < 4:
  446. return 0.0
  447. return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
  448. def is_aggregate_type(box: Dict[str, Any]) -> bool:
  449. """检查是否是聚合类型"""
  450. original_label = box.get('raw', {}).get('original_label', '').lower()
  451. return original_label in AGGREGATE_LABELS
  452. def is_bbox_inside(inner: List[float], outer: List[float]) -> bool:
  453. """检查inner是否完全包含在outer内"""
  454. if len(inner) < 4 or len(outer) < 4:
  455. return False
  456. return (inner[0] >= outer[0] and inner[1] >= outer[1] and
  457. inner[2] <= outer[2] and inner[3] <= outer[3])
  458. def calculate_composite_score(box: Dict[str, Any], area: float) -> float:
  459. """计算text类型的综合评分(面积+置信度)"""
  460. if box.get('category') != 'text':
  461. return box.get('confidence', box.get('score', 0))
  462. normalized_area = min(area / MAX_AREA, 1.0)
  463. area_score = (normalized_area ** 0.5) * AREA_WEIGHT
  464. confidence_score = box.get('confidence', box.get('score', 0)) * CONFIDENCE_WEIGHT
  465. bonus = AGGREGATE_BONUS if is_aggregate_type(box) else 0.0
  466. return area_score + confidence_score + bonus
  467. def should_keep_box1(box1: Dict[str, Any], box2: Dict[str, Any],
  468. iou: float, overlap_ratio: float,
  469. contained_1_in_2: bool, contained_2_in_1: bool) -> bool:
  470. """判断是否应该保留box1"""
  471. # 提取基本信息
  472. cat1, cat2 = box1.get('category', ''), box2.get('category', '')
  473. score1 = box1.get('confidence', box1.get('score', 0))
  474. score2 = box2.get('confidence', box2.get('score', 0))
  475. bbox1, bbox2 = box1.get('bbox', [0, 0, 0, 0]), box2.get('bbox', [0, 0, 0, 0])
  476. area1, area2 = get_bbox_area(bbox1), get_bbox_area(bbox2)
  477. is_agg1, is_agg2 = is_aggregate_type(box1), is_aggregate_type(box2)
  478. # 规则1: 类别优先级
  479. priority1 = CATEGORY_PRIORITY.get(cat1, 1)
  480. priority2 = CATEGORY_PRIORITY.get(cat2, 1)
  481. if priority1 != priority2:
  482. return priority1 > priority2
  483. # 规则2: 包含关系 + 聚合类型优先
  484. if contained_2_in_1 and is_agg1 and not is_agg2:
  485. return True
  486. if contained_1_in_2 and is_agg2 and not is_agg1:
  487. return False
  488. # 规则3: 包含关系 + 面积比例
  489. if contained_2_in_1 and area1 > area2 * AREA_RATIO_THRESHOLD:
  490. return True
  491. if contained_1_in_2 and area2 > area1 * AREA_RATIO_THRESHOLD:
  492. return False
  493. # 规则4: text类型使用综合评分
  494. if cat1 == 'text' or cat2 == 'text':
  495. comp_score1 = calculate_composite_score(box1, area1)
  496. comp_score2 = calculate_composite_score(box2, area2)
  497. if abs(comp_score1 - comp_score2) > 0.05:
  498. return comp_score1 > comp_score2
  499. # 规则5: 置信度比较
  500. if abs(score1 - score2) > 0.1:
  501. return score1 > score2
  502. # 规则6: 面积比较
  503. return area1 >= area2
  504. # 主处理逻辑
  505. results = [item.copy() for item in layout_results]
  506. need_remove = set()
  507. # 按综合评分排序(高分优先)
  508. def get_sort_key(i: int) -> float:
  509. item = results[i]
  510. if item.get('category') == 'text':
  511. return -calculate_composite_score(item, get_bbox_area(item.get('bbox', [])))
  512. return -item.get('confidence', item.get('score', 0))
  513. sorted_indices = sorted(range(len(results)), key=get_sort_key)
  514. # 比较每对框
  515. for idx_i, i in enumerate(sorted_indices):
  516. if i in need_remove:
  517. continue
  518. for idx_j, j in enumerate(sorted_indices):
  519. if j == i or j in need_remove or idx_j >= idx_i:
  520. continue
  521. bbox1, bbox2 = results[i].get('bbox', []), results[j].get('bbox', [])
  522. if len(bbox1) < 4 or len(bbox2) < 4:
  523. continue
  524. # 计算重叠指标
  525. iou = coordinate_utils.calculate_iou(bbox1, bbox2)
  526. overlap_ratio = coordinate_utils.calculate_overlap_ratio(bbox1, bbox2)
  527. contained_1_in_2 = is_bbox_inside(bbox1, bbox2)
  528. contained_2_in_1 = is_bbox_inside(bbox2, bbox1)
  529. # 检查是否有显著重叠
  530. if not (iou > iou_threshold or overlap_ratio > overlap_ratio_threshold or
  531. contained_1_in_2 or contained_2_in_1):
  532. continue
  533. # 应用决策规则
  534. if should_keep_box1(results[i], results[j], iou, overlap_ratio,
  535. contained_1_in_2, contained_2_in_1):
  536. need_remove.add(j)
  537. else:
  538. need_remove.add(i)
  539. break
  540. return [results[i] for i in range(len(results)) if i not in need_remove]
  541. class BaseVLRecognizer(BaseAdapter):
  542. """VL识别器基类"""
  543. @abstractmethod
  544. def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
  545. """识别表格"""
  546. pass
  547. @abstractmethod
  548. def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
  549. """识别公式"""
  550. pass
  551. class BaseOCRRecognizer(BaseAdapter):
  552. """OCR识别器基类"""
  553. @abstractmethod
  554. def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
  555. """识别文本"""
  556. pass
  557. @abstractmethod
  558. def detect_text_boxes(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
  559. """
  560. 只检测文本框(不识别文字内容)
  561. 子类必须实现此方法。建议使用只运行检测模型的方式(不运行识别模型)以优化性能。
  562. 如果无法优化,至少实现一个调用 recognize_text() 的版本作为兜底。
  563. Returns:
  564. 文本框列表,每项包含 'bbox', 'poly',可能包含 'confidence'
  565. """
  566. pass