|
|
@@ -13,13 +13,10 @@ from typing import Dict, List, Optional, Tuple, Union
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
|
|
|
-from rapid_table.utils import DownloadModel, LoadImage, Logger, VisTable
|
|
|
-
|
|
|
from .matcher import TableMatch
|
|
|
from .table_structure import TableStructurer
|
|
|
from .table_structure_unitable import TableStructureUnitable
|
|
|
|
|
|
-logger = Logger(logger_name=__name__).get_log()
|
|
|
root_dir = Path(__file__).resolve().parent
|
|
|
|
|
|
|
|
|
@@ -68,7 +65,7 @@ class RapidTable:
|
|
|
f"{self.model_type} is not supported. The currently supported models are {model_list}."
|
|
|
)
|
|
|
|
|
|
- config.model_path = self.get_model_path(config.model_type, config.model_path)
|
|
|
+ config.model_path = config.model_path
|
|
|
if self.model_type == ModelType.UNITABLE.value:
|
|
|
self.table_structure = TableStructureUnitable(asdict(config))
|
|
|
else:
|
|
|
@@ -81,11 +78,10 @@ class RapidTable:
|
|
|
except ModuleNotFoundError:
|
|
|
self.ocr_engine = None
|
|
|
|
|
|
- self.load_img = LoadImage()
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
- img_content: Union[str, np.ndarray, bytes, Path],
|
|
|
+ img: np.ndarray,
|
|
|
ocr_result: List[Union[List[List[float]], str, str]] = None,
|
|
|
) -> RapidTableOutput:
|
|
|
if self.ocr_engine is None and ocr_result is None:
|
|
|
@@ -93,8 +89,6 @@ class RapidTable:
|
|
|
"One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
|
|
|
)
|
|
|
|
|
|
- img = self.load_img(img_content)
|
|
|
-
|
|
|
s = time.perf_counter()
|
|
|
h, w = img.shape[:2]
|
|
|
|
|
|
@@ -153,27 +147,6 @@ class RapidTable:
|
|
|
cell_bboxes[:, 1::2] *= h_ratio
|
|
|
return cell_bboxes
|
|
|
|
|
|
- @staticmethod
|
|
|
- def get_model_path(
|
|
|
- model_type: str, model_path: Union[str, Path, None]
|
|
|
- ) -> Union[str, Dict[str, str]]:
|
|
|
- if model_path is not None:
|
|
|
- return model_path
|
|
|
-
|
|
|
- model_url = KEY_TO_MODEL_URL.get(model_type, None)
|
|
|
- if isinstance(model_url, str):
|
|
|
- model_path = DownloadModel.download(model_url)
|
|
|
- return model_path
|
|
|
-
|
|
|
- if isinstance(model_url, dict):
|
|
|
- model_paths = {}
|
|
|
- for k, url in model_url.items():
|
|
|
- model_paths[k] = DownloadModel.download(
|
|
|
- url, save_model_name=f"{model_type}_{Path(url).name}"
|
|
|
- )
|
|
|
- return model_paths
|
|
|
-
|
|
|
- raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
|
|
|
|
|
|
|
|
|
def parse_args(arg_list: Optional[List[str]] = None):
|
|
|
@@ -221,14 +194,6 @@ def main(arg_list: Optional[List[str]] = None):
|
|
|
table_results = table_engine(img, ocr_result)
|
|
|
print(table_results.pred_html)
|
|
|
|
|
|
- viser = VisTable()
|
|
|
- if args.vis:
|
|
|
- img_path = Path(args.img_path)
|
|
|
-
|
|
|
- save_dir = img_path.resolve().parent
|
|
|
- save_html_path = save_dir / f"{Path(img_path).stem}.html"
|
|
|
- save_drawed_path = save_dir / f"vis_{Path(img_path).name}"
|
|
|
- viser(img_path, table_results, save_html_path, save_drawed_path)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|