official_models.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import shutil
  16. import tempfile
  17. from abc import ABC, abstractmethod
  18. from pathlib import Path
  19. import huggingface_hub as hf_hub
  20. hf_hub.logging.set_verbosity_error()
  21. import modelscope
  22. import requests
  23. os.environ["AISTUDIO_LOG"] = "critical"
  24. from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
  25. from ...utils import logging
  26. from ...utils.cache import CACHE_DIR
  27. from ...utils.download import download_and_extract
  28. from ...utils.flags import MODEL_SOURCE
  29. ALL_MODELS = [
  30. "ResNet18",
  31. "ResNet18_vd",
  32. "ResNet34",
  33. "ResNet34_vd",
  34. "ResNet50",
  35. "ResNet50_vd",
  36. "ResNet101",
  37. "ResNet101_vd",
  38. "ResNet152",
  39. "ResNet152_vd",
  40. "ResNet200_vd",
  41. "PaddleOCR-VL-0.9B",
  42. "PP-LCNet_x0_25",
  43. "PP-LCNet_x0_25_textline_ori",
  44. "PP-LCNet_x0_35",
  45. "PP-LCNet_x0_5",
  46. "PP-LCNet_x0_75",
  47. "PP-LCNet_x1_0",
  48. "PP-LCNet_x1_0_doc_ori",
  49. "PP-LCNet_x1_0_textline_ori",
  50. "PP-LCNet_x1_5",
  51. "PP-LCNet_x2_5",
  52. "PP-LCNet_x2_0",
  53. "PP-LCNetV2_small",
  54. "PP-LCNetV2_base",
  55. "PP-LCNetV2_large",
  56. "MobileNetV3_large_x0_35",
  57. "MobileNetV3_large_x0_5",
  58. "MobileNetV3_large_x0_75",
  59. "MobileNetV3_large_x1_0",
  60. "MobileNetV3_large_x1_25",
  61. "MobileNetV3_small_x0_35",
  62. "MobileNetV3_small_x0_5",
  63. "MobileNetV3_small_x0_75",
  64. "MobileNetV3_small_x1_0",
  65. "MobileNetV3_small_x1_25",
  66. "ConvNeXt_tiny",
  67. "ConvNeXt_small",
  68. "ConvNeXt_base_224",
  69. "ConvNeXt_base_384",
  70. "ConvNeXt_large_224",
  71. "ConvNeXt_large_384",
  72. "MobileNetV2_x0_25",
  73. "MobileNetV2_x0_5",
  74. "MobileNetV2_x1_0",
  75. "MobileNetV2_x1_5",
  76. "MobileNetV2_x2_0",
  77. "MobileNetV1_x0_25",
  78. "MobileNetV1_x0_5",
  79. "MobileNetV1_x0_75",
  80. "MobileNetV1_x1_0",
  81. "SwinTransformer_tiny_patch4_window7_224",
  82. "SwinTransformer_small_patch4_window7_224",
  83. "SwinTransformer_base_patch4_window7_224",
  84. "SwinTransformer_base_patch4_window12_384",
  85. "SwinTransformer_large_patch4_window7_224",
  86. "SwinTransformer_large_patch4_window12_384",
  87. "PP-HGNet_tiny",
  88. "PP-HGNet_small",
  89. "PP-HGNet_base",
  90. "PP-HGNetV2-B0",
  91. "PP-HGNetV2-B1",
  92. "PP-HGNetV2-B2",
  93. "PP-HGNetV2-B3",
  94. "PP-HGNetV2-B4",
  95. "PP-HGNetV2-B5",
  96. "PP-HGNetV2-B6",
  97. "FasterNet-L",
  98. "FasterNet-M",
  99. "FasterNet-S",
  100. "FasterNet-T0",
  101. "FasterNet-T1",
  102. "FasterNet-T2",
  103. "StarNet-S1",
  104. "StarNet-S2",
  105. "StarNet-S3",
  106. "StarNet-S4",
  107. "MobileNetV4_conv_small",
  108. "MobileNetV4_conv_medium",
  109. "MobileNetV4_conv_large",
  110. "MobileNetV4_hybrid_medium",
  111. "MobileNetV4_hybrid_large",
  112. "CLIP_vit_base_patch16_224",
  113. "CLIP_vit_large_patch14_224",
  114. "PP-LCNet_x1_0_ML",
  115. "PP-HGNetV2-B0_ML",
  116. "PP-HGNetV2-B4_ML",
  117. "PP-HGNetV2-B6_ML",
  118. "ResNet50_ML",
  119. "CLIP_vit_base_patch16_448_ML",
  120. "PP-YOLOE_plus-X",
  121. "PP-YOLOE_plus-L",
  122. "PP-YOLOE_plus-M",
  123. "PP-YOLOE_plus-S",
  124. "RT-DETR-L",
  125. "RT-DETR-H",
  126. "RT-DETR-X",
  127. "YOLOv3-DarkNet53",
  128. "YOLOv3-MobileNetV3",
  129. "YOLOv3-ResNet50_vd_DCN",
  130. "YOLOX-L",
  131. "YOLOX-M",
  132. "YOLOX-N",
  133. "YOLOX-S",
  134. "YOLOX-T",
  135. "YOLOX-X",
  136. "RT-DETR-R18",
  137. "RT-DETR-R50",
  138. "PicoDet-S",
  139. "PicoDet-L",
  140. "Deeplabv3-R50",
  141. "Deeplabv3-R101",
  142. "Deeplabv3_Plus-R50",
  143. "Deeplabv3_Plus-R101",
  144. "PP-ShiTuV2_rec",
  145. "PP-ShiTuV2_rec_CLIP_vit_base",
  146. "PP-ShiTuV2_rec_CLIP_vit_large",
  147. "PP-LiteSeg-T",
  148. "PP-LiteSeg-B",
  149. "OCRNet_HRNet-W48",
  150. "OCRNet_HRNet-W18",
  151. "SegFormer-B0",
  152. "SegFormer-B1",
  153. "SegFormer-B2",
  154. "SegFormer-B3",
  155. "SegFormer-B4",
  156. "SegFormer-B5",
  157. "SeaFormer_tiny",
  158. "SeaFormer_small",
  159. "SeaFormer_base",
  160. "SeaFormer_large",
  161. "Mask-RT-DETR-H",
  162. "Mask-RT-DETR-L",
  163. "PP-OCRv4_server_rec",
  164. "Mask-RT-DETR-S",
  165. "Mask-RT-DETR-M",
  166. "Mask-RT-DETR-X",
  167. "SOLOv2",
  168. "MaskRCNN-ResNet50",
  169. "MaskRCNN-ResNet50-FPN",
  170. "MaskRCNN-ResNet50-vd-FPN",
  171. "MaskRCNN-ResNet101-FPN",
  172. "MaskRCNN-ResNet101-vd-FPN",
  173. "MaskRCNN-ResNeXt101-vd-FPN",
  174. "Cascade-MaskRCNN-ResNet50-FPN",
  175. "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN",
  176. "PP-YOLOE_seg-S",
  177. "PP-OCRv3_mobile_rec",
  178. "en_PP-OCRv3_mobile_rec",
  179. "korean_PP-OCRv3_mobile_rec",
  180. "japan_PP-OCRv3_mobile_rec",
  181. "chinese_cht_PP-OCRv3_mobile_rec",
  182. "te_PP-OCRv3_mobile_rec",
  183. "ka_PP-OCRv3_mobile_rec",
  184. "ta_PP-OCRv3_mobile_rec",
  185. "latin_PP-OCRv3_mobile_rec",
  186. "arabic_PP-OCRv3_mobile_rec",
  187. "cyrillic_PP-OCRv3_mobile_rec",
  188. "devanagari_PP-OCRv3_mobile_rec",
  189. "en_PP-OCRv4_mobile_rec",
  190. "PP-OCRv4_server_rec_doc",
  191. "PP-OCRv4_mobile_rec",
  192. "PP-OCRv4_server_det",
  193. "PP-OCRv4_mobile_det",
  194. "PP-OCRv3_server_det",
  195. "PP-OCRv3_mobile_det",
  196. "PP-OCRv4_server_seal_det",
  197. "PP-OCRv4_mobile_seal_det",
  198. "ch_RepSVTR_rec",
  199. "ch_SVTRv2_rec",
  200. "PP-LCNet_x1_0_pedestrian_attribute",
  201. "PP-LCNet_x1_0_vehicle_attribute",
  202. "PicoDet_layout_1x",
  203. "PicoDet_layout_1x_table",
  204. "SLANet",
  205. "SLANet_plus",
  206. "LaTeX_OCR_rec",
  207. "UniMERNet",
  208. "PP-FormulaNet-S",
  209. "PP-FormulaNet-L",
  210. "PP-FormulaNet_plus-S",
  211. "PP-FormulaNet_plus-M",
  212. "PP-FormulaNet_plus-L",
  213. "FasterRCNN-ResNet34-FPN",
  214. "FasterRCNN-ResNet50",
  215. "FasterRCNN-ResNet50-FPN",
  216. "FasterRCNN-ResNet50-vd-FPN",
  217. "FasterRCNN-ResNet50-vd-SSLDv2-FPN",
  218. "FasterRCNN-ResNet101",
  219. "FasterRCNN-ResNet101-FPN",
  220. "FasterRCNN-ResNeXt101-vd-FPN",
  221. "FasterRCNN-Swin-Tiny-FPN",
  222. "Cascade-FasterRCNN-ResNet50-FPN",
  223. "Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN",
  224. "UVDoc",
  225. "DLinear",
  226. "NLinear",
  227. "RLinear",
  228. "Nonstationary",
  229. "TimesNet",
  230. "TiDE",
  231. "PatchTST",
  232. "DLinear_ad",
  233. "AutoEncoder_ad",
  234. "Nonstationary_ad",
  235. "PatchTST_ad",
  236. "TimesNet_ad",
  237. "TimesNet_cls",
  238. "STFPM",
  239. "FCOS-ResNet50",
  240. "DETR-R50",
  241. "PP-YOLOE-L_vehicle",
  242. "PP-YOLOE-S_vehicle",
  243. "PP-ShiTuV2_det",
  244. "PP-YOLOE-S_human",
  245. "PP-YOLOE-L_human",
  246. "PicoDet-M",
  247. "PicoDet-XS",
  248. "PP-YOLOE_plus_SOD-L",
  249. "PP-YOLOE_plus_SOD-S",
  250. "PP-YOLOE_plus_SOD-largesize-L",
  251. "CenterNet-DLA-34",
  252. "CenterNet-ResNet50",
  253. "PicoDet-S_layout_3cls",
  254. "PicoDet-S_layout_17cls",
  255. "PicoDet-L_layout_3cls",
  256. "PicoDet-L_layout_17cls",
  257. "RT-DETR-H_layout_3cls",
  258. "RT-DETR-H_layout_17cls",
  259. "PicoDet_LCNet_x2_5_face",
  260. "BlazeFace",
  261. "BlazeFace-FPN-SSH",
  262. "PP-YOLOE_plus-S_face",
  263. "MobileFaceNet",
  264. "ResNet50_face",
  265. "PP-YOLOE-R-L",
  266. "Co-Deformable-DETR-R50",
  267. "Co-Deformable-DETR-Swin-T",
  268. "Co-DINO-R50",
  269. "Co-DINO-Swin-L",
  270. "whisper_large",
  271. "whisper_base",
  272. "whisper_medium",
  273. "whisper_small",
  274. "whisper_tiny",
  275. "PP-TSM-R50_8frames_uniform",
  276. "PP-TSMv2-LCNetV2_8frames_uniform",
  277. "PP-TSMv2-LCNetV2_16frames_uniform",
  278. "MaskFormer_tiny",
  279. "MaskFormer_small",
  280. "PP-LCNet_x1_0_table_cls",
  281. "SLANeXt_wired",
  282. "SLANeXt_wireless",
  283. "RT-DETR-L_wired_table_cell_det",
  284. "RT-DETR-L_wireless_table_cell_det",
  285. "YOWO",
  286. "PP-TinyPose_128x96",
  287. "PP-TinyPose_256x192",
  288. "GroundingDINO-T",
  289. "SAM-H_box",
  290. "SAM-H_point",
  291. "PP-DocLayoutV2",
  292. "PP-DocLayout-L",
  293. "PP-DocLayout-M",
  294. "PP-DocLayout-S",
  295. "PP-DocLayout_plus-L",
  296. "PP-DocBlockLayout",
  297. "BEVFusion",
  298. "YOLO-Worldv2-L",
  299. "PP-DocBee-2B",
  300. "PP-DocBee-7B",
  301. "PP-Chart2Table",
  302. "PP-OCRv5_server_det",
  303. "PP-OCRv5_mobile_det",
  304. "PP-OCRv5_server_rec",
  305. "PP-OCRv5_mobile_rec",
  306. "eslav_PP-OCRv5_mobile_rec",
  307. "PP-DocBee2-3B",
  308. "latin_PP-OCRv5_mobile_rec",
  309. "korean_PP-OCRv5_mobile_rec",
  310. "th_PP-OCRv5_mobile_rec",
  311. "el_PP-OCRv5_mobile_rec",
  312. "en_PP-OCRv5_mobile_rec",
  313. "arabic_PP-OCRv5_mobile_rec",
  314. "te_PP-OCRv5_mobile_rec",
  315. "ta_PP-OCRv5_mobile_rec",
  316. "devanagari_PP-OCRv5_mobile_rec",
  317. "cyrillic_PP-OCRv5_mobile_rec",
  318. ]
  319. OCR_MODELS = [
  320. "arabic_PP-OCRv3_mobile_rec",
  321. "chinese_cht_PP-OCRv3_mobile_rec",
  322. "ch_RepSVTR_rec",
  323. "ch_SVTRv2_rec",
  324. "cyrillic_PP-OCRv3_mobile_rec",
  325. "devanagari_PP-OCRv3_mobile_rec",
  326. "en_PP-OCRv3_mobile_rec",
  327. "en_PP-OCRv4_mobile_rec",
  328. "eslav_PP-OCRv5_mobile_rec",
  329. "japan_PP-OCRv3_mobile_rec",
  330. "ka_PP-OCRv3_mobile_rec",
  331. "korean_PP-OCRv3_mobile_rec",
  332. "korean_PP-OCRv5_mobile_rec",
  333. "LaTeX_OCR_rec",
  334. "latin_PP-OCRv3_mobile_rec",
  335. "latin_PP-OCRv5_mobile_rec",
  336. "en_PP-OCRv5_mobile_rec",
  337. "th_PP-OCRv5_mobile_rec",
  338. "el_PP-OCRv5_mobile_rec",
  339. "PaddleOCR-VL-0.9B",
  340. "PicoDet_layout_1x",
  341. "PicoDet_layout_1x_table",
  342. "PicoDet-L_layout_17cls",
  343. "PicoDet-L_layout_3cls",
  344. "PicoDet-S_layout_17cls",
  345. "PicoDet-S_layout_3cls",
  346. "PP-DocBee2-3B",
  347. "PP-Chart2Table",
  348. "PP-DocBee-2B",
  349. "PP-DocBee-7B",
  350. "PP-DocBlockLayout",
  351. "PP-DocLayoutV2",
  352. "PP-DocLayout-L",
  353. "PP-DocLayout-M",
  354. "PP-DocLayout_plus-L",
  355. "PP-DocLayout-S",
  356. "PP-DocLayoutV2",
  357. "PP-FormulaNet-L",
  358. "PP-FormulaNet_plus-L",
  359. "PP-FormulaNet_plus-M",
  360. "PP-FormulaNet_plus-S",
  361. "PP-FormulaNet-S",
  362. "PP-LCNet_x0_25_textline_ori",
  363. "PP-LCNet_x1_0_doc_ori",
  364. "PP-LCNet_x1_0_table_cls",
  365. "PP-LCNet_x1_0_textline_ori",
  366. "PP-OCRv3_mobile_det",
  367. "PP-OCRv3_mobile_rec",
  368. "PP-OCRv3_server_det",
  369. "PP-OCRv4_mobile_det",
  370. "PP-OCRv4_mobile_rec",
  371. "PP-OCRv4_mobile_seal_det",
  372. "PP-OCRv4_server_det",
  373. "PP-OCRv4_server_rec_doc",
  374. "PP-OCRv4_server_rec",
  375. "PP-OCRv4_server_seal_det",
  376. "PP-OCRv5_mobile_det",
  377. "PP-OCRv5_mobile_rec",
  378. "PP-OCRv5_server_det",
  379. "PP-OCRv5_server_rec",
  380. "RT-DETR-H_layout_17cls",
  381. "RT-DETR-H_layout_3cls",
  382. "RT-DETR-L_wired_table_cell_det",
  383. "RT-DETR-L_wireless_table_cell_det",
  384. "SLANet",
  385. "SLANet_plus",
  386. "SLANeXt_wired",
  387. "SLANeXt_wireless",
  388. "ta_PP-OCRv3_mobile_rec",
  389. "te_PP-OCRv3_mobile_rec",
  390. "UniMERNet",
  391. "UVDoc",
  392. "arabic_PP-OCRv5_mobile_rec",
  393. "te_PP-OCRv5_mobile_rec",
  394. "ta_PP-OCRv5_mobile_rec",
  395. "devanagari_PP-OCRv5_mobile_rec",
  396. "cyrillic_PP-OCRv5_mobile_rec",
  397. ]
  398. class _BaseModelHoster(ABC):
  399. alias = ""
  400. model_list = []
  401. healthcheck_url = None
  402. _healthcheck_timeout = 1
  403. def __init__(self, save_dir):
  404. self._save_dir = save_dir
  405. def get_model(self, model_name):
  406. assert (
  407. model_name in self.model_list
  408. ), f"The model {model_name} is not supported on hosting {self.__class__.__name__}!"
  409. if model_name == "PaddleOCR-VL-0.9B":
  410. model_name = "PaddleOCR-VL"
  411. model_dir = self._save_dir / f"{model_name}"
  412. if os.path.exists(model_dir):
  413. logging.info(
  414. f"Model files already exist. Using cached files. To redownload, please delete the directory manually: `{model_dir}`."
  415. )
  416. else:
  417. logging.info(
  418. f"Using official model ({model_name}), the model files will be automatically downloaded and saved in `{model_dir}`."
  419. )
  420. self._download(model_name, model_dir)
  421. return (
  422. model_dir / "PaddleOCR-VL-0.9B"
  423. if model_name == "PaddleOCR-VL"
  424. else model_dir
  425. )
  426. @abstractmethod
  427. def _download(self):
  428. raise NotImplementedError
  429. @classmethod
  430. def is_available(cls):
  431. if cls.healthcheck_url is None:
  432. return True
  433. try:
  434. response = requests.head(
  435. cls.healthcheck_url, timeout=cls._healthcheck_timeout
  436. )
  437. return response.ok == True
  438. except Exception:
  439. logging.debug(f"The model hosting platform({cls.__name__}) is unreachable!")
  440. return False
  441. class _BosModelHoster(_BaseModelHoster):
  442. model_list = ALL_MODELS
  443. alias = "bos"
  444. healthcheck_url = "https://paddle-model-ecology.bj.bcebos.com"
  445. version = "paddle3.0.0"
  446. base_url = (
  447. "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model"
  448. )
  449. special_model_fn = {
  450. "whisper_large": "whisper_large.tar",
  451. "whisper_base": "whisper_base.tar",
  452. "whisper_medium": "whisper_medium.tar",
  453. "whisper_small": "whisper_small.tar",
  454. "whisper_tiny": "whisper_tiny.tar",
  455. }
  456. def _download(self, model_name, save_dir):
  457. if model_name in self.special_model_fn:
  458. fn = self.special_model_fn[model_name]
  459. else:
  460. fn = f"{model_name}_infer.tar"
  461. url = f"{self.base_url}/{self.version}/{fn}"
  462. download_and_extract(url, save_dir.parent, model_name, overwrite=False)
  463. class _HuggingFaceModelHoster(_BaseModelHoster):
  464. model_list = OCR_MODELS
  465. alias = "huggingface"
  466. healthcheck_url = "https://huggingface.co"
  467. def _download(self, model_name, save_dir):
  468. def _clone(local_dir):
  469. hf_hub.snapshot_download(
  470. repo_id=f"PaddlePaddle/{model_name}", local_dir=local_dir
  471. )
  472. if os.path.exists(save_dir):
  473. _clone(save_dir)
  474. else:
  475. with tempfile.TemporaryDirectory() as td:
  476. temp_dir = os.path.join(td, "temp_dir")
  477. _clone(temp_dir)
  478. shutil.move(temp_dir, save_dir)
  479. class _ModelScopeModelHoster(_BaseModelHoster):
  480. model_list = OCR_MODELS
  481. alias = "modelscope"
  482. healthcheck_url = "https://modelscope.cn"
  483. def _download(self, model_name, save_dir):
  484. def _clone(local_dir):
  485. modelscope.snapshot_download(
  486. repo_id=f"PaddlePaddle/{model_name}", local_dir=local_dir
  487. )
  488. if os.path.exists(save_dir):
  489. _clone(save_dir)
  490. else:
  491. with tempfile.TemporaryDirectory() as td:
  492. temp_dir = os.path.join(td, "temp_dir")
  493. _clone(temp_dir)
  494. shutil.move(temp_dir, save_dir)
  495. class _AIStudioModelHoster(_BaseModelHoster):
  496. model_list = OCR_MODELS
  497. alias = "aistudio"
  498. healthcheck_url = "https://aistudio.baidu.com"
  499. def _download(self, model_name, save_dir):
  500. def _clone(local_dir):
  501. aistudio_download(repo_id=f"PaddleX/{model_name}", local_dir=local_dir)
  502. if os.path.exists(save_dir):
  503. _clone(save_dir)
  504. else:
  505. with tempfile.TemporaryDirectory() as td:
  506. temp_dir = os.path.join(td, "temp_dir")
  507. _clone(temp_dir)
  508. shutil.move(temp_dir, save_dir)
  509. class _ModelManager:
  510. model_list = ALL_MODELS
  511. _save_dir = Path(CACHE_DIR) / "official_models"
  512. def __init__(self) -> None:
  513. self._hosters = self._build_hosters()
  514. def _build_hosters(self):
  515. hosters = []
  516. for hoster_cls in [
  517. _HuggingFaceModelHoster,
  518. _AIStudioModelHoster,
  519. _ModelScopeModelHoster,
  520. _BosModelHoster,
  521. ]:
  522. if hoster_cls.alias == MODEL_SOURCE:
  523. if hoster_cls.is_available():
  524. hosters.insert(0, hoster_cls(self._save_dir))
  525. else:
  526. if hoster_cls.is_available():
  527. hosters.append(hoster_cls(self._save_dir))
  528. if len(hosters) == 0:
  529. logging.warning(
  530. f"""No model hoster is available! Please check your network connection to one of the following model hosts:
  531. HuggingFace ({_HuggingFaceModelHoster.healthcheck_url}),
  532. ModelScope ({_ModelScopeModelHoster.healthcheck_url}),
  533. AIStudio ({_AIStudioModelHoster.healthcheck_url}), or
  534. BOS ({_BosModelHoster.healthcheck_url}).
  535. Otherwise, only local models can be used."""
  536. )
  537. return hosters
  538. def _get_model_local_path(self, model_name):
  539. if len(self._hosters) == 0:
  540. msg = "No available model hosting platforms detected. Please check your network connection."
  541. logging.error(msg)
  542. raise Exception(msg)
  543. return self._download_from_hoster(self._hosters, model_name)
  544. def _download_from_hoster(self, hosters, model_name):
  545. for idx, hoster in enumerate(hosters):
  546. if model_name in hoster.model_list:
  547. try:
  548. return hoster.get_model(model_name)
  549. except Exception as e:
  550. logging.warning(
  551. f"Encounter exception when download model from {hoster.alias}: \n{e}."
  552. )
  553. if len(hosters) <= 1:
  554. raise Exception(
  555. f"No model source is available! Please check network or use local model files!"
  556. )
  557. logging.warning(
  558. f"PaddleX would try to download from other model sources."
  559. )
  560. return self._download_from_hoster(hosters[idx + 1 :], model_name)
  561. def __contains__(self, model_name):
  562. return model_name in self.model_list
  563. def __getitem__(self, model_name):
  564. return self._get_model_local_path(model_name)
  565. official_models = _ModelManager()