unet_table.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import html
  2. import os
  3. import time
  4. import traceback
  5. from dataclasses import dataclass, asdict
  6. from typing import List, Optional, Union, Dict, Any
  7. import cv2
  8. import numpy as np
  9. from loguru import logger
  10. from rapid_table import RapidTableInput, RapidTable
  11. from mineru.utils.enum_class import ModelPath
  12. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  13. from .table_structure_unet import TSRUnet
  14. from .table_recover import TableRecover
  15. from .wired_table_rec_utils import InputType, LoadImage
  16. from .table_recover_utils import (
  17. match_ocr_cell,
  18. plot_html_table,
  19. box_4_2_poly_to_box_4_1,
  20. sorted_ocr_boxes,
  21. gather_ocr_list_by_row,
  22. )
  23. @dataclass
  24. class UnetTableInput:
  25. model_path: str
  26. device: str = "cpu"
  27. @dataclass
  28. class UnetTableOutput:
  29. pred_html: Optional[str] = None
  30. cell_bboxes: Optional[np.ndarray] = None
  31. logic_points: Optional[np.ndarray] = None
  32. elapse: Optional[float] = None
  33. class UnetTableRecognition:
  34. def __init__(self, config: UnetTableInput):
  35. self.table_structure = TSRUnet(asdict(config))
  36. self.load_img = LoadImage()
  37. self.table_recover = TableRecover()
  38. def __call__(
  39. self,
  40. img: InputType,
  41. ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
  42. ocr_engine = None,
  43. **kwargs,
  44. ) -> UnetTableOutput:
  45. s = time.perf_counter()
  46. need_ocr = True
  47. col_threshold = 15
  48. row_threshold = 10
  49. if kwargs:
  50. need_ocr = kwargs.get("need_ocr", True)
  51. col_threshold = kwargs.get("col_threshold", 15)
  52. row_threshold = kwargs.get("row_threshold", 10)
  53. img = self.load_img(img)
  54. polygons, rotated_polygons = self.table_structure(img, **kwargs)
  55. if polygons is None:
  56. # logger.warning("polygons is None.")
  57. return UnetTableOutput("", None, None, 0.0)
  58. try:
  59. table_res, logi_points = self.table_recover(
  60. rotated_polygons, row_threshold, col_threshold
  61. )
  62. # 将坐标由逆时针转为顺时针方向,后续处理与无线表格对齐
  63. polygons[:, 1, :], polygons[:, 3, :] = (
  64. polygons[:, 3, :].copy(),
  65. polygons[:, 1, :].copy(),
  66. )
  67. if not need_ocr:
  68. sorted_polygons, idx_list = sorted_ocr_boxes(
  69. [box_4_2_poly_to_box_4_1(box) for box in polygons]
  70. )
  71. return UnetTableOutput(
  72. "",
  73. sorted_polygons,
  74. logi_points[idx_list],
  75. time.perf_counter() - s,
  76. )
  77. cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
  78. # 如果有识别框没有ocr结果,直接进行rec补充
  79. cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map, ocr_engine)
  80. # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
  81. t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
  82. # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
  83. t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
  84. # cell_box_map =
  85. logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
  86. cell_box_det_map = {
  87. i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
  88. for i, t_box_ocr in enumerate(t_rec_ocr_list)
  89. }
  90. pred_html = plot_html_table(logi_points, cell_box_det_map)
  91. polygons = np.array(polygons).reshape(-1, 8)
  92. logi_points = np.array(logi_points)
  93. elapse = time.perf_counter() - s
  94. except Exception:
  95. logger.warning(traceback.format_exc())
  96. return UnetTableOutput("", None, None, 0.0)
  97. return UnetTableOutput(pred_html, polygons, logi_points, elapse)
  98. def transform_res(
  99. self,
  100. cell_box_det_map: Dict[int, List[any]],
  101. polygons: np.ndarray,
  102. logi_points: List[np.ndarray],
  103. ) -> List[Dict[str, any]]:
  104. res = []
  105. for i in range(len(polygons)):
  106. ocr_res_list = cell_box_det_map.get(i)
  107. if not ocr_res_list:
  108. continue
  109. xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list])
  110. ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list])
  111. xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list])
  112. ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list])
  113. dict_res = {
  114. # xmin,xmax,ymin,ymax
  115. "t_box": [xmin, ymin, xmax, ymax],
  116. # row_start,row_end,col_start,col_end
  117. "t_logic_box": logi_points[i].tolist(),
  118. # [[xmin,xmax,ymin,ymax], text]
  119. "t_ocr_res": [
  120. [box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]]
  121. for ocr_det in ocr_res_list
  122. ],
  123. }
  124. res.append(dict_res)
  125. return res
  126. def sort_and_gather_ocr_res(self, res):
  127. for i, dict_res in enumerate(res):
  128. _, sorted_idx = sorted_ocr_boxes(
  129. [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threshold=0.3
  130. )
  131. dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
  132. dict_res["t_ocr_res"] = gather_ocr_list_by_row(
  133. dict_res["t_ocr_res"], threshold=0.3
  134. )
  135. return res
  136. def fill_blank_rec(
  137. self,
  138. img: np.ndarray,
  139. sorted_polygons: np.ndarray,
  140. cell_box_map: Dict[int, List[str]],
  141. ocr_engine
  142. ) -> Dict[int, List[Any]]:
  143. """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
  144. img_crop_info_list = []
  145. img_crop_list = []
  146. for i in range(sorted_polygons.shape[0]):
  147. if cell_box_map.get(i):
  148. continue
  149. box = sorted_polygons[i]
  150. if ocr_engine is None:
  151. logger.warning(f"No OCR engine provided for box {i}: {box}")
  152. continue
  153. # 从img中截取对应的区域
  154. x1, y1, x2, y2 = box[0][0], box[0][1], box[2][0], box[2][1]
  155. if x1 >= x2 or y1 >= y2:
  156. logger.warning(f"Invalid box coordinates: {box}")
  157. continue
  158. img_crop = img[int(y1):int(y2), int(x1):int(x2)]
  159. img_crop_list.append(img_crop)
  160. img_crop_info_list.append([i, box])
  161. continue
  162. if len(img_crop_list) > 0:
  163. # 进行ocr识别
  164. ocr_result = ocr_engine.ocr(img_crop_list, det=False)
  165. if not ocr_result or not isinstance(ocr_result, list) or len(ocr_result) == 0:
  166. logger.warning("OCR engine returned no results or invalid result for image crops.")
  167. return cell_box_map
  168. ocr_res_list = ocr_result[0]
  169. if not isinstance(ocr_res_list, list) or len(ocr_res_list) != len(img_crop_list):
  170. logger.warning("OCR result list length does not match image crop list length.")
  171. return cell_box_map
  172. for j, ocr_res in enumerate(ocr_res_list):
  173. img_crop_info_list[j].append(ocr_res)
  174. for i, box, ocr_res in img_crop_info_list:
  175. # 处理ocr结果
  176. ocr_text, ocr_score = ocr_res
  177. # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
  178. if ocr_score < 0.9:
  179. # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
  180. continue
  181. cell_box_map[i] = [[box, ocr_text, ocr_score]]
  182. return cell_box_map
  183. def escape_html(input_string):
  184. """Escape HTML Entities."""
  185. return html.escape(input_string)
  186. class UnetTableModel:
  187. def __init__(self, ocr_engine):
  188. model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
  189. wired_input_args = UnetTableInput(model_path=model_path)
  190. self.wired_table_model = UnetTableRecognition(wired_input_args)
  191. slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
  192. wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
  193. self.wireless_table_model = RapidTable(wireless_input_args)
  194. self.ocr_engine = ocr_engine
  195. def predict(self, img, table_cls_score):
  196. bgr_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
  197. ocr_result = self.ocr_engine.ocr(bgr_img)[0]
  198. if ocr_result:
  199. ocr_result = [
  200. [item[0], escape_html(item[1][0]), item[1][1]]
  201. for item in ocr_result
  202. if len(item) == 2 and isinstance(item[1], tuple)
  203. ]
  204. else:
  205. ocr_result = None
  206. if ocr_result:
  207. try:
  208. wired_table_results = self.wired_table_model(np.asarray(img), ocr_result, self.ocr_engine)
  209. wired_html_code = wired_table_results.pred_html
  210. wired_table_cell_bboxes = wired_table_results.cell_bboxes
  211. wired_logic_points = wired_table_results.logic_points
  212. wired_elapse = wired_table_results.elapse
  213. wireless_table_results = self.wireless_table_model(np.asarray(img), ocr_result)
  214. wireless_html_code = wireless_table_results.pred_html
  215. wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
  216. wireless_logic_points = wireless_table_results.logic_points
  217. wireless_elapse = wireless_table_results.elapse
  218. wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
  219. wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
  220. # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
  221. # 计算两种模型检测的单元格数量差异
  222. gap_of_len = wireless_len - wired_len
  223. # 判断是否使用无线表格模型的结果
  224. if (
  225. wired_len <= round(wireless_len * 0.5) # 有线模型检测到的单元格数太少(低于无线模型的50%)
  226. or ((wireless_len < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.949) # 有线模型检测到的单元格数反而更多
  227. or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75)) # 两者相差不大但有线模型结果较少
  228. or (gap_of_len == 0 and wired_len <= 4) # 单元格数量完全相等且总量小于等于4
  229. ):
  230. # logger.debug("fall back to wireless table model")
  231. html_code = wireless_html_code
  232. table_cell_bboxes = wireless_table_cell_bboxes
  233. logic_points = wireless_logic_points
  234. else:
  235. html_code = wired_html_code
  236. table_cell_bboxes = wired_table_cell_bboxes
  237. logic_points = wired_logic_points
  238. elapse = wired_elapse + wireless_elapse
  239. return html_code, table_cell_bboxes, logic_points, elapse
  240. except Exception as e:
  241. logger.exception(e)
  242. return None, None, None, None