|
|
@@ -92,6 +92,8 @@ class CustomPEKModel:
|
|
|
import torch_npu
|
|
|
os.environ['FLAGS_npu_jit_compile'] = '0'
|
|
|
os.environ['FLAGS_use_stride_kernel'] = '0'
|
|
|
+ elif str(self.device).startswith("mps"):
|
|
|
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
|
|
|
|
|
logger.info('using device: {}'.format(self.device))
|
|
|
models_dir = kwargs.get(
|
|
|
@@ -119,11 +121,12 @@ class CustomPEKModel:
|
|
|
os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
|
|
|
)
|
|
|
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
|
|
|
+
|
|
|
self.mfr_model = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.MFR,
|
|
|
mfr_weight_dir=mfr_weight_dir,
|
|
|
mfr_cfg_path=mfr_cfg_path,
|
|
|
- device=self.device,
|
|
|
+ device='cpu' if str(self.device).startswith("mps") else self.device,
|
|
|
)
|
|
|
|
|
|
# 初始化layout模型
|