Эх сурвалжийг харах

feat(model): add npu support and optimize table model

- Add NPU support for memory cleaning and model initialization
- Optimize table model initialization and prediction process
- Update memory utils to support NPU
- Add language parameter for table model
myhloli 10 сар өмнө
parent
commit
7990e7dfbb

+ 3 - 0
magic_pdf/libs/clean_memory.py

@@ -7,4 +7,7 @@ def clean_memory():
     if torch.cuda.is_available():
         torch.cuda.empty_cache()
         torch.cuda.ipc_collect()
+    elif torch.npu.is_available():
+        torch.npu.empty_cache()
+        torch.npu.ipc_collect()
     gc.collect()

+ 7 - 0
magic_pdf/model/pdf_extract_kit.py

@@ -87,6 +87,12 @@ class CustomPEKModel:
         )
         # 初始化解析方案
         self.device = kwargs.get('device', 'cpu')
+
+        if str(self.device).startswith("npu"):
+            import torch_npu
+            os.environ['FLAGS_npu_jit_compile'] = '0'
+            os.environ['FLAGS_use_stride_kernel'] = '0'
+
         logger.info('using device: {}'.format(self.device))
         models_dir = kwargs.get(
             'models_dir', os.path.join(root_dir, 'resources', 'models')
@@ -164,6 +170,7 @@ class CustomPEKModel:
                 table_model_path=str(os.path.join(models_dir, table_model_dir)),
                 table_max_time=self.table_max_time,
                 device=self.device,
+                lang=self.lang,
             )
 
         logger.info('DocAnalysis init done!')

+ 18 - 4
magic_pdf/model/sub_modules/model_init.py

@@ -1,6 +1,8 @@
+import torch
 from loguru import logger
 
 from magic_pdf.config.constants import MODEL_NAME
+from magic_pdf.libs.config_reader import get_device
 from magic_pdf.model.model_list import AtomicModel
 from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
     DocLayoutYOLOModel
@@ -19,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
     TableMasterPaddleModel
 
 
-def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
+def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None):
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
         table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
     elif table_model_type == MODEL_NAME.TABLE_MASTER:
@@ -29,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
         }
         table_model = TableMasterPaddleModel(config)
     elif table_model_type == MODEL_NAME.RAPID_TABLE:
-        table_model = RapidTableModel()
+        table_model = RapidTableModel(lang)
     else:
         logger.error('table model type not allow')
         exit(1)
@@ -38,6 +40,8 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
 
 
 def mfd_model_init(weight, device='cpu'):
+    if str(device).startswith("npu"):
+        device = torch.device(device)
     mfd_model = YOLOv8MFDModel(weight, device)
     return mfd_model
 
@@ -53,6 +57,8 @@ def layout_model_init(weight, config_file, device):
 
 
 def doclayout_yolo_model_init(weight, device='cpu'):
+    if str(device).startswith("npu"):
+        device = torch.device(device)
     model = DocLayoutYOLOModel(weight, device)
     return model
 
@@ -63,6 +69,12 @@ def ocr_model_init(show_log: bool = False,
                    use_dilation=True,
                    det_db_unclip_ratio=1.8,
                    ):
+
+    use_npu = False
+    device = get_device()
+    if str(device).startswith("npu"):
+        use_npu = True
+
     if lang is not None and lang != '':
         model = ModifiedPaddleOCR(
             show_log=show_log,
@@ -70,6 +82,7 @@ def ocr_model_init(show_log: bool = False,
             lang=lang,
             use_dilation=use_dilation,
             det_db_unclip_ratio=det_db_unclip_ratio,
+            use_npu=use_npu,
         )
     else:
         model = ModifiedPaddleOCR(
@@ -77,7 +90,7 @@ def ocr_model_init(show_log: bool = False,
             det_db_box_thresh=det_db_box_thresh,
             use_dilation=use_dilation,
             det_db_unclip_ratio=det_db_unclip_ratio,
-            # use_angle_cls=True,
+            use_npu=use_npu,
         )
     return model
 
@@ -146,7 +159,8 @@ def atom_model_init(model_name: str, **kwargs):
             kwargs.get('table_model_name'),
             kwargs.get('table_model_path'),
             kwargs.get('table_max_time'),
-            kwargs.get('device')
+            kwargs.get('device'),
+            kwargs.get('lang'),
         )
     else:
         logger.error('model name not allow')

+ 3 - 0
magic_pdf/model/sub_modules/model_utils.py

@@ -54,4 +54,7 @@ def get_vram(device):
     if torch.cuda.is_available() and device != 'cpu':
         total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # 将字节转换为 GB
         return total_memory
+    elif torch.npu.is_available() and device != 'cpu':
+        total_memory = torch.npu.get_device_properties(device).total_memory / (1024 ** 3)  # 转为 GB
+        return total_memory
     return None

+ 34 - 6
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -1,16 +1,44 @@
+import os
+import cv2
 import numpy as np
 from rapid_table import RapidTable
 from rapidocr_paddle import RapidOCR
 
+try:
+    import torchtext
+
+    if torchtext.__version__ >= '0.18.0':
+        torchtext.disable_torchtext_deprecation_warning()
+except ImportError:
+    pass
+os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
+
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
+
 
 class RapidTableModel(object):
-    def __init__(self):
+    def __init__(self, lang=None):
         self.table_model = RapidTable()
-        self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
+        # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
+
+        atom_model_manager = AtomModelSingleton()
+        self.ocr_engine = atom_model_manager.get_atom_model(
+            atom_model_name='ocr',
+            ocr_show_log=False,
+            det_db_box_thresh=0.3,
+            lang=lang,
+        )
 
     def predict(self, image):
-        ocr_result, _ = self.ocr_engine(np.asarray(image))
-        if ocr_result is None:
+        # ocr_result, _ = self.ocr_engine(np.asarray(image))
+
+        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+        ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
+                      len(item) == 2 and isinstance(item[1], tuple)]
+
+        if ocr_result:
+            html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
+            return html_code, table_cell_bboxes, elapse
+        else:
             return None, None, None
-        html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
-        return html_code, table_cell_bboxes, elapse

+ 8 - 0
magic_pdf/pdf_parse_union_core_v2.py

@@ -284,6 +284,14 @@ def model_init(model_name: str):
             supports_bfloat16 = True
         else:
             supports_bfloat16 = False
+
+    elif torch.npu.is_available():
+        device = torch.device('npu')
+        if torch.npu.is_bf16_supported():
+            supports_bfloat16 = True
+        else:
+            supports_bfloat16 = False
+
     else:
         device = torch.device('cpu')
         supports_bfloat16 = False