|
|
@@ -12,6 +12,9 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
+import os
|
|
|
+import shutil
|
|
|
+import tempfile
|
|
|
from functools import lru_cache
|
|
|
from pathlib import Path
|
|
|
|
|
|
@@ -446,9 +449,17 @@ class OfficialModelsDict(dict):
|
|
|
def _download_from_hf():
|
|
|
local_dir = self._save_dir / f"{key}"
|
|
|
try:
|
|
|
- hf_hub.snapshot_download(
|
|
|
- repo_id=f"PaddlePaddle/{key}", local_dir=local_dir
|
|
|
- )
|
|
|
+ if os.path.exists(local_dir):
|
|
|
+ hf_hub.snapshot_download(
|
|
|
+ repo_id=f"PaddlePaddle/{key}", local_dir=local_dir
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ with tempfile.TemporaryDirectory() as td:
|
|
|
+ temp_dir = os.path.join(td, "temp_dir")
|
|
|
+ hf_hub.snapshot_download(
|
|
|
+ repo_id=f"PaddlePaddle/{key}", local_dir=temp_dir
|
|
|
+ )
|
|
|
+ shutil.move(temp_dir, local_dir)
|
|
|
except Exception as e:
|
|
|
logging.warning(
|
|
|
f"Encounter exception when download model from huggingface: \n{e}.\nPaddleX would try to download from BOS."
|