|
|
@@ -58,7 +58,7 @@ def mfd_model_init(weight):
|
|
|
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
|
|
|
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
|
|
cfg = Config(args)
|
|
|
- cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
|
|
|
+ cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
|
|
cfg.config.model.model_config.model_name = weight_dir
|
|
|
cfg.config.model.tokenizer_config.path = weight_dir
|
|
|
task = tasks.setup_task(cfg)
|