Selaa lähdekoodia

PaddleOCR-VL supports FP32 (#4658)

Lin Manhui 3 viikkoa sitten
vanhempi
commit
803bdd105e

+ 7 - 3
paddlex/inference/models/common/vlm/transformers/model_utils.py

@@ -275,7 +275,7 @@ def _load_state_dict_into_model(
     model_to_load, state_dict, start_prefix, convert_from_hf
 ):
     # torch will cast dtype in load_state_dict, but paddle strictly check dtype
-    _convert_state_dict_dtype_and_shape(state_dict, model_to_load)
+    _convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf)
 
     error_msgs = []
 
@@ -305,12 +305,16 @@ def _load_state_dict_into_model(
     return error_msgs
 
 
-def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
+def _convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf):
     # convert the dtype of state dict
     def is_0d_or_1d(tensor):
         return len(tensor.shape) == 0 or list(tensor.shape) == [1]
 
-    for key, value in model_to_load.state_dict().items():
+    if convert_from_hf:
+        model_state_dict = model_to_load.get_hf_state_dict()
+    else:
+        model_state_dict = model_to_load.state_dict()
+    for key, value in model_state_dict.items():
         if key in list(state_dict.keys()):
             if isinstance(state_dict[key], np.ndarray):
                 raise ValueError(

+ 2 - 9
paddlex/inference/models/doc_vlm/predictor.py

@@ -28,8 +28,8 @@ from ....modules.doc_vlm.model_list import MODELS
 from ....utils import logging
 from ....utils.deps import require_genai_client_plugin
 from ....utils.device import TemporaryDeviceChanger
-from ....utils.env import get_device_type
 from ...common.batch_sampler import DocVLMBatchSampler
+from ...utils.misc import is_bfloat16_available
 from ..base import BasePredictor
 from .result import DocVLMResult
 
@@ -53,15 +53,8 @@ class DocVLMPredictor(BasePredictor):
         super().__init__(*args, **kwargs)
 
         if self._use_local_model:
-            import paddle
-
             self.device = kwargs.get("device", None)
-            self.dtype = (
-                "bfloat16"
-                if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())
-                and (self.device is None or "cpu" not in self.device)
-                else "float32"
-            )
+            self.dtype = "bfloat16" if is_bfloat16_available(self.device) else "float32"
 
             self.infer, self.processor = self._build(**kwargs)
 

+ 14 - 0
paddlex/inference/utils/misc.py

@@ -12,9 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from ...utils.device import get_default_device, parse_device
+from ...utils.env import get_device_type
+
 
 def is_mkldnn_available():
     # XXX: Not sure if this is the best way to check if MKL-DNN is available
     from paddle.inference import Config
 
     return hasattr(Config, "set_mkldnn_cache_capacity")
+
+
+def is_bfloat16_available(device):
+    import paddle.amp
+
+    if device is None:
+        device = get_default_device()
+    device_type, _ = parse_device(device)
+    return (
+        "npu" in get_device_type() or paddle.amp.is_bfloat16_supported()
+    ) and device_type in ("gpu", "npu", "xpu", "mlu", "dcu")