|
|
@@ -21,7 +21,7 @@ class MathDataset(Dataset):
|
|
|
class UnimernetModel(object):
|
|
|
def __init__(self, weight_dir, _device_="cpu"):
|
|
|
from .unimernet_hf import UnimernetModel
|
|
|
- if _device_.startswith("mps"):
|
|
|
+ if _device_.startswith("mps") or _device_.startswith("npu"):
|
|
|
self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
|
|
|
else:
|
|
|
self.model = UnimernetModel.from_pretrained(weight_dir)
|