Ver Fonte

feat: implement model downloading functions with logging

myhloli há 4 meses atrás
pai
commit
c7c1e30e9f
1 ficheiros alterados com 28 adições e 25 exclusões
  1. 28 25
      mineru/cli/models_download.py

+ 28 - 25
mineru/cli/models_download.py

@@ -3,6 +3,7 @@ import os
 import sys
 import click
 import requests
+from loguru import logger
 
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
@@ -57,6 +58,31 @@ def configure_model(model_dir, model_type):
     print(f'The configuration file has been successfully configured, the path is: {config_file}')
 
 
+def download_pipeline_models():
+    """下载Pipeline模型"""
+    model_paths = [
+        ModelPath.doclayout_yolo,
+        ModelPath.yolo_v8_mfd,
+        ModelPath.unimernet_small,
+        ModelPath.pytorch_paddle,
+        ModelPath.layout_reader,
+        ModelPath.slanet_plus
+    ]
+    download_finish_path = ""
+    for model_path in model_paths:
+        click.echo(f"Downloading model: {model_path}")
+        download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
+    click.echo(f"Pipeline models downloaded successfully to: {download_finish_path}")
+    configure_model(download_finish_path, "pipeline")
+
+
+def download_vlm_models():
+    """下载VLM模型"""
+    download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
+    click.echo(f"VLM models downloaded successfully to: {download_finish_path}")
+    configure_model(download_finish_path, "vlm")
+
+
 @click.command()
 @click.option(
     '-s',
@@ -102,30 +128,7 @@ def download_models(model_source, model_type):
             default='all'
         )
 
-    click.echo(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
-
-    def download_pipeline_models():
-        """下载Pipeline模型"""
-        model_paths = [
-            ModelPath.doclayout_yolo,
-            ModelPath.yolo_v8_mfd,
-            ModelPath.unimernet_small,
-            ModelPath.pytorch_paddle,
-            ModelPath.layout_reader,
-            ModelPath.slanet_plus
-        ]
-        download_finish_path = ""
-        for model_path in model_paths:
-            click.echo(f"Downloading model: {model_path}")
-            download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
-        click.echo(f"Pipeline models downloaded successfully to: {download_finish_path}")
-        configure_model(download_finish_path, "pipeline")
-
-    def download_vlm_models():
-        """下载VLM模型"""
-        download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
-        click.echo(f"VLM models downloaded successfully to: {download_finish_path}")
-        configure_model(download_finish_path, "vlm")
+    logger.info(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
 
     try:
         if model_type == 'pipeline':
@@ -140,7 +143,7 @@ def download_models(model_source, model_type):
             sys.exit(1)
 
     except Exception as e:
-        click.echo(f"Download failed: {str(e)}", err=True)
+        logger.exception(f"An error occurred while downloading models: {str(e)}")
         sys.exit(1)
 
 if __name__ == '__main__':