Przeglądaj źródła

Remove broken quantization_config logic (#4654)

Lin Manhui 3 tygodni temu
rodzic
commit
f6bb816a22

+ 0 - 20
paddlex/inference/models/common/vlm/transformers/configuration_utils.py

@@ -823,11 +823,6 @@ class PretrainedConfig:
                 )
         to_remove = []
         for key, value in kwargs.items():
-            if key == "quantization_config" and isinstance(value, Dict):
-                for q_key in value:
-                    setattr(config.quantization_config, q_key, value[q_key])
-                to_remove.append(key)
-                continue
             if hasattr(config, key):
                 setattr(config, key, value)
                 if key != "dtype":
@@ -889,11 +884,6 @@ class PretrainedConfig:
 
         # only serialize values that differ from the default config
         for key, value in config_dict.items():
-            if key == "quantization_config":
-                quantization_diff_dict = self.quantization_config.to_diff_dict()
-                if len(quantization_diff_dict) > 0:
-                    serializable_config_dict[key] = quantization_diff_dict
-                continue
             if (
                 key not in default_config_dict
                 or key == "paddlenlp_version"
@@ -942,16 +932,6 @@ class PretrainedConfig:
                 if key in self._unsavable_keys:
                     output.pop(key)
 
-        if hasattr(self, "quantization_config"):
-            output["quantization_config"] = (
-                self.quantization_config.to_dict()
-                if not isinstance(self.quantization_config, dict)
-                else self.quantization_config
-            )
-
-            # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
-            _ = output.pop("_pre_quantization_dtype", None)
-
         return output
 
     def update(self, config_dict: Dict[str, Any]):

+ 7 - 101
paddlex/inference/models/common/vlm/transformers/model_utils.py

@@ -258,8 +258,6 @@ def load_state_dict(
                 tensor_parallel_split_mapping,
                 fliter_dict_keys,
                 "expected",
-                quantization_linear_list=None,
-                quantization_config=None,
                 dtype=None,
                 return_numpy=False,
                 convert_from_hf=convert_from_hf,
@@ -631,34 +629,6 @@ class PretrainedModel(
         config.weightonly_group_size = predictor_args.weightonly_group_size
         config.weight_block_size = predictor_args.weight_block_size
         config.moe_quant_type = predictor_args.moe_quant_type
-        if config.quantization_config.quant_method is not None:
-            predictor_args.weight_block_size = (
-                config.quantization_config.weight_block_size
-            )
-            config.weight_block_size = predictor_args.weight_block_size
-
-        if config.quantization_config.quant_type is not None:
-            if predictor_args.mode == "dynamic":
-                predictor_args.quant_type = config.quantization_config.quant_type
-                config.quant_type = config.quantization_config.quant_type
-            if "c8" in config.quant_type:
-                predictor_args.cachekv_int8_type = "static"
-                if predictor_args.mode == "dynamic":
-                    config.cachekv_int8_type = "static"
-
-            if predictor_args.mode == "dynamic":
-                ptq_multicards_num = 0
-                if os.path.exists(config.model_name_or_path):
-                    prefix = "act_scales_"
-                    for filename in os.listdir(config.model_name_or_path):
-                        if filename.startswith(prefix):
-                            ptq_multicards_num += 1
-
-                logging.info(
-                    f"PTQ from {ptq_multicards_num} cards, so we will not split"
-                )
-                if ptq_multicards_num > 1:
-                    config.single_card_ptq = False
 
         if predictor_args.block_attn:
             config.block_size = predictor_args.block_size
@@ -1323,45 +1293,6 @@ class PretrainedModel(
                     ".".join([prefix, s]) for s in quantization_linear_list
                 ]
 
-        # Weight quantization if not yet quantized & update loaded_keys
-        if (
-            hasattr(config, "quantization_config")
-            and config.quantization_config.is_weight_quantize()
-        ):
-            try:
-                from ..quantization.quantization_utils import (
-                    convert_to_quantize_state_dict,
-                    update_loaded_state_dict_keys,
-                )
-            except ImportError:
-                raise ImportError(
-                    "Quantization features require `paddlepaddle >= 2.5.2`"
-                )
-            if state_dict is not None:
-                state_dict = convert_to_quantize_state_dict(
-                    state_dict,
-                    quantization_linear_list,
-                    config.quantization_config,
-                    dtype,
-                )
-                loaded_keys = [k for k in state_dict.keys()]
-            else:
-                loaded_keys = update_loaded_state_dict_keys(
-                    loaded_keys, quantization_linear_list, config.quantization_config
-                )
-            if keep_in_fp32_modules is None:
-                keep_in_fp32_modules = (
-                    ["quant_scale"]
-                    if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
-                    else None
-                )
-            else:
-                keep_in_fp32_modules = (
-                    keep_in_fp32_modules + ["quant_scale"]
-                    if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
-                    else keep_in_fp32_modules
-                )
-
         missing_keys = list(set(expected_keys) - set(loaded_keys))
         unexpected_keys = list(set(loaded_keys) - set(expected_keys))
 
@@ -1525,27 +1456,12 @@ class PretrainedModel(
                 ignore_mismatched_sizes,
             )
 
-            if (
-                hasattr(config, "quantization_config")
-                and config.quantization_config.is_weight_quantize()
-            ):
-                error_msgs = _load_state_dict_into_meta_model(
-                    model_to_load,
-                    state_dict,
-                    loaded_keys,
-                    start_prefix,
-                    expected_keys,
-                    dtype=dtype,
-                    is_safetensors=is_safetensors,
-                    keep_in_fp32_modules=keep_in_fp32_modules,
-                )
-            else:
-                error_msgs = _load_state_dict_into_model(
-                    model_to_load,
-                    state_dict,
-                    start_prefix,
-                    convert_from_hf=convert_from_hf,
-                )
+            error_msgs = _load_state_dict_into_model(
+                model_to_load,
+                state_dict,
+                start_prefix,
+                convert_from_hf=convert_from_hf,
+            )
         else:
             # Sharded checkpoint or whole but low_cpu_mem_usage==True
 
@@ -1600,8 +1516,6 @@ class PretrainedModel(
                         if k[-1] in tp_actions:
                             fuse_actions.pop(k[-1], None)
 
-                if config.quantization_config.is_weight_quantize():
-                    filter_dict_keys = None
                 try:
                     transpose_weight_keys = model.get_transpose_weight_keys()
                 except NotImplementedError:
@@ -1630,14 +1544,6 @@ class PretrainedModel(
                 missing_keys = list(set(missing_keys) - set(new_keys))
                 unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
 
-                if config.quantization_config.is_weight_quantize():
-                    state_dict = convert_to_quantize_state_dict(
-                        state_dict,
-                        quantization_linear_list,
-                        config.quantization_config,
-                        dtype,
-                    )
-
                 # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
                 # matching the weights in the model.
                 mismatched_keys += _find_mismatched_keys(
@@ -1664,7 +1570,7 @@ class PretrainedModel(
                     )
                     logging.info("Converted state_dict to Tensor Parallel Format")
 
-                if low_cpu_mem_usage or config.quantization_config.is_weight_quantize():
+                if low_cpu_mem_usage:
                     new_error_msgs = _load_state_dict_into_meta_model(
                         model_to_load,
                         state_dict,