ソースを参照

PaddleOCR-VL supports CPU and CUDA 11 (#4666)

Lin Manhui 3 週間 前
コミット
406d84dd66

+ 10 - 6
paddlex/inference/models/doc_vlm/modeling/paddleocr_vl/_config.py

@@ -26,6 +26,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from ......utils.device import parse_device
+from ......utils.env import get_paddle_cuda_version
 from ....common.vlm.transformers import PretrainedConfig
 
 
@@ -120,6 +122,8 @@ class PaddleOCRVLConfig(PretrainedConfig):
         vision_config=None,
         **kwargs,
     ):
+        import paddle
+
         # Set default for tied embeddings if not specified.
         super().__init__(
             pad_token_id=pad_token_id,
@@ -165,13 +169,13 @@ class PaddleOCRVLConfig(PretrainedConfig):
         super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
 
         # Currently, these configuration items are hard-coded
-        from ......utils.env import get_paddle_cuda_version
 
-        cuda_version = get_paddle_cuda_version()
-        if cuda_version and cuda_version[0] > 11:
-            self.fuse_rms_norm = True
-        else:
-            self.fuse_rms_norm = False
+        self.fuse_rms_norm = False
+        device_type, _ = parse_device(paddle.device.get_device())
+        if device_type == "gpu":
+            cuda_version = get_paddle_cuda_version()
+            if cuda_version and cuda_version[0] > 11:
+                self.fuse_rms_norm = True
         self.use_sparse_flash_attn = True
         self.use_var_len_flash_attn = False
         self.scale_qk_coeff = 1.0