Эх сурвалжийг харах

feat: add command-line options for SgLang engine configuration in main function

myhloli 4 сар өмнө
parent
commit
8a5e1e3447
1 өөрчлөгдсөн 67 нэмэгдсэн , 26 устгасан
  1. 67 26
      mineru/cli/gradio_app.py

+ 67 - 26
mineru/cli/gradio_app.py

@@ -7,6 +7,7 @@ import time
 import zipfile
 from pathlib import Path
 
+import click
 import gradio as gr
 from gradio_pdf import PDF
 from loguru import logger
@@ -174,41 +175,75 @@ def to_pdf(file_path):
 
     return tmp_file_path
 
-
-def main():
-    example_enable = False
-
-    # try:
-    #     print("Start init SgLang engine...")
-    #     from mineru.backend.vlm.vlm_analyze import ModelSingleton
-    #     modelsingleton = ModelSingleton()
-    #     predictor = modelsingleton.get_model(
-    #         "sglang-engine",
-    #         None,
-    #         None,
-    #         mem_fraction_static=0.5,
-    #         enable_torch_compile=True,
-    #     )
-    #     print("SgLang engine init successfully.")
-    # except Exception as e:
-    #     logger.exception(e)
-
+@click.command()
+@click.option(
+    '--enable-example',
+    'example_enable',
+    type=bool,
+    help="Enable example files for input."
+         "The example files to be input need to be placed in the `example` folder within the directory where the command is currently executed.",
+    default=False,
+)
+@click.option(
+    '--enable-sglang-engine',
+    'sglang_engine_enable',
+    type=bool,
+    help="Enable SgLang engine backend for faster processing.",
+    default=False,
+)
+@click.option(
+    '--mem-fraction-static',
+    'mem_fraction_static',
+    type=float,
+    help="Set the static memory fraction for SgLang engine. ",
+    default=0.5,
+)
+@click.option(
+    '--enable-torch-compile',
+    'enable_torch_compile',
+    type=bool,
+    help="Enable torch compile for SgLang engine. ",
+    default=True,
+)
+def main(example_enable, sglang_engine_enable, mem_fraction_static, enable_torch_compile):
+    if sglang_engine_enable:
+        try:
+            print("Start init SgLang engine...")
+            from mineru.backend.vlm.vlm_analyze import ModelSingleton
+            modelsingleton = ModelSingleton()
+            predictor = modelsingleton.get_model(
+                "sglang-engine",
+                None,
+                None,
+                mem_fraction_static=mem_fraction_static,
+                enable_torch_compile=enable_torch_compile,
+            )
+            print("SgLang engine init successfully.")
+        except Exception as e:
+            logger.exception(e)
+
+    suffixes = pdf_suffixes + image_suffixes
     with gr.Blocks() as demo:
         gr.HTML(header)
         with gr.Row():
             with gr.Column(variant='panel', scale=5):
                 with gr.Row():
-                    suffixes = pdf_suffixes + image_suffixes
                     input_file = gr.File(label='Please upload a PDF or image', file_types=suffixes)
                 with gr.Row():
                     max_pages = gr.Slider(1, 20, 10, step=1, label='Max convert pages')
                 with gr.Row():
-                    backend = gr.Dropdown(["pipeline", "vlm-transformers", "vlm-sglang-client"], label="Backend", value="pipeline")
-                with gr.Row(visible=True) as ocr_options:
+                    if sglang_engine_enable:
+                        drop_list = ["pipeline", "vlm-sglang-engine"]
+                        preferred_option = "vlm-sglang-engine"
+                    else:
+                        drop_list = ["pipeline", "vlm-transformers", "vlm-sglang-client"]
+                        preferred_option = "pipeline"
+                    backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option)
+                with gr.Row(visible=False) as ocr_options:
                     language = gr.Dropdown(all_lang, label='Language', value='ch')
                 with gr.Row(visible=False) as client_options:
                     url = gr.Textbox(label='Server URL', value='http://localhost:30000', placeholder='http://localhost:30000')
-                with gr.Row(visible=True) as pipeline_options:
+                with gr.Row(visible=False) as pipeline_options:
                     is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
                     formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
                     table_enable = gr.Checkbox(label='Enable table recognition(test)', value=True)
@@ -217,12 +252,12 @@ def main():
                     clear_bu = gr.ClearButton(value='Clear')
                 pdf_show = PDF(label='PDF preview', interactive=False, visible=True, height=800)
                 if example_enable:
-                    example_root = os.path.join(os.path.dirname(__file__), 'examples')
+                    example_root = os.path.join(os.getcwd(), 'examples')
                     if os.path.exists(example_root):
                         with gr.Accordion('Examples:'):
                             gr.Examples(
                                 examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
-                                          _.endswith('pdf')],
+                                          _.endswith(tuple(suffixes))],
                                 inputs=input_file
                             )
 
@@ -255,13 +290,19 @@ def main():
             inputs=[backend],
             outputs=[client_options, ocr_options, pipeline_options]
         )
+        # 添加demo.load事件,在页面加载时触发一次界面更新
+        demo.load(
+            fn=update_interface,
+            inputs=[backend],
+            outputs=[client_options, ocr_options, pipeline_options]
+        )
 
         input_file.change(fn=to_pdf, inputs=input_file, outputs=pdf_show)
         change_bu.click(fn=to_markdown, inputs=[input_file, max_pages, is_ocr, formula_enable, table_enable, language, backend, url],
                         outputs=[md, md_text, output_file, pdf_show])
         clear_bu.add([input_file, md, pdf_show, md_text, output_file, is_ocr])
 
-    demo.launch(server_name='localhost')
+    demo.launch()
 
 
 if __name__ == '__main__':