Selaa lähdekoodia

fix: integrate clean_vram function to manage GPU memory usage during predictions

myhloli 2 kuukautta sitten
vanhempi
commit
c65fb7de8a

+ 7 - 2
mineru/backend/pipeline/batch_analyze.py

@@ -9,7 +9,7 @@ import numpy as np
 from .model_init import AtomModelSingleton
 from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
-from ...utils.model_utils import crop_img, get_res_list_from_layout_res
+from ...utils.model_utils import crop_img, get_res_list_from_layout_res, clean_vram
 from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
 from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence, get_rotate_crop_image
 from ...utils.pdf_image_tools import get_crop_np_img
@@ -71,7 +71,7 @@ class BatchAnalyze:
                 mfr_count += len(images_formula_list[image_index])
 
         # 清理显存
-        # clean_vram(self.model.device, vram_threshold=8)
+        clean_vram(self.model.device, vram_threshold=8)
 
         ocr_res_list_all_page = []
         table_res_list_all_page = []
@@ -183,6 +183,8 @@ class BatchAnalyze:
                             [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
                         ]
 
+            clean_vram(self.model.device, vram_threshold=8)
+
             # 先对所有表格使用无线表格模型,然后对分类为有线的表格使用有线表格模型
             wireless_table_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.WirelessTable,
@@ -198,6 +200,9 @@ class BatchAnalyze:
                 for table_res_dict in tqdm(
                         wired_table_res_list, desc="Table-wired Predict"
                 ):
+                    if not table_res_dict.get("ocr_result", None):
+                        continue
+
                     wired_table_model = atom_model_manager.get_atom_model(
                         atom_model_name=AtomicModel.WiredTable,
                         lang=table_res_dict["lang"],

+ 16 - 69
mineru/model/table/rec/slanet_plus/main.py

@@ -1,11 +1,8 @@
 import os
-import argparse
 import copy
-import importlib
 import time
 import html
 from dataclasses import asdict, dataclass
-from enum import Enum
 from pathlib import Path
 from typing import Dict, List, Optional, Tuple, Union
 
@@ -19,32 +16,10 @@ from .table_structure import TableStructurer
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
-root_dir = Path(__file__).resolve().parent
-
-
-class ModelType(Enum):
-    PPSTRUCTURE_EN = "ppstructure_en"
-    PPSTRUCTURE_ZH = "ppstructure_zh"
-    SLANETPLUS = "slanet_plus"
-    UNITABLE = "unitable"
-
-
-ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
-KEY_TO_MODEL_URL = {
-    ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
-    ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
-    ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
-    ModelType.UNITABLE.value: {
-        "encoder": f"{ROOT_URL}/unitable/encoder.pth",
-        "decoder": f"{ROOT_URL}/unitable/decoder.pth",
-        "vocab": f"{ROOT_URL}/unitable/vocab.json",
-    },
-}
-
 
 @dataclass
 class RapidTableInput:
-    model_type: Optional[str] = ModelType.SLANETPLUS.value
+    model_type: Optional[str] = "slanet_plus"
     model_path: Union[str, Path, None, Dict[str, str]] = None
     use_cuda: bool = False
     device: str = "cpu"
@@ -60,18 +35,7 @@ class RapidTableOutput:
 
 class RapidTable:
     def __init__(self, config: RapidTableInput):
-        self.model_type = config.model_type
-        if self.model_type not in KEY_TO_MODEL_URL:
-            model_list = ",".join(KEY_TO_MODEL_URL)
-            raise ValueError(
-                f"{self.model_type} is not supported. The currently supported models are {model_list}."
-            )
-
-        config.model_path = config.model_path
-        if self.model_type == ModelType.SLANETPLUS.value:
-            self.table_structure = TableStructurer(asdict(config))
-        else:
-            raise ValueError(f"{self.model_type} is not supported.")
+        self.table_structure = TableStructurer(asdict(config))
         self.table_matcher = TableMatch()
 
     def predict(
@@ -177,29 +141,6 @@ class RapidTable:
         return cell_bboxes
 
 
-def parse_args(arg_list: Optional[List[str]] = None):
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "-v",
-        "--vis",
-        action="store_true",
-        default=False,
-        help="Wheter to visualize the layout results.",
-    )
-    parser.add_argument(
-        "-img", "--img_path", type=str, required=True, help="Path to image for layout."
-    )
-    parser.add_argument(
-        "-m",
-        "--model_type",
-        type=str,
-        default=ModelType.SLANETPLUS.value,
-        choices=list(KEY_TO_MODEL_URL),
-    )
-    args = parser.parse_args(arg_list)
-    return args
-
-
 def escape_html(input_string):
     """Escape HTML Entities."""
     return html.escape(input_string)
@@ -245,22 +186,28 @@ class RapidTableModel(object):
 
     def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
         """对传入的字典列表进行批量预测,无返回值"""
-        with tqdm(total=len(table_res_list), desc="Table-wireless Predict") as pbar:
-            for index in range(0, len(table_res_list), batch_size):
+
+        not_none_table_res_list = []
+        for table_res in table_res_list:
+            if table_res.get("ocr_result", None):
+                not_none_table_res_list.append(table_res)
+
+        with tqdm(total=len(not_none_table_res_list), desc="Table-wireless Predict") as pbar:
+            for index in range(0, len(not_none_table_res_list), batch_size):
                 batch_imgs = [
-                    cv2.cvtColor(np.asarray(table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
-                    for i in range(index, min(index + batch_size, len(table_res_list)))
+                    cv2.cvtColor(np.asarray(not_none_table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
+                    for i in range(index, min(index + batch_size, len(not_none_table_res_list)))
                 ]
                 batch_ocrs = [
-                    table_res_list[i]["ocr_result"]
-                    for i in range(index, min(index + batch_size, len(table_res_list)))
+                    not_none_table_res_list[i]["ocr_result"]
+                    for i in range(index, min(index + batch_size, len(not_none_table_res_list)))
                 ]
                 results = self.table_model.batch_predict(
                     batch_imgs, batch_ocrs, batch_size=batch_size
                 )
                 for i, result in enumerate(results):
                     if result.pred_html:
-                        table_res_list[index + i]['table_res']['html'] = result.pred_html
+                        not_none_table_res_list[index + i]['table_res']['html'] = result.pred_html
 
                 # 更新进度条
-                pbar.update(len(results))
+                pbar.update(len(results))

+ 3 - 0
mineru/utils/model_utils.py

@@ -1,3 +1,4 @@
+import os
 import time
 import gc
 from PIL import Image
@@ -427,6 +428,8 @@ def clean_memory(device='cuda'):
 
 def clean_vram(device, vram_threshold=8):
     total_memory = get_vram(device)
+    if total_memory is not None:
+        total_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(total_memory)))
     if total_memory and total_memory <= vram_threshold:
         gc_start = time.time()
         clean_memory(device)