浏览代码

download to temp dir and then move to dst, to avoid interruptions when downloading

gaotingquan 5 月之前
父节点
当前提交
b85b64fe27
共有 1 个文件被更改,包括 14 次插入3 次删除
  1. 14 3
      paddlex/inference/utils/official_models.py

+ 14 - 3
paddlex/inference/utils/official_models.py

@@ -12,6 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import os
+import shutil
+import tempfile
 from functools import lru_cache
 from pathlib import Path
 
@@ -446,9 +449,17 @@ class OfficialModelsDict(dict):
         def _download_from_hf():
             local_dir = self._save_dir / f"{key}"
             try:
-                hf_hub.snapshot_download(
-                    repo_id=f"PaddlePaddle/{key}", local_dir=local_dir
-                )
+                if os.path.exists(local_dir):
+                    hf_hub.snapshot_download(
+                        repo_id=f"PaddlePaddle/{key}", local_dir=local_dir
+                    )
+                else:
+                    with tempfile.TemporaryDirectory() as td:
+                        temp_dir = os.path.join(td, "temp_dir")
+                        hf_hub.snapshot_download(
+                            repo_id=f"PaddlePaddle/{key}", local_dir=temp_dir
+                        )
+                        shutil.move(temp_dir, local_dir)
             except Exception as e:
                 logging.warning(
                     f"Encounter exception when download model from huggingface: \n{e}.\nPaddleX would try to download from BOS."