ソースを参照

refactor: add support for additional keyword arguments in client and analysis functions

myhloli 4 ヶ月 前
コミット
9a3a314916
3 ファイル変更19 行追加8 行削除
  1. 4 2
      mineru/backend/vlm/vlm_analyze.py
  2. 7 2
      mineru/cli/client.py
  3. 8 4
      mineru/cli/common.py

+ 4 - 2
mineru/backend/vlm/vlm_analyze.py

@@ -47,9 +47,10 @@ def doc_analyze(
     backend="transformers",
     model_path: str | None = None,
     server_url: str | None = None,
+    **kwargs,
 ):
     if predictor is None:
-        predictor = ModelSingleton().get_model(backend, model_path, server_url)
+        predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
 
     # load_images_start = time.time()
     images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
@@ -73,9 +74,10 @@ async def aio_doc_analyze(
     backend="transformers",
     model_path: str | None = None,
     server_url: str | None = None,
+    **kwargs,
 ):
     if predictor is None:
-        predictor = ModelSingleton().get_model(backend, model_path, server_url)
+        predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
 
     # load_images_start = time.time()
     images_list, pdf_doc = load_images_from_pdf(pdf_bytes)

+ 7 - 2
mineru/cli/client.py

@@ -9,7 +9,7 @@ from mineru.utils.model_utils import get_vram
 from ..version import __version__
 from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
 
-@click.command()
+@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
 @click.version_option(__version__,
                       '--version',
                       '-v',
@@ -137,7 +137,11 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
 )
 
 
-def main(input_path, output_dir, method, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source):
+def main(
+        input_path, output_dir, method, backend, lang, server_url,
+        start_page_id, end_page_id, formula_enable, table_enable,
+        device_mode, virtual_vram, model_source, **kwargs
+):
 
     if not backend.endswith('-client'):
         def get_device_mode() -> str:
@@ -185,6 +189,7 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
                 server_url=server_url,
                 start_page_id=start_page_id,
                 end_page_id=end_page_id
+                **kwargs,
             )
         except Exception as e:
             logger.exception(e)

+ 8 - 4
mineru/cli/common.py

@@ -225,6 +225,7 @@ async def _async_process_vlm(
         f_dump_content_list,
         f_make_md_mode,
         server_url=None,
+        **kwargs,
 ):
     """异步处理VLM后端逻辑"""
     parse_method = "vlm"
@@ -238,7 +239,7 @@ async def _async_process_vlm(
         image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
 
         middle_json, infer_result = await aio_vlm_doc_analyze(
-            pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url
+            pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
         )
 
         pdf_info = middle_json["pdf_info"]
@@ -265,6 +266,7 @@ def _process_vlm(
         f_dump_content_list,
         f_make_md_mode,
         server_url=None,
+        **kwargs,
 ):
     """同步处理VLM后端逻辑"""
     parse_method = "vlm"
@@ -278,7 +280,7 @@ def _process_vlm(
         image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
 
         middle_json, infer_result = vlm_doc_analyze(
-            pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url
+            pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
         )
 
         pdf_info = middle_json["pdf_info"]
@@ -311,6 +313,7 @@ def do_parse(
         f_make_md_mode=MakeMode.MM_MD,
         start_page_id=0,
         end_page_id=None,
+        **kwargs,
 ):
     # 预处理PDF字节数据
     pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
@@ -333,7 +336,7 @@ def do_parse(
             output_dir, pdf_file_names, pdf_bytes_list, backend,
             f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
             f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
-            server_url
+            server_url, **kwargs,
         )
 
 
@@ -357,6 +360,7 @@ async def aio_do_parse(
         f_make_md_mode=MakeMode.MM_MD,
         start_page_id=0,
         end_page_id=None,
+        **kwargs,
 ):
     # 预处理PDF字节数据
     pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
@@ -380,7 +384,7 @@ async def aio_do_parse(
             output_dir, pdf_file_names, pdf_bytes_list, backend,
             f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
             f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
-            server_url
+            server_url, **kwargs,
         )