Browse Source

refactor: extract command-line argument parsing to cli_parser.py and update usage in main functions

myhloli 4 months ago
parent
commit
e6f817fe6f
4 changed files with 68 additions and 86 deletions
  1. 2 36
      mineru/cli/client.py
  2. 25 14
      mineru/cli/fast_api.py
  3. 3 36
      mineru/cli/gradio_app.py
  4. 38 0
      mineru/utils/cli_parser.py

+ 2 - 36
mineru/cli/client.py

@@ -4,6 +4,7 @@ import click
 from pathlib import Path
 from pathlib import Path
 from loguru import logger
 from loguru import logger
 
 
+from mineru.utils.cli_parser import arg_parse
 from mineru.utils.config_reader import get_device
 from mineru.utils.config_reader import get_device
 from mineru.utils.model_utils import get_vram
 from mineru.utils.model_utils import get_vram
 from ..version import __version__
 from ..version import __version__
@@ -145,42 +146,7 @@ def main(
         device_mode, virtual_vram, model_source, **kwargs
         device_mode, virtual_vram, model_source, **kwargs
 ):
 ):
 
 
-    # 解析额外参数
-    extra_kwargs = {}
-    i = 0
-    while i < len(ctx.args):
-        arg = ctx.args[i]
-        if arg.startswith('--'):
-            param_name = arg[2:].replace('-', '_')  # 转换参数名格式
-            i += 1
-            if i < len(ctx.args) and not ctx.args[i].startswith('--'):
-                # 参数有值
-                try:
-                    # 尝试转换为适当的类型
-                    if ctx.args[i].lower() == 'true':
-                        extra_kwargs[param_name] = True
-                    elif ctx.args[i].lower() == 'false':
-                        extra_kwargs[param_name] = False
-                    elif '.' in ctx.args[i]:
-                        try:
-                            extra_kwargs[param_name] = float(ctx.args[i])
-                        except ValueError:
-                            extra_kwargs[param_name] = ctx.args[i]
-                    else:
-                        try:
-                            extra_kwargs[param_name] = int(ctx.args[i])
-                        except ValueError:
-                            extra_kwargs[param_name] = ctx.args[i]
-                except:
-                    extra_kwargs[param_name] = ctx.args[i]
-            else:
-                # 布尔型标志参数
-                extra_kwargs[param_name] = True
-                i -= 1
-        i += 1
-
-    # 将解析出的参数合并到kwargs
-    kwargs.update(extra_kwargs)
+    kwargs.update(arg_parse(ctx))
 
 
     if not backend.endswith('-client'):
     if not backend.endswith('-client'):
         def get_device_mode() -> str:
         def get_device_mode() -> str:

+ 25 - 14
mineru/cli/fast_api.py

@@ -1,7 +1,7 @@
 import uuid
 import uuid
 import os
 import os
 import uvicorn
 import uvicorn
-import argparse
+import click
 from pathlib import Path
 from pathlib import Path
 from glob import glob
 from glob import glob
 from fastapi import FastAPI, UploadFile, File, Form
 from fastapi import FastAPI, UploadFile, File, Form
@@ -12,6 +12,7 @@ from loguru import logger
 from base64 import b64encode
 from base64 import b64encode
 
 
 from mineru.cli.common import aio_do_parse, read_fn
 from mineru.cli.common import aio_do_parse, read_fn
+from mineru.utils.cli_parser import arg_parse
 from mineru.version import __version__
 from mineru.version import __version__
 
 
 app = FastAPI()
 app = FastAPI()
@@ -50,6 +51,10 @@ async def parse_pdf(
         start_page_id: int = Form(0),
         start_page_id: int = Form(0),
         end_page_id: int = Form(99999),
         end_page_id: int = Form(99999),
 ):
 ):
+
+    # 获取命令行配置参数
+    config = getattr(app.state, "config", {})
+
     try:
     try:
         # 创建唯一的输出目录
         # 创建唯一的输出目录
         unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
         unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
@@ -113,6 +118,7 @@ async def parse_pdf(
             f_dump_content_list=return_content_list,
             f_dump_content_list=return_content_list,
             start_page_id=start_page_id,
             start_page_id=start_page_id,
             end_page_id=end_page_id,
             end_page_id=end_page_id,
+            **config
         )
         )
 
 
         # 构建结果路径
         # 构建结果路径
@@ -162,24 +168,29 @@ async def parse_pdf(
         )
         )
 
 
 
 
