Browse Source

feat: enhance model download options with source and type parameters

myhloli 5 months ago
parent
commit
705f83319f
2 changed files with 75 additions and 30 deletions
  1. 3 5
      mineru/cli/client.py
  2. 72 25
      mineru/cli/models_download.py

+ 3 - 5
mineru/cli/client.py

@@ -24,7 +24,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
 )
 @click.option(
     '-o',
-    '--output-dir',
+    '--output',
     'output_dir',
     type=click.Path(),
     required=True,
@@ -118,16 +118,14 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
     default=None,
 )
 @click.option(
-    '-vm',
-    '--virtual-vram',
+    '--vram',
     'virtual_vram',
     type=int,
     help='Upper limit of GPU memory occupied by a single process. Adapted only for the case where the backend is set to "pipeline". ',
     default=None,
 )
 @click.option(
-    '-r',
-    '--repo',
+    '--source',
     'model_source',
     type=click.Choice(['huggingface', 'modelscope', 'local']),
     help="""

+ 72 - 25
mineru/cli/models_download.py

@@ -58,43 +58,90 @@ def configure_model(model_dir, model_type):
 
 
 @click.command()
-def download_models():
+@click.option(
+    '-s',
+    '--source',
+    'model_source',
+    type=click.Choice(['huggingface', 'modelscope']),
+    help="""
+        The source of the model repository. 
+        """,
+    default=None,
+)
+@click.option(
+    '-m',
+    '--model_type',
+    'model_type',
+    type=click.Choice(['pipeline', 'vlm', 'all']),
+    help="""
+        The type of the model to download.
+        """,
+    default=None,
+)
+def download_models(model_source, model_type):
     """Download MinerU model files.
 
     Supports downloading pipeline or VLM models from ModelScope or HuggingFace.
     """
-    # 交互式输入下载来源
-    source = click.prompt(
-        "Please select the model download source: ",
-        type=click.Choice(['huggingface', 'modelscope']),
-        default='huggingface'
-    )
-
-    os.environ['MINERU_MODEL_SOURCE'] = source
-
-    # 交互式输入模型类型
-    model_type = click.prompt(
-        "Please select the model type to download: ",
-        type=click.Choice(['pipeline', 'vlm']),
-        default='pipeline'
-    )
+    # 如果未显式指定则交互式输入下载来源
+    if model_source is None:
+        model_source = click.prompt(
+            "Please select the model download source: ",
+            type=click.Choice(['huggingface', 'modelscope']),
+            default='huggingface'
+        )
+
+    if os.getenv('MINERU_MODEL_SOURCE', None) is None:
+        os.environ['MINERU_MODEL_SOURCE'] = model_source
+
+    # 如果未显式指定则交互式输入模型类型
+    if model_type is None:
+        model_type = click.prompt(
+            "Please select the model type to download: ",
+            type=click.Choice(['pipeline', 'vlm', 'all']),
+            default='all'
+        )
+
+    click.echo(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
+
+    def download_pipeline_models():
+        """下载Pipeline模型"""
+        model_paths = [
+            ModelPath.doclayout_yolo,
+            ModelPath.yolo_v8_mfd,
+            ModelPath.unimernet_small,
+            ModelPath.pytorch_paddle,
+            ModelPath.layout_reader,
+            ModelPath.slanet_plus
+        ]
+        download_finish_path = ""
+        for model_path in model_paths:
+            click.echo(f"Downloading model: {model_path}")
+            download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
+        click.echo(f"Pipeline models downloaded successfully to: {download_finish_path}")
+        configure_model(download_finish_path, model_type)
 
-    click.echo(f"Downloading {model_type} model from {source}...")
+    def download_vlm_models():
+        """下载VLM模型"""
+        download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
+        click.echo(f"VLM models downloaded successfully to: {download_finish_path}")
+        configure_model(download_finish_path, model_type)
 
     try:
-        download_finish_path = ""
         if model_type == 'pipeline':
-            for model_path in [ModelPath.doclayout_yolo, ModelPath.yolo_v8_mfd, ModelPath.unimernet_small, ModelPath.pytorch_paddle, ModelPath.layout_reader, ModelPath.slanet_plus]:
-                click.echo(f"Downloading model: {model_path}")
-                download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode=model_type)
+            download_pipeline_models()
         elif model_type == 'vlm':
-            download_finish_path = auto_download_and_get_model_root_path("/", repo_mode=model_type)
-        click.echo(f"Models downloaded successfully to: {download_finish_path}")
-        configure_model(download_finish_path, model_type)
+            download_vlm_models()
+        elif model_type == 'all':
+            download_pipeline_models()
+            download_vlm_models()
+        else:
+            click.echo(f"Unsupported model type: {model_type}", err=True)
+            sys.exit(1)
+
     except Exception as e:
         click.echo(f"Download failed: {str(e)}", err=True)
         sys.exit(1)
 
-
 if __name__ == '__main__':
     download_models()