瀏覽代碼

use model cache files when network is unavailable (#4676)

Tingquan Gao 2 周之前
父節點
當前提交
e0c509eef1
共有 2 個文件被更改,包括 35 次插入35 次删除
  1. 0 3
      paddlex/inference/models/__init__.py
  2. 35 32
      paddlex/inference/utils/official_models.py

+ 0 - 3
paddlex/inference/models/__init__.py

@@ -70,9 +70,6 @@ def create_predictor(
 
     if need_local_model(genai_config):
         if model_dir is None:
-            assert (
-                model_name in official_models
-            ), f"The model ({model_name}) is not supported! Please using directory of local model files or model name supported by PaddleX!"
             model_dir = official_models[model_name]
         else:
             assert Path(model_dir).exists(), f"{model_dir} is not exists!"

+ 35 - 32
paddlex/inference/utils/official_models.py

@@ -45,7 +45,7 @@ ALL_MODELS = [
     "ResNet152",
     "ResNet152_vd",
     "ResNet200_vd",
-    "PaddleOCR-VL-0.9B",
+    "PaddleOCR-VL",
     "PP-LCNet_x0_25",
     "PP-LCNet_x0_25_textline_ori",
     "PP-LCNet_x0_35",
@@ -345,7 +345,7 @@ OCR_MODELS = [
     "en_PP-OCRv5_mobile_rec",
     "th_PP-OCRv5_mobile_rec",
     "el_PP-OCRv5_mobile_rec",
-    "PaddleOCR-VL-0.9B",
+    "PaddleOCR-VL",
     "PicoDet_layout_1x",
     "PicoDet_layout_1x_table",
     "PicoDet-L_layout_17cls",
@@ -419,27 +419,15 @@ class _BaseModelHoster(ABC):
         assert (
             model_name in self.model_list
         ), f"The model {model_name} is not supported on hosting {self.__class__.__name__}!"
-        if model_name == "PaddleOCR-VL-0.9B":
-            model_name = "PaddleOCR-VL"
 
         model_dir = self._save_dir / f"{model_name}"
-        if os.path.exists(model_dir):
-            logging.info(
-                f"Model files already exist. Using cached files. To redownload, please delete the directory manually: `{model_dir}`."
-            )
-        else:
-            logging.info(
-                f"Using official model ({model_name}), the model files will be automatically downloaded and saved in `{model_dir}`."
-            )
-            self._download(model_name, model_dir)
-            logging.debug(
-                f"`{model_name}` model files has been download from model source: `{self.alias}`!"
-            )
-
-        if model_name == "PaddleOCR-VL":
-            vl_model_dir = model_dir / "PaddleOCR-VL-0.9B"
-            if vl_model_dir.exists() and vl_model_dir.is_dir():
-                return vl_model_dir
+        logging.info(
+            f"Using official model ({model_name}), the model files will be automatically downloaded and saved in `{model_dir}`."
+        )
+        self._download(model_name, model_dir)
+        logging.debug(
+            f"`{model_name}` model files has been download from model source: `{self.alias}`!"
+        )
 
         return model_dir
 
@@ -573,21 +561,33 @@ class _ModelManager:
                     hosters.append(hoster_cls(self._save_dir))
         if len(hosters) == 0:
             logging.warning(
-                f"""No model hoster is available! Please check your network connection to one of the following model hosts:
-HuggingFace ({_HuggingFaceModelHoster.healthcheck_url}),
-ModelScope ({_ModelScopeModelHoster.healthcheck_url}),
-AIStudio ({_AIStudioModelHoster.healthcheck_url}), or
-BOS ({_BosModelHoster.healthcheck_url}).
-Otherwise, only local models can be used."""
+                f"No model hoster is available! Please check your network connection to one of the following model hosts: HuggingFace ({_HuggingFaceModelHoster.healthcheck_url}), ModelScope ({_ModelScopeModelHoster.healthcheck_url}), AIStudio ({_AIStudioModelHoster.healthcheck_url}), or BOS ({_BosModelHoster.healthcheck_url}). Otherwise, only local models can be used."
             )
         return hosters
 
     def _get_model_local_path(self, model_name):
-        if len(self._hosters) == 0:
-            msg = "No available model hosting platforms detected. Please check your network connection."
-            logging.error(msg)
-            raise Exception(msg)
-        return self._download_from_hoster(self._hosters, model_name)
+        if model_name == "PaddleOCR-VL-0.9B":
+            model_name = "PaddleOCR-VL"
+
+        model_dir = self._save_dir / f"{model_name}"
+        if os.path.exists(model_dir):
+            logging.info(
+                f"Model files already exist. Using cached files. To redownload, please delete the directory manually: `{model_dir}`."
+            )
+        else:
+            if len(self._hosters) == 0:
+                msg = "No available model hosting platforms detected. Please check your network connection."
+                logging.error(msg)
+                raise Exception(msg)
+
+            model_dir = self._download_from_hoster(self._hosters, model_name)
+
+        if model_name == "PaddleOCR-VL":
+            vl_model_dir = model_dir / "PaddleOCR-VL-0.9B"
+            if vl_model_dir.exists() and vl_model_dir.is_dir():
+                return vl_model_dir
+
+        return model_dir
 
     def _download_from_hoster(self, hosters, model_name):
         for idx, hoster in enumerate(hosters):
@@ -605,6 +605,9 @@ Otherwise, only local models can be used."""
                         f"Encountering exception when download model from {hoster.alias}: \n{e}, will try to download from other model sources: `{hosters[idx + 1].alias}`."
                     )
                     return self._download_from_hoster(hosters[idx + 1 :], model_name)
+        raise Exception(
+            f"No model source is available for model `{model_name}`! Please check model name and network, or use local model files!"
+        )
 
     def __contains__(self, model_name):
         return model_name in self.model_list