-def main():
-    """启动MinerU FastAPI服务器的命令行入口"""
-    parser = argparse.ArgumentParser(description='Start MinerU FastAPI Service')
-    parser.add_argument('--host', type=str, default='127.0.0.1', help='Server host (default: 127.0.0.1)')
-    parser.add_argument('--port', type=int, default=8000, help='Server port (default: 8000)')
-    parser.add_argument('--reload', action='store_true', help='Enable auto-reload (development mode)')
-    args = parser.parse_args()
+@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
+@click.pass_context
+@click.option('--host', default='127.0.0.1', help='Server host (default: 127.0.0.1)')
+@click.option('--port', default=8000, type=int, help='Server port (default: 8000)')
+@click.option('--reload', is_flag=True, help='Enable auto-reload (development mode)')
+def main(ctx, host, port, reload, **kwargs):
 
 
-    print(f"Start MinerU FastAPI Service: http://{args.host}:{args.port}")
+    kwargs.update(arg_parse(ctx))
+
+    # 将配置参数存储到应用状态中
+    app.state.config = kwargs
+
+    """启动MinerU FastAPI服务器的命令行入口"""
+    print(f"Start MinerU FastAPI Service: http://{host}:{port}")
     print("The API documentation can be accessed at the following address:")
     print("The API documentation can be accessed at the following address:")
-    print(f"- Swagger UI: http://{args.host}:{args.port}/docs")
-    print(f"- ReDoc: http://{args.host}:{args.port}/redoc")
+    print(f"- Swagger UI: http://{host}:{port}/docs")
+    print(f"- ReDoc: http://{host}:{port}/redoc")
 
 
     uvicorn.run(
     uvicorn.run(
         "mineru.cli.fast_api:app",
         "mineru.cli.fast_api:app",
-        host=args.host,
-        port=args.port,
-        reload=args.reload
+        host=host,
+        port=port,
+        reload=reload
     )
     )
 
 
 
 

+ 3 - 36
mineru/cli/gradio_app.py

@@ -13,6 +13,7 @@ from gradio_pdf import PDF
 from loguru import logger
 from loguru import logger
 
 
 from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
 from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
+from mineru.utils.cli_parser import arg_parse
 from mineru.utils.hash_utils import str_sha256
 from mineru.utils.hash_utils import str_sha256
 
 
 
 
@@ -237,42 +238,8 @@ def main(ctx,
         example_enable, sglang_engine_enable, api_enable, max_convert_pages,
         example_enable, sglang_engine_enable, api_enable, max_convert_pages,
         server_name, server_port, **kwargs
         server_name, server_port, **kwargs
 ):
 ):
-    # 解析额外参数
-    extra_kwargs = {}
-    i = 0
-    while i < len(ctx.args):
-        arg = ctx.args[i]
-        if arg.startswith('--'):
-            param_name = arg[2:].replace('-', '_')  # 转换参数名格式
-            i += 1
-            if i < len(ctx.args) and not ctx.args[i].startswith('--'):
-                # 参数有值
-                try:
-                    # 尝试转换为适当的类型
-                    if ctx.args[i].lower() == 'true':
-                        extra_kwargs[param_name] = True
-                    elif ctx.args[i].lower() == 'false':
-                        extra_kwargs[param_name] = False
-                    elif '.' in ctx.args[i]:
-                        try:
-                            extra_kwargs[param_name] = float(ctx.args[i])
-                        except ValueError:
-                            extra_kwargs[param_name] = ctx.args[i]
-                    else:
-                        try:
-                            extra_kwargs[param_name] = int(ctx.args[i])
-                        except ValueError:
-                            extra_kwargs[param_name] = ctx.args[i]
-                except:
-                    extra_kwargs[param_name] = ctx.args[i]
-            else:
-                # 布尔型标志参数
-                extra_kwargs[param_name] = True
-                i -= 1
-        i += 1
-
-    # 将解析出的参数合并到kwargs
-    kwargs.update(extra_kwargs)
+
+    kwargs.update(arg_parse(ctx))
 
 
     if sglang_engine_enable:
     if sglang_engine_enable:
         try:
         try:

+ 38 - 0
mineru/utils/cli_parser.py

@@ -0,0 +1,38 @@
+import click
+
+
+def arg_parse(ctx: 'click.Context') -> dict:
+    # 解析额外参数
+    extra_kwargs = {}
+    i = 0
+    while i < len(ctx.args):
+        arg = ctx.args[i]
+        if arg.startswith('--'):
+            param_name = arg[2:].replace('-', '_')  # 转换参数名格式
+            i += 1
+            if i < len(ctx.args) and not ctx.args[i].startswith('--'):
+                # 参数有值
+                try:
+                    # 尝试转换为适当的类型
+                    if ctx.args[i].lower() == 'true':
+                        extra_kwargs[param_name] = True
+                    elif ctx.args[i].lower() == 'false':
+                        extra_kwargs[param_name] = False
+                    elif '.' in ctx.args[i]:
+                        try:
+                            extra_kwargs[param_name] = float(ctx.args[i])
+                        except ValueError:
+                            extra_kwargs[param_name] = ctx.args[i]
+                    else:
+                        try:
+                            extra_kwargs[param_name] = int(ctx.args[i])
+                        except ValueError:
+                            extra_kwargs[param_name] = ctx.args[i]
+                except:
+                    extra_kwargs[param_name] = ctx.args[i]
+            else:
+                # 布尔型标志参数
+                extra_kwargs[param_name] = True
+                i -= 1
+        i += 1
+    return extra_kwargs