|
|
@@ -258,8 +258,6 @@ def load_state_dict(
|
|
|
tensor_parallel_split_mapping,
|
|
|
fliter_dict_keys,
|
|
|
"expected",
|
|
|
- quantization_linear_list=None,
|
|
|
- quantization_config=None,
|
|
|
dtype=None,
|
|
|
return_numpy=False,
|
|
|
convert_from_hf=convert_from_hf,
|
|
|
@@ -277,7 +275,7 @@ def _load_state_dict_into_model(
|
|
|
model_to_load, state_dict, start_prefix, convert_from_hf
|
|
|
):
|
|
|
# torch will cast dtype in load_state_dict, but paddle strictly check dtype
|
|
|
- _convert_state_dict_dtype_and_shape(state_dict, model_to_load)
|
|
|
+ _convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf)
|
|
|
|
|
|
error_msgs = []
|
|
|
|
|
|
@@ -307,12 +305,16 @@ def _load_state_dict_into_model(
|
|
|
return error_msgs
|
|
|
|
|
|
|
|
|
-def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
|
|
|
+def _convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf):
|
|
|
# convert the dtype of state dict
|
|
|
def is_0d_or_1d(tensor):
|
|
|
return len(tensor.shape) == 0 or list(tensor.shape) == [1]
|
|
|
|
|
|
- for key, value in model_to_load.state_dict().items():
|
|
|
+ if convert_from_hf:
|
|
|
+ model_state_dict = model_to_load.get_hf_state_dict()
|
|
|
+ else:
|
|
|
+ model_state_dict = model_to_load.state_dict()
|
|
|
+ for key, value in model_state_dict.items():
|
|
|
if key in list(state_dict.keys()):
|
|
|
if isinstance(state_dict[key], np.ndarray):
|
|
|
raise ValueError(
|
|
|
@@ -631,34 +633,6 @@ class PretrainedModel(
|
|
|
config.weightonly_group_size = predictor_args.weightonly_group_size
|
|
|
config.weight_block_size = predictor_args.weight_block_size
|
|
|
config.moe_quant_type = predictor_args.moe_quant_type
|
|
|
- if config.quantization_config.quant_method is not None:
|
|
|
- predictor_args.weight_block_size = (
|
|
|
- config.quantization_config.weight_block_size
|
|
|
- )
|
|
|
- config.weight_block_size = predictor_args.weight_block_size
|
|
|
-
|
|
|
- if config.quantization_config.quant_type is not None:
|
|
|
- if predictor_args.mode == "dynamic":
|
|
|
- predictor_args.quant_type = config.quantization_config.quant_type
|
|
|
- config.quant_type = config.quantization_config.quant_type
|
|
|
- if "c8" in config.quant_type:
|
|
|
- predictor_args.cachekv_int8_type = "static"
|
|
|
- if predictor_args.mode == "dynamic":
|
|
|
- config.cachekv_int8_type = "static"
|
|
|
-
|
|
|
- if predictor_args.mode == "dynamic":
|
|
|
- ptq_multicards_num = 0
|
|
|
- if os.path.exists(config.model_name_or_path):
|
|
|
- prefix = "act_scales_"
|
|
|
- for filename in os.listdir(config.model_name_or_path):
|
|
|
- if filename.startswith(prefix):
|
|
|
- ptq_multicards_num += 1
|
|
|
-
|
|
|
- logging.info(
|
|
|
- f"PTQ from {ptq_multicards_num} cards, so we will not split"
|
|
|
- )
|
|
|
- if ptq_multicards_num > 1:
|
|
|
- config.single_card_ptq = False
|
|
|
|
|
|
if predictor_args.block_attn:
|
|
|
config.block_size = predictor_args.block_size
|
|
|
@@ -1323,45 +1297,6 @@ class PretrainedModel(
|
|
|
".".join([prefix, s]) for s in quantization_linear_list
|
|
|
]
|
|
|
|
|
|
- # Weight quantization if not yet quantized & update loaded_keys
|
|
|
- if (
|
|
|
- hasattr(config, "quantization_config")
|
|
|
- and config.quantization_config.is_weight_quantize()
|
|
|
- ):
|
|
|
- try:
|
|
|
- from ..quantization.quantization_utils import (
|
|
|
- convert_to_quantize_state_dict,
|
|
|
- update_loaded_state_dict_keys,
|
|
|
- )
|
|
|
- except ImportError:
|
|
|
- raise ImportError(
|
|
|
- "Quantization features require `paddlepaddle >= 2.5.2`"
|
|
|
- )
|
|
|
- if state_dict is not None:
|
|
|
- state_dict = convert_to_quantize_state_dict(
|
|
|
- state_dict,
|
|
|
- quantization_linear_list,
|
|
|
- config.quantization_config,
|
|
|
- dtype,
|
|
|
- )
|
|
|
- loaded_keys = [k for k in state_dict.keys()]
|
|
|
- else:
|
|
|
- loaded_keys = update_loaded_state_dict_keys(
|
|
|
- loaded_keys, quantization_linear_list, config.quantization_config
|
|
|
- )
|
|
|
- if keep_in_fp32_modules is None:
|
|
|
- keep_in_fp32_modules = (
|
|
|
- ["quant_scale"]
|
|
|
- if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
|
|
|
- else None
|
|
|
- )
|
|
|
- else:
|
|
|
- keep_in_fp32_modules = (
|
|
|
- keep_in_fp32_modules + ["quant_scale"]
|
|
|
- if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
|
|
|
- else keep_in_fp32_modules
|
|
|
- )
|
|
|
-
|
|
|
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
|
|
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
|
|
|
|
|
@@ -1525,27 +1460,12 @@ class PretrainedModel(
|
|
|
ignore_mismatched_sizes,
|
|
|
)
|
|
|
|
|
|
- if (
|
|
|
- hasattr(config, "quantization_config")
|
|
|
- and config.quantization_config.is_weight_quantize()
|
|
|
- ):
|
|
|
- error_msgs = _load_state_dict_into_meta_model(
|
|
|
- model_to_load,
|
|
|
- state_dict,
|
|
|
- loaded_keys,
|
|
|
- start_prefix,
|
|
|
- expected_keys,
|
|
|
- dtype=dtype,
|
|
|
- is_safetensors=is_safetensors,
|
|
|
- keep_in_fp32_modules=keep_in_fp32_modules,
|
|
|
- )
|
|
|
- else:
|
|
|
- error_msgs = _load_state_dict_into_model(
|
|
|
- model_to_load,
|
|
|
- state_dict,
|
|
|
- start_prefix,
|
|
|
- convert_from_hf=convert_from_hf,
|
|
|
- )
|
|
|
+ error_msgs = _load_state_dict_into_model(
|
|
|
+ model_to_load,
|
|
|
+ state_dict,
|
|
|
+ start_prefix,
|
|
|
+ convert_from_hf=convert_from_hf,
|
|
|
+ )
|
|
|
else:
|
|
|
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
|
|
|
|
|
@@ -1600,8 +1520,6 @@ class PretrainedModel(
|
|
|
if k[-1] in tp_actions:
|
|
|
fuse_actions.pop(k[-1], None)
|
|
|
|
|
|
- if config.quantization_config.is_weight_quantize():
|
|
|
- filter_dict_keys = None
|
|
|
try:
|
|
|
transpose_weight_keys = model.get_transpose_weight_keys()
|
|
|
except NotImplementedError:
|
|
|
@@ -1630,14 +1548,6 @@ class PretrainedModel(
|
|
|
missing_keys = list(set(missing_keys) - set(new_keys))
|
|
|
unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
|
|
|
|
|
|
- if config.quantization_config.is_weight_quantize():
|
|
|
- state_dict = convert_to_quantize_state_dict(
|
|
|
- state_dict,
|
|
|
- quantization_linear_list,
|
|
|
- config.quantization_config,
|
|
|
- dtype,
|
|
|
- )
|
|
|
-
|
|
|
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
|
|
# matching the weights in the model.
|
|
|
mismatched_keys += _find_mismatched_keys(
|
|
|
@@ -1664,7 +1574,7 @@ class PretrainedModel(
|
|
|
)
|
|
|
logging.info("Converted state_dict to Tensor Parallel Format")
|
|
|
|
|
|
- if low_cpu_mem_usage or config.quantization_config.is_weight_quantize():
|
|
|
+ if low_cpu_mem_usage:
|
|
|
new_error_msgs = _load_state_dict_into_meta_model(
|
|
|
model_to_load,
|
|
|
state_dict,
|