|
|
@@ -275,7 +275,7 @@ def _load_state_dict_into_model(
|
|
|
model_to_load, state_dict, start_prefix, convert_from_hf
|
|
|
):
|
|
|
# torch will cast dtype in load_state_dict, but paddle strictly check dtype
|
|
|
- _convert_state_dict_dtype_and_shape(state_dict, model_to_load)
|
|
|
+ _convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf)
|
|
|
|
|
|
error_msgs = []
|
|
|
|
|
|
@@ -305,12 +305,16 @@ def _load_state_dict_into_model(
|
|
|
return error_msgs
|
|
|
|
|
|
|
|
|
-def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
|
|
|
+def _convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf):
|
|
|
# convert the dtype of state dict
|
|
|
def is_0d_or_1d(tensor):
|
|
|
return len(tensor.shape) == 0 or list(tensor.shape) == [1]
|
|
|
|
|
|
- for key, value in model_to_load.state_dict().items():
|
|
|
+ if convert_from_hf:
|
|
|
+ model_state_dict = model_to_load.get_hf_state_dict()
|
|
|
+ else:
|
|
|
+ model_state_dict = model_to_load.state_dict()
|
|
|
+ for key, value in model_state_dict.items():
|
|
|
if key in list(state_dict.keys()):
|
|
|
if isinstance(state_dict[key], np.ndarray):
|
|
|
raise ValueError(
|