Эх сурвалжийг харах

refactor: improve environment variable initialization and enhance GPU memory handling

myhloli 5 сар өмнө
parent
commit
3eef1218e4
1 өөрчлөгдсөн 11 нэмэгдсэн , 7 устгасан
  1. 11 7
      mineru/cli/client.py

+ 11 - 7
mineru/cli/client.py

@@ -110,7 +110,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
     '--virtual-vram',
     'virtual_vram',
     type=int,
-    help='Device mode for model inference, e.g., "cpu", "cuda", "cuda:0", "npu", "npu:0", "mps". Default is "cpu". Adapted only for the case where the backend is set to "pipeline". ',
+    help='Upper limit of GPU memory occupied by a single process. Adapted only for the case where the backend is set to "pipeline". ',
     default=None,
 )
 @click.option(
@@ -127,8 +127,10 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
 
 def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source):
 
-    os.environ['MINERU_FORMULA_ENABLE'] = str(formula_enable).lower()
-    os.environ['MINERU_TABLE_ENABLE'] = str(table_enable).lower()
+    if os.getenv('MINERU_FORMULA_ENABLE', None) is None:
+        os.environ['MINERU_FORMULA_ENABLE'] = str(formula_enable).lower()
+    if os.getenv('MINERU_TABLE_ENABLE', None) is None:
+        os.environ['MINERU_TABLE_ENABLE'] = str(table_enable).lower()
     def get_device_mode() -> str:
         if device_mode is not None:
             return device_mode
@@ -137,7 +139,8 @@ def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_p
         if torch.backends.mps.is_available():
             return "mps"
         return "cpu"
-    os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
+    if os.getenv('MINERU_DEVICE_MODE', None) is None:
+        os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
 
     def get_virtual_vram_size() -> int:
         if virtual_vram is not None:
@@ -145,10 +148,11 @@ def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_p
         if get_device_mode().startswith("cuda") or get_device_mode().startswith("npu"):
             return round(get_vram(get_device_mode()))
         return 1
+    if os.getenv('MINERU_VIRTUAL_VRAM_SIZE', None) is None:
+        os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
 
-    os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
-
-    os.environ['MINERU_MODEL_SOURCE'] = model_source
+    if os.getenv('MINERU_BACKEND', None) is None:
+        os.environ['MINERU_MODEL_SOURCE'] = model_source
 
     os.makedirs(output_dir, exist_ok=True)