|
|
@@ -0,0 +1,154 @@
|
|
|
+import html
|
|
|
+import os
|
|
|
+import time
|
|
|
+from pathlib import Path
|
|
|
+from typing import List
|
|
|
+
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+from loguru import logger
|
|
|
+from rapid_table import ModelType, RapidTable, RapidTableInput
|
|
|
+from rapid_table.utils import RapidTableOutput
|
|
|
+from tqdm import tqdm
|
|
|
+
|
|
|
+from mineru.model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
|
|
|
+from mineru.utils.enum_class import ModelPath
|
|
|
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
+
|
|
|
+
|
|
|
+def escape_html(input_string):
|
|
|
+ """Escape HTML Entities."""
|
|
|
+ return html.escape(input_string)
|
|
|
+
|
|
|
+
|
|
|
+class CustomRapidTable(RapidTable):
|
|
|
+ def __init__(self, cfg: RapidTableInput):
|
|
|
+ import logging
|
|
|
+ # 通过环境变量控制日志级别
|
|
|
+ logging.disable(logging.INFO)
|
|
|
+ super().__init__(cfg)
|
|
|
+ def __call__(self, img_contents, ocr_results=None, batch_size=1):
|
|
|
+ if not isinstance(img_contents, list):
|
|
|
+ img_contents = [img_contents]
|
|
|
+
|
|
|
+ s = time.perf_counter()
|
|
|
+
|
|
|
+ results = RapidTableOutput()
|
|
|
+
|
|
|
+ total_nums = len(img_contents)
|
|
|
+
|
|
|
+ with tqdm(total=total_nums, desc="Table-wireless Predict") as pbar:
|
|
|
+ for start_i in range(0, total_nums, batch_size):
|
|
|
+ end_i = min(total_nums, start_i + batch_size)
|
|
|
+
|
|
|
+ imgs = self._load_imgs(img_contents[start_i:end_i])
|
|
|
+
|
|
|
+ pred_structures, cell_bboxes = self.table_structure(imgs)
|
|
|
+ logic_points = self.table_matcher.decode_logic_points(pred_structures)
|
|
|
+
|
|
|
+ dt_boxes, rec_res = self.get_ocr_results(imgs, start_i, end_i, ocr_results)
|
|
|
+ pred_htmls = self.table_matcher(
|
|
|
+ pred_structures, cell_bboxes, dt_boxes, rec_res
|
|
|
+ )
|
|
|
+
|
|
|
+ results.pred_htmls.extend(pred_htmls)
|
|
|
+ # 更新进度条
|
|
|
+ pbar.update(end_i - start_i)
|
|
|
+
|
|
|
+ elapse = time.perf_counter() - s
|
|
|
+ results.elapse = elapse / total_nums
|
|
|
+ return results
|
|
|
+
|
|
|
+
|
|
|
+class RapidTableModel():
|
|
|
+ def __init__(self, ocr_engine):
|
|
|
+ slanet_plus_model_path = os.path.join(
|
|
|
+ auto_download_and_get_model_root_path(ModelPath.slanet_plus),
|
|
|
+ ModelPath.slanet_plus,
|
|
|
+ )
|
|
|
+ input_args = RapidTableInput(
|
|
|
+ model_type=ModelType.SLANETPLUS,
|
|
|
+ model_dir_or_path=slanet_plus_model_path,
|
|
|
+ use_ocr=False
|
|
|
+ )
|
|
|
+ self.table_model = CustomRapidTable(input_args)
|
|
|
+ self.ocr_engine = ocr_engine
|
|
|
+
|
|
|
+ def predict(self, image, ocr_result=None):
|
|
|
+ bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
|
|
+ # Continue with OCR on potentially rotated image
|
|
|
+
|
|
|
+ if not ocr_result:
|
|
|
+ raw_ocr_result = self.ocr_engine.ocr(bgr_image)[0]
|
|
|
+ # 分离边界框、文本和置信度
|
|
|
+ boxes = []
|
|
|
+ texts = []
|
|
|
+ scores = []
|
|
|
+ for item in raw_ocr_result:
|
|
|
+ if len(item) == 3:
|
|
|
+ boxes.append(item[0])
|
|
|
+ texts.append(escape_html(item[1]))
|
|
|
+ scores.append(item[2])
|
|
|
+ elif len(item) == 2 and isinstance(item[1], tuple):
|
|
|
+ boxes.append(item[0])
|
|
|
+ texts.append(escape_html(item[1][0]))
|
|
|
+ scores.append(item[1][1])
|
|
|
+ # 按照 rapid_table 期望的格式构建 ocr_results
|
|
|
+ ocr_result = [(boxes, texts, scores)]
|
|
|
+
|
|
|
+ if ocr_result:
|
|
|
+ try:
|
|
|
+ table_results = self.table_model(img_contents=np.asarray(image), ocr_results=ocr_result)
|
|
|
+ html_code = table_results.pred_htmls
|
|
|
+ table_cell_bboxes = table_results.cell_bboxes
|
|
|
+ logic_points = table_results.logic_points
|
|
|
+ elapse = table_results.elapse
|
|
|
+ return html_code, table_cell_bboxes, logic_points, elapse
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception(e)
|
|
|
+
|
|
|
+ return None, None, None, None
|
|
|
+
|
|
|
+ def batch_predict(self, table_res_list: List[dict], batch_size: int = 4):
|
|
|
+ not_none_table_res_list = []
|
|
|
+ for table_res in table_res_list:
|
|
|
+ if table_res.get("ocr_result", None):
|
|
|
+ not_none_table_res_list.append(table_res)
|
|
|
+
|
|
|
+ if not_none_table_res_list:
|
|
|
+ img_contents = [table_res["table_img"] for table_res in not_none_table_res_list]
|
|
|
+ ocr_results = []
|
|
|
+ # ocr_results需要按照rapid_table期望的格式构建
|
|
|
+ for table_res in not_none_table_res_list:
|
|
|
+ raw_ocr_result = table_res["ocr_result"]
|
|
|
+ boxes = []
|
|
|
+ texts = []
|
|
|
+ scores = []
|
|
|
+ for item in raw_ocr_result:
|
|
|
+ if len(item) == 3:
|
|
|
+ boxes.append(item[0])
|
|
|
+ texts.append(escape_html(item[1]))
|
|
|
+ scores.append(item[2])
|
|
|
+ elif len(item) == 2 and isinstance(item[1], tuple):
|
|
|
+ boxes.append(item[0])
|
|
|
+ texts.append(escape_html(item[1][0]))
|
|
|
+ scores.append(item[1][1])
|
|
|
+ ocr_results.append((boxes, texts, scores))
|
|
|
+ table_results = self.table_model(img_contents=img_contents, ocr_results=ocr_results, batch_size=batch_size)
|
|
|
+
|
|
|
+ for i, result in enumerate(table_results.pred_htmls):
|
|
|
+ if result:
|
|
|
+ not_none_table_res_list[i]['table_res']['html'] = result
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ ocr_engine= PytorchPaddleOCR(
|
|
|
+ det_db_box_thresh=0.5,
|
|
|
+ det_db_unclip_ratio=1.6,
|
|
|
+ enable_merge_det_boxes=False,
|
|
|
+ )
|
|
|
+ table_model = RapidTableModel(ocr_engine)
|
|
|
+ img_path = Path(r"D:\project\20240729ocrtest\pythonProject\images\601c939cc6dabaf07af763e2f935f54896d0251f37cc47beb7fc6b069353455d.jpg")
|
|
|
+ image = cv2.imread(str(img_path))
|
|
|
+ html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(image)
|
|
|
+ print(html_code)
|
|
|
+
|