Selaa lähdekoodia

bugfix: map PaddleOCR-VL-0.9B to PaddleOCR-VL

gaotingquan 1 kuukausi sitten
vanhempi
commit
439ce72b80
1 muutettua tiedostoa jossa 7 lisäystä ja 4 poistoa
  1. 7 4
      paddlex/inference/utils/official_models.py

+ 7 - 4
paddlex/inference/utils/official_models.py

@@ -45,7 +45,7 @@ ALL_MODELS = [
     "ResNet152",
     "ResNet152_vd",
     "ResNet200_vd",
-    "PaddleOCR-VL",
+    "PaddleOCR-VL-0.9B",
     "PP-LCNet_x0_25",
     "PP-LCNet_x0_25_textline_ori",
     "PP-LCNet_x0_35",
@@ -345,7 +345,7 @@ OCR_MODELS = [
     "en_PP-OCRv5_mobile_rec",
     "th_PP-OCRv5_mobile_rec",
     "el_PP-OCRv5_mobile_rec",
-    "PaddleOCR-VL",
+    "PaddleOCR-VL-0.9B",
     "PicoDet_layout_1x",
     "PicoDet_layout_1x_table",
     "PicoDet-L_layout_17cls",
@@ -419,6 +419,9 @@ class _BaseModelHoster(ABC):
         assert (
             model_name in self.model_list
         ), f"The model {model_name} is not supported on hosting {self.__class__.__name__}!"
+        if model_name == "PaddleOCR-VL-0.9B":
+            model_name = "PaddleOCR-VL"
+
         model_dir = self._save_dir / f"{model_name}"
         if os.path.exists(model_dir):
             logging.info(
@@ -431,8 +434,8 @@ class _BaseModelHoster(ABC):
             self._download(model_name, model_dir)
 
         return (
-            model_dir / "PaddleOCR-VL"
-            if model_name == "PaddleOCR-VL-0.9B"
+            model_dir / "PaddleOCR-VL-0.9B"
+            if model_name == "PaddleOCR-VL"
             else model_dir
         )