main.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import os
  2. import argparse
  3. import copy
  4. import importlib
  5. import time
  6. import html
  7. from dataclasses import asdict, dataclass
  8. from enum import Enum
  9. from pathlib import Path
  10. from typing import Dict, List, Optional, Tuple, Union
  11. import cv2
  12. import numpy as np
  13. from loguru import logger
  14. from tqdm import tqdm
  15. from .matcher import TableMatch
  16. from .table_structure import TableStructurer
  17. from mineru.utils.enum_class import ModelPath
  18. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  19. root_dir = Path(__file__).resolve().parent
  20. class ModelType(Enum):
  21. PPSTRUCTURE_EN = "ppstructure_en"
  22. PPSTRUCTURE_ZH = "ppstructure_zh"
  23. SLANETPLUS = "slanet_plus"
  24. UNITABLE = "unitable"
  25. ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
  26. KEY_TO_MODEL_URL = {
  27. ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
  28. ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
  29. ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
  30. ModelType.UNITABLE.value: {
  31. "encoder": f"{ROOT_URL}/unitable/encoder.pth",
  32. "decoder": f"{ROOT_URL}/unitable/decoder.pth",
  33. "vocab": f"{ROOT_URL}/unitable/vocab.json",
  34. },
  35. }
  36. @dataclass
  37. class RapidTableInput:
  38. model_type: Optional[str] = ModelType.SLANETPLUS.value
  39. model_path: Union[str, Path, None, Dict[str, str]] = None
  40. use_cuda: bool = False
  41. device: str = "cpu"
  42. @dataclass
  43. class RapidTableOutput:
  44. pred_html: Optional[str] = None
  45. cell_bboxes: Optional[np.ndarray] = None
  46. logic_points: Optional[np.ndarray] = None
  47. elapse: Optional[float] = None
  48. class RapidTable:
  49. def __init__(self, config: RapidTableInput):
  50. self.model_type = config.model_type
  51. if self.model_type not in KEY_TO_MODEL_URL:
  52. model_list = ",".join(KEY_TO_MODEL_URL)
  53. raise ValueError(
  54. f"{self.model_type} is not supported. The currently supported models are {model_list}."
  55. )
  56. config.model_path = config.model_path
  57. if self.model_type == ModelType.SLANETPLUS.value:
  58. self.table_structure = TableStructurer(asdict(config))
  59. else:
  60. raise ValueError(f"{self.model_type} is not supported.")
  61. self.table_matcher = TableMatch()
  62. def predict(
  63. self,
  64. img: np.ndarray,
  65. ocr_result: List[Union[List[List[float]], str, str]] = None,
  66. ) -> RapidTableOutput:
  67. if ocr_result is None:
  68. raise ValueError("OCR result is None")
  69. s = time.perf_counter()
  70. h, w = img.shape[:2]
  71. dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
  72. pred_structures, cell_bboxes, _ = self.table_structure.process(
  73. copy.deepcopy(img)
  74. )
  75. # 适配slanet-plus模型输出的box缩放还原
  76. cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
  77. pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
  78. # 过滤掉占位的bbox
  79. mask = ~np.all(cell_bboxes == 0, axis=1)
  80. cell_bboxes = cell_bboxes[mask]
  81. logic_points = self.table_matcher.decode_logic_points(pred_structures)
  82. elapse = time.perf_counter() - s
  83. return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
  84. def batch_predict(
  85. self,
  86. images: List[np.ndarray],
  87. ocr_results: List[List[Union[List[List[float]], str, str]]],
  88. batch_size: int = 4,
  89. ) -> List[RapidTableOutput]:
  90. """批量处理图像"""
  91. s = time.perf_counter()
  92. batch_dt_boxes = []
  93. batch_rec_res = []
  94. for i, img in enumerate(images):
  95. h, w = img.shape[:2]
  96. dt_boxes, rec_res = self.get_boxes_recs(ocr_results[i], h, w)
  97. batch_dt_boxes.append(dt_boxes)
  98. batch_rec_res.append(rec_res)
  99. # 批量表格结构识别
  100. batch_results = self.table_structure.batch_process(images)
  101. output_results = []
  102. for i, (img, ocr_result, (pred_structures, cell_bboxes, _)) in enumerate(
  103. zip(images, ocr_results, batch_results)
  104. ):
  105. # 适配slanet-plus模型输出的box缩放还原
  106. cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
  107. pred_html = self.table_matcher(
  108. pred_structures, cell_bboxes, batch_dt_boxes[i], batch_rec_res[i]
  109. )
  110. # 过滤掉占位的bbox
  111. mask = ~np.all(cell_bboxes == 0, axis=1)
  112. cell_bboxes = cell_bboxes[mask]
  113. logic_points = self.table_matcher.decode_logic_points(pred_structures)
  114. result = RapidTableOutput(pred_html, cell_bboxes, logic_points, 0)
  115. output_results.append(result)
  116. total_elapse = time.perf_counter() - s
  117. for result in output_results:
  118. result.elapse = total_elapse / len(output_results)
  119. return output_results
  120. def get_boxes_recs(
  121. self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
  122. ) -> Tuple[np.ndarray, Tuple[str, str]]:
  123. dt_boxes, rec_res, scores = list(zip(*ocr_result))
  124. rec_res = list(zip(rec_res, scores))
  125. r_boxes = []
  126. for box in dt_boxes:
  127. box = np.array(box)
  128. x_min = max(0, box[:, 0].min() - 1)
  129. x_max = min(w, box[:, 0].max() + 1)
  130. y_min = max(0, box[:, 1].min() - 1)
  131. y_max = min(h, box[:, 1].max() + 1)
  132. box = [x_min, y_min, x_max, y_max]
  133. r_boxes.append(box)
  134. dt_boxes = np.array(r_boxes)
  135. return dt_boxes, rec_res
  136. def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
  137. h, w = img.shape[:2]
  138. resized = 488
  139. ratio = min(resized / h, resized / w)
  140. w_ratio = resized / (w * ratio)
  141. h_ratio = resized / (h * ratio)
  142. cell_bboxes[:, 0::2] *= w_ratio
  143. cell_bboxes[:, 1::2] *= h_ratio
  144. return cell_bboxes
  145. def parse_args(arg_list: Optional[List[str]] = None):
  146. parser = argparse.ArgumentParser()
  147. parser.add_argument(
  148. "-v",
  149. "--vis",
  150. action="store_true",
  151. default=False,
  152. help="Wheter to visualize the layout results.",
  153. )
  154. parser.add_argument(
  155. "-img", "--img_path", type=str, required=True, help="Path to image for layout."
  156. )
  157. parser.add_argument(
  158. "-m",
  159. "--model_type",
  160. type=str,
  161. default=ModelType.SLANETPLUS.value,
  162. choices=list(KEY_TO_MODEL_URL),
  163. )
  164. args = parser.parse_args(arg_list)
  165. return args
  166. def escape_html(input_string):
  167. """Escape HTML Entities."""
  168. return html.escape(input_string)
  169. class RapidTableModel(object):
  170. def __init__(self, ocr_engine):
  171. slanet_plus_model_path = os.path.join(
  172. auto_download_and_get_model_root_path(ModelPath.slanet_plus),
  173. ModelPath.slanet_plus,
  174. )
  175. input_args = RapidTableInput(
  176. model_type="slanet_plus", model_path=slanet_plus_model_path
  177. )
  178. self.table_model = RapidTable(input_args)
  179. self.ocr_engine = ocr_engine
  180. def predict(self, image, table_cls_score):
  181. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  182. # Continue with OCR on potentially rotated image
  183. ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  184. if ocr_result:
  185. ocr_result = [
  186. [item[0], escape_html(item[1][0]), item[1][1]]
  187. for item in ocr_result
  188. if len(item) == 2 and isinstance(item[1], tuple)
  189. ]
  190. else:
  191. ocr_result = None
  192. if ocr_result:
  193. try:
  194. table_results = self.table_model.predict(np.asarray(image), ocr_result)
  195. html_code = table_results.pred_html
  196. table_cell_bboxes = table_results.cell_bboxes
  197. logic_points = table_results.logic_points
  198. elapse = table_results.elapse
  199. return html_code, table_cell_bboxes, logic_points, elapse
  200. except Exception as e:
  201. logger.exception(e)
  202. return None, None, None, None
  203. def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
  204. """对传入的字典列表进行批量预测,无返回值"""
  205. for index in tqdm(
  206. range(0, len(table_res_list), batch_size),
  207. desc=f"Wireless Table Batch Predict, total={len(table_res_list)}, batch_size={batch_size}",
  208. ):
  209. batch_imgs = [
  210. cv2.cvtColor(np.asarray(table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
  211. for i in range(index, min(index + batch_size, len(table_res_list)))
  212. ]
  213. batch_ocrs = [
  214. table_res_list[i]["ocr_result"]
  215. for i in range(index, min(index + batch_size, len(table_res_list)))
  216. ]
  217. results = self.table_model.batch_predict(
  218. batch_imgs, batch_ocrs, batch_size=batch_size
  219. )
  220. for i, result in enumerate(results):
  221. if result.pred_html:
  222. # 检查html_code是否包含'<table>'和'</table>'
  223. if '<table>' in result.pred_html and '</table>' in result.pred_html:
  224. # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
  225. start_index = result.pred_html.find('<table>')
  226. end_index = result.pred_html.rfind('</table>') + len('</table>')
  227. table_res_list[index + i]['table_res']['html'] = result.pred_html[start_index:end_index]
  228. else:
  229. logger.warning(
  230. 'table recognition processing fails, not found expected HTML table end'
  231. )
  232. else:
  233. logger.warning(
  234. "table recognition processing fails, not get html return"
  235. )