Ver Fonte

refactor: enhance command-line argument parsing to support additional parameters in gradio_app.py

myhloli há 4 meses atrás
pai
commit
54913fc205
1 ficheiros alterados com 43 adições e 27 exclusões
  1. 43 27
      mineru/cli/gradio_app.py

+ 43 - 27
mineru/cli/gradio_app.py

@@ -188,7 +188,8 @@ def update_interface(backend_choice):
         pass
 
 
-@click.command()
+@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
+@click.pass_context
 @click.option(
     '--enable-example',
     'example_enable',
@@ -205,20 +206,6 @@ def update_interface(backend_choice):
     default=False,
 )
 @click.option(
-    '--mem-fraction-static',
-    'mem_fraction_static',
-    type=float,
-    help="Set the static memory fraction for SgLang engine. ",
-    default=None,
-)
-@click.option(
-    '--enable-torch-compile',
-    'torch_compile_enable',
-    type=bool,
-    help="Enable torch compile for SgLang engine. ",
-    default=False,
-)
-@click.option(
     '--enable-api',
     'api_enable',
     type=bool,
@@ -246,28 +233,57 @@ def update_interface(backend_choice):
     help="Set the server port for the Gradio app.",
     default=None,
 )
-def main(
-        example_enable, sglang_engine_enable, mem_fraction_static, torch_compile_enable, api_enable, max_convert_pages,
-        server_name, server_port
+def main(ctx,
+        example_enable, sglang_engine_enable, api_enable, max_convert_pages,
+        server_name, server_port, **kwargs
 ):
+    # 解析额外参数
+    extra_kwargs = {}
+    i = 0
+    while i < len(ctx.args):
+        arg = ctx.args[i]
+        if arg.startswith('--'):
+            param_name = arg[2:].replace('-', '_')  # 转换参数名格式
+            i += 1
+            if i < len(ctx.args) and not ctx.args[i].startswith('--'):
+                # 参数有值
+                try:
+                    # 尝试转换为适当的类型
+                    if ctx.args[i].lower() == 'true':
+                        extra_kwargs[param_name] = True
+                    elif ctx.args[i].lower() == 'false':
+                        extra_kwargs[param_name] = False
+                    elif '.' in ctx.args[i]:
+                        try:
+                            extra_kwargs[param_name] = float(ctx.args[i])
+                        except ValueError:
+                            extra_kwargs[param_name] = ctx.args[i]
+                    else:
+                        try:
+                            extra_kwargs[param_name] = int(ctx.args[i])
+                        except ValueError:
+                            extra_kwargs[param_name] = ctx.args[i]
+                except:
+                    extra_kwargs[param_name] = ctx.args[i]
+            else:
+                # 布尔型标志参数
+                extra_kwargs[param_name] = True
+                i -= 1
+        i += 1
+
+    # 将解析出的参数合并到kwargs
+    kwargs.update(extra_kwargs)
+
     if sglang_engine_enable:
         try:
             print("Start init SgLang engine...")
             from mineru.backend.vlm.vlm_analyze import ModelSingleton
             model_singleton = ModelSingleton()
-
-            model_params = {
-                "enable_torch_compile": torch_compile_enable
-            }
-            # 只有当mem_fraction_static不为None时才添加该参数
-            if mem_fraction_static is not None:
-                model_params["mem_fraction_static"] = mem_fraction_static
-
             predictor = model_singleton.get_model(
                 "sglang-engine",
                 None,
                 None,
-                **model_params
+                **kwargs
             )
             print("SgLang engine init successfully.")
         except Exception as e: