Explorar o código

support PaddleOCR-VL (#4609)

* support PaddleOCR-VL

* remove confusing number characters, 0, 1, 9

* support PP-DocLayoutV2
Tingquan Gao hai 1 mes
pai
achega
a3b9c4b7f6

+ 1 - 1
.precommit/check_imports.py

@@ -30,7 +30,7 @@ from setup import REQUIRED_DEPS
 # 1. It is supported only in Python 3.10+.
 # 2. It requires the packages to be installed, but we are doing a static check.
 MOD_TO_DEP = {
-    "aistudio-sdk": "aistudio_sdk",
+    "aistudio_sdk": "aistudio-sdk",
     "aiohttp": "aiohttp",
     "baidubce": "bce-python-sdk",
     "bs4": "beautifulsoup4",

+ 12 - 1
paddlex/inference/pipelines/paddleocr_vl/uilts.py

@@ -395,13 +395,24 @@ def tokenize_figure_of_table(table_block_img, table_box, figures):
             - Token-to-img HTML map,
             - List of figure paths dropped.
     """
+
+    def gen_random_map(num):
+        exclude_digits = {"0", "1", "9"}
+        seq = []
+        i = 0
+        while len(seq) < num:
+            if not (set(str(i)) & exclude_digits):
+                seq.append(i)
+            i += 1
+        return seq
+
     import random
 
     random.seed(1024)
     token_map = {}
     table_x_min, table_y_min, table_x_max, table_y_max = table_box
     drop_idxes = []
-    random_map = list(range(len(figures)))
+    random_map = gen_random_map(len(figures))
     random.shuffle(random_map)
     for figure_id, figure in enumerate(figures):
         figure_x_min, figure_y_min, figure_x_max, figure_y_max = figure["coordinate"]

+ 8 - 1
paddlex/inference/utils/official_models.py

@@ -45,6 +45,7 @@ ALL_MODELS = [
     "ResNet152",
     "ResNet152_vd",
     "ResNet200_vd",
+    "PaddleOCR-VL-0.9B",
     "PP-LCNet_x0_25",
     "PP-LCNet_x0_25_textline_ori",
     "PP-LCNet_x0_35",
@@ -294,6 +295,7 @@ ALL_MODELS = [
     "GroundingDINO-T",
     "SAM-H_box",
     "SAM-H_point",
+    "PP-DocLayoutV2",
     "PP-DocLayout-L",
     "PP-DocLayout-M",
     "PP-DocLayout-S",
@@ -424,7 +426,12 @@ class _BaseModelHoster(ABC):
                 f"Using official model ({model_name}), the model files will be automatically downloaded and saved in `{model_dir}`."
             )
             self._download(model_name, model_dir)
-        return model_dir
+
+        return (
+            model_dir / "PaddleOCR-VL-0.9B"
+            if model_name == "PaddleOCR-VL-0.9B"
+            else model_dir
+        )
 
     @abstractmethod
     def _download(self):