ソースを参照

refactor: streamline formula and table enable configurations in the pipeline

myhloli 5 ヶ月 前
コミット
98b8c4a967

+ 13 - 14
mineru/backend/pipeline/pipeline_analyze.py

@@ -5,7 +5,7 @@ import PIL.Image
 import torch
 
 from .model_init import MineruPipelineModel
-from mineru.utils.config_reader import get_device, get_formula_config, get_table_recog_config
+from mineru.utils.config_reader import get_device
 from ...utils.pdf_classify import classify
 from ...utils.pdf_image_tools import load_images_from_pdf
 
@@ -44,20 +44,15 @@ class ModelSingleton:
 
 def custom_model_init(
     lang=None,
-    formula_enable=None,
-    table_enable=None,
+    formula_enable=True,
+    table_enable=True,
 ):
     model_init_start = time.time()
     # 从配置文件读取model-dir和device
     device = get_device()
 
-    formula_config = get_formula_config()
-    if formula_enable is not None:
-        formula_config['enable'] = formula_enable
-
-    table_config = get_table_recog_config()
-    if table_enable is not None:
-        table_config['enable'] = table_enable
+    formula_config = {"enable": formula_enable}
+    table_config = {"enable": table_enable}
 
     model_input = {
         'device': device,
@@ -78,8 +73,8 @@ def doc_analyze(
         pdf_bytes_list,
         lang_list,
         parse_method: str = 'auto',
-        formula_enable=None,
-        table_enable=None,
+        formula_enable=True,
+        table_enable=True,
 ):
     MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
 
@@ -152,8 +147,8 @@ def doc_analyze(
 
 def batch_image_analyze(
         images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
-        formula_enable=None,
-        table_enable=None):
+        formula_enable=True,
+        table_enable=True):
     # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
 
     from .batch_analyze import BatchAnalyze
@@ -194,6 +189,10 @@ def batch_image_analyze(
             batch_ratio = 1
             logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
 
+    if os.getenv('MINERU_FORMULA_ENABLE', None) is not None:
+        formula_enable = os.getenv('MINERU_FORMULA_ENABLE').lower() == 'true'
+    if os.getenv('MINERU_TABLE_ENABLE', None) is not None:
+        table_enable = os.getenv('MINERU_TABLE_ENABLE').lower() == 'true'
     batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
     results = batch_model(images_with_extra_info)
 

+ 2 - 4
mineru/cli/client.py

@@ -140,10 +140,6 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
 
 def main(input_path, output_dir, method, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source):
 
-    if os.getenv('MINERU_FORMULA_ENABLE', None) is None:
-        os.environ['MINERU_FORMULA_ENABLE'] = str(formula_enable).lower()
-    if os.getenv('MINERU_TABLE_ENABLE', None) is None:
-        os.environ['MINERU_TABLE_ENABLE'] = str(table_enable).lower()
     def get_device_mode() -> str:
         if device_mode is not None:
             return device_mode
@@ -184,6 +180,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
                 p_lang_list=lang_list,
                 backend=backend,
                 parse_method=method,
+                p_formula_enable=formula_enable,
+                p_table_enable=table_enable,
                 server_url=server_url,
                 start_page_id=start_page_id,
                 end_page_id=end_page_id

+ 3 - 0
mineru/cli/common.py

@@ -115,6 +115,9 @@ def do_parse(
             pdf_doc = all_pdf_docs[idx]
             _lang = lang_list[idx]
             _ocr_enable = ocr_enabled_list[idx]
+
+            if os.getenv('MINERU_FORMULA_ENABLE', None) is not None:
+                p_formula_enable = os.getenv('MINERU_FORMULA_ENABLE').lower() == 'true'
             middle_json = pipeline_result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr_enable, p_formula_enable)
 
             pdf_info = middle_json["pdf_info"]

+ 0 - 18
mineru/utils/config_reader.py

@@ -86,24 +86,6 @@ def get_device():
         return "cpu"
 
 
-def get_table_recog_config():
-    table_enable = os.getenv('MINERU_TABLE_ENABLE', None)
-    if table_enable is not None:
-        return json.loads(f'{{"enable": {table_enable}}}')
-    else:
-        # logger.warning(f"not found 'MINERU_TABLE_ENABLE' in environment variable, use 'true' as default.")
-        return json.loads(f'{{"enable": true}}')
-
-
-def get_formula_config():
-    formula_enable = os.getenv('MINERU_FORMULA_ENABLE', None)
-    if formula_enable is not None:
-        return json.loads(f'{{"enable": {formula_enable}}}')
-    else:
-        # logger.warning(f"not found 'MINERU_FORMULA_ENABLE' in environment variable, use 'true' as default.")
-        return json.loads(f'{{"enable": true}}')
-
-
 def get_latex_delimiter_config():
     config = read_config()
     if config is None: