Przeglądaj źródła

cast to uint8 during seg visualization

will-jl944 4 lat temu
rodzic
commit
ea5fb1c3eb

+ 4 - 3
paddlex/cv/models/segmenter.py

@@ -114,9 +114,10 @@ class BaseSegmenter(BaseModel):
                 score_map_list = []
                 for logit in logit_list:
                     logit = paddle.transpose(logit, perm=[0, 2, 3, 1])  # NHWC
-                    label_map_list.append((paddle.argmax(
-                        logit, axis=-1, keepdim=False, dtype='int32')).squeeze(
-                        ).numpy().astype('int32'))
+                    label_map_list.append(
+                        paddle.argmax(
+                            logit, axis=-1, keepdim=False, dtype='int32')
+                        .squeeze().numpy())
                     score_map_list.append(
                         F.softmax(
                             logit, axis=-1).squeeze().numpy().astype(

+ 1 - 1
paddlex/cv/models/utils/visualize.py

@@ -61,7 +61,7 @@ def visualize_segmentation(image,
         save_dir: the directory for saving visual image
         color: the list of a BGR-mode color for each label.
     """
-    label_map = result['label_map']
+    label_map = result['label_map'].astype("uint8")
     color_map = get_color_map_list(256)
     if color is not None:
         for i in range(len(color) // 3):