Pārlūkot izejas kodu

[CustomDevice] Fix some infer errors on CustomDevice (#4143)

* Fix MaskRcnn if error

* Empty commit

* Fix GroundingDINO-T: TypeError: 'int' object is not iterable

* Add slice to CUSTOM_DEVICE_BLACK_LIST for NPU

* Fix Cascade-MaskRCNN-ResNet50-FPN: disable roi_align on npu

* restore CUSTOM_DEVICE_BLACK_LIST change
Zhou Xin 5 mēneši atpakaļ
vecāks
revīzija
075e77e239

+ 3 - 0
paddlex/inference/models/open_vocabulary_detection/processors/groundingdino_processors.py

@@ -199,6 +199,9 @@ class GroundingDINOPostProcessor(object):
         tokenized = self.tokenizer(prompt)
         if posmap.dim() == 1:
             non_zero_idx = posmap.nonzero(as_tuple=True)[0].squeeze(-1).tolist()
+            non_zero_idx = (
+                [non_zero_idx] if not isinstance(non_zero_idx, list) else non_zero_idx
+            )
             token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
             return self.tokenizer.decode(token_ids)
         else:

+ 4 - 1
paddlex/utils/device.py

@@ -132,7 +132,10 @@ def set_env_for_device_type(device_type):
         }
         _set(envs)
     if device_type.lower() == "mlu":
-        envs = {"FLAGS_use_stride_kernel": "0"}
+        envs = {
+            "FLAGS_use_stride_kernel": "0",
+            "FLAGS_use_stream_safe_cuda_allocator": "0",
+        }
         _set(envs)
     if device_type.lower() == "gcu":
         envs = {"FLAGS_use_stride_kernel": "0"}