Эх сурвалжийг харах

fix(model): move environment variable settings to global scope

- Move environment variable settings for NPU, MPS, and other configurations to the global scope in doc_analyze_by_custom_model.py
- Remove redundant environment variable settings in pdf_extract_kit.py
- This change ensures consistent configuration across the application and avoids potential conflicts or duplicate settings
myhloli 9 сар өмнө
parent
commit
f5112e2157

+ 4 - 4
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,8 +1,11 @@
 import os
 import time
-
 import torch
 
+os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
+os.environ['FLAGS_use_stride_kernel'] = '0'
+os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
+os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 # 关闭paddle的信号处理
 import paddle
 paddle.disable_signal_handler()
@@ -12,11 +15,8 @@ from loguru import logger
 from magic_pdf.model.batch_analyze import BatchAnalyze
 from magic_pdf.model.sub_modules.model_utils import get_vram
 
-os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
-
 try:
     import torchtext
-
     if torchtext.__version__ >= '0.18.0':
         torchtext.disable_torchtext_deprecation_warning()
 except ImportError:

+ 0 - 7
magic_pdf/model/pdf_extract_kit.py

@@ -89,13 +89,6 @@ class CustomPEKModel:
         # 初始化解析方案
         self.device = kwargs.get('device', 'cpu')
 
-        if str(self.device).startswith("npu"):
-            import torch_npu
-            os.environ['FLAGS_npu_jit_compile'] = '0'
-            os.environ['FLAGS_use_stride_kernel'] = '0'
-        elif str(self.device).startswith("mps"):
-            os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
-
         logger.info('using device: {}'.format(self.device))
         models_dir = kwargs.get(
             'models_dir', os.path.join(root_dir, 'resources', 'models')