| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- import os
- import argparse
- import copy
- import importlib
- import time
- import html
- from dataclasses import asdict, dataclass
- from enum import Enum
- from pathlib import Path
- from typing import Dict, List, Optional, Tuple, Union
- import cv2
- import numpy as np
- from loguru import logger
- from .matcher import TableMatch
- from .table_structure import TableStructurer
- from mineru.utils.enum_class import ModelPath
- from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
- root_dir = Path(__file__).resolve().parent
- class ModelType(Enum):
- PPSTRUCTURE_EN = "ppstructure_en"
- PPSTRUCTURE_ZH = "ppstructure_zh"
- SLANETPLUS = "slanet_plus"
- UNITABLE = "unitable"
- ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
- KEY_TO_MODEL_URL = {
- ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
- ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
- ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
- ModelType.UNITABLE.value: {
- "encoder": f"{ROOT_URL}/unitable/encoder.pth",
- "decoder": f"{ROOT_URL}/unitable/decoder.pth",
- "vocab": f"{ROOT_URL}/unitable/vocab.json",
- },
- }
- @dataclass
- class RapidTableInput:
- model_type: Optional[str] = ModelType.SLANETPLUS.value
- model_path: Union[str, Path, None, Dict[str, str]] = None
- use_cuda: bool = False
- device: str = "cpu"
- @dataclass
- class RapidTableOutput:
- pred_html: Optional[str] = None
- cell_bboxes: Optional[np.ndarray] = None
- logic_points: Optional[np.ndarray] = None
- elapse: Optional[float] = None
- class RapidTable:
- def __init__(self, config: RapidTableInput):
- self.model_type = config.model_type
- if self.model_type not in KEY_TO_MODEL_URL:
- model_list = ",".join(KEY_TO_MODEL_URL)
- raise ValueError(
- f"{self.model_type} is not supported. The currently supported models are {model_list}."
- )
- config.model_path = config.model_path
- if self.model_type == ModelType.SLANETPLUS.value:
- self.table_structure = TableStructurer(asdict(config))
- else:
- raise ValueError(f"{self.model_type} is not supported.")
- self.table_matcher = TableMatch()
- try:
- self.ocr_engine = importlib.import_module("rapidocr").RapidOCR()
- except ModuleNotFoundError:
- self.ocr_engine = None
- def __call__(
- self,
- img: np.ndarray,
- ocr_result: List[Union[List[List[float]], str, str]] = None,
- ) -> RapidTableOutput:
- if self.ocr_engine is None and ocr_result is None:
- raise ValueError(
- "One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
- )
- s = time.perf_counter()
- h, w = img.shape[:2]
- if ocr_result is None:
- ocr_result = self.ocr_engine(img)
- ocr_result = list(
- zip(
- ocr_result.boxes,
- ocr_result.txts,
- ocr_result.scores,
- )
- )
- dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
- pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
- # 适配slanet-plus模型输出的box缩放还原
- if self.model_type == ModelType.SLANETPLUS.value:
- cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
- pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
- # 过滤掉占位的bbox
- mask = ~np.all(cell_bboxes == 0, axis=1)
- cell_bboxes = cell_bboxes[mask]
- logic_points = self.table_matcher.decode_logic_points(pred_structures)
- elapse = time.perf_counter() - s
- return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
- def get_boxes_recs(
- self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
- ) -> Tuple[np.ndarray, Tuple[str, str]]:
- dt_boxes, rec_res, scores = list(zip(*ocr_result))
- rec_res = list(zip(rec_res, scores))
- r_boxes = []
- for box in dt_boxes:
- box = np.array(box)
- x_min = max(0, box[:, 0].min() - 1)
- x_max = min(w, box[:, 0].max() + 1)
- y_min = max(0, box[:, 1].min() - 1)
- y_max = min(h, box[:, 1].max() + 1)
- box = [x_min, y_min, x_max, y_max]
- r_boxes.append(box)
- dt_boxes = np.array(r_boxes)
- return dt_boxes, rec_res
- def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
- h, w = img.shape[:2]
- resized = 488
- ratio = min(resized / h, resized / w)
- w_ratio = resized / (w * ratio)
- h_ratio = resized / (h * ratio)
- cell_bboxes[:, 0::2] *= w_ratio
- cell_bboxes[:, 1::2] *= h_ratio
- return cell_bboxes
- def parse_args(arg_list: Optional[List[str]] = None):
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-v",
- "--vis",
- action="store_true",
- default=False,
- help="Wheter to visualize the layout results.",
- )
- parser.add_argument(
- "-img", "--img_path", type=str, required=True, help="Path to image for layout."
- )
- parser.add_argument(
- "-m",
- "--model_type",
- type=str,
- default=ModelType.SLANETPLUS.value,
- choices=list(KEY_TO_MODEL_URL),
- )
- args = parser.parse_args(arg_list)
- return args
- def escape_html(input_string):
- """Escape HTML Entities."""
- return html.escape(input_string)
- class RapidTableModel(object):
- def __init__(self, ocr_engine):
- slanet_plus_model_path = os.path.join(
- auto_download_and_get_model_root_path(ModelPath.slanet_plus),
- ModelPath.slanet_plus,
- )
- input_args = RapidTableInput(
- model_type="slanet_plus", model_path=slanet_plus_model_path
- )
- self.table_model = RapidTable(input_args)
- self.ocr_engine = ocr_engine
- def predict(self, image, table_cls_score):
- bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
- # Continue with OCR on potentially rotated image
- ocr_result = self.ocr_engine.ocr(bgr_image)[0]
- if ocr_result:
- ocr_result = [
- [item[0], escape_html(item[1][0]), item[1][1]]
- for item in ocr_result
- if len(item) == 2 and isinstance(item[1], tuple)
- ]
- else:
- ocr_result = None
- if ocr_result:
- try:
- table_results = self.table_model(np.asarray(image), ocr_result)
- html_code = table_results.pred_html
- table_cell_bboxes = table_results.cell_bboxes
- logic_points = table_results.logic_points
- elapse = table_results.elapse
- return html_code, table_cell_bboxes, logic_points, elapse
- except Exception as e:
- logger.exception(e)
- return None, None, None, None
|