Răsfoiți Sursa

Merge pull request #1672 from myhloli/dev

refactor(model): integrate Ascend plugin for NPU support
Xiaomeng Zhao 9 luni în urmă
părinte
comite
80fd937ffd

+ 7 - 20
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,21 +1,22 @@
 import os
 import time
+import torch
 
+os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
+os.environ['FLAGS_use_stride_kernel'] = '0'
+os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
+os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 # 关闭paddle的信号处理
 import paddle
-import torch
+paddle.disable_signal_handler()
+
 from loguru import logger
 
 from magic_pdf.model.batch_analyze import BatchAnalyze
 from magic_pdf.model.sub_modules.model_utils import get_vram
 
-paddle.disable_signal_handler()
-
-os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
-
 try:
     import torchtext
-
     if torchtext.__version__ >= '0.18.0':
         torchtext.disable_torchtext_deprecation_warning()
 except ImportError:
@@ -32,20 +33,6 @@ from magic_pdf.model.model_list import MODEL
 from magic_pdf.operators.models import InferenceResult
 
 
-def dict_compare(d1, d2):
-    return d1.items() == d2.items()
-
-
-def remove_duplicates_dicts(lst):
-    unique_dicts = []
-    for dict_item in lst:
-        if not any(
-            dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
-        ):
-            unique_dicts.append(dict_item)
-    return unique_dicts
-
-
 class ModelSingleton:
     _instance = None
     _models = {}

+ 0 - 7
magic_pdf/model/pdf_extract_kit.py

@@ -89,13 +89,6 @@ class CustomPEKModel:
         # 初始化解析方案
         self.device = kwargs.get('device', 'cpu')
 
-        if str(self.device).startswith("npu"):
-            import torch_npu
-            os.environ['FLAGS_npu_jit_compile'] = '0'
-            os.environ['FLAGS_use_stride_kernel'] = '0'
-        elif str(self.device).startswith("mps"):
-            os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
-
         logger.info('using device: {}'.format(self.device))
         models_dir = kwargs.get(
             'models_dir', os.path.join(root_dir, 'resources', 'models')

+ 13 - 14
magic_pdf/model/sub_modules/model_init.py

@@ -4,22 +4,22 @@ from loguru import logger
 from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.model.model_list import AtomicModel
 from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
-from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
-    DocLayoutYOLOModel
-from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
-    Layoutlmv3_Predictor
+from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
+from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
 from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
 from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
-from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
-    ModifiedPaddleOCR
-from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
-    RapidTableModel
-# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
-from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
-    StructTableModel
-from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
-    TableMasterPaddleModel
 
+try:
+    from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
+    from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
+    logger.info('Using Ascend Plugin')
+except ImportError:
+    from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
+    # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
+    from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
+
+from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
+from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
@@ -76,7 +76,6 @@ def ocr_model_init(show_log: bool = False,
                    use_dilation=True,
                    det_db_unclip_ratio=1.8,
                    ):
-
     if lang is not None and lang != '':
         model = ModifiedPaddleOCR(
             show_log=show_log,