瀏覽代碼

refactor(device): optimize memory cleaning and device selection

- Update clean_memory function to support both CUDA and NPU devices
- Implement get_device function to centralize device selection logic
- Modify model initialization and memory cleaning to use the selected device
- Update RapidTableModel to support both RapidOCR and PaddleOCR engines
myhloli 10 月之前
父節點
當前提交
50f4841716

+ 10 - 7
magic_pdf/libs/clean_memory.py

@@ -3,11 +3,14 @@ import torch
 import gc
 
 
-def clean_memory():
-    if torch.cuda.is_available():
-        torch.cuda.empty_cache()
-        torch.cuda.ipc_collect()
-    elif torch.npu.is_available():
-        torch.npu.empty_cache()
-        torch.npu.ipc_collect()
+def clean_memory(device='cuda'):
+    if device == 'cuda':
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.ipc_collect()
+    elif str(device).startswith("npu"):
+        import torch_npu
+        if torch.npu.is_available():
+            torch_npu.empty_cache()
+            torch_npu.ipc_collect()
     gc.collect()

+ 2 - 1
magic_pdf/model/batch_analyze.py

@@ -10,6 +10,7 @@ from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
 from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.clean_memory import clean_memory
+from magic_pdf.libs.config_reader import get_device
 from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
 from magic_pdf.model.pdf_extract_kit import CustomPEKModel
 from magic_pdf.model.sub_modules.model_utils import (
@@ -268,7 +269,7 @@ def doc_batch_analyze(
 
     # TODO: clean memory when gpu memory is not enough
     clean_memory_start_time = time.time()
-    clean_memory()
+    clean_memory(get_device())
     logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
 
     return InferenceResult(model_json, dataset)

+ 1 - 1
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -183,7 +183,7 @@ def doc_analyze(
         model_json.append(page_dict)
 
     gc_start = time.time()
-    clean_memory()
+    clean_memory(get_device())
     gc_time = round(time.time() - gc_start, 2)
     logger.info(f'gc time: {gc_time}')
 

+ 1 - 1
magic_pdf/model/pdf_extract_kit.py

@@ -170,7 +170,7 @@ class CustomPEKModel:
                 table_model_path=str(os.path.join(models_dir, table_model_dir)),
                 table_max_time=self.table_max_time,
                 device=self.device,
-                lang=self.lang,
+                ocr_engine=self.ocr_model,
             )
 
         logger.info('DocAnalysis init done!')

+ 2 - 3
magic_pdf/model/sub_modules/model_init.py

@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
     TableMasterPaddleModel
 
 
-def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None):
+def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
         table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
     elif table_model_type == MODEL_NAME.TABLE_MASTER:
@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lan
         }
         table_model = TableMasterPaddleModel(config)
     elif table_model_type == MODEL_NAME.RAPID_TABLE:
-        table_model = RapidTableModel(lang)
+        table_model = RapidTableModel(ocr_engine)
     else:
         logger.error('table model type not allow')
         exit(1)
@@ -160,7 +160,6 @@ def atom_model_init(model_name: str, **kwargs):
             kwargs.get('table_model_path'),
             kwargs.get('table_max_time'),
             kwargs.get('device'),
-            kwargs.get('lang'),
         )
     else:
         logger.error('model name not allow')

+ 8 - 5
magic_pdf/model/sub_modules/model_utils.py

@@ -45,7 +45,7 @@ def clean_vram(device, vram_threshold=8):
     total_memory = get_vram(device)
     if total_memory and total_memory <= vram_threshold:
         gc_start = time.time()
-        clean_memory()
+        clean_memory(device)
         gc_time = round(time.time() - gc_start, 2)
         logger.info(f"gc time: {gc_time}")
 
@@ -54,7 +54,10 @@ def get_vram(device):
     if torch.cuda.is_available() and device != 'cpu':
         total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # 将字节转换为 GB
         return total_memory
-    elif torch.npu.is_available() and device != 'cpu':
-        total_memory = torch.npu.get_device_properties(device).total_memory / (1024 ** 3)  # 转为 GB
-        return total_memory
-    return None
+    elif str(device).startswith("npu"):
+        import torch_npu
+        if torch.npu.is_available():
+            total_memory = torch.npu.get_device_properties(device).total_memory / (1024 ** 3)  # 转为 GB
+            return total_memory
+    else:
+        return None

+ 18 - 27
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -1,41 +1,32 @@
-import os
 import cv2
 import numpy as np
+from loguru import logger
 from rapid_table import RapidTable
 from rapidocr_paddle import RapidOCR
 
-try:
-    import torchtext
-
-    if torchtext.__version__ >= '0.18.0':
-        torchtext.disable_torchtext_deprecation_warning()
-except ImportError:
-    pass
-os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
-
-from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
-
 
 class RapidTableModel(object):
-    def __init__(self, lang=None):
+    def __init__(self, ocr_engine):
         self.table_model = RapidTable()
-        # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
-
-        atom_model_manager = AtomModelSingleton()
-        self.ocr_engine = atom_model_manager.get_atom_model(
-            atom_model_name='ocr',
-            ocr_show_log=False,
-            det_db_box_thresh=0.3,
-            lang=lang,
-        )
+        if ocr_engine is None:
+            self.ocr_model_name = "RapidOCR"
+            self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
+        else:
+            self.ocr_model_name = "PaddleOCR"
+            self.ocr_engine = ocr_engine
 
     def predict(self, image):
-        # ocr_result, _ = self.ocr_engine(np.asarray(image))
 
-        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
-        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
-        ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
-                      len(item) == 2 and isinstance(item[1], tuple)]
+        if self.ocr_model_name == "RapidOCR":
+            ocr_result, _ = self.ocr_engine(np.asarray(image))
+        elif self.ocr_model_name == "PaddleOCR":
+            bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+            ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+            ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
+                          len(item) == 2 and isinstance(item[1], tuple)]
+        else:
+            logger.error("OCR model not supported")
+            ocr_result = None
 
         if ocr_result:
             html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)

+ 12 - 9
magic_pdf/pdf_parse_union_core_v2.py

@@ -14,7 +14,7 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
 from magic_pdf.data.dataset import Dataset, PageableData
 from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
 from magic_pdf.libs.clean_memory import clean_memory
-from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config
+from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
 from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
@@ -277,21 +277,24 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
 
 def model_init(model_name: str):
     from transformers import LayoutLMv3ForTokenClassification
-
+    device = get_device()
     if torch.cuda.is_available():
         device = torch.device('cuda')
         if torch.cuda.is_bf16_supported():
             supports_bfloat16 = True
         else:
             supports_bfloat16 = False
-
-    elif torch.npu.is_available():
-        device = torch.device('npu')
-        if torch.npu.is_bf16_supported():
-            supports_bfloat16 = True
+    elif str(device).startswith("npu"):
+        import torch_npu
+        if torch.npu.is_available():
+            device = torch.device('npu')
+            if torch.npu.is_bf16_supported():
+                supports_bfloat16 = True
+            else:
+                supports_bfloat16 = False
         else:
+            device = torch.device('cpu')
             supports_bfloat16 = False
-
     else:
         device = torch.device('cpu')
         supports_bfloat16 = False
@@ -865,7 +868,7 @@ def pdf_parse_union(
         'pdf_info': pdf_info_list,
     }
 
-    clean_memory()
+    clean_memory(get_device())
 
     return new_pdf_info_dict