Forráskód Böngészése

Merge remote-tracking branch 'paddle/develop' into cpp_trt

Channingss 5 éve
szülő
commit
ec3f106598
2 módosított fájl, 71 hozzáadás és 28 törlés
  1. 7 7
      paddlex/cv/models/deeplabv3p.py
  2. 64 21
      paddlex/cv/transforms/seg_transforms.py

+ 7 - 7
paddlex/cv/models/deeplabv3p.py

@@ -397,13 +397,13 @@ class DeepLabv3p(BaseAPI):
             fetch_list=list(self.test_outputs.values()))
         pred = result[0]
         pred = np.squeeze(pred).astype('uint8')
-        keys = list(im_info.keys())
-        for k in keys[::-1]:
-            if k == 'shape_before_resize':
-                h, w = im_info[k][0], im_info[k][1]
+        for info in im_info[::-1]:
+            if info[0] == 'resize':
+                w, h = info[1][1], info[1][0]
                 pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
-            elif k == 'shape_before_padding':
-                h, w = im_info[k][0], im_info[k][1]
+            elif info[0] == 'padding':
+                w, h = info[1][1], info[1][0]
                 pred = pred[0:h, 0:w]
-
+            else:
+                raise Exception("Unexpected info '{}' in im_info".format(info[0]))
         return {'label_map': pred, 'score_map': result[1]}

+ 64 - 21
paddlex/cv/transforms/seg_transforms.py

@@ -48,9 +48,10 @@ class Compose:
         """
         Args:
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息,dict中的字段如下:
-                - shape_before_resize (tuple): 图像resize之前的大小(h, w)。
-                - shape_before_padding (tuple): 图像padding之前的大小(h, w)。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。
 
         Returns:
@@ -58,7 +59,7 @@ class Compose:
         """
 
         if im_info is None:
-            im_info = dict()
+            im_info = list()
         try:
             im = cv2.imread(im).astype('float32')
         except:
@@ -93,7 +94,10 @@ class RandomHorizontalFlip:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns:
@@ -125,7 +129,10 @@ class RandomVerticalFlip:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns:
@@ -191,7 +198,10 @@ class Resize:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns:
@@ -208,7 +218,7 @@ class Resize:
         """
         if im_info is None:
             im_info = OrderedDict()
-        im_info['shape_before_resize'] = im.shape[:2]
+        im_info.append(('resize', im.shape[:2]))
 
         if not isinstance(im, np.ndarray):
             raise TypeError("ResizeImage: image type is not np.ndarray.")
@@ -264,7 +274,10 @@ class ResizeByLong:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns:
@@ -277,7 +290,7 @@ class ResizeByLong:
         if im_info is None:
             im_info = OrderedDict()
 
-        im_info['shape_before_resize'] = im.shape[:2]
+        im_info.append(('resize', im.shape[:2]))
         im = resize_long(im, self.long_size)
         if label is not None:
             label = resize_long(label, self.long_size, cv2.INTER_NEAREST)
@@ -385,7 +398,10 @@ class ResizeRangeScaling:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns:
@@ -438,7 +454,10 @@ class ResizeStepScaling:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns:
@@ -506,7 +525,10 @@ class Normalize:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
          Returns:
@@ -560,7 +582,10 @@ class Padding:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns:
@@ -575,7 +600,7 @@ class Padding:
         """
         if im_info is None:
             im_info = OrderedDict()
-        im_info['shape_before_padding'] = im.shape[:2]
+        im_info.append(('padding', im.shape[:2]))
 
         im_height, im_width = im.shape[0], im.shape[1]
         if isinstance(self.target_size, int):
@@ -648,7 +673,10 @@ class RandomPaddingCrop:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
          Returns:
@@ -724,7 +752,10 @@ class RandomBlur:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns:
@@ -777,7 +808,10 @@ class RandomRotate:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns:
@@ -839,7 +873,10 @@ class RandomScaleAspect:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns:
@@ -921,7 +958,10 @@ class RandomDistort:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns:
@@ -996,7 +1036,10 @@ class ArrangeSegmenter:
         """
         Args:
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
 
         Returns: