main.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import os
  2. import argparse
  3. import copy
  4. import importlib
  5. import time
  6. import html
  7. from dataclasses import asdict, dataclass
  8. from enum import Enum
  9. from pathlib import Path
  10. from typing import Dict, List, Optional, Tuple, Union
  11. import cv2
  12. import numpy as np
  13. from loguru import logger
  14. from .matcher import TableMatch
  15. from .table_structure import TableStructurer
  16. from mineru.utils.enum_class import ModelPath
  17. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  18. root_dir = Path(__file__).resolve().parent
  19. class ModelType(Enum):
  20. PPSTRUCTURE_EN = "ppstructure_en"
  21. PPSTRUCTURE_ZH = "ppstructure_zh"
  22. SLANETPLUS = "slanet_plus"
  23. UNITABLE = "unitable"
  24. ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
  25. KEY_TO_MODEL_URL = {
  26. ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
  27. ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
  28. ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
  29. ModelType.UNITABLE.value: {
  30. "encoder": f"{ROOT_URL}/unitable/encoder.pth",
  31. "decoder": f"{ROOT_URL}/unitable/decoder.pth",
  32. "vocab": f"{ROOT_URL}/unitable/vocab.json",
  33. },
  34. }
  35. @dataclass
  36. class RapidTableInput:
  37. model_type: Optional[str] = ModelType.SLANETPLUS.value
  38. model_path: Union[str, Path, None, Dict[str, str]] = None
  39. use_cuda: bool = False
  40. device: str = "cpu"
  41. @dataclass
  42. class RapidTableOutput:
  43. pred_html: Optional[str] = None
  44. cell_bboxes: Optional[np.ndarray] = None
  45. logic_points: Optional[np.ndarray] = None
  46. elapse: Optional[float] = None
  47. class RapidTable:
  48. def __init__(self, config: RapidTableInput):
  49. self.model_type = config.model_type
  50. if self.model_type not in KEY_TO_MODEL_URL:
  51. model_list = ",".join(KEY_TO_MODEL_URL)
  52. raise ValueError(
  53. f"{self.model_type} is not supported. The currently supported models are {model_list}."
  54. )
  55. config.model_path = config.model_path
  56. if self.model_type == ModelType.SLANETPLUS.value:
  57. self.table_structure = TableStructurer(asdict(config))
  58. else:
  59. raise ValueError(f"{self.model_type} is not supported.")
  60. self.table_matcher = TableMatch()
  61. try:
  62. self.ocr_engine = importlib.import_module("rapidocr").RapidOCR()
  63. except ModuleNotFoundError:
  64. self.ocr_engine = None
  65. def __call__(
  66. self,
  67. img: np.ndarray,
  68. ocr_result: List[Union[List[List[float]], str, str]] = None,
  69. ) -> RapidTableOutput:
  70. if self.ocr_engine is None and ocr_result is None:
  71. raise ValueError(
  72. "One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
  73. )
  74. s = time.perf_counter()
  75. h, w = img.shape[:2]
  76. if ocr_result is None:
  77. ocr_result = self.ocr_engine(img)
  78. ocr_result = list(
  79. zip(
  80. ocr_result.boxes,
  81. ocr_result.txts,
  82. ocr_result.scores,
  83. )
  84. )
  85. dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
  86. pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
  87. # 适配slanet-plus模型输出的box缩放还原
  88. if self.model_type == ModelType.SLANETPLUS.value:
  89. cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
  90. pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
  91. # 过滤掉占位的bbox
  92. mask = ~np.all(cell_bboxes == 0, axis=1)
  93. cell_bboxes = cell_bboxes[mask]
  94. logic_points = self.table_matcher.decode_logic_points(pred_structures)
  95. elapse = time.perf_counter() - s
  96. return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
  97. def get_boxes_recs(
  98. self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
  99. ) -> Tuple[np.ndarray, Tuple[str, str]]:
  100. dt_boxes, rec_res, scores = list(zip(*ocr_result))
  101. rec_res = list(zip(rec_res, scores))
  102. r_boxes = []
  103. for box in dt_boxes:
  104. box = np.array(box)
  105. x_min = max(0, box[:, 0].min() - 1)
  106. x_max = min(w, box[:, 0].max() + 1)
  107. y_min = max(0, box[:, 1].min() - 1)
  108. y_max = min(h, box[:, 1].max() + 1)
  109. box = [x_min, y_min, x_max, y_max]
  110. r_boxes.append(box)
  111. dt_boxes = np.array(r_boxes)
  112. return dt_boxes, rec_res
  113. def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
  114. h, w = img.shape[:2]
  115. resized = 488
  116. ratio = min(resized / h, resized / w)
  117. w_ratio = resized / (w * ratio)
  118. h_ratio = resized / (h * ratio)
  119. cell_bboxes[:, 0::2] *= w_ratio
  120. cell_bboxes[:, 1::2] *= h_ratio
  121. return cell_bboxes
  122. def parse_args(arg_list: Optional[List[str]] = None):
  123. parser = argparse.ArgumentParser()
  124. parser.add_argument(
  125. "-v",
  126. "--vis",
  127. action="store_true",
  128. default=False,
  129. help="Wheter to visualize the layout results.",
  130. )
  131. parser.add_argument(
  132. "-img", "--img_path", type=str, required=True, help="Path to image for layout."
  133. )
  134. parser.add_argument(
  135. "-m",
  136. "--model_type",
  137. type=str,
  138. default=ModelType.SLANETPLUS.value,
  139. choices=list(KEY_TO_MODEL_URL),
  140. )
  141. args = parser.parse_args(arg_list)
  142. return args
  143. def escape_html(input_string):
  144. """Escape HTML Entities."""
  145. return html.escape(input_string)
  146. class RapidTableModel(object):
  147. def __init__(self, ocr_engine):
  148. slanet_plus_model_path = os.path.join(
  149. auto_download_and_get_model_root_path(ModelPath.slanet_plus),
  150. ModelPath.slanet_plus,
  151. )
  152. input_args = RapidTableInput(
  153. model_type="slanet_plus", model_path=slanet_plus_model_path
  154. )
  155. self.table_model = RapidTable(input_args)
  156. self.ocr_engine = ocr_engine
  157. def predict(self, image, table_cls_score):
  158. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  159. # Continue with OCR on potentially rotated image
  160. ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  161. if ocr_result:
  162. ocr_result = [
  163. [item[0], escape_html(item[1][0]), item[1][1]]
  164. for item in ocr_result
  165. if len(item) == 2 and isinstance(item[1], tuple)
  166. ]
  167. else:
  168. ocr_result = None
  169. if ocr_result:
  170. try:
  171. table_results = self.table_model(np.asarray(image), ocr_result)
  172. html_code = table_results.pred_html
  173. table_cell_bboxes = table_results.cell_bboxes
  174. logic_points = table_results.logic_points
  175. elapse = table_results.elapse
  176. return html_code, table_cell_bboxes, logic_points, elapse
  177. except Exception as e:
  178. logger.exception(e)
  179. return None, None, None, None