main.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import os
  2. import copy
  3. import time
  4. import html
  5. from dataclasses import asdict, dataclass
  6. from pathlib import Path
  7. from typing import Dict, List, Optional, Tuple, Union
  8. import cv2
  9. import numpy as np
  10. from loguru import logger
  11. from tqdm import tqdm
  12. from .matcher import TableMatch
  13. from .table_structure import TableStructurer
  14. from mineru.utils.enum_class import ModelPath
  15. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  16. @dataclass
  17. class RapidTableInput:
  18. model_type: Optional[str] = "slanet_plus"
  19. model_path: Union[str, Path, None, Dict[str, str]] = None
  20. use_cuda: bool = False
  21. device: str = "cpu"
  22. @dataclass
  23. class RapidTableOutput:
  24. pred_html: Optional[str] = None
  25. cell_bboxes: Optional[np.ndarray] = None
  26. logic_points: Optional[np.ndarray] = None
  27. elapse: Optional[float] = None
  28. class RapidTable:
  29. def __init__(self, config: RapidTableInput):
  30. self.table_structure = TableStructurer(asdict(config))
  31. self.table_matcher = TableMatch()
  32. def predict(
  33. self,
  34. img: np.ndarray,
  35. ocr_result: List[Union[List[List[float]], str, str]] = None,
  36. ) -> RapidTableOutput:
  37. if ocr_result is None:
  38. raise ValueError("OCR result is None")
  39. s = time.perf_counter()
  40. h, w = img.shape[:2]
  41. dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
  42. pred_structures, cell_bboxes, _ = self.table_structure.process(
  43. copy.deepcopy(img)
  44. )
  45. # 适配slanet-plus模型输出的box缩放还原
  46. cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
  47. pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
  48. # 过滤掉占位的bbox
  49. mask = ~np.all(cell_bboxes == 0, axis=1)
  50. cell_bboxes = cell_bboxes[mask]
  51. logic_points = self.table_matcher.decode_logic_points(pred_structures)
  52. elapse = time.perf_counter() - s
  53. return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
  54. def batch_predict(
  55. self,
  56. images: List[np.ndarray],
  57. ocr_results: List[List[Union[List[List[float]], str, str]]],
  58. batch_size: int = 4,
  59. ) -> List[RapidTableOutput]:
  60. """批量处理图像"""
  61. s = time.perf_counter()
  62. batch_dt_boxes = []
  63. batch_rec_res = []
  64. for i, img in enumerate(images):
  65. h, w = img.shape[:2]
  66. dt_boxes, rec_res = self.get_boxes_recs(ocr_results[i], h, w)
  67. batch_dt_boxes.append(dt_boxes)
  68. batch_rec_res.append(rec_res)
  69. # 批量表格结构识别
  70. batch_results = self.table_structure.batch_process(images)
  71. output_results = []
  72. for i, (img, ocr_result, (pred_structures, cell_bboxes, _)) in enumerate(
  73. zip(images, ocr_results, batch_results)
  74. ):
  75. # 适配slanet-plus模型输出的box缩放还原
  76. cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
  77. pred_html = self.table_matcher(
  78. pred_structures, cell_bboxes, batch_dt_boxes[i], batch_rec_res[i]
  79. )
  80. # 过滤掉占位的bbox
  81. mask = ~np.all(cell_bboxes == 0, axis=1)
  82. cell_bboxes = cell_bboxes[mask]
  83. logic_points = self.table_matcher.decode_logic_points(pred_structures)
  84. result = RapidTableOutput(pred_html, cell_bboxes, logic_points, 0)
  85. output_results.append(result)
  86. total_elapse = time.perf_counter() - s
  87. for result in output_results:
  88. result.elapse = total_elapse / len(output_results)
  89. return output_results
  90. def get_boxes_recs(
  91. self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
  92. ) -> Tuple[np.ndarray, Tuple[str, str]]:
  93. dt_boxes, rec_res, scores = list(zip(*ocr_result))
  94. rec_res = list(zip(rec_res, scores))
  95. r_boxes = []
  96. for box in dt_boxes:
  97. box = np.array(box)
  98. x_min = max(0, box[:, 0].min() - 1)
  99. x_max = min(w, box[:, 0].max() + 1)
  100. y_min = max(0, box[:, 1].min() - 1)
  101. y_max = min(h, box[:, 1].max() + 1)
  102. box = [x_min, y_min, x_max, y_max]
  103. r_boxes.append(box)
  104. dt_boxes = np.array(r_boxes)
  105. return dt_boxes, rec_res
  106. def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
  107. h, w = img.shape[:2]
  108. resized = 488
  109. ratio = min(resized / h, resized / w)
  110. w_ratio = resized / (w * ratio)
  111. h_ratio = resized / (h * ratio)
  112. cell_bboxes[:, 0::2] *= w_ratio
  113. cell_bboxes[:, 1::2] *= h_ratio
  114. return cell_bboxes
  115. def escape_html(input_string):
  116. """Escape HTML Entities."""
  117. return html.escape(input_string)
  118. class RapidTableModel(object):
  119. def __init__(self, ocr_engine):
  120. slanet_plus_model_path = os.path.join(
  121. auto_download_and_get_model_root_path(ModelPath.slanet_plus),
  122. ModelPath.slanet_plus,
  123. )
  124. input_args = RapidTableInput(
  125. model_type="slanet_plus", model_path=slanet_plus_model_path
  126. )
  127. self.table_model = RapidTable(input_args)
  128. self.ocr_engine = ocr_engine
  129. def predict(self, image, table_cls_score):
  130. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  131. # Continue with OCR on potentially rotated image
  132. ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  133. if ocr_result:
  134. ocr_result = [
  135. [item[0], escape_html(item[1][0]), item[1][1]]
  136. for item in ocr_result
  137. if len(item) == 2 and isinstance(item[1], tuple)
  138. ]
  139. else:
  140. ocr_result = None
  141. if ocr_result:
  142. try:
  143. table_results = self.table_model.predict(np.asarray(image), ocr_result)
  144. html_code = table_results.pred_html
  145. table_cell_bboxes = table_results.cell_bboxes
  146. logic_points = table_results.logic_points
  147. elapse = table_results.elapse
  148. return html_code, table_cell_bboxes, logic_points, elapse
  149. except Exception as e:
  150. logger.exception(e)
  151. return None, None, None, None
  152. def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
  153. """对传入的字典列表进行批量预测,无返回值"""
  154. not_none_table_res_list = []
  155. for table_res in table_res_list:
  156. if table_res.get("ocr_result", None):
  157. not_none_table_res_list.append(table_res)
  158. with tqdm(total=len(not_none_table_res_list), desc="Table-wireless Predict") as pbar:
  159. for index in range(0, len(not_none_table_res_list), batch_size):
  160. batch_imgs = [
  161. cv2.cvtColor(np.asarray(not_none_table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
  162. for i in range(index, min(index + batch_size, len(not_none_table_res_list)))
  163. ]
  164. batch_ocrs = [
  165. not_none_table_res_list[i]["ocr_result"]
  166. for i in range(index, min(index + batch_size, len(not_none_table_res_list)))
  167. ]
  168. results = self.table_model.batch_predict(
  169. batch_imgs, batch_ocrs, batch_size=batch_size
  170. )
  171. for i, result in enumerate(results):
  172. if result.pred_html:
  173. not_none_table_res_list[index + i]['table_res']['html'] = result.pred_html
  174. # 更新进度条
  175. pbar.update(len(results))