| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- import html
- import os
- import time
- from pathlib import Path
- from typing import List
- import cv2
- import numpy as np
- from loguru import logger
- from rapid_table import ModelType, RapidTable, RapidTableInput
- from rapid_table.utils import RapidTableOutput
- from tqdm import tqdm
- from mineru.model.ocr.pytorch_paddle import PytorchPaddleOCR
- from mineru.utils.enum_class import ModelPath
- from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
- def escape_html(input_string):
- """Escape HTML Entities."""
- return html.escape(input_string)
- class CustomRapidTable(RapidTable):
- def __init__(self, cfg: RapidTableInput):
- import logging
- # 通过环境变量控制日志级别
- logging.disable(logging.INFO)
- super().__init__(cfg)
- def __call__(self, img_contents, ocr_results=None, batch_size=1):
- if not isinstance(img_contents, list):
- img_contents = [img_contents]
- s = time.perf_counter()
- results = RapidTableOutput()
- total_nums = len(img_contents)
- with tqdm(total=total_nums, desc="Table-wireless Predict") as pbar:
- for start_i in range(0, total_nums, batch_size):
- end_i = min(total_nums, start_i + batch_size)
- imgs = self._load_imgs(img_contents[start_i:end_i])
- pred_structures, cell_bboxes = self.table_structure(imgs)
- logic_points = self.table_matcher.decode_logic_points(pred_structures)
- dt_boxes, rec_res = self.get_ocr_results(imgs, start_i, end_i, ocr_results)
- pred_htmls = self.table_matcher(
- pred_structures, cell_bboxes, dt_boxes, rec_res
- )
- results.pred_htmls.extend(pred_htmls)
- # 更新进度条
- pbar.update(end_i - start_i)
- elapse = time.perf_counter() - s
- results.elapse = elapse / total_nums
- return results
- class RapidTableModel():
- def __init__(self, ocr_engine):
- slanet_plus_model_path = os.path.join(
- auto_download_and_get_model_root_path(ModelPath.slanet_plus),
- ModelPath.slanet_plus,
- )
- input_args = RapidTableInput(
- model_type=ModelType.SLANETPLUS,
- model_dir_or_path=slanet_plus_model_path,
- use_ocr=False
- )
- self.table_model = CustomRapidTable(input_args)
- self.ocr_engine = ocr_engine
- def predict(self, image, ocr_result=None):
- bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
- # Continue with OCR on potentially rotated image
- if not ocr_result:
- raw_ocr_result = self.ocr_engine.ocr(bgr_image)[0]
- # 分离边界框、文本和置信度
- boxes = []
- texts = []
- scores = []
- for item in raw_ocr_result:
- if len(item) == 3:
- boxes.append(item[0])
- texts.append(escape_html(item[1]))
- scores.append(item[2])
- elif len(item) == 2 and isinstance(item[1], tuple):
- boxes.append(item[0])
- texts.append(escape_html(item[1][0]))
- scores.append(item[1][1])
- # 按照 rapid_table 期望的格式构建 ocr_results
- ocr_result = [(boxes, texts, scores)]
- if ocr_result:
- try:
- table_results = self.table_model(img_contents=np.asarray(image), ocr_results=ocr_result)
- html_code = table_results.pred_htmls
- table_cell_bboxes = table_results.cell_bboxes
- logic_points = table_results.logic_points
- elapse = table_results.elapse
- return html_code, table_cell_bboxes, logic_points, elapse
- except Exception as e:
- logger.exception(e)
- return None, None, None, None
- def batch_predict(self, table_res_list: List[dict], batch_size: int = 4):
- not_none_table_res_list = []
- for table_res in table_res_list:
- if table_res.get("ocr_result", None):
- not_none_table_res_list.append(table_res)
- if not_none_table_res_list:
- img_contents = [table_res["table_img"] for table_res in not_none_table_res_list]
- ocr_results = []
- # ocr_results需要按照rapid_table期望的格式构建
- for table_res in not_none_table_res_list:
- raw_ocr_result = table_res["ocr_result"]
- boxes = []
- texts = []
- scores = []
- for item in raw_ocr_result:
- if len(item) == 3:
- boxes.append(item[0])
- texts.append(escape_html(item[1]))
- scores.append(item[2])
- elif len(item) == 2 and isinstance(item[1], tuple):
- boxes.append(item[0])
- texts.append(escape_html(item[1][0]))
- scores.append(item[1][1])
- ocr_results.append((boxes, texts, scores))
- table_results = self.table_model(img_contents=img_contents, ocr_results=ocr_results, batch_size=batch_size)
- for i, result in enumerate(table_results.pred_htmls):
- if result:
- not_none_table_res_list[i]['table_res']['html'] = result
- if __name__ == '__main__':
- ocr_engine= PytorchPaddleOCR(
- det_db_box_thresh=0.5,
- det_db_unclip_ratio=1.6,
- enable_merge_det_boxes=False,
- )
- table_model = RapidTableModel(ocr_engine)
- img_path = Path(r"D:\project\20240729ocrtest\pythonProject\images\601c939cc6dabaf07af763e2f935f54896d0251f37cc47beb7fc6b069353455d.jpg")
- image = cv2.imread(str(img_path))
- html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(image)
- print(html_code)
|