import html
import os
import time
from pathlib import Path
from typing import List
import cv2
import numpy as np
from loguru import logger
from rapid_table import ModelType, RapidTable, RapidTableInput
from rapid_table.utils import RapidTableOutput
from tqdm import tqdm
from mineru.model.ocr.pytorch_paddle import PytorchPaddleOCR
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
def escape_html(input_string):
"""Escape HTML Entities."""
return html.escape(input_string)
class CustomRapidTable(RapidTable):
def __init__(self, cfg: RapidTableInput):
import logging
# 通过环境变量控制日志级别
logging.disable(logging.INFO)
super().__init__(cfg)
def __call__(self, img_contents, ocr_results=None, batch_size=1):
if not isinstance(img_contents, list):
img_contents = [img_contents]
s = time.perf_counter()
results = RapidTableOutput()
total_nums = len(img_contents)
with tqdm(total=total_nums, desc="Table-wireless Predict") as pbar:
for start_i in range(0, total_nums, batch_size):
end_i = min(total_nums, start_i + batch_size)
imgs = self._load_imgs(img_contents[start_i:end_i])
pred_structures, cell_bboxes = self.table_structure(imgs)
logic_points = self.table_matcher.decode_logic_points(pred_structures)
dt_boxes, rec_res = self.get_ocr_results(imgs, start_i, end_i, ocr_results)
pred_htmls = self.table_matcher(
pred_structures, cell_bboxes, dt_boxes, rec_res
)
results.pred_htmls.extend(pred_htmls)
# 更新进度条
pbar.update(end_i - start_i)
elapse = time.perf_counter() - s
results.elapse = elapse / total_nums
return results
class RapidTableModel():
def __init__(self, ocr_engine):
slanet_plus_model_path = os.path.join(
auto_download_and_get_model_root_path(ModelPath.slanet_plus),
ModelPath.slanet_plus,
)
input_args = RapidTableInput(
model_type=ModelType.SLANETPLUS,
model_dir_or_path=slanet_plus_model_path,
use_ocr=False
)
self.table_model = CustomRapidTable(input_args)
self.ocr_engine = ocr_engine
def predict(self, image, ocr_result=None):
bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
# Continue with OCR on potentially rotated image
if not ocr_result:
raw_ocr_result = self.ocr_engine.ocr(bgr_image)[0]
# 分离边界框、文本和置信度
boxes = []
texts = []
scores = []
for item in raw_ocr_result:
if len(item) == 3:
boxes.append(item[0])
texts.append(escape_html(item[1]))
scores.append(item[2])
elif len(item) == 2 and isinstance(item[1], tuple):
boxes.append(item[0])
texts.append(escape_html(item[1][0]))
scores.append(item[1][1])
# 按照 rapid_table 期望的格式构建 ocr_results
ocr_result = [(boxes, texts, scores)]
if ocr_result:
try:
table_results = self.table_model(img_contents=np.asarray(image), ocr_results=ocr_result)
html_code = table_results.pred_htmls
table_cell_bboxes = table_results.cell_bboxes
logic_points = table_results.logic_points
elapse = table_results.elapse
return html_code, table_cell_bboxes, logic_points, elapse
except Exception as e:
logger.exception(e)
return None, None, None, None
def batch_predict(self, table_res_list: List[dict], batch_size: int = 4):
not_none_table_res_list = []
for table_res in table_res_list:
if table_res.get("ocr_result", None):
not_none_table_res_list.append(table_res)
if not_none_table_res_list:
img_contents = [table_res["table_img"] for table_res in not_none_table_res_list]
ocr_results = []
# ocr_results需要按照rapid_table期望的格式构建
for table_res in not_none_table_res_list:
raw_ocr_result = table_res["ocr_result"]
boxes = []
texts = []
scores = []
for item in raw_ocr_result:
if len(item) == 3:
boxes.append(item[0])
texts.append(escape_html(item[1]))
scores.append(item[2])
elif len(item) == 2 and isinstance(item[1], tuple):
boxes.append(item[0])
texts.append(escape_html(item[1][0]))
scores.append(item[1][1])
ocr_results.append((boxes, texts, scores))
table_results = self.table_model(img_contents=img_contents, ocr_results=ocr_results, batch_size=batch_size)
for i, result in enumerate(table_results.pred_htmls):
if result:
not_none_table_res_list[i]['table_res']['html'] = result
if __name__ == '__main__':
ocr_engine= PytorchPaddleOCR(
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
enable_merge_det_boxes=False,
)
table_model = RapidTableModel(ocr_engine)
img_path = Path(r"D:\project\20240729ocrtest\pythonProject\images\601c939cc6dabaf07af763e2f935f54896d0251f37cc47beb7fc6b069353455d.jpg")
image = cv2.imread(str(img_path))
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(image)
print(html_code)