|
|
@@ -24,6 +24,7 @@ import paddlex.utils.logging as logging
|
|
|
import paddlex
|
|
|
from paddlex.cv.transforms import arrange_transforms
|
|
|
from paddlex.cv.datasets import generate_minibatch
|
|
|
+from paddlex.cv.transforms.seg_transforms import Compose
|
|
|
from collections import OrderedDict
|
|
|
from .base import BaseAPI
|
|
|
from .utils.seg_eval import ConfusionMatrix
|
|
|
@@ -448,7 +449,11 @@ class DeepLabv3p(BaseAPI):
|
|
|
return metrics
|
|
|
|
|
|
@staticmethod
|
|
|
- def _preprocess(images, transforms, model_type, class_name, thread_pool=None):
|
|
|
+ def _preprocess(images,
|
|
|
+ transforms,
|
|
|
+ model_type,
|
|
|
+ class_name,
|
|
|
+ thread_pool=None):
|
|
|
arrange_transforms(
|
|
|
model_type=model_type,
|
|
|
class_name=class_name,
|
|
|
@@ -554,3 +559,102 @@ class DeepLabv3p(BaseAPI):
|
|
|
|
|
|
preds = DeepLabv3p._postprocess(result, im_info)
|
|
|
return preds
|
|
|
+
|
|
|
+ def overlap_tile_predict(self,
|
|
|
+ img_file,
|
|
|
+ tile_size=[512, 512],
|
|
|
+ pad_size=[64, 64],
|
|
|
+ batch_size=32,
|
|
|
+ transforms=None):
|
|
|
+ """有重叠的大图切小图预测。
|
|
|
+ Args:
|
|
|
+ img_file(str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
|
|
|
+ tile_size(list|tuple): 滑动窗口的大小,该区域内用于拼接预测结果,格式为(W,H)。默认值为[512, 512]。
|
|
|
+ pad_size(list|tuple): 滑动窗口向四周扩展的大小,扩展区域内不用于拼接预测结果,格式为(W,H)。默认值为[64,64]。
|
|
|
+ batch_size(int):对窗口进行批量预测时的批量大小。默认值为32
|
|
|
+ transforms(paddlex.cv.transforms): 数据预处理操作。
|
|
|
+
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,
|
|
|
+ 像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)
|
|
|
+ """
|
|
|
+
|
|
|
+ if transforms is None and not hasattr(self, 'test_transforms'):
|
|
|
+ raise Exception("transforms need to be defined, now is None.")
|
|
|
+
|
|
|
+ if isinstance(img_file, str):
|
|
|
+ image, _ = Compose.decode_image(img_file, None)
|
|
|
+ elif isinstance(img_file, np.ndarray):
|
|
|
+ image = img_file.copy()
|
|
|
+ else:
|
|
|
+ raise Exception("im_file must be list/tuple")
|
|
|
+
|
|
|
+ height, width, channel = image.shape
|
|
|
+ image_tile_list = list()
|
|
|
+
|
|
|
+ # Padding along the left and right sides
|
|
|
+ if pad_size[0] > 0:
|
|
|
+ left_pad = cv2.flip(image[0:height, 0:pad_size[0], :], 1)
|
|
|
+ right_pad = cv2.flip(image[0:height, -pad_size[0]:width, :], 1)
|
|
|
+ padding_image = cv2.hconcat([left_pad, image])
|
|
|
+ padding_image = cv2.hconcat([padding_image, right_pad])
|
|
|
+ else:
|
|
|
+ import copy
|
|
|
+ padding_image = copy.deepcopy(image)
|
|
|
+
|
|
|
+ # Padding along the upper and lower sides
|
|
|
+ padding_height, padding_width, _ = padding_image.shape
|
|
|
+ if pad_size[1] > 0:
|
|
|
+ upper_pad = cv2.flip(
|
|
|
+ padding_image[0:pad_size[1], 0:padding_width, :], 0)
|
|
|
+ lower_pad = cv2.flip(
|
|
|
+ padding_image[-pad_size[1]:padding_height, 0:padding_width, :],
|
|
|
+ 0)
|
|
|
+ padding_image = cv2.vconcat([upper_pad, padding_image])
|
|
|
+ padding_image = cv2.vconcat([padding_image, lower_pad])
|
|
|
+
|
|
|
+ # crop the padding image into tile pieces
|
|
|
+ padding_height, padding_width, _ = padding_image.shape
|
|
|
+
|
|
|
+ for h_id in range(0, height // tile_size[1] + 1):
|
|
|
+ for w_id in range(0, width // tile_size[0] + 1):
|
|
|
+ left = w_id * tile_size[0]
|
|
|
+ upper = h_id * tile_size[1]
|
|
|
+ right = min(left + tile_size[0] + pad_size[0] * 2,
|
|
|
+ padding_width)
|
|
|
+ lower = min(upper + tile_size[1] + pad_size[1] * 2,
|
|
|
+ padding_height)
|
|
|
+ image_tile = padding_image[upper:lower, left:right, :]
|
|
|
+ image_tile_list.append(image_tile)
|
|
|
+
|
|
|
+ # predict
|
|
|
+ label_map = np.zeros((height, width), dtype=np.uint8)
|
|
|
+ score_map = np.zeros(
|
|
|
+ (height, width, self.num_classes), dtype=np.float32)
|
|
|
+ num_tiles = len(image_tile_list)
|
|
|
+ for i in range(0, num_tiles, batch_size):
|
|
|
+ begin = i
|
|
|
+ end = min(i + batch_size, num_tiles)
|
|
|
+ res = self.batch_predict(
|
|
|
+ img_file_list=image_tile_list[begin:end],
|
|
|
+ transforms=transforms)
|
|
|
+ for j in range(begin, end):
|
|
|
+ h_id = j // (width // tile_size[0] + 1)
|
|
|
+ w_id = j % (width // tile_size[0] + 1)
|
|
|
+ left = w_id * tile_size[0]
|
|
|
+ upper = h_id * tile_size[1]
|
|
|
+ right = min((w_id + 1) * tile_size[0], width)
|
|
|
+ lower = min((h_id + 1) * tile_size[1], height)
|
|
|
+ tile_label_map = res[j - begin]["label_map"]
|
|
|
+ tile_score_map = res[j - begin]["score_map"]
|
|
|
+ tile_upper = pad_size[1]
|
|
|
+ tile_lower = tile_label_map.shape[0] - pad_size[1]
|
|
|
+ tile_left = pad_size[0]
|
|
|
+ tile_right = tile_label_map.shape[1] - pad_size[0]
|
|
|
+ label_map[upper:lower, left:right] = \
|
|
|
+ tile_label_map[tile_upper:tile_lower, tile_left:tile_right]
|
|
|
+ score_map[upper:lower, left:right, :] = \
|
|
|
+ tile_score_map[tile_upper:tile_lower, tile_left:tile_right, :]
|
|
|
+ result = {"label_map": label_map, "score_map": score_map}
|
|
|
+ return result
|