浏览代码

fix: refactor formula and table enable handling to use environment variables

myhloli 5 月之前
父节点
当前提交
1383787bad

+ 3 - 2
mineru/backend/pipeline/batch_analyze.py

@@ -5,6 +5,7 @@ from collections import defaultdict
 import numpy as np
 
 from .model_init import AtomModelSingleton
+from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
 from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
 
@@ -16,8 +17,8 @@ MFR_BASE_BATCH_SIZE = 16
 class BatchAnalyze:
     def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
         self.batch_ratio = batch_ratio
-        self.formula_enable = formula_enable
-        self.table_enable = table_enable
+        self.formula_enable = get_formula_enable(formula_enable)
+        self.table_enable = get_table_enable(table_enable)
         self.model_manager = model_manager
         self.enable_ocr_det_batch = enable_ocr_det_batch
 

+ 2 - 2
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -4,7 +4,7 @@ import time
 from loguru import logger
 from tqdm import tqdm
 
-from mineru.utils.config_reader import get_device, get_llm_aided_config
+from mineru.utils.config_reader import get_device, get_llm_aided_config, get_formula_enable
 from mineru.backend.pipeline.model_init import AtomModelSingleton
 from mineru.backend.pipeline.para_split import para_split
 from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
@@ -78,7 +78,7 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
 
 
     """将所有区块的bbox整理到一起"""
-    if formula_enabled:
+    if get_formula_enable(formula_enabled):
         interline_equation_blocks = []
 
     if len(interline_equation_blocks) > 0:

+ 0 - 4
mineru/backend/pipeline/pipeline_analyze.py

@@ -189,10 +189,6 @@ def batch_image_analyze(
             batch_ratio = 1
             logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
 
-    if os.getenv('MINERU_FORMULA_ENABLE', None) is not None:
-        formula_enable = os.getenv('MINERU_FORMULA_ENABLE').lower() == 'true'
-    if os.getenv('MINERU_TABLE_ENABLE', None) is not None:
-        table_enable = os.getenv('MINERU_TABLE_ENABLE').lower() == 'true'
     batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
     results = batch_model(images_with_extra_info)
 

+ 0 - 2
mineru/cli/common.py

@@ -116,8 +116,6 @@ def do_parse(
             _lang = lang_list[idx]
             _ocr_enable = ocr_enabled_list[idx]
 
-            if os.getenv('MINERU_FORMULA_ENABLE', None) is not None:
-                p_formula_enable = os.getenv('MINERU_FORMULA_ENABLE').lower() == 'true'
             middle_json = pipeline_result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr_enable, p_formula_enable)
 
             pdf_info = middle_json["pdf_info"]

+ 12 - 0
mineru/utils/config_reader.py

@@ -86,6 +86,18 @@ def get_device():
         return "cpu"
 
 
+def get_formula_enable(formula_enable):
+    formula_enable_env = os.getenv('MINERU_FORMULA_ENABLE')
+    formula_enable = formula_enable if formula_enable_env is None else formula_enable_env.lower() == 'true'
+    return formula_enable
+
+
+def get_table_enable(table_enable):
+    table_enable_env = os.getenv('MINERU_TABLE_ENABLE')
+    table_enable = table_enable if table_enable_env is None else table_enable_env.lower() == 'true'
+    return table_enable
+
+
 def get_latex_delimiter_config():
     config = read_config()
     if config is None: