download_model.py 903 B

12345678910111213141516171819
  1. from argparse import ArgumentParser
  2. import os
  3. if __name__ == '__main__':
  4. parser = ArgumentParser()
  5. parser.add_argument('--type', '-t', type=str, default="huggingface")
  6. parser.add_argument('--name', '-n', type=str, default="rednote-hilab/dots.ocr")
  7. args = parser.parse_args()
  8. script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  9. print(f"Attention: The model save dir dots.ocr should be replace by a name without `.` like DotsOCR, util we merge our code to transformers.")
  10. model_dir = os.path.join(script_dir, "weights/DotsOCR")
  11. if not os.path.exists(model_dir):
  12. os.makedirs(model_dir)
  13. if args.type == "huggingface":
  14. from huggingface_hub import snapshot_download
  15. snapshot_download(repo_id=args.name, local_dir=model_dir, local_dir_use_symlinks=False, resume_download=True)
  16. print(f"model downloaded to {model_dir}")