소스 검색

refactor: enhance main function parameters and improve device handling logic

myhloli 5 달 전
부모
커밋
2688e3f7d0
2개의 변경된 파일65개의 추가작업 그리고 7개의 파일을 삭제
  1. 5 5
      mineru/backend/pipeline/pipeline_analyze.py
  2. 60 2
      mineru/cli/client.py

+ 5 - 5
mineru/backend/pipeline/pipeline_analyze.py

@@ -92,14 +92,14 @@ def doc_analyze(
     ocr_enabled_list = []
     ocr_enabled_list = []
     for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
     for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
         # 确定OCR设置
         # 确定OCR设置
-        _ocr = False
+        _ocr_enable = False
         if parse_method == 'auto':
         if parse_method == 'auto':
             if classify(pdf_bytes) == 'ocr':
             if classify(pdf_bytes) == 'ocr':
-                _ocr = True
+                _ocr_enable = True
         elif parse_method == 'ocr':
         elif parse_method == 'ocr':
-            _ocr = True
+            _ocr_enable = True
 
 
-        ocr_enabled_list.append(_ocr)
+        ocr_enabled_list.append(_ocr_enable)
         _lang = lang_list[pdf_idx]
         _lang = lang_list[pdf_idx]
 
 
         # 收集每个数据集中的页面
         # 收集每个数据集中的页面
@@ -110,7 +110,7 @@ def doc_analyze(
             img_dict = images_list[page_idx]
             img_dict = images_list[page_idx]
             all_pages_info.append((
             all_pages_info.append((
                 pdf_idx, page_idx,
                 pdf_idx, page_idx,
-                img_dict['img_pil'], _ocr, _lang,
+                img_dict['img_pil'], _ocr_enable, _lang,
             ))
             ))
 
 
     # 准备批处理
     # 准备批处理

+ 60 - 2
mineru/cli/client.py

@@ -2,7 +2,9 @@
 import os
 import os
 import click
 import click
 from pathlib import Path
 from pathlib import Path
+import torch
 from loguru import logger
 from loguru import logger
+from mineru.utils.model_utils import get_vram
 from ..version import __version__
 from ..version import __version__
 from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
 from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
 
 
@@ -38,7 +40,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
     vlm-huggingface: More general.
     vlm-huggingface: More general.
     vlm-sglang-engine: Faster(engine).
     vlm-sglang-engine: Faster(engine).
     vlm-sglang-client: Faster(client).
     vlm-sglang-client: Faster(client).
-    without method specified, huggingface will be used by default.""",
+    without method specified, pipeline will be used by default.""",
     default='pipeline',
     default='pipeline',
 )
 )
 @click.option(
 @click.option(
@@ -49,6 +51,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
     help="""
     help="""
     Input the languages in the pdf (if known) to improve OCR accuracy.  Optional.
     Input the languages in the pdf (if known) to improve OCR accuracy.  Optional.
     Without languages specified, 'ch' will be used by default.
     Without languages specified, 'ch' will be used by default.
+    Adapted only for the case where the backend is set to "pipeline".
     """,
     """,
     default='ch',
     default='ch',
 )
 )
@@ -78,8 +81,63 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
     help='The ending page for PDF parsing, beginning from 0.',
     help='The ending page for PDF parsing, beginning from 0.',
     default=None,
     default=None,
 )
 )
+@click.option(
+    '-f',
+    '--formula',
+    'formula_enable',
+    type=bool,
+    help='Enable formula parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
+    default=True,
+)
+@click.option(
+    '-t',
+    '--table',
+    'table_enable',
+    type=bool,
+    help='Enable table parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
+    default=True,
+)
+@click.option(
+    '-d',
+    '--device',
+    'device_mode',
+    type=str,
+    help='Device mode for model inference, e.g., "cpu", "cuda", "cuda:0", "npu", "npu:0", "mps". Adapted only for the case where the backend is set to "pipeline". ',
+    default=None,
+)
+@click.option(
+    '-vm',
+    '--virtual-vram',
+    'virtual_vram',
+    type=int,
+    help='Device mode for model inference, e.g., "cpu", "cuda", "cuda:0", "npu", "npu:0", "mps". Default is "cpu". Adapted only for the case where the backend is set to "pipeline". ',
+    default=None,
+)
+
+
+def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram):
+
+    os.environ['MINERU_FORMULA_ENABLE'] = str(formula_enable).lower()
+    os.environ['MINERU_TABLE_ENABLE'] = str(table_enable).lower()
+    def get_device_mode() -> str:
+        if device_mode is not None:
+            return device_mode
+        if torch.cuda.is_available():
+            return "cuda"
+        if torch.backends.mps.is_available():
+            return "mps"
+        return "cpu"
+    os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
+
+    def get_virtual_vram_size() -> int:
+        if virtual_vram is not None:
+            return virtual_vram
+        if get_device_mode().startswith("cuda") or get_device_mode().startswith("npu"):
+            return round(get_vram(get_device_mode()))
+        return 1
+
+    os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
 
 
-def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_page_id):
     os.makedirs(output_dir, exist_ok=True)
     os.makedirs(output_dir, exist_ok=True)
 
 
     def parse_doc(path_list: list[Path]):
     def parse_doc(path_list: list[Path]):