|
|
@@ -58,43 +58,90 @@ def configure_model(model_dir, model_type):
|
|
|
|
|
|
|
|
|
@click.command()
|
|
|
-def download_models():
|
|
|
+@click.option(
|
|
|
+ '-s',
|
|
|
+ '--source',
|
|
|
+ 'model_source',
|
|
|
+ type=click.Choice(['huggingface', 'modelscope']),
|
|
|
+ help="""
|
|
|
+ The source of the model repository.
|
|
|
+ """,
|
|
|
+ default=None,
|
|
|
+)
|
|
|
+@click.option(
|
|
|
+ '-m',
|
|
|
+ '--model_type',
|
|
|
+ 'model_type',
|
|
|
+ type=click.Choice(['pipeline', 'vlm', 'all']),
|
|
|
+ help="""
|
|
|
+ The type of the model to download.
|
|
|
+ """,
|
|
|
+ default=None,
|
|
|
+)
|
|
|
+def download_models(model_source, model_type):
|
|
|
"""Download MinerU model files.
|
|
|
|
|
|
Supports downloading pipeline or VLM models from ModelScope or HuggingFace.
|
|
|
"""
|
|
|
- # 交互式输入下载来源
|
|
|
- source = click.prompt(
|
|
|
- "Please select the model download source: ",
|
|
|
- type=click.Choice(['huggingface', 'modelscope']),
|
|
|
- default='huggingface'
|
|
|
- )
|
|
|
-
|
|
|
- os.environ['MINERU_MODEL_SOURCE'] = source
|
|
|
-
|
|
|
- # 交互式输入模型类型
|
|
|
- model_type = click.prompt(
|
|
|
- "Please select the model type to download: ",
|
|
|
- type=click.Choice(['pipeline', 'vlm']),
|
|
|
- default='pipeline'
|
|
|
- )
|
|
|
+ # 如果未显式指定则交互式输入下载来源
|
|
|
+ if model_source is None:
|
|
|
+ model_source = click.prompt(
|
|
|
+ "Please select the model download source: ",
|
|
|
+ type=click.Choice(['huggingface', 'modelscope']),
|
|
|
+ default='huggingface'
|
|
|
+ )
|
|
|
+
|
|
|
+ if os.getenv('MINERU_MODEL_SOURCE', None) is None:
|
|
|
+ os.environ['MINERU_MODEL_SOURCE'] = model_source
|
|
|
+
|
|
|
+ # 如果未显式指定则交互式输入模型类型
|
|
|
+ if model_type is None:
|
|
|
+ model_type = click.prompt(
|
|
|
+ "Please select the model type to download: ",
|
|
|
+ type=click.Choice(['pipeline', 'vlm', 'all']),
|
|
|
+ default='all'
|
|
|
+ )
|
|
|
+
|
|
|
+ click.echo(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
|
|
|
+
|
|
|
+ def download_pipeline_models():
|
|
|
+ """下载Pipeline模型"""
|
|
|
+ model_paths = [
|
|
|
+ ModelPath.doclayout_yolo,
|
|
|
+ ModelPath.yolo_v8_mfd,
|
|
|
+ ModelPath.unimernet_small,
|
|
|
+ ModelPath.pytorch_paddle,
|
|
|
+ ModelPath.layout_reader,
|
|
|
+ ModelPath.slanet_plus
|
|
|
+ ]
|
|
|
+ download_finish_path = ""
|
|
|
+ for model_path in model_paths:
|
|
|
+ click.echo(f"Downloading model: {model_path}")
|
|
|
+ download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
|
|
|
+ click.echo(f"Pipeline models downloaded successfully to: {download_finish_path}")
|
|
|
+ configure_model(download_finish_path, model_type)
|
|
|
|
|
|
- click.echo(f"Downloading {model_type} model from {source}...")
|
|
|
+ def download_vlm_models():
|
|
|
+ """下载VLM模型"""
|
|
|
+ download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
|
|
|
+ click.echo(f"VLM models downloaded successfully to: {download_finish_path}")
|
|
|
+ configure_model(download_finish_path, model_type)
|
|
|
|
|
|
try:
|
|
|
- download_finish_path = ""
|
|
|
if model_type == 'pipeline':
|
|
|
- for model_path in [ModelPath.doclayout_yolo, ModelPath.yolo_v8_mfd, ModelPath.unimernet_small, ModelPath.pytorch_paddle, ModelPath.layout_reader, ModelPath.slanet_plus]:
|
|
|
- click.echo(f"Downloading model: {model_path}")
|
|
|
- download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode=model_type)
|
|
|
+ download_pipeline_models()
|
|
|
elif model_type == 'vlm':
|
|
|
- download_finish_path = auto_download_and_get_model_root_path("/", repo_mode=model_type)
|
|
|
- click.echo(f"Models downloaded successfully to: {download_finish_path}")
|
|
|
- configure_model(download_finish_path, model_type)
|
|
|
+ download_vlm_models()
|
|
|
+ elif model_type == 'all':
|
|
|
+ download_pipeline_models()
|
|
|
+ download_vlm_models()
|
|
|
+ else:
|
|
|
+ click.echo(f"Unsupported model type: {model_type}", err=True)
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
except Exception as e:
|
|
|
click.echo(f"Download failed: {str(e)}", err=True)
|
|
|
sys.exit(1)
|
|
|
|
|
|
-
|
|
|
if __name__ == '__main__':
|
|
|
download_models()
|