浏览代码

fix(device): enable MPS support and fix related issues

- Add MPS support for Apple Silicon devices
- Implement empty_cache() for MPS devices
- Set PYTORCH_ENABLE_MPS_FALLBACK environment variable
- Adjust MFR model device allocation for MPS
myhloli 10 月之前
父节点
当前提交
203b8f9004
共有 2 个文件被更改,包括 6 次插入1 次删除
  1. 2 0
      magic_pdf/libs/clean_memory.py
  2. 4 1
      magic_pdf/model/pdf_extract_kit.py

+ 2 - 0
magic_pdf/libs/clean_memory.py

@@ -12,4 +12,6 @@ def clean_memory(device='cuda'):
         import torch_npu
         if torch_npu.npu.is_available():
             torch_npu.npu.empty_cache()
+    elif str(device).startswith("mps"):
+        torch.mps.empty_cache()
     gc.collect()

+ 4 - 1
magic_pdf/model/pdf_extract_kit.py

@@ -92,6 +92,8 @@ class CustomPEKModel:
             import torch_npu
             os.environ['FLAGS_npu_jit_compile'] = '0'
             os.environ['FLAGS_use_stride_kernel'] = '0'
+        elif str(self.device).startswith("mps"):
+            os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
 
         logger.info('using device: {}'.format(self.device))
         models_dir = kwargs.get(
@@ -119,11 +121,12 @@ class CustomPEKModel:
                 os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
             )
             mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
+
             self.mfr_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFR,
                 mfr_weight_dir=mfr_weight_dir,
                 mfr_cfg_path=mfr_cfg_path,
-                device=self.device,
+                device='cpu' if str(self.device).startswith("mps") else self.device,
             )
 
         # 初始化layout模型