|
|
@@ -1,11 +1,8 @@
|
|
|
import os
|
|
|
-import argparse
|
|
|
import copy
|
|
|
-import importlib
|
|
|
import time
|
|
|
import html
|
|
|
from dataclasses import asdict, dataclass
|
|
|
-from enum import Enum
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
@@ -19,32 +16,10 @@ from .table_structure import TableStructurer
|
|
|
from mineru.utils.enum_class import ModelPath
|
|
|
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
|
|
|
-root_dir = Path(__file__).resolve().parent
|
|
|
-
|
|
|
-
|
|
|
-class ModelType(Enum):
|
|
|
- PPSTRUCTURE_EN = "ppstructure_en"
|
|
|
- PPSTRUCTURE_ZH = "ppstructure_zh"
|
|
|
- SLANETPLUS = "slanet_plus"
|
|
|
- UNITABLE = "unitable"
|
|
|
-
|
|
|
-
|
|
|
-ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
|
|
|
-KEY_TO_MODEL_URL = {
|
|
|
- ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
|
|
|
- ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
|
|
|
- ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
|
|
|
- ModelType.UNITABLE.value: {
|
|
|
- "encoder": f"{ROOT_URL}/unitable/encoder.pth",
|
|
|
- "decoder": f"{ROOT_URL}/unitable/decoder.pth",
|
|
|
- "vocab": f"{ROOT_URL}/unitable/vocab.json",
|
|
|
- },
|
|
|
-}
|
|
|
-
|
|
|
|
|
|
@dataclass
|
|
|
class RapidTableInput:
|
|
|
- model_type: Optional[str] = ModelType.SLANETPLUS.value
|
|
|
+ model_type: Optional[str] = "slanet_plus"
|
|
|
model_path: Union[str, Path, None, Dict[str, str]] = None
|
|
|
use_cuda: bool = False
|
|
|
device: str = "cpu"
|
|
|
@@ -60,18 +35,7 @@ class RapidTableOutput:
|
|
|
|
|
|
class RapidTable:
|
|
|
def __init__(self, config: RapidTableInput):
|
|
|
- self.model_type = config.model_type
|
|
|
- if self.model_type not in KEY_TO_MODEL_URL:
|
|
|
- model_list = ",".join(KEY_TO_MODEL_URL)
|
|
|
- raise ValueError(
|
|
|
- f"{self.model_type} is not supported. The currently supported models are {model_list}."
|
|
|
- )
|
|
|
-
|
|
|
- config.model_path = config.model_path
|
|
|
- if self.model_type == ModelType.SLANETPLUS.value:
|
|
|
- self.table_structure = TableStructurer(asdict(config))
|
|
|
- else:
|
|
|
- raise ValueError(f"{self.model_type} is not supported.")
|
|
|
+ self.table_structure = TableStructurer(asdict(config))
|
|
|
self.table_matcher = TableMatch()
|
|
|
|
|
|
def predict(
|
|
|
@@ -177,29 +141,6 @@ class RapidTable:
|
|
|
return cell_bboxes
|
|
|
|
|
|
|
|
|
-def parse_args(arg_list: Optional[List[str]] = None):
|
|
|
- parser = argparse.ArgumentParser()
|
|
|
- parser.add_argument(
|
|
|
- "-v",
|
|
|
- "--vis",
|
|
|
- action="store_true",
|
|
|
- default=False,
|
|
|
- help="Wheter to visualize the layout results.",
|
|
|
- )
|
|
|
- parser.add_argument(
|
|
|
- "-img", "--img_path", type=str, required=True, help="Path to image for layout."
|
|
|
- )
|
|
|
- parser.add_argument(
|
|
|
- "-m",
|
|
|
- "--model_type",
|
|
|
- type=str,
|
|
|
- default=ModelType.SLANETPLUS.value,
|
|
|
- choices=list(KEY_TO_MODEL_URL),
|
|
|
- )
|
|
|
- args = parser.parse_args(arg_list)
|
|
|
- return args
|
|
|
-
|
|
|
-
|
|
|
def escape_html(input_string):
|
|
|
"""Escape HTML Entities."""
|
|
|
return html.escape(input_string)
|
|
|
@@ -245,22 +186,28 @@ class RapidTableModel(object):
|
|
|
|
|
|
def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
|
|
|
"""对传入的字典列表进行批量预测,无返回值"""
|
|
|
- with tqdm(total=len(table_res_list), desc="Table-wireless Predict") as pbar:
|
|
|
- for index in range(0, len(table_res_list), batch_size):
|
|
|
+
|
|
|
+ not_none_table_res_list = []
|
|
|
+ for table_res in table_res_list:
|
|
|
+ if table_res.get("ocr_result", None):
|
|
|
+ not_none_table_res_list.append(table_res)
|
|
|
+
|
|
|
+ with tqdm(total=len(not_none_table_res_list), desc="Table-wireless Predict") as pbar:
|
|
|
+ for index in range(0, len(not_none_table_res_list), batch_size):
|
|
|
batch_imgs = [
|
|
|
- cv2.cvtColor(np.asarray(table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
|
|
|
- for i in range(index, min(index + batch_size, len(table_res_list)))
|
|
|
+ cv2.cvtColor(np.asarray(not_none_table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
|
|
|
+ for i in range(index, min(index + batch_size, len(not_none_table_res_list)))
|
|
|
]
|
|
|
batch_ocrs = [
|
|
|
- table_res_list[i]["ocr_result"]
|
|
|
- for i in range(index, min(index + batch_size, len(table_res_list)))
|
|
|
+ not_none_table_res_list[i]["ocr_result"]
|
|
|
+ for i in range(index, min(index + batch_size, len(not_none_table_res_list)))
|
|
|
]
|
|
|
results = self.table_model.batch_predict(
|
|
|
batch_imgs, batch_ocrs, batch_size=batch_size
|
|
|
)
|
|
|
for i, result in enumerate(results):
|
|
|
if result.pred_html:
|
|
|
- table_res_list[index + i]['table_res']['html'] = result.pred_html
|
|
|
+ not_none_table_res_list[index + i]['table_res']['html'] = result.pred_html
|
|
|
|
|
|
# 更新进度条
|
|
|
- pbar.update(len(results))
|
|
|
+ pbar.update(len(results))
|