Browse Source

fix: remove rapid-table

Sidney233 3 months ago
parent
commit
aad384f2e7

+ 2 - 37
mineru/model/table/rec/slanet_plus/main.py

@@ -13,13 +13,10 @@ from typing import Dict, List, Optional, Tuple, Union
 import cv2
 import numpy as np
 
-from rapid_table.utils import DownloadModel, LoadImage, Logger, VisTable
-
 from .matcher import TableMatch
 from .table_structure import TableStructurer
 from .table_structure_unitable import TableStructureUnitable
 
-logger = Logger(logger_name=__name__).get_log()
 root_dir = Path(__file__).resolve().parent
 
 
@@ -68,7 +65,7 @@ class RapidTable:
                 f"{self.model_type} is not supported. The currently supported models are {model_list}."
             )
 
-        config.model_path = self.get_model_path(config.model_type, config.model_path)
+        config.model_path = config.model_path
         if self.model_type == ModelType.UNITABLE.value:
             self.table_structure = TableStructureUnitable(asdict(config))
         else:
@@ -81,11 +78,10 @@ class RapidTable:
         except ModuleNotFoundError:
             self.ocr_engine = None
 
-        self.load_img = LoadImage()
 
     def __call__(
         self,
-        img_content: Union[str, np.ndarray, bytes, Path],
+        img: np.ndarray,
         ocr_result: List[Union[List[List[float]], str, str]] = None,
     ) -> RapidTableOutput:
         if self.ocr_engine is None and ocr_result is None:
@@ -93,8 +89,6 @@ class RapidTable:
                 "One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
             )
 
-        img = self.load_img(img_content)
-
         s = time.perf_counter()
         h, w = img.shape[:2]
 
@@ -153,27 +147,6 @@ class RapidTable:
         cell_bboxes[:, 1::2] *= h_ratio
         return cell_bboxes
 
-    @staticmethod
-    def get_model_path(
-        model_type: str, model_path: Union[str, Path, None]
-    ) -> Union[str, Dict[str, str]]:
-        if model_path is not None:
-            return model_path
-
-        model_url = KEY_TO_MODEL_URL.get(model_type, None)
-        if isinstance(model_url, str):
-            model_path = DownloadModel.download(model_url)
-            return model_path
-
-        if isinstance(model_url, dict):
-            model_paths = {}
-            for k, url in model_url.items():
-                model_paths[k] = DownloadModel.download(
-                    url, save_model_name=f"{model_type}_{Path(url).name}"
-                )
-            return model_paths
-
-        raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
 
 
 def parse_args(arg_list: Optional[List[str]] = None):
@@ -221,14 +194,6 @@ def main(arg_list: Optional[List[str]] = None):
     table_results = table_engine(img, ocr_result)
     print(table_results.pred_html)
 
-    viser = VisTable()
-    if args.vis:
-        img_path = Path(args.img_path)
-
-        save_dir = img_path.resolve().parent
-        save_html_path = save_dir / f"{Path(img_path).stem}.html"
-        save_drawed_path = save_dir / f"vis_{Path(img_path).name}"
-        viser(img_path, table_results, save_html_path, save_drawed_path)
 
 
 if __name__ == "__main__":

+ 6 - 1
mineru/model/table/rec/slanet_plus/rapid_table.py

@@ -1,5 +1,7 @@
 import os
 import html
+from typing import List
+
 import cv2
 import numpy as np
 from loguru import logger
@@ -21,7 +23,6 @@ class RapidTableModel(object):
         self.table_model = RapidTable(input_args)
         self.ocr_engine = ocr_engine
 
-
     def predict(self, image, table_cls_score):
         bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
         # Continue with OCR on potentially rotated image
@@ -44,3 +45,7 @@ class RapidTableModel(object):
                 logger.exception(e)
 
         return None, None, None, None
+
+    def batch_predict(self, images: List[np.ndarray], batch_size: int = 1):
+        # TODO: ocr也需要batch
+        pass

+ 2 - 2
mineru/model/table/rec/slanet_plus/table_stucture_utils.py

@@ -31,7 +31,7 @@ from onnxruntime import (
     get_device,
 )
 
-from rapid_table.utils import Logger
+from loguru import logger
 
 
 class EP(Enum):
@@ -42,7 +42,7 @@ class EP(Enum):
 
 class OrtInferSession:
     def __init__(self, config: Dict[str, Any]):
-        self.logger = Logger(logger_name=__name__).get_log()
+        self.logger = logger
 
         model_path = config.get("model_path", None)
         self._verify_model(model_path)

+ 1 - 1
mineru/model/table/rec/unet_table/main.py

@@ -10,7 +10,7 @@ import numpy as np
 import cv2
 from PIL import Image
 from loguru import logger
-from rapid_table import RapidTableInput, RapidTable
+from  ..slanet_plus.rapid_table import RapidTableInput, RapidTable
 
 from .table_structure_unet import TSRUnet