瀏覽代碼

Merge pull request #2889 from myhloli/dev

Dev
Xiaomeng Zhao 4 月之前
父節點
當前提交
5dbdbf023c
共有 4 個文件被更改,包括 75 次插入77 次删除
  1. 2 36
      mineru/cli/client.py
  2. 25 14
      mineru/cli/fast_api.py
  3. 10 27
      mineru/cli/gradio_app.py
  4. 38 0
      mineru/utils/cli_parser.py

+ 2 - 36
mineru/cli/client.py

@@ -4,6 +4,7 @@ import click
 from pathlib import Path
 from loguru import logger
 
+from mineru.utils.cli_parser import arg_parse
 from mineru.utils.config_reader import get_device
 from mineru.utils.model_utils import get_vram
 from ..version import __version__
@@ -145,42 +146,7 @@ def main(
         device_mode, virtual_vram, model_source, **kwargs
 ):
 
-    # 解析额外参数
-    extra_kwargs = {}
-    i = 0
-    while i < len(ctx.args):
-        arg = ctx.args[i]
-        if arg.startswith('--'):
-            param_name = arg[2:].replace('-', '_')  # 转换参数名格式
-            i += 1
-            if i < len(ctx.args) and not ctx.args[i].startswith('--'):
-                # 参数有值
-                try:
-                    # 尝试转换为适当的类型
-                    if ctx.args[i].lower() == 'true':
-                        extra_kwargs[param_name] = True
-                    elif ctx.args[i].lower() == 'false':
-                        extra_kwargs[param_name] = False
-                    elif '.' in ctx.args[i]:
-                        try:
-                            extra_kwargs[param_name] = float(ctx.args[i])
-                        except ValueError:
-                            extra_kwargs[param_name] = ctx.args[i]
-                    else:
-                        try:
-                            extra_kwargs[param_name] = int(ctx.args[i])
-                        except ValueError:
-                            extra_kwargs[param_name] = ctx.args[i]
-                except:
-                    extra_kwargs[param_name] = ctx.args[i]
-            else:
-                # 布尔型标志参数
-                extra_kwargs[param_name] = True
-                i -= 1
-        i += 1
-
-    # 将解析出的参数合并到kwargs
-    kwargs.update(extra_kwargs)
+    kwargs.update(arg_parse(ctx))
 
     if not backend.endswith('-client'):
         def get_device_mode() -> str:

+ 25 - 14
mineru/cli/fast_api.py

@@ -1,7 +1,7 @@
 import uuid
 import os
 import uvicorn
-import argparse
+import click
 from pathlib import Path
 from glob import glob
 from fastapi import FastAPI, UploadFile, File, Form
@@ -12,6 +12,7 @@ from loguru import logger
 from base64 import b64encode
 
 from mineru.cli.common import aio_do_parse, read_fn
+from mineru.utils.cli_parser import arg_parse
 from mineru.version import __version__
 
 app = FastAPI()
@@ -50,6 +51,10 @@ async def parse_pdf(
         start_page_id: int = Form(0),
         end_page_id: int = Form(99999),
 ):
+
+    # 获取命令行配置参数
+    config = getattr(app.state, "config", {})
+
     try:
         # 创建唯一的输出目录
         unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
@@ -113,6 +118,7 @@ async def parse_pdf(
             f_dump_content_list=return_content_list,
             start_page_id=start_page_id,
             end_page_id=end_page_id,
+            **config
         )
 
         # 构建结果路径
@@ -162,24 +168,29 @@ async def parse_pdf(
         )
 
 
-def main():
-    """启动MinerU FastAPI服务器的命令行入口"""
-    parser = argparse.ArgumentParser(description='Start MinerU FastAPI Service')
-    parser.add_argument('--host', type=str, default='127.0.0.1', help='Server host (default: 127.0.0.1)')
-    parser.add_argument('--port', type=int, default=8000, help='Server port (default: 8000)')
-    parser.add_argument('--reload', action='store_true', help='Enable auto-reload (development mode)')
-    args = parser.parse_args()
+@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
+@click.pass_context
+@click.option('--host', default='127.0.0.1', help='Server host (default: 127.0.0.1)')
+@click.option('--port', default=8000, type=int, help='Server port (default: 8000)')
+@click.option('--reload', is_flag=True, help='Enable auto-reload (development mode)')
+def main(ctx, host, port, reload, **kwargs):
 
-    print(f"Start MinerU FastAPI Service: http://{args.host}:{args.port}")
+    kwargs.update(arg_parse(ctx))
+
+    # 将配置参数存储到应用状态中
+    app.state.config = kwargs
+
+    """启动MinerU FastAPI服务器的命令行入口"""
+    print(f"Start MinerU FastAPI Service: http://{host}:{port}")
     print("The API documentation can be accessed at the following address:")
-    print(f"- Swagger UI: http://{args.host}:{args.port}/docs")
-    print(f"- ReDoc: http://{args.host}:{args.port}/redoc")
+    print(f"- Swagger UI: http://{host}:{port}/docs")
+    print(f"- ReDoc: http://{host}:{port}/redoc")
 
     uvicorn.run(
         "mineru.cli.fast_api:app",
-        host=args.host,
-        port=args.port,
-        reload=args.reload
+        host=host,
+        port=port,
+        reload=reload
     )
 
 

+ 10 - 27
mineru/cli/gradio_app.py

@@ -13,6 +13,7 @@ from gradio_pdf import PDF
 from loguru import logger
 
 from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
+from mineru.utils.cli_parser import arg_parse
 from mineru.utils.hash_utils import str_sha256
 
 
@@ -188,7 +189,8 @@ def update_interface(backend_choice):
         pass
 
 
-@click.command()
+@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
+@click.pass_context
 @click.option(
     '--enable-example',
     'example_enable',
@@ -205,20 +207,6 @@ def update_interface(backend_choice):
     default=False,
 )
 @click.option(
-    '--mem-fraction-static',
-    'mem_fraction_static',
-    type=float,
-    help="Set the static memory fraction for SgLang engine. ",
-    default=None,
-)
-@click.option(
-    '--enable-torch-compile',
-    'torch_compile_enable',
-    type=bool,
-    help="Enable torch compile for SgLang engine. ",
-    default=False,
-)
-@click.option(
     '--enable-api',
     'api_enable',
     type=bool,
@@ -246,28 +234,23 @@ def update_interface(backend_choice):
     help="Set the server port for the Gradio app.",
     default=None,
 )
-def main(
-        example_enable, sglang_engine_enable, mem_fraction_static, torch_compile_enable, api_enable, max_convert_pages,
-        server_name, server_port
+def main(ctx,
+        example_enable, sglang_engine_enable, api_enable, max_convert_pages,
+        server_name, server_port, **kwargs
 ):
+
+    kwargs.update(arg_parse(ctx))
+
     if sglang_engine_enable:
         try:
             print("Start init SgLang engine...")
             from mineru.backend.vlm.vlm_analyze import ModelSingleton
             model_singleton = ModelSingleton()
-
-            model_params = {
-                "enable_torch_compile": torch_compile_enable
-            }
-            # 只有当mem_fraction_static不为None时才添加该参数
-            if mem_fraction_static is not None:
-                model_params["mem_fraction_static"] = mem_fraction_static
-
             predictor = model_singleton.get_model(
                 "sglang-engine",
                 None,
                 None,
-                **model_params
+                **kwargs
             )
             print("SgLang engine init successfully.")
         except Exception as e:

+ 38 - 0
mineru/utils/cli_parser.py

@@ -0,0 +1,38 @@
+import click
+
+
+def arg_parse(ctx: 'click.Context') -> dict:
+    # 解析额外参数
+    extra_kwargs = {}
+    i = 0
+    while i < len(ctx.args):
+        arg = ctx.args[i]
+        if arg.startswith('--'):
+            param_name = arg[2:].replace('-', '_')  # 转换参数名格式
+            i += 1
+            if i < len(ctx.args) and not ctx.args[i].startswith('--'):
+                # 参数有值
+                try:
+                    # 尝试转换为适当的类型
+                    if ctx.args[i].lower() == 'true':
+                        extra_kwargs[param_name] = True
+                    elif ctx.args[i].lower() == 'false':
+                        extra_kwargs[param_name] = False
+                    elif '.' in ctx.args[i]:
+                        try:
+                            extra_kwargs[param_name] = float(ctx.args[i])
+                        except ValueError:
+                            extra_kwargs[param_name] = ctx.args[i]
+                    else:
+                        try:
+                            extra_kwargs[param_name] = int(ctx.args[i])
+                        except ValueError:
+                            extra_kwargs[param_name] = ctx.args[i]
+                except:
+                    extra_kwargs[param_name] = ctx.args[i]
+            else:
+                # 布尔型标志参数
+                extra_kwargs[param_name] = True
+                i -= 1
+        i += 1
+    return extra_kwargs