unet_table.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import html
  2. import os
  3. import time
  4. import traceback
  5. from dataclasses import dataclass, asdict
  6. from typing import List, Optional, Union, Dict, Any
  7. import cv2
  8. import numpy as np
  9. from loguru import logger
  10. from mineru.utils.enum_class import ModelPath
  11. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  12. from .table_structure_unet import TSRUnet
  13. from .table_recover import TableRecover
  14. from .wired_table_rec_utils import InputType, LoadImage
  15. from .table_recover_utils import (
  16. match_ocr_cell,
  17. plot_html_table,
  18. box_4_2_poly_to_box_4_1,
  19. sorted_ocr_boxes,
  20. gather_ocr_list_by_row,
  21. )
  22. @dataclass
  23. class UnetTableInput:
  24. model_path: str
  25. device: str = "cpu"
  26. @dataclass
  27. class UnetTableOutput:
  28. pred_html: Optional[str] = None
  29. cell_bboxes: Optional[np.ndarray] = None
  30. logic_points: Optional[np.ndarray] = None
  31. elapse: Optional[float] = None
  32. class UnetTableRecognition:
  33. def __init__(self, config: UnetTableInput):
  34. self.table_structure = TSRUnet(asdict(config))
  35. self.load_img = LoadImage()
  36. self.table_recover = TableRecover()
  37. def __call__(
  38. self,
  39. img: InputType,
  40. ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
  41. **kwargs,
  42. ) -> UnetTableOutput:
  43. s = time.perf_counter()
  44. need_ocr = True
  45. col_threshold = 15
  46. row_threshold = 10
  47. if kwargs:
  48. need_ocr = kwargs.get("need_ocr", True)
  49. col_threshold = kwargs.get("col_threshold", 15)
  50. row_threshold = kwargs.get("row_threshold", 10)
  51. img = self.load_img(img)
  52. polygons, rotated_polygons = self.table_structure(img, **kwargs)
  53. if polygons is None:
  54. logger.warning("polygons is None.")
  55. return UnetTableOutput("", None, None, 0.0)
  56. try:
  57. table_res, logi_points = self.table_recover(
  58. rotated_polygons, row_threshold, col_threshold
  59. )
  60. # 将坐标由逆时针转为顺时针方向,后续处理与无线表格对齐
  61. polygons[:, 1, :], polygons[:, 3, :] = (
  62. polygons[:, 3, :].copy(),
  63. polygons[:, 1, :].copy(),
  64. )
  65. if not need_ocr:
  66. sorted_polygons, idx_list = sorted_ocr_boxes(
  67. [box_4_2_poly_to_box_4_1(box) for box in polygons]
  68. )
  69. return UnetTableOutput(
  70. "",
  71. sorted_polygons,
  72. logi_points[idx_list],
  73. time.perf_counter() - s,
  74. )
  75. cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
  76. # 如果有识别框没有ocr结果,直接进行rec补充
  77. cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
  78. # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
  79. t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
  80. # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
  81. t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
  82. # cell_box_map =
  83. logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
  84. cell_box_det_map = {
  85. i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
  86. for i, t_box_ocr in enumerate(t_rec_ocr_list)
  87. }
  88. pred_html = plot_html_table(logi_points, cell_box_det_map)
  89. polygons = np.array(polygons).reshape(-1, 8)
  90. logi_points = np.array(logi_points)
  91. elapse = time.perf_counter() - s
  92. except Exception:
  93. logger.warning(traceback.format_exc())
  94. return UnetTableOutput("", None, None, 0.0)
  95. return UnetTableOutput(pred_html, polygons, logi_points, elapse)
  96. def transform_res(
  97. self,
  98. cell_box_det_map: Dict[int, List[any]],
  99. polygons: np.ndarray,
  100. logi_points: List[np.ndarray],
  101. ) -> List[Dict[str, any]]:
  102. res = []
  103. for i in range(len(polygons)):
  104. ocr_res_list = cell_box_det_map.get(i)
  105. if not ocr_res_list:
  106. continue
  107. xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list])
  108. ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list])
  109. xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list])
  110. ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list])
  111. dict_res = {
  112. # xmin,xmax,ymin,ymax
  113. "t_box": [xmin, ymin, xmax, ymax],
  114. # row_start,row_end,col_start,col_end
  115. "t_logic_box": logi_points[i].tolist(),
  116. # [[xmin,xmax,ymin,ymax], text]
  117. "t_ocr_res": [
  118. [box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]]
  119. for ocr_det in ocr_res_list
  120. ],
  121. }
  122. res.append(dict_res)
  123. return res
  124. def sort_and_gather_ocr_res(self, res):
  125. for i, dict_res in enumerate(res):
  126. _, sorted_idx = sorted_ocr_boxes(
  127. [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threshold=0.3
  128. )
  129. dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
  130. dict_res["t_ocr_res"] = gather_ocr_list_by_row(
  131. dict_res["t_ocr_res"], threshold=0.3
  132. )
  133. return res
  134. def fill_blank_rec(
  135. self,
  136. img: np.ndarray,
  137. sorted_polygons: np.ndarray,
  138. cell_box_map: Dict[int, List[str]],
  139. ) -> Dict[int, List[Any]]:
  140. """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
  141. for i in range(sorted_polygons.shape[0]):
  142. if cell_box_map.get(i):
  143. continue
  144. box = sorted_polygons[i]
  145. cell_box_map[i] = [[box, "", 1]]
  146. continue
  147. return cell_box_map
  148. def escape_html(input_string):
  149. """Escape HTML Entities."""
  150. return html.escape(input_string)
  151. class UnetTableModel:
  152. def __init__(self, ocr_engine):
  153. model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
  154. input_args = UnetTableInput(model_path=model_path)
  155. self.table_model = UnetTableRecognition(input_args)
  156. self.ocr_engine = ocr_engine
  157. def predict(self, img):
  158. bgr_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
  159. ocr_result = self.ocr_engine.ocr(bgr_img)[0]
  160. if ocr_result:
  161. ocr_result = [
  162. [item[0], escape_html(item[1][0]), item[1][1]]
  163. for item in ocr_result
  164. if len(item) == 2 and isinstance(item[1], tuple)
  165. ]
  166. else:
  167. ocr_result = None
  168. if ocr_result:
  169. try:
  170. table_results = self.table_model(np.asarray(img), ocr_result)
  171. html_code = table_results.pred_html
  172. table_cell_bboxes = table_results.cell_bboxes
  173. logic_points = table_results.logic_points
  174. elapse = table_results.elapse
  175. return html_code, table_cell_bboxes, logic_points, elapse
  176. except Exception as e:
  177. logger.exception(e)
  178. return None, None, None, None