Эх сурвалжийг харах

refactor: enhance main function parameters and improve device handling logic

myhloli 5 сар өмнө
parent
commit
2688e3f7d0

+ 5 - 5
mineru/backend/pipeline/pipeline_analyze.py

@@ -92,14 +92,14 @@ def doc_analyze(
     ocr_enabled_list = []
     for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
         # 确定OCR设置
-        _ocr = False
+        _ocr_enable = False
         if parse_method == 'auto':
             if classify(pdf_bytes) == 'ocr':
-                _ocr = True
+                _ocr_enable = True
         elif parse_method == 'ocr':
-            _ocr = True
+            _ocr_enable = True
 
-        ocr_enabled_list.append(_ocr)
+        ocr_enabled_list.append(_ocr_enable)
         _lang = lang_list[pdf_idx]
 
         # 收集每个数据集中的页面
@@ -110,7 +110,7 @@ def doc_analyze(
             img_dict = images_list[page_idx]
             all_pages_info.append((
                 pdf_idx, page_idx,
-                img_dict['img_pil'], _ocr, _lang,
+                img_dict['img_pil'], _ocr_enable, _lang,
             ))
 
     # 准备批处理

+ 60 - 2
mineru/cli/client.py

@@ -2,7 +2,9 @@
 import os
 import click
 from pathlib import Path
+import torch
 from loguru import logger
+from mineru.utils.model_utils import get_vram
 from ..version import __version__
 from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
 
@@ -38,7 +40,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
     vlm-huggingface: More general.
     vlm-sglang-engine: Faster(engine).
     vlm-sglang-client: Faster(client).
-    without method specified, huggingface will be used by default.""",
+    without method specified, pipeline will be used by default.""",
     default='pipeline',
 )
 @click.option(
@@ -49,6 +51,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
     help="""
     Input the languages in the pdf (if known) to improve OCR accuracy.  Optional.
     Without languages specified, 'ch' will be used by default.
+    Adapted only for the case where the backend is set to "pipeline".
     """,
     default='ch',
 )
@@ -78,8 +81,63 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
     help='The ending page for PDF parsing, beginning from 0.',
     default=None,
 )
+@click.option(
+    '-f',
+    '--formula',
+    'formula_enable',
+    type=bool,
+    help='Enable formula parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
+    default=True,
+)
+@click.option(
+    '-t',
+    '--table',
+    'table_enable',
+    type=bool,
+    help='Enable table parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
+    default=True,
+)
+@click.option(
+    '-d',
+    '--device',
+    'device_mode',
+    type=str,
+    help='Device mode for model inference, e.g., "cpu", "cuda", "cuda:0", "npu", "npu:0", "mps". Adapted only for the case where the backend is set to "pipeline". ',
+    default=None,
+)
+@click.option(
+    '-vm',
+    '--virtual-vram',
+    'virtual_vram',
+    type=int,
+    help='Device mode for model inference, e.g., "cpu", "cuda", "cuda:0", "npu", "npu:0", "mps". Default is "cpu". Adapted only for the case where the backend is set to "pipeline". ',
+    default=None,
+)
+
+
+def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram):
+
+    os.environ['MINERU_FORMULA_ENABLE'] = str(formula_enable).lower()
+    os.environ['MINERU_TABLE_ENABLE'] = str(table_enable).lower()
+    def get_device_mode() -> str:
+        if device_mode is not None:
+            return device_mode
+        if torch.cuda.is_available():
+            return "cuda"
+        if torch.backends.mps.is_available():
+            return "mps"
+        return "cpu"
+    os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
+
+    def get_virtual_vram_size() -> int:
+        if virtual_vram is not None:
+            return virtual_vram
+        if get_device_mode().startswith("cuda") or get_device_mode().startswith("npu"):
+            return round(get_vram(get_device_mode()))
+        return 1
+
+    os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
 
-def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_page_id):
     os.makedirs(output_dir, exist_ok=True)
 
     def parse_doc(path_list: list[Path]):