瀏覽代碼

add tile_predict and overlap_tile_predict for seg model

FlyingQianMM 5 年之前
父節點
當前提交
00b64384a5
共有 2 個文件被更改,包括 121 次插入14 次删除
  1. 106 0
      paddlex/cv/models/deeplabv3p.py
  2. 15 14
      paddlex/cv/transforms/seg_transforms.py

+ 106 - 0
paddlex/cv/models/deeplabv3p.py

@@ -548,3 +548,109 @@ class DeepLabv3p(BaseAPI):
 
         preds = DeepLabv3p._postprocess(result, im_info)
         return preds
+
+    def tile_predict(self,
+                     img_file,
+                     tile_size=[512, 512],
+                     batch_size=32,
+                     thread_num=8):
+        image = cv2.imread(img_file)
+        height, width, channel = image.shape
+        image_tile_list = list()
+        # crop the image into tile pieces
+        for h in range(0, height, tile_size[1]):
+            for w in range(0, width, tile_size[0]):
+                left = w
+                upper = h
+                right = min(w + tile_size[0], width)
+                lower = min(h + tile_size[1], height)
+                image_tile = image[upper:lower, left:right, :]
+                image_tile_list.append(image_tile)
+
+        # predict
+        label_map = np.zeros((height, width), dtype=np.uint8)
+        score_map = np.zeros(
+            (height, width, self.num_classes), dtype=np.float32)
+        num_tiles = len(image_tile_list)
+        for i in range(0, num_tiles, batch_size):
+            begin = i
+            end = min(i + batch_size, num_tiles)
+            res = self.batch_predict(
+                img_file_list=image_tile_list[begin:end],
+                thread_num=thread_num)
+            for j in range(begin, end):
+                h_id = j // (width // tile_size[0] + 1)
+                w_id = j % (width // tile_size[0] + 1)
+                left = w_id * tile_size[0]
+                upper = h_id * tile_size[1]
+                right = min((w_id + 1) * tile_size[0], width)
+                lower = min((h_id + 1) * tile_size[1], height)
+                label_map[upper:lower, left:right] = res[j - begin][
+                    "label_map"]
+                score_map[upper:lower, left:right, :] = res[j - begin][
+                    "score_map"]
+        result = {"label_map": label_map, "score_map": score_map}
+        return result
+
+    def overlap_tile_predict(self,
+                             img_file,
+                             tile_size=[512, 512],
+                             pad_size=[64, 64],
+                             batch_size=32,
+                             thread_num=8):
+        image = cv2.imread(img_file)
+        height, width, channel = image.shape
+        image_tile_list = list()
+
+        # Padding along the left and right sides
+        left_pad = cv2.flip(image[0:height, 0:pad_size[0], :], 1)
+        right_pad = cv2.flip(image[0:height, -pad_size[0]:width, :], 1)
+        padding_image = cv2.hconcat([left_pad, image])
+        padding_image = cv2.hconcat([padding_image, right_pad])
+
+        # Padding along the upper and lower sides
+        padding_height, padding_width, _ = padding_image.shape
+        upper_pad = cv2.flip(padding_image[0:pad_size[1], 0:padding_width, :],
+                             0)
+        lower_pad = cv2.flip(
+            padding_image[-pad_size[1]:padding_height, 0:padding_width, :], 0)
+        padding_image = cv2.vconcat([upper_pad, padding_image])
+        padding_image = cv2.vconcat([padding_image, lower_pad])
+
+        padding_height, padding_width, _ = padding_image.shape
+        # crop the padding image into tile pieces
+        for h in range(0, padding_height, tile_size[1]):
+            for w in range(0, padding_width, tile_size[0]):
+                left = w
+                upper = h
+                right = min(w + tile_size[0] + pad_size[0] * 2, padding_width)
+                lower = min(h + tile_size[1] + pad_size[1] * 2, padding_height)
+                image_tile = padding_image[upper:lower, left:right, :]
+                image_tile_list.append(image_tile)
+
+        # predict
+        label_map = np.zeros((height, width), dtype=np.uint8)
+        score_map = np.zeros(
+            (height, width, self.num_classes), dtype=np.float32)
+        num_tiles = len(image_tile_list)
+        for i in range(0, num_tiles, batch_size):
+            begin = i
+            end = min(i + batch_size, num_tiles)
+            res = self.batch_predict(
+                img_file_list=image_tile_list[begin:end],
+                thread_num=thread_num)
+            for j in range(begin, end):
+                h_id = j // (width // tile_size[0] + 1)
+                w_id = j % (width // tile_size[0] + 1)
+                left = w_id * tile_size[0]
+                upper = h_id * tile_size[1]
+                right = min((w_id + 1) * tile_size[0], width)
+                lower = min((h_id + 1) * tile_size[1], height)
+                tile_label_map = res[j - begin]["label_map"]
+                tile_score_map = res[j - begin]["score_map"]
+                label_map[upper:lower, left:right] = \
+                    tile_label_map[pad_size[1]:-pad_size[1], pad_size[0]:-pad_size[0]]
+                score_map[upper:lower, left:right, :] = \
+                    tile_score_map[pad_size[1]:-pad_size[1], pad_size[0]:-pad_size[0], :]
+        result = {"label_map": label_map, "score_map": score_map}
+        return result

+ 15 - 14
paddlex/cv/transforms/seg_transforms.py

@@ -661,27 +661,28 @@ class Padding(SegTransform):
         pad_height = target_height - im_height
         pad_width = target_width - im_width
         if pad_height < 0 or pad_width < 0:
-            raise ValueError(
+            logging.warning(
                 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
                 .format(im_width, im_height, target_width, target_height))
-        else:
-            im = cv2.copyMakeBorder(
-                im,
+        pad_height = max(pad_height, 0)
+        pad_width = max(pad_width, 0)
+        im = cv2.copyMakeBorder(
+            im,
+            0,
+            pad_height,
+            0,
+            pad_width,
+            cv2.BORDER_CONSTANT,
+            value=self.im_padding_value)
+        if label is not None:
+            label = cv2.copyMakeBorder(
+                label,
                 0,
                 pad_height,
                 0,
                 pad_width,
                 cv2.BORDER_CONSTANT,
-                value=self.im_padding_value)
-            if label is not None:
-                label = cv2.copyMakeBorder(
-                    label,
-                    0,
-                    pad_height,
-                    0,
-                    pad_width,
-                    cv2.BORDER_CONSTANT,
-                    value=self.label_padding_value)
+                value=self.label_padding_value)
         if label is None:
             return (im, im_info)
         else: