RapidTable.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import html
  2. import os
  3. import time
  4. from pathlib import Path
  5. from typing import List
  6. import cv2
  7. import numpy as np
  8. from loguru import logger
  9. from rapid_table import ModelType, RapidTable, RapidTableInput
  10. from rapid_table.utils import RapidTableOutput
  11. from tqdm import tqdm
  12. from mineru.model.ocr.pytorch_paddle import PytorchPaddleOCR
  13. from mineru.utils.enum_class import ModelPath
  14. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  15. def escape_html(input_string):
  16. """Escape HTML Entities."""
  17. return html.escape(input_string)
  18. class CustomRapidTable(RapidTable):
  19. def __init__(self, cfg: RapidTableInput):
  20. import logging
  21. # 通过环境变量控制日志级别
  22. logging.disable(logging.INFO)
  23. super().__init__(cfg)
  24. def __call__(self, img_contents, ocr_results=None, batch_size=1):
  25. if not isinstance(img_contents, list):
  26. img_contents = [img_contents]
  27. s = time.perf_counter()
  28. results = RapidTableOutput()
  29. total_nums = len(img_contents)
  30. with tqdm(total=total_nums, desc="Table-wireless Predict") as pbar:
  31. for start_i in range(0, total_nums, batch_size):
  32. end_i = min(total_nums, start_i + batch_size)
  33. imgs = self._load_imgs(img_contents[start_i:end_i])
  34. pred_structures, cell_bboxes = self.table_structure(imgs)
  35. logic_points = self.table_matcher.decode_logic_points(pred_structures)
  36. dt_boxes, rec_res = self.get_ocr_results(imgs, start_i, end_i, ocr_results)
  37. pred_htmls = self.table_matcher(
  38. pred_structures, cell_bboxes, dt_boxes, rec_res
  39. )
  40. results.pred_htmls.extend(pred_htmls)
  41. # 更新进度条
  42. pbar.update(end_i - start_i)
  43. elapse = time.perf_counter() - s
  44. results.elapse = elapse / total_nums
  45. return results
  46. class RapidTableModel():
  47. def __init__(self, ocr_engine):
  48. slanet_plus_model_path = os.path.join(
  49. auto_download_and_get_model_root_path(ModelPath.slanet_plus),
  50. ModelPath.slanet_plus,
  51. )
  52. input_args = RapidTableInput(
  53. model_type=ModelType.SLANETPLUS,
  54. model_dir_or_path=slanet_plus_model_path,
  55. use_ocr=False
  56. )
  57. self.table_model = CustomRapidTable(input_args)
  58. self.ocr_engine = ocr_engine
  59. def predict(self, image, ocr_result=None):
  60. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  61. # Continue with OCR on potentially rotated image
  62. if not ocr_result:
  63. raw_ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  64. # 分离边界框、文本和置信度
  65. boxes = []
  66. texts = []
  67. scores = []
  68. for item in raw_ocr_result:
  69. if len(item) == 3:
  70. boxes.append(item[0])
  71. texts.append(escape_html(item[1]))
  72. scores.append(item[2])
  73. elif len(item) == 2 and isinstance(item[1], tuple):
  74. boxes.append(item[0])
  75. texts.append(escape_html(item[1][0]))
  76. scores.append(item[1][1])
  77. # 按照 rapid_table 期望的格式构建 ocr_results
  78. ocr_result = [(boxes, texts, scores)]
  79. if ocr_result:
  80. try:
  81. table_results = self.table_model(img_contents=np.asarray(image), ocr_results=ocr_result)
  82. html_code = table_results.pred_htmls
  83. table_cell_bboxes = table_results.cell_bboxes
  84. logic_points = table_results.logic_points
  85. elapse = table_results.elapse
  86. return html_code, table_cell_bboxes, logic_points, elapse
  87. except Exception as e:
  88. logger.exception(e)
  89. return None, None, None, None
  90. def batch_predict(self, table_res_list: List[dict], batch_size: int = 4):
  91. not_none_table_res_list = []
  92. for table_res in table_res_list:
  93. if table_res.get("ocr_result", None):
  94. not_none_table_res_list.append(table_res)
  95. if not_none_table_res_list:
  96. img_contents = [table_res["table_img"] for table_res in not_none_table_res_list]
  97. ocr_results = []
  98. # ocr_results需要按照rapid_table期望的格式构建
  99. for table_res in not_none_table_res_list:
  100. raw_ocr_result = table_res["ocr_result"]
  101. boxes = []
  102. texts = []
  103. scores = []
  104. for item in raw_ocr_result:
  105. if len(item) == 3:
  106. boxes.append(item[0])
  107. texts.append(escape_html(item[1]))
  108. scores.append(item[2])
  109. elif len(item) == 2 and isinstance(item[1], tuple):
  110. boxes.append(item[0])
  111. texts.append(escape_html(item[1][0]))
  112. scores.append(item[1][1])
  113. ocr_results.append((boxes, texts, scores))
  114. table_results = self.table_model(img_contents=img_contents, ocr_results=ocr_results, batch_size=batch_size)
  115. for i, result in enumerate(table_results.pred_htmls):
  116. if result:
  117. not_none_table_res_list[i]['table_res']['html'] = result
  118. if __name__ == '__main__':
  119. ocr_engine= PytorchPaddleOCR(
  120. det_db_box_thresh=0.5,
  121. det_db_unclip_ratio=1.6,
  122. enable_merge_det_boxes=False,
  123. )
  124. table_model = RapidTableModel(ocr_engine)
  125. img_path = Path(r"D:\project\20240729ocrtest\pythonProject\images\601c939cc6dabaf07af763e2f935f54896d0251f37cc47beb7fc6b069353455d.jpg")
  126. image = cv2.imread(str(img_path))
  127. html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(image)
  128. print(html_code)