# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil import tempfile from abc import ABC, abstractmethod from pathlib import Path import huggingface_hub as hf_hub hf_hub.logging.set_verbosity_error() import modelscope import requests os.environ["AISTUDIO_LOG"] = "critical" from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download from ...utils import logging from ...utils.cache import CACHE_DIR from ...utils.download import download_and_extract from ...utils.flags import MODEL_SOURCE ALL_MODELS = [ "ResNet18", "ResNet18_vd", "ResNet34", "ResNet34_vd", "ResNet50", "ResNet50_vd", "ResNet101", "ResNet101_vd", "ResNet152", "ResNet152_vd", "ResNet200_vd", "PaddleOCR-VL-0.9B", "PP-LCNet_x0_25", "PP-LCNet_x0_25_textline_ori", "PP-LCNet_x0_35", "PP-LCNet_x0_5", "PP-LCNet_x0_75", "PP-LCNet_x1_0", "PP-LCNet_x1_0_doc_ori", "PP-LCNet_x1_0_textline_ori", "PP-LCNet_x1_5", "PP-LCNet_x2_5", "PP-LCNet_x2_0", "PP-LCNetV2_small", "PP-LCNetV2_base", "PP-LCNetV2_large", "MobileNetV3_large_x0_35", "MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75", "MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25", "MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5", "MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0", "MobileNetV3_small_x1_25", "ConvNeXt_tiny", "ConvNeXt_small", "ConvNeXt_base_224", "ConvNeXt_base_384", "ConvNeXt_large_224", "ConvNeXt_large_384", "MobileNetV2_x0_25", "MobileNetV2_x0_5", "MobileNetV2_x1_0", "MobileNetV2_x1_5", "MobileNetV2_x2_0", "MobileNetV1_x0_25", "MobileNetV1_x0_5", "MobileNetV1_x0_75", "MobileNetV1_x1_0", "SwinTransformer_tiny_patch4_window7_224", "SwinTransformer_small_patch4_window7_224", "SwinTransformer_base_patch4_window7_224", "SwinTransformer_base_patch4_window12_384", "SwinTransformer_large_patch4_window7_224", "SwinTransformer_large_patch4_window12_384", "PP-HGNet_tiny", "PP-HGNet_small", "PP-HGNet_base", "PP-HGNetV2-B0", "PP-HGNetV2-B1", "PP-HGNetV2-B2", "PP-HGNetV2-B3", "PP-HGNetV2-B4", "PP-HGNetV2-B5", "PP-HGNetV2-B6", "FasterNet-L", "FasterNet-M", "FasterNet-S", "FasterNet-T0", "FasterNet-T1", "FasterNet-T2", "StarNet-S1", "StarNet-S2", "StarNet-S3", "StarNet-S4", "MobileNetV4_conv_small", "MobileNetV4_conv_medium", "MobileNetV4_conv_large", "MobileNetV4_hybrid_medium", "MobileNetV4_hybrid_large", "CLIP_vit_base_patch16_224", "CLIP_vit_large_patch14_224", "PP-LCNet_x1_0_ML", "PP-HGNetV2-B0_ML", "PP-HGNetV2-B4_ML", "PP-HGNetV2-B6_ML", "ResNet50_ML", "CLIP_vit_base_patch16_448_ML", "PP-YOLOE_plus-X", "PP-YOLOE_plus-L", "PP-YOLOE_plus-M", "PP-YOLOE_plus-S", "RT-DETR-L", "RT-DETR-H", "RT-DETR-X", "YOLOv3-DarkNet53", "YOLOv3-MobileNetV3", "YOLOv3-ResNet50_vd_DCN", "YOLOX-L", "YOLOX-M", "YOLOX-N", "YOLOX-S", "YOLOX-T", "YOLOX-X", "RT-DETR-R18", "RT-DETR-R50", "PicoDet-S", "PicoDet-L", "Deeplabv3-R50", "Deeplabv3-R101", "Deeplabv3_Plus-R50", "Deeplabv3_Plus-R101", "PP-ShiTuV2_rec", "PP-ShiTuV2_rec_CLIP_vit_base", "PP-ShiTuV2_rec_CLIP_vit_large", "PP-LiteSeg-T", "PP-LiteSeg-B", "OCRNet_HRNet-W48", "OCRNet_HRNet-W18", "SegFormer-B0", "SegFormer-B1", "SegFormer-B2", "SegFormer-B3", "SegFormer-B4", "SegFormer-B5", "SeaFormer_tiny", "SeaFormer_small", "SeaFormer_base", "SeaFormer_large", "Mask-RT-DETR-H", "Mask-RT-DETR-L", "PP-OCRv4_server_rec", "Mask-RT-DETR-S", "Mask-RT-DETR-M", "Mask-RT-DETR-X", "SOLOv2", "MaskRCNN-ResNet50", "MaskRCNN-ResNet50-FPN", "MaskRCNN-ResNet50-vd-FPN", "MaskRCNN-ResNet101-FPN", "MaskRCNN-ResNet101-vd-FPN", "MaskRCNN-ResNeXt101-vd-FPN", "Cascade-MaskRCNN-ResNet50-FPN", "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN", "PP-YOLOE_seg-S", "PP-OCRv3_mobile_rec", "en_PP-OCRv3_mobile_rec", "korean_PP-OCRv3_mobile_rec", "japan_PP-OCRv3_mobile_rec", "chinese_cht_PP-OCRv3_mobile_rec", "te_PP-OCRv3_mobile_rec", "ka_PP-OCRv3_mobile_rec", "ta_PP-OCRv3_mobile_rec", "latin_PP-OCRv3_mobile_rec", "arabic_PP-OCRv3_mobile_rec", "cyrillic_PP-OCRv3_mobile_rec", "devanagari_PP-OCRv3_mobile_rec", "en_PP-OCRv4_mobile_rec", "PP-OCRv4_server_rec_doc", "PP-OCRv4_mobile_rec", "PP-OCRv4_server_det", "PP-OCRv4_mobile_det", "PP-OCRv3_server_det", "PP-OCRv3_mobile_det", "PP-OCRv4_server_seal_det", "PP-OCRv4_mobile_seal_det", "ch_RepSVTR_rec", "ch_SVTRv2_rec", "PP-LCNet_x1_0_pedestrian_attribute", "PP-LCNet_x1_0_vehicle_attribute", "PicoDet_layout_1x", "PicoDet_layout_1x_table", "SLANet", "SLANet_plus", "LaTeX_OCR_rec", "UniMERNet", "PP-FormulaNet-S", "PP-FormulaNet-L", "PP-FormulaNet_plus-S", "PP-FormulaNet_plus-M", "PP-FormulaNet_plus-L", "FasterRCNN-ResNet34-FPN", "FasterRCNN-ResNet50", "FasterRCNN-ResNet50-FPN", "FasterRCNN-ResNet50-vd-FPN", "FasterRCNN-ResNet50-vd-SSLDv2-FPN", "FasterRCNN-ResNet101", "FasterRCNN-ResNet101-FPN", "FasterRCNN-ResNeXt101-vd-FPN", "FasterRCNN-Swin-Tiny-FPN", "Cascade-FasterRCNN-ResNet50-FPN", "Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN", "UVDoc", "DLinear", "NLinear", "RLinear", "Nonstationary", "TimesNet", "TiDE", "PatchTST", "DLinear_ad", "AutoEncoder_ad", "Nonstationary_ad", "PatchTST_ad", "TimesNet_ad", "TimesNet_cls", "STFPM", "FCOS-ResNet50", "DETR-R50", "PP-YOLOE-L_vehicle", "PP-YOLOE-S_vehicle", "PP-ShiTuV2_det", "PP-YOLOE-S_human", "PP-YOLOE-L_human", "PicoDet-M", "PicoDet-XS", "PP-YOLOE_plus_SOD-L", "PP-YOLOE_plus_SOD-S", "PP-YOLOE_plus_SOD-largesize-L", "CenterNet-DLA-34", "CenterNet-ResNet50", "PicoDet-S_layout_3cls", "PicoDet-S_layout_17cls", "PicoDet-L_layout_3cls", "PicoDet-L_layout_17cls", "RT-DETR-H_layout_3cls", "RT-DETR-H_layout_17cls", "PicoDet_LCNet_x2_5_face", "BlazeFace", "BlazeFace-FPN-SSH", "PP-YOLOE_plus-S_face", "MobileFaceNet", "ResNet50_face", "PP-YOLOE-R-L", "Co-Deformable-DETR-R50", "Co-Deformable-DETR-Swin-T", "Co-DINO-R50", "Co-DINO-Swin-L", "whisper_large", "whisper_base", "whisper_medium", "whisper_small", "whisper_tiny", "PP-TSM-R50_8frames_uniform", "PP-TSMv2-LCNetV2_8frames_uniform", "PP-TSMv2-LCNetV2_16frames_uniform", "MaskFormer_tiny", "MaskFormer_small", "PP-LCNet_x1_0_table_cls", "SLANeXt_wired", "SLANeXt_wireless", "RT-DETR-L_wired_table_cell_det", "RT-DETR-L_wireless_table_cell_det", "YOWO", "PP-TinyPose_128x96", "PP-TinyPose_256x192", "GroundingDINO-T", "SAM-H_box", "SAM-H_point", "PP-DocLayoutV2", "PP-DocLayout-L", "PP-DocLayout-M", "PP-DocLayout-S", "PP-DocLayout_plus-L", "PP-DocBlockLayout", "BEVFusion", "YOLO-Worldv2-L", "PP-DocBee-2B", "PP-DocBee-7B", "PP-Chart2Table", "PP-OCRv5_server_det", "PP-OCRv5_mobile_det", "PP-OCRv5_server_rec", "PP-OCRv5_mobile_rec", "eslav_PP-OCRv5_mobile_rec", "PP-DocBee2-3B", "latin_PP-OCRv5_mobile_rec", "korean_PP-OCRv5_mobile_rec", "th_PP-OCRv5_mobile_rec", "el_PP-OCRv5_mobile_rec", "en_PP-OCRv5_mobile_rec", "arabic_PP-OCRv5_mobile_rec", "te_PP-OCRv5_mobile_rec", "ta_PP-OCRv5_mobile_rec", "devanagari_PP-OCRv5_mobile_rec", "cyrillic_PP-OCRv5_mobile_rec", ] OCR_MODELS = [ "arabic_PP-OCRv3_mobile_rec", "chinese_cht_PP-OCRv3_mobile_rec", "ch_RepSVTR_rec", "ch_SVTRv2_rec", "cyrillic_PP-OCRv3_mobile_rec", "devanagari_PP-OCRv3_mobile_rec", "en_PP-OCRv3_mobile_rec", "en_PP-OCRv4_mobile_rec", "eslav_PP-OCRv5_mobile_rec", "japan_PP-OCRv3_mobile_rec", "ka_PP-OCRv3_mobile_rec", "korean_PP-OCRv3_mobile_rec", "korean_PP-OCRv5_mobile_rec", "LaTeX_OCR_rec", "latin_PP-OCRv3_mobile_rec", "latin_PP-OCRv5_mobile_rec", "en_PP-OCRv5_mobile_rec", "th_PP-OCRv5_mobile_rec", "el_PP-OCRv5_mobile_rec", "PaddleOCR-VL-0.9B", "PicoDet_layout_1x", "PicoDet_layout_1x_table", "PicoDet-L_layout_17cls", "PicoDet-L_layout_3cls", "PicoDet-S_layout_17cls", "PicoDet-S_layout_3cls", "PP-DocBee2-3B", "PP-Chart2Table", "PP-DocBee-2B", "PP-DocBee-7B", "PP-DocBlockLayout", "PP-DocLayoutV2", "PP-DocLayout-L", "PP-DocLayout-M", "PP-DocLayout_plus-L", "PP-DocLayout-S", "PP-DocLayoutV2", "PP-FormulaNet-L", "PP-FormulaNet_plus-L", "PP-FormulaNet_plus-M", "PP-FormulaNet_plus-S", "PP-FormulaNet-S", "PP-LCNet_x0_25_textline_ori", "PP-LCNet_x1_0_doc_ori", "PP-LCNet_x1_0_table_cls", "PP-LCNet_x1_0_textline_ori", "PP-OCRv3_mobile_det", "PP-OCRv3_mobile_rec", "PP-OCRv3_server_det", "PP-OCRv4_mobile_det", "PP-OCRv4_mobile_rec", "PP-OCRv4_mobile_seal_det", "PP-OCRv4_server_det", "PP-OCRv4_server_rec_doc", "PP-OCRv4_server_rec", "PP-OCRv4_server_seal_det", "PP-OCRv5_mobile_det", "PP-OCRv5_mobile_rec", "PP-OCRv5_server_det", "PP-OCRv5_server_rec", "RT-DETR-H_layout_17cls", "RT-DETR-H_layout_3cls", "RT-DETR-L_wired_table_cell_det", "RT-DETR-L_wireless_table_cell_det", "SLANet", "SLANet_plus", "SLANeXt_wired", "SLANeXt_wireless", "ta_PP-OCRv3_mobile_rec", "te_PP-OCRv3_mobile_rec", "UniMERNet", "UVDoc", "arabic_PP-OCRv5_mobile_rec", "te_PP-OCRv5_mobile_rec", "ta_PP-OCRv5_mobile_rec", "devanagari_PP-OCRv5_mobile_rec", "cyrillic_PP-OCRv5_mobile_rec", ] class _BaseModelHoster(ABC): alias = "" model_list = [] healthcheck_url = None _healthcheck_timeout = 1 def __init__(self, save_dir): self._save_dir = save_dir def get_model(self, model_name): assert ( model_name in self.model_list ), f"The model {model_name} is not supported on hosting {self.__class__.__name__}!" if model_name == "PaddleOCR-VL-0.9B": model_name = "PaddleOCR-VL" model_dir = self._save_dir / f"{model_name}" if os.path.exists(model_dir): logging.info( f"Model files already exist. Using cached files. To redownload, please delete the directory manually: `{model_dir}`." ) else: logging.info( f"Using official model ({model_name}), the model files will be automatically downloaded and saved in `{model_dir}`." ) self._download(model_name, model_dir) logging.debug( f"`{model_name}` model files has been download from model source: `{self.alias}`!" ) if model_name == "PaddleOCR-VL": vl_model_dir = model_dir / "PaddleOCR-VL-0.9B" if vl_model_dir.exists() and vl_model_dir.is_dir(): return vl_model_dir return model_dir @abstractmethod def _download(self): raise NotImplementedError @classmethod def is_available(cls): if cls.healthcheck_url is None: return True try: response = requests.head( cls.healthcheck_url, timeout=cls._healthcheck_timeout ) return response.ok == True except Exception: logging.debug(f"The model hosting platform({cls.__name__}) is unreachable!") return False class _BosModelHoster(_BaseModelHoster): model_list = ALL_MODELS alias = "bos" healthcheck_url = "https://paddle-model-ecology.bj.bcebos.com" version = "paddle3.0.0" base_url = ( "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model" ) special_model_fn = { "whisper_large": "whisper_large.tar", "whisper_base": "whisper_base.tar", "whisper_medium": "whisper_medium.tar", "whisper_small": "whisper_small.tar", "whisper_tiny": "whisper_tiny.tar", } def _download(self, model_name, save_dir): if model_name in self.special_model_fn: fn = self.special_model_fn[model_name] else: fn = f"{model_name}_infer.tar" url = f"{self.base_url}/{self.version}/{fn}" download_and_extract(url, save_dir.parent, model_name, overwrite=False) class _HuggingFaceModelHoster(_BaseModelHoster): model_list = OCR_MODELS alias = "huggingface" healthcheck_url = "https://huggingface.co" def _download(self, model_name, save_dir): def _clone(local_dir): hf_hub.snapshot_download( repo_id=f"PaddlePaddle/{model_name}", local_dir=local_dir ) if os.path.exists(save_dir): _clone(save_dir) else: with tempfile.TemporaryDirectory() as td: temp_dir = os.path.join(td, "temp_dir") _clone(temp_dir) shutil.move(temp_dir, save_dir) class _ModelScopeModelHoster(_BaseModelHoster): model_list = OCR_MODELS alias = "modelscope" healthcheck_url = "https://modelscope.cn" def _download(self, model_name, save_dir): def _clone(local_dir): modelscope.snapshot_download( repo_id=f"PaddlePaddle/{model_name}", local_dir=local_dir ) if os.path.exists(save_dir): _clone(save_dir) else: with tempfile.TemporaryDirectory() as td: temp_dir = os.path.join(td, "temp_dir") _clone(temp_dir) shutil.move(temp_dir, save_dir) class _AIStudioModelHoster(_BaseModelHoster): model_list = OCR_MODELS alias = "aistudio" healthcheck_url = "https://aistudio.baidu.com" def _download(self, model_name, save_dir): def _clone(local_dir): if model_name == "PaddleOCR-VL": aistudio_download( repo_id=f"PaddlePaddle/{model_name}", local_dir=local_dir ) else: aistudio_download(repo_id=f"PaddleX/{model_name}", local_dir=local_dir) if os.path.exists(save_dir): _clone(save_dir) else: with tempfile.TemporaryDirectory() as td: temp_dir = os.path.join(td, "temp_dir") _clone(temp_dir) shutil.move(temp_dir, save_dir) class _ModelManager: model_list = ALL_MODELS _save_dir = Path(CACHE_DIR) / "official_models" def __init__(self) -> None: self._hosters = self._build_hosters() def _build_hosters(self): hosters = [] for hoster_cls in [ _HuggingFaceModelHoster, _AIStudioModelHoster, _ModelScopeModelHoster, _BosModelHoster, ]: if hoster_cls.alias == MODEL_SOURCE: if hoster_cls.is_available(): hosters.insert(0, hoster_cls(self._save_dir)) else: if hoster_cls.is_available(): hosters.append(hoster_cls(self._save_dir)) if len(hosters) == 0: logging.warning( f"""No model hoster is available! Please check your network connection to one of the following model hosts: HuggingFace ({_HuggingFaceModelHoster.healthcheck_url}), ModelScope ({_ModelScopeModelHoster.healthcheck_url}), AIStudio ({_AIStudioModelHoster.healthcheck_url}), or BOS ({_BosModelHoster.healthcheck_url}). Otherwise, only local models can be used.""" ) return hosters def _get_model_local_path(self, model_name): if len(self._hosters) == 0: msg = "No available model hosting platforms detected. Please check your network connection." logging.error(msg) raise Exception(msg) return self._download_from_hoster(self._hosters, model_name) def _download_from_hoster(self, hosters, model_name): for idx, hoster in enumerate(hosters): if model_name in hoster.model_list: try: model_path = hoster.get_model(model_name) return model_path except Exception as e: if len(hosters) <= 1: raise Exception( f"Encounter exception when download model from {hoster.alias}. No model source is available! Please check network or use local model files!" ) logging.warning( f"Encountering exception when download model from {hoster.alias}: \n{e}, will try to download from other model sources: `{hosters[idx + 1].alias}`." ) return self._download_from_hoster(hosters[idx + 1 :], model_name) def __contains__(self, model_name): return model_name in self.model_list def __getitem__(self, model_name): return self._get_model_local_path(model_name) official_models = _ModelManager()