瀏覽代碼

fix: refactor device mode and virtual VRAM size handling in client.py and common.py

myhloli 5 月之前
父節點
當前提交
bfaf07c69f
共有 2 個文件被更改,包括 27 次插入25 次删除
  1. 19 20
      mineru/cli/client.py
  2. 8 5
      mineru/cli/common.py

+ 19 - 20
mineru/cli/client.py

@@ -7,7 +7,7 @@ from loguru import logger
 from mineru.utils.config_reader import get_device
 from mineru.utils.model_utils import get_vram
 from ..version import __version__
-
+from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
 
 @click.command()
 @click.version_option(__version__,
@@ -138,27 +138,26 @@ from ..version import __version__
 
 def main(input_path, output_dir, method, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source):
 
-    from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
-
-    def get_device_mode() -> str:
-        if device_mode is not None:
-            return device_mode
-        else:
-            return get_device()
-    if os.getenv('MINERU_DEVICE_MODE', None) is None:
-        os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
+    if not backend.endswith('-client'):
+        def get_device_mode() -> str:
+            if device_mode is not None:
+                return device_mode
+            else:
+                return get_device()
+        if os.getenv('MINERU_DEVICE_MODE', None) is None:
+            os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
 
-    def get_virtual_vram_size() -> int:
-        if virtual_vram is not None:
-            return virtual_vram
-        if get_device_mode().startswith("cuda") or get_device_mode().startswith("npu"):
-            return round(get_vram(get_device_mode()))
-        return 1
-    if os.getenv('MINERU_VIRTUAL_VRAM_SIZE', None) is None:
-        os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
+        def get_virtual_vram_size() -> int:
+            if virtual_vram is not None:
+                return virtual_vram
+            if get_device_mode().startswith("cuda") or get_device_mode().startswith("npu"):
+                return round(get_vram(get_device_mode()))
+            return 1
+        if os.getenv('MINERU_VIRTUAL_VRAM_SIZE', None) is None:
+            os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
 
-    if os.getenv('MINERU_MODEL_SOURCE', None) is None:
-        os.environ['MINERU_MODEL_SOURCE'] = model_source
+        if os.getenv('MINERU_MODEL_SOURCE', None) is None:
+            os.environ['MINERU_MODEL_SOURCE'] = model_source
 
     os.makedirs(output_dir, exist_ok=True)
 

+ 8 - 5
mineru/cli/common.py

@@ -8,15 +8,12 @@ from pathlib import Path
 import pypdfium2 as pdfium
 from loguru import logger
 
-from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
-from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
-from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
-from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
-from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
 from mineru.data.data_reader_writer import FileBasedDataWriter
 from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
 from mineru.utils.enum_class import MakeMode
 from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
+from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
+from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
 
 pdf_suffixes = [".pdf"]
 image_suffixes = [".png", ".jpeg", ".jpg"]
@@ -99,6 +96,11 @@ def do_parse(
 ):
 
     if backend == "pipeline":
+
+        from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
+        from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
+        from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
+
         for idx, pdf_bytes in enumerate(pdf_bytes_list):
             new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
             pdf_bytes_list[idx] = new_pdf_bytes
@@ -163,6 +165,7 @@ def do_parse(
 
             logger.info(f"local output dir is {local_md_dir}")
     else:
+
         if backend.startswith("vlm-"):
             backend = backend[4:]