|
|
@@ -269,11 +269,9 @@ def load_pretrain_weights(exe,
|
|
|
vars_to_load.append(var)
|
|
|
logging.debug("Weight {} will be load".format(var.name))
|
|
|
|
|
|
- fluid.io.load_vars(
|
|
|
- executor=exe,
|
|
|
- dirname=weights_dir,
|
|
|
- main_program=main_prog,
|
|
|
- vars=vars_to_load)
|
|
|
+ params_dict = fluid.io.load_program_state(
|
|
|
+ weights_dir, var_list=vars_to_load)
|
|
|
+ fluid.io.set_program_state(main_prog, params_dict)
|
|
|
if len(vars_to_load) == 0:
|
|
|
logging.warning(
|
|
|
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
|