瀏覽代碼

fix some bug

Channingss 5 年之前
父節點
當前提交
5433fc46de
共有 3 個文件被更改,包括 26 次插入19 次删除
  1. 0 1
      paddlex/cv/models/load_model.py
  2. 13 10
      paddlex/cv/transforms/det_transforms.py
  3. 13 8
      paddlex/cv/transforms/seg_transforms.py

+ 0 - 1
paddlex/cv/models/load_model.py

@@ -125,7 +125,6 @@ def fix_input_shape(info, fixed_input_shape=None):
                 logging.warning(
                     "fixed_input_shape must == input shape when trainning")
         else:
-            print("*" * 10)
             resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
             resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
             padding['Padding']['target_size'] = list(fixed_input_shape)

+ 13 - 10
paddlex/cv/transforms/det_transforms.py

@@ -208,10 +208,10 @@ class Padding:
 
     Args:
         coarsest_stride (int): 填充后的图像长、宽为该参数的倍数,默认为1。
-        target_size (int|list): 填充后的图像长、宽,默认为1
+        target_size (int|list): 填充后的图像长、宽,默认为None
     """
 
-    def __init__(self, coarsest_stride=1, target_size=1):
+    def __init__(self, coarsest_stride=1, target_size=None):
         self.coarsest_stride = coarsest_stride
         self.target_size = target_size
 
@@ -230,15 +230,15 @@ class Padding:
         Raises:
             TypeError: 形参数据类型不满足需求。
             ValueError: 数据长度不匹配。
+            ValueError: coarsest_stride,target_size需有且只有一个被指定,coarset_stride优先级更高。
             ValueError: target_size小于原图的大小。
         """
 
-        if self.coarsest_stride == 1:
-            if isinstance(self.target_size, int) and self.target_size == 1:
-                if label_info is None:
-                    return (im, im_info)
-                else:
-                    return (im, im_info, label_info)
+        if self.coarsest_stride == 1 and self.target_size is None:
+            if label_info is None:
+                return (im, im_info)
+            else:
+                return (im, im_info, label_info)
         if im_info is None:
             im_info = dict()
         if not isinstance(im, np.ndarray):
@@ -251,13 +251,16 @@ class Padding:
                 np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride)
             padding_im_w = int(
                 np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride)
-
-        if isinstance(self.target_size, int) and self.target_size != 1:
+        elif isinstance(self.target_size, int):
             padding_im_h = self.target_size
             padding_im_w = self.target_size
         elif isinstance(self.target_size, list):
             padding_im_w = self.target_size[0]
             padding_im_h = self.target_size[1]
+        else:
+            raise ValueError(
+                "coarsest_stridei(>1) or target_size(list|int) need setting in Padding transform"
+            )
         pad_height = padding_im_h - im_h
         pad_width = padding_im_w - im_w
         if pad_height < 0 or pad_width < 0:

+ 13 - 8
paddlex/cv/transforms/seg_transforms.py

@@ -287,6 +287,7 @@ class ResizeByLong:
         else:
             return (im, im_info, label)
 
+
 class ResizeByShort:
     """根据图像的短边调整图像大小(resize)。
 
@@ -315,12 +316,12 @@ class ResizeByShort:
         if not (isinstance(self.max_size, int)):
             raise TypeError("max_size: input type is invalid.")
 
-    def __call__(self, im, im_info=None, label_info=None):
+    def __call__(self, im, im_info=None, label=None):
         """
         Args:
             im (numnp.ndarraypy): 图像np.ndarray数据。
             im_info (dict, 可选): 存储与图像相关的信息。
-            label_info (dict, 可选): 存储与标注框相关的信息
+            label (np.ndarray): 标注图像np.ndarray数据
 
         Returns:
             tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
@@ -335,11 +336,12 @@ class ResizeByShort:
             ValueError: 数据长度不匹配。
         """
         if im_info is None:
-            im_info = dict()
+            im_info = OrderedDict()
         if not isinstance(im, np.ndarray):
             raise TypeError("ResizeByShort: image type is not numpy.")
         if len(im.shape) != 3:
             raise ValueError('ResizeByShort: image is not 3-dimensional.')
+        im_info['shape_before_resize'] = im.shape[:2]
         im_short_size = min(im.shape[0], im.shape[1])
         im_long_size = max(im.shape[0], im.shape[1])
         scale = float(self.short_size) / im_short_size
@@ -348,15 +350,18 @@ class ResizeByShort:
             scale = float(self.max_size) / float(im_long_size)
         resized_width = int(round(im.shape[1] * scale))
         resized_height = int(round(im.shape[0] * scale))
-        im_resize_info = [resized_height, resized_width, scale]
         im = cv2.resize(
             im, (resized_width, resized_height),
-            interpolation=cv2.INTER_LINEAR)
-        im_info['im_resize_info'] = np.array(im_resize_info).astype(np.float32)
-        if label_info is None:
+            interpolation=cv2.INTER_NEAREST)
+        if label is not None:
+            im = cv2.resize(
+                label, (resized_width, resized_height),
+                interpolation=cv2.INTER_NEAREST)
+        if label is None:
             return (im, im_info)
         else:
-            return (im, im_info, label_info)
+            return (im, im_info, label)
+
 
 class ResizeRangeScaling:
     """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。