Procházet zdrojové kódy

Merge pull request #2889 from myhloli/dev

Dev
Xiaomeng Zhao před 4 měsíci
rodič
revize
5dbdbf023c
4 změnil soubory, kde provedl 75 přidání a 77 odebrání
  1. 2 36
      mineru/cli/client.py
  2. 25 14
      mineru/cli/fast_api.py
  3. 10 27
      mineru/cli/gradio_app.py
  4. 38 0
      mineru/utils/cli_parser.py

+ 2 - 36
mineru/cli/client.py

@@ -4,6 +4,7 @@ import click
 from pathlib import Path
 from pathlib import Path
 from loguru import logger
 from loguru import logger
 
 
+from mineru.utils.cli_parser import arg_parse
 from mineru.utils.config_reader import get_device
 from mineru.utils.config_reader import get_device
 from mineru.utils.model_utils import get_vram
 from mineru.utils.model_utils import get_vram
 from ..version import __version__
 from ..version import __version__
@@ -145,42 +146,7 @@ def main(
         device_mode, virtual_vram, model_source, **kwargs
         device_mode, virtual_vram, model_source, **kwargs
 ):
 ):
 
 
-    # 解析额外参数
-    extra_kwargs = {}
-    i = 0
-    while i < len(ctx.args):
-        arg = ctx.args[i]
-        if arg.startswith('--'):
-            param_name = arg[2:].replace('-', '_')  # 转换参数名格式
-            i += 1
-            if i < len(ctx.args) and not ctx.args[i].startswith('--'):
-                # 参数有值
-                try:
-                    # 尝试转换为适当的类型
-                    if ctx.args[i].lower() == 'true':
-                        extra_kwargs[param_name] = True
-                    elif ctx.args[i].lower() == 'false':
-                        extra_kwargs[param_name] = False
-                    elif '.' in ctx.args[i]:
-                        try:
-                            extra_kwargs[param_name] = float(ctx.args[i])
-                        except ValueError:
-                            extra_kwargs[param_name] = ctx.args[i]
-                    else:
-                        try:
-                            extra_kwargs[param_name] = int(ctx.args[i])
-                        except ValueError:
-                            extra_kwargs[param_name] = ctx.args[i]
-                except:
-                    extra_kwargs[param_name] = ctx.args[i]
-            else:
-                # 布尔型标志参数
-                extra_kwargs[param_name] = True
-                i -= 1
-        i += 1
-
-    # 将解析出的参数合并到kwargs
-    kwargs.update(extra_kwargs)
+    kwargs.update(arg_parse(ctx))
 
 
     if not backend.endswith('-client'):
     if not backend.endswith('-client'):
         def get_device_mode() -> str:
         def get_device_mode() -> str:

+ 25 - 14
mineru/cli/fast_api.py

@@ -1,7 +1,7 @@
 import uuid
 import uuid
 import os
 import os
 import uvicorn
 import uvicorn
-import argparse
+import click
 from pathlib import Path
 from pathlib import Path
 from glob import glob
 from glob import glob
 from fastapi import FastAPI, UploadFile, File, Form
 from fastapi import FastAPI, UploadFile, File, Form
@@ -12,6 +12,7 @@ from loguru import logger
 from base64 import b64encode
 from base64 import b64encode
 
 
 from mineru.cli.common import aio_do_parse, read_fn
 from mineru.cli.common import aio_do_parse, read_fn
+from mineru.utils.cli_parser import arg_parse
 from mineru.version import __version__
 from mineru.version import __version__
 
 
 app = FastAPI()
 app = FastAPI()
@@ -50,6 +51,10 @@ async def parse_pdf(
         start_page_id: int = Form(0),
         start_page_id: int = Form(0),
         end_page_id: int = Form(99999),
         end_page_id: int = Form(99999),
 ):
 ):
+
+    # 获取命令行配置参数
+    config = getattr(app.state, "config", {})
+
     try:
     try:
         # 创建唯一的输出目录
         # 创建唯一的输出目录
         unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
         unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
@@ -113,6 +118,7 @@ async def parse_pdf(
             f_dump_content_list=return_content_list,
             f_dump_content_list=return_content_list,
             start_page_id=start_page_id,
             start_page_id=start_page_id,
             end_page_id=end_page_id,
             end_page_id=end_page_id,
+            **config
         )
         )
 
 
         # 构建结果路径
         # 构建结果路径
@@ -162,24 +168,29 @@ async def parse_pdf(
         )
         )
 
 
 
 
-def main():
-    """启动MinerU FastAPI服务器的命令行入口"""
-    parser = argparse.ArgumentParser(description='Start MinerU FastAPI Service')
-    parser.add_argument('--host', type=str, default='127.0.0.1', help='Server host (default: 127.0.0.1)')
-    parser.add_argument('--port', type=int, default=8000, help='Server port (default: 8000)')
-    parser.add_argument('--reload', action='store_true', help='Enable auto-reload (development mode)')
-    args = parser.parse_args()
+@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
+@click.pass_context
+@click.option('--host', default='127.0.0.1', help='Server host (default: 127.0.0.1)')
+@click.option('--port', default=8000, type=int, help='Server port (default: 8000)')
+@click.option('--reload', is_flag=True, help='Enable auto-reload (development mode)')
+def main(ctx, host, port, reload, **kwargs):
 
 
-    print(f"Start MinerU FastAPI Service: http://{args.host}:{args.port}")
+    kwargs.update(arg_parse(ctx))
+
+    # 将配置参数存储到应用状态中
+    app.state.config = kwargs
+
+    """启动MinerU FastAPI服务器的命令行入口"""
+    print(f"Start MinerU FastAPI Service: http://{host}:{port}")
     print("The API documentation can be accessed at the following address:")
     print("The API documentation can be accessed at the following address:")
-    print(f"- Swagger UI: http://{args.host}:{args.port}/docs")
-    print(f"- ReDoc: http://{args.host}:{args.port}/redoc")
+    print(f"- Swagger UI: http://{host}:{port}/docs")
+    print(f"- ReDoc: http://{host}:{port}/redoc")
 
 
     uvicorn.run(
     uvicorn.run(
         "mineru.cli.fast_api:app",
         "mineru.cli.fast_api:app",
-        host=args.host,
-        port=args.port,
-        reload=args.reload
+        host=host,
+        port=port,
+        reload=reload
     )
     )
 
 
 
 

+ 10 - 27
mineru/cli/gradio_app.py

@@ -13,6 +13,7 @@ from gradio_pdf import PDF
 from loguru import logger
 from loguru import logger
 
 
 from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
 from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
+from mineru.utils.cli_parser import arg_parse
 from mineru.utils.hash_utils import str_sha256
 from mineru.utils.hash_utils import str_sha256
 
 
 
 
@@ -188,7 +189,8 @@ def update_interface(backend_choice):
         pass
         pass
 
 
 
 
-@click.command()
+@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
+@click.pass_context
 @click.option(
 @click.option(
     '--enable-example',
     '--enable-example',
     'example_enable',
     'example_enable',
@@ -205,20 +207,6 @@ def update_interface(backend_choice):
     default=False,
     default=False,
 )
 )
 @click.option(
 @click.option(
-    '--mem-fraction-static',
-    'mem_fraction_static',
-    type=float,
-    help="Set the static memory fraction for SgLang engine. ",
-    default=None,
-)
-@click.option(
-    '--enable-torch-compile',
-    'torch_compile_enable',
-    type=bool,
-    help="Enable torch compile for SgLang engine. ",
-    default=False,
-)
-@click.option(
     '--enable-api',
     '--enable-api',
     'api_enable',
     'api_enable',
     type=bool,
     type=bool,
@@ -246,28 +234,23 @@ def update_interface(backend_choice):
     help="Set the server port for the Gradio app.",
     help="Set the server port for the Gradio app.",
     default=None,
     default=None,
 )
 )
-def main(
-        example_enable, sglang_engine_enable, mem_fraction_static, torch_compile_enable, api_enable, max_convert_pages,
-        server_name, server_port
+def main(ctx,
+        example_enable, sglang_engine_enable, api_enable, max_convert_pages,
+        server_name, server_port, **kwargs
 ):
 ):
+
+    kwargs.update(arg_parse(ctx))
+
     if sglang_engine_enable:
     if sglang_engine_enable:
         try:
         try:
             print("Start init SgLang engine...")
             print("Start init SgLang engine...")
             from mineru.backend.vlm.vlm_analyze import ModelSingleton
             from mineru.backend.vlm.vlm_analyze import ModelSingleton
             model_singleton = ModelSingleton()
             model_singleton = ModelSingleton()
-
-            model_params = {
-                "enable_torch_compile": torch_compile_enable
-            }
-            # 只有当mem_fraction_static不为None时才添加该参数
-            if mem_fraction_static is not None:
-                model_params["mem_fraction_static"] = mem_fraction_static
-
             predictor = model_singleton.get_model(
             predictor = model_singleton.get_model(
                 "sglang-engine",
                 "sglang-engine",
                 None,
                 None,
                 None,
                 None,
-                **model_params
+                **kwargs
             )
             )
             print("SgLang engine init successfully.")
             print("SgLang engine init successfully.")
         except Exception as e:
         except Exception as e:

+ 38 - 0
mineru/utils/cli_parser.py

@@ -0,0 +1,38 @@
+import click
+
+
+def arg_parse(ctx: 'click.Context') -> dict:
+    # 解析额外参数
+    extra_kwargs = {}
+    i = 0
+    while i < len(ctx.args):
+        arg = ctx.args[i]
+        if arg.startswith('--'):
+            param_name = arg[2:].replace('-', '_')  # 转换参数名格式
+            i += 1
+            if i < len(ctx.args) and not ctx.args[i].startswith('--'):
+                # 参数有值
+                try:
+                    # 尝试转换为适当的类型
+                    if ctx.args[i].lower() == 'true':
+                        extra_kwargs[param_name] = True
+                    elif ctx.args[i].lower() == 'false':
+                        extra_kwargs[param_name] = False
+                    elif '.' in ctx.args[i]:
+                        try:
+                            extra_kwargs[param_name] = float(ctx.args[i])
+                        except ValueError:
+                            extra_kwargs[param_name] = ctx.args[i]
+                    else:
+                        try:
+                            extra_kwargs[param_name] = int(ctx.args[i])
+                        except ValueError:
+                            extra_kwargs[param_name] = ctx.args[i]
+                except:
+                    extra_kwargs[param_name] = ctx.args[i]
+            else:
+                # 布尔型标志参数
+                extra_kwargs[param_name] = True
+                i -= 1
+        i += 1
+    return extra_kwargs