Răsfoiți Sursa

!12 Bump version to 3.3.5
Merge pull request !12 from zhch158/release/3.3

zhch158 3 săptămâni în urmă
părinte
comite
0c3a4ecf56

+ 1 - 1
paddlex/.version

@@ -1 +1 @@
-3.3.4
+3.3.5

+ 0 - 20
paddlex/inference/models/common/vlm/transformers/configuration_utils.py

@@ -823,11 +823,6 @@ class PretrainedConfig:
                 )
         to_remove = []
         for key, value in kwargs.items():
-            if key == "quantization_config" and isinstance(value, Dict):
-                for q_key in value:
-                    setattr(config.quantization_config, q_key, value[q_key])
-                to_remove.append(key)
-                continue
             if hasattr(config, key):
                 setattr(config, key, value)
                 if key != "dtype":
@@ -889,11 +884,6 @@ class PretrainedConfig:
 
         # only serialize values that differ from the default config
         for key, value in config_dict.items():
-            if key == "quantization_config":
-                quantization_diff_dict = self.quantization_config.to_diff_dict()
-                if len(quantization_diff_dict) > 0:
-                    serializable_config_dict[key] = quantization_diff_dict
-                continue
             if (
                 key not in default_config_dict
                 or key == "paddlenlp_version"
@@ -942,16 +932,6 @@ class PretrainedConfig:
                 if key in self._unsavable_keys:
                     output.pop(key)
 
-        if hasattr(self, "quantization_config"):
-            output["quantization_config"] = (
-                self.quantization_config.to_dict()
-                if not isinstance(self.quantization_config, dict)
-                else self.quantization_config
-            )
-
-            # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
-            _ = output.pop("_pre_quantization_dtype", None)
-
         return output
 
     def update(self, config_dict: Dict[str, Any]):

+ 14 - 104
paddlex/inference/models/common/vlm/transformers/model_utils.py

@@ -258,8 +258,6 @@ def load_state_dict(
                 tensor_parallel_split_mapping,
                 fliter_dict_keys,
                 "expected",
-                quantization_linear_list=None,
-                quantization_config=None,
                 dtype=None,
                 return_numpy=False,
                 convert_from_hf=convert_from_hf,
@@ -277,7 +275,7 @@ def _load_state_dict_into_model(
     model_to_load, state_dict, start_prefix, convert_from_hf
 ):
     # torch will cast dtype in load_state_dict, but paddle strictly check dtype
-    _convert_state_dict_dtype_and_shape(state_dict, model_to_load)
+    _convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf)
 
     error_msgs = []
 
@@ -307,12 +305,16 @@ def _load_state_dict_into_model(
     return error_msgs
 
 
-def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
+def _convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf):
     # convert the dtype of state dict
     def is_0d_or_1d(tensor):
         return len(tensor.shape) == 0 or list(tensor.shape) == [1]
 
-    for key, value in model_to_load.state_dict().items():
+    if convert_from_hf:
+        model_state_dict = model_to_load.get_hf_state_dict()
+    else:
+        model_state_dict = model_to_load.state_dict()
+    for key, value in model_state_dict.items():
         if key in list(state_dict.keys()):
             if isinstance(state_dict[key], np.ndarray):
                 raise ValueError(
@@ -631,34 +633,6 @@ class PretrainedModel(
         config.weightonly_group_size = predictor_args.weightonly_group_size
         config.weight_block_size = predictor_args.weight_block_size
         config.moe_quant_type = predictor_args.moe_quant_type
-        if config.quantization_config.quant_method is not None:
-            predictor_args.weight_block_size = (
-                config.quantization_config.weight_block_size
-            )
-            config.weight_block_size = predictor_args.weight_block_size
-
-        if config.quantization_config.quant_type is not None:
-            if predictor_args.mode == "dynamic":
-                predictor_args.quant_type = config.quantization_config.quant_type
-                config.quant_type = config.quantization_config.quant_type
-            if "c8" in config.quant_type:
-                predictor_args.cachekv_int8_type = "static"
-                if predictor_args.mode == "dynamic":
-                    config.cachekv_int8_type = "static"
-
-            if predictor_args.mode == "dynamic":
-                ptq_multicards_num = 0
-                if os.path.exists(config.model_name_or_path):
-                    prefix = "act_scales_"
-                    for filename in os.listdir(config.model_name_or_path):
-                        if filename.startswith(prefix):
-                            ptq_multicards_num += 1
-
-                logging.info(
-                    f"PTQ from {ptq_multicards_num} cards, so we will not split"
-                )
-                if ptq_multicards_num > 1:
-                    config.single_card_ptq = False
 
         if predictor_args.block_attn:
             config.block_size = predictor_args.block_size
@@ -1323,45 +1297,6 @@ class PretrainedModel(
                     ".".join([prefix, s]) for s in quantization_linear_list
                 ]
 
-        # Weight quantization if not yet quantized & update loaded_keys
-        if (
-            hasattr(config, "quantization_config")
-            and config.quantization_config.is_weight_quantize()
-        ):
-            try:
-                from ..quantization.quantization_utils import (
-                    convert_to_quantize_state_dict,
-                    update_loaded_state_dict_keys,
-                )
-            except ImportError:
-                raise ImportError(
-                    "Quantization features require `paddlepaddle >= 2.5.2`"
-                )
-            if state_dict is not None:
-                state_dict = convert_to_quantize_state_dict(
-                    state_dict,
-                    quantization_linear_list,
-                    config.quantization_config,
-                    dtype,
-                )
-                loaded_keys = [k for k in state_dict.keys()]
-            else:
-                loaded_keys = update_loaded_state_dict_keys(
-                    loaded_keys, quantization_linear_list, config.quantization_config
-                )
-            if keep_in_fp32_modules is None:
-                keep_in_fp32_modules = (
-                    ["quant_scale"]
-                    if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
-                    else None
-                )
-            else:
-                keep_in_fp32_modules = (
-                    keep_in_fp32_modules + ["quant_scale"]
-                    if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
-                    else keep_in_fp32_modules
-                )
-
         missing_keys = list(set(expected_keys) - set(loaded_keys))
         unexpected_keys = list(set(loaded_keys) - set(expected_keys))
 
@@ -1525,27 +1460,12 @@ class PretrainedModel(
                 ignore_mismatched_sizes,
             )
 
-            if (
-                hasattr(config, "quantization_config")
-                and config.quantization_config.is_weight_quantize()
-            ):
-                error_msgs = _load_state_dict_into_meta_model(
-                    model_to_load,
-                    state_dict,
-                    loaded_keys,
-                    start_prefix,
-                    expected_keys,
-                    dtype=dtype,
-                    is_safetensors=is_safetensors,
-                    keep_in_fp32_modules=keep_in_fp32_modules,
-                )
-            else:
-                error_msgs = _load_state_dict_into_model(
-                    model_to_load,
-                    state_dict,
-                    start_prefix,
-                    convert_from_hf=convert_from_hf,
-                )
+            error_msgs = _load_state_dict_into_model(
+                model_to_load,
+                state_dict,
+                start_prefix,
+                convert_from_hf=convert_from_hf,
+            )
         else:
             # Sharded checkpoint or whole but low_cpu_mem_usage==True
 
@@ -1600,8 +1520,6 @@ class PretrainedModel(
                         if k[-1] in tp_actions:
                             fuse_actions.pop(k[-1], None)
 
-                if config.quantization_config.is_weight_quantize():
-                    filter_dict_keys = None
                 try:
                     transpose_weight_keys = model.get_transpose_weight_keys()
                 except NotImplementedError:
@@ -1630,14 +1548,6 @@ class PretrainedModel(
                 missing_keys = list(set(missing_keys) - set(new_keys))
                 unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
 
-                if config.quantization_config.is_weight_quantize():
-                    state_dict = convert_to_quantize_state_dict(
-                        state_dict,
-                        quantization_linear_list,
-                        config.quantization_config,
-                        dtype,
-                    )
-
                 # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
                 # matching the weights in the model.
                 mismatched_keys += _find_mismatched_keys(
@@ -1664,7 +1574,7 @@ class PretrainedModel(
                     )
                     logging.info("Converted state_dict to Tensor Parallel Format")
 
-                if low_cpu_mem_usage or config.quantization_config.is_weight_quantize():
+                if low_cpu_mem_usage:
                     new_error_msgs = _load_state_dict_into_meta_model(
                         model_to_load,
                         state_dict,

+ 2 - 9
paddlex/inference/models/doc_vlm/predictor.py

@@ -28,8 +28,8 @@ from ....modules.doc_vlm.model_list import MODELS
 from ....utils import logging
 from ....utils.deps import require_genai_client_plugin
 from ....utils.device import TemporaryDeviceChanger
-from ....utils.env import get_device_type
 from ...common.batch_sampler import DocVLMBatchSampler
+from ...utils.misc import is_bfloat16_available
 from ..base import BasePredictor
 from .result import DocVLMResult
 
@@ -53,15 +53,8 @@ class DocVLMPredictor(BasePredictor):
         super().__init__(*args, **kwargs)
 
         if self._use_local_model:
-            import paddle
-
             self.device = kwargs.get("device", None)
-            self.dtype = (
-                "bfloat16"
-                if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())
-                and (self.device is None or "cpu" not in self.device)
-                else "float32"
-            )
+            self.dtype = "bfloat16" if is_bfloat16_available(self.device) else "float32"
 
             self.infer, self.processor = self._build(**kwargs)
 

+ 17 - 2
paddlex/inference/pipelines/paddleocr_vl/pipeline.py

@@ -249,10 +249,14 @@ class _PaddleOCRVLPipeline(BasePipeline):
                     vlm_block_ids.append((i, j))
                     drop_figures_set.update(drop_figures)
 
+        if vlm_kwargs is None:
+            vlm_kwargs = {}
+        elif vlm_kwargs.get("max_new_tokens", None) is None:
+            vlm_kwargs["max_new_tokens"] = 4096
+
         kwargs = {
             "use_cache": True,
-            "max_new_tokens": 4096,
-            **(vlm_kwargs or {}),
+            **vlm_kwargs,
         }
         vl_rec_results = list(
             self.vl_rec_model.predict(
@@ -358,6 +362,7 @@ class _PaddleOCRVLPipeline(BasePipeline):
         top_p: Optional[float] = None,
         min_pixels: Optional[int] = None,
         max_pixels: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
         **kwargs,
     ) -> PaddleOCRVLResult:
         """
@@ -376,6 +381,15 @@ class _PaddleOCRVLPipeline(BasePipeline):
                 If it's a tuple of two numbers, then they are used separately for width and height respectively.
                 If it's None, then no unclipping will be performed.
             layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
+            use_queues (Optional[bool], optional): Whether to use queues. Defaults to None.
+            prompt_label (Optional[Union[str, None]], optional): The label of the prompt in ['ocr', 'formula', 'table', 'chart']. Defaults to None.
+            format_block_content (Optional[bool]): Whether to format the block content. Default is None.
+            repetition_penalty (Optional[float]): The repetition penalty parameter used for VL model sampling. Default is None.
+            temperature (Optional[float]): Temperature parameter used for VL model sampling. Default is None.
+            top_p (Optional[float]): Top-p parameter used for VL model sampling. Default is None.
+            min_pixels (Optional[int]): The minimum number of pixels allowed when the VL model preprocesses images. Default is None.
+            max_pixels (Optional[int]): The maximum number of pixels allowed when the VL model preprocesses images. Default is None.
+            max_new_tokens (Optional[int]): The maximum number of new tokens. Default is None.
             **kwargs (Any): Additional settings to extend functionality.
 
         Returns:
@@ -499,6 +513,7 @@ class _PaddleOCRVLPipeline(BasePipeline):
                         "top_p": top_p,
                         "min_pixels": min_pixels,
                         "max_pixels": max_pixels,
+                        "max_new_tokens": max_new_tokens,
                     },
                 )
             )

+ 3 - 1
paddlex/inference/utils/io/readers.py

@@ -267,7 +267,9 @@ class OpenCVImageReaderBackend(_ImageReaderBackend):
 
     def read_file(self, in_path):
         """read image file from path by OpenCV"""
-        return cv2.imread(in_path, flags=self.flags)
+        with open(in_path, "rb") as f:
+            img_array = np.frombuffer(f.read(), np.uint8)
+        return cv2.imdecode(img_array, flags=self.flags)
 
 
 class PILImageReaderBackend(_ImageReaderBackend):

+ 14 - 0
paddlex/inference/utils/misc.py

@@ -12,9 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from ...utils.device import get_default_device, parse_device
+from ...utils.env import get_device_type
+
 
 def is_mkldnn_available():
     # XXX: Not sure if this is the best way to check if MKL-DNN is available
     from paddle.inference import Config
 
     return hasattr(Config, "set_mkldnn_cache_capacity")
+
+
+def is_bfloat16_available(device):
+    import paddle.amp
+
+    if device is None:
+        device = get_default_device()
+    device_type, _ = parse_device(device)
+    return (
+        "npu" in get_device_type() or paddle.amp.is_bfloat16_supported()
+    ) and device_type in ("gpu", "npu", "xpu", "mlu", "dcu")

+ 10 - 5
paddlex/inference/utils/official_models.py

@@ -432,6 +432,9 @@ class _BaseModelHoster(ABC):
                 f"Using official model ({model_name}), the model files will be automatically downloaded and saved in `{model_dir}`."
             )
             self._download(model_name, model_dir)
+            logging.debug(
+                f"`{model_name}` model files has been download from model source: `{self.alias}`!"
+            )
 
         if model_name == "PaddleOCR-VL":
             vl_model_dir = model_dir / "PaddleOCR-VL-0.9B"
@@ -531,7 +534,12 @@ class _AIStudioModelHoster(_BaseModelHoster):
 
     def _download(self, model_name, save_dir):
         def _clone(local_dir):
-            aistudio_download(repo_id=f"PaddleX/{model_name}", local_dir=local_dir)
+            if model_name == "PaddleOCR-VL":
+                aistudio_download(
+                    repo_id=f"PaddlePaddle/{model_name}", local_dir=local_dir
+                )
+            else:
+                aistudio_download(repo_id=f"PaddleX/{model_name}", local_dir=local_dir)
 
         if os.path.exists(save_dir):
             _clone(save_dir)
@@ -586,9 +594,6 @@ Otherwise, only local models can be used."""
             if model_name in hoster.model_list:
                 try:
                     model_path = hoster.get_model(model_name)
-                    logging.debug(
-                        f"`{model_name}` model files has been download from model source: `{hoster.alias}`!"
-                    )
                     return model_path
 
                 except Exception as e:
@@ -597,7 +602,7 @@ Otherwise, only local models can be used."""
                             f"Encounter exception when download model from {hoster.alias}. No model source is available! Please check network or use local model files!"
                         )
                     logging.warning(
-                        f"Encountering exception when download model from {hoster.alias}: \n{e}, will try to download from other model sources: `hosters[idx + 1].alias`."
+                        f"Encountering exception when download model from {hoster.alias}: \n{e}, will try to download from other model sources: `{hosters[idx + 1].alias}`."
                     )
                     return self._download_from_hoster(hosters[idx + 1 :], model_name)