Ver código fonte

fix prediction in bg_replace

FlyingQianMM 5 anos atrás
pai
commit
41183ea7c6

+ 61 - 37
examples/human_segmentation/bg_replace.py

@@ -19,7 +19,7 @@ import os.path as osp
 import cv2
 import numpy as np
 
-from utils.humanseg_postprocess import postprocess, threshold_mask
+from postprocess import postprocess, threshold_mask
 import paddlex as pdx
 import paddlex.utils.logging as logging
 from paddlex.seg import transforms
@@ -74,44 +74,29 @@ def parse_args():
     return parser.parse_args()
 
 
-def predict(img, model, test_transforms):
-    model.arrange_transforms(transforms=test_transforms, mode='test')
-    img, im_info = test_transforms(img.astype('float32'))
-    img = np.expand_dims(img, axis=0)
-    result = model.exe.run(model.test_prog,
-                           feed={'image': img},
-                           fetch_list=list(model.test_outputs.values()))
-    score_map = result[1]
-    score_map = np.squeeze(score_map, axis=0)
-    score_map = np.transpose(score_map, (1, 2, 0))
-    return score_map, im_info
+def bg_replace(label_map, img, bg):
+    h, w, _ = img.shape
+    bg = cv2.resize(bg, (w, h))
+    label_map = np.repeat(label_map[:, :, np.newaxis], 3, axis=2)
+    comb = (label_map * img + (1 - label_map) * bg).astype(np.uint8)
+    return comb
 
 
 def recover(img, im_info):
-    for info in im_info[::-1]:
-        if info[0] == 'resize':
-            w, h = info[1][1], info[1][0]
-            img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
-        elif info[0] == 'padding':
-            w, h = info[1][0], info[1][0]
-            img = img[0:h, 0:w, :]
+    if im_info[0] == 'resize':
+        w, h = im_info[1][1], im_info[1][0]
+        img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
+    elif im_info[0] == 'padding':
+        w, h = im_info[1][0], im_info[1][0]
+        img = img[0:h, 0:w, :]
     return img
 
 
-def bg_replace(score_map, img, bg):
-    h, w, _ = img.shape
-    bg = cv2.resize(bg, (w, h))
-    score_map = np.repeat(score_map[:, :, np.newaxis], 3, axis=2)
-    comb = (score_map * img + (1 - score_map) * bg).astype(np.uint8)
-    return comb
-
-
 def infer(args):
     resize_h = args.image_shape[1]
     resize_w = args.image_shape[0]
 
-    test_transforms = transforms.Compose(
-        [transforms.Resize((resize_w, resize_h)), transforms.Normalize()])
+    test_transforms = transforms.Compose([transforms.Normalize()])
     model = pdx.load_model(args.model_dir)
 
     if not osp.exists(args.save_dir):
@@ -130,14 +115,27 @@ def infer(args):
                 raise Exception(
                     'The --background_image_path is not existed: {}'.format(
                         args.background_image_path))
+
         img = cv2.imread(args.image_path)
-        score_map, im_info = predict(img, model, test_transforms)
-        score_map = score_map[:, :, 1]
-        score_map = recover(score_map, im_info)
+        im_shape = img.shape
+        im_scale_x = float(resize_w) / float(im_shape[1])
+        im_scale_y = float(resize_h) / float(im_shape[0])
+        im = cv2.resize(
+            img,
+            None,
+            None,
+            fx=im_scale_x,
+            fy=im_scale_y,
+            interpolation=cv2.INTER_LINEAR)
+        image = im.astype('float32')
+        im_info = ('resize', im_shape[0:2])
+        pred = model.predict(image, test_transforms)
+        label_map = pred['label_map']
+        label_map = recover(label_map, im_info)
         bg = cv2.imread(args.background_image_path)
         save_name = osp.basename(args.image_path)
         save_path = osp.join(args.save_dir, save_name)
-        result = bg_replace(score_map, img, bg)
+        result = bg_replace(label_map, img, bg)
         cv2.imwrite(save_path, result)
 
     # 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片
@@ -192,8 +190,21 @@ def infer(args):
             while cap_video.isOpened():
                 ret, frame = cap_video.read()
                 if ret:
-                    score_map, im_info = predict(frame, model, test_transforms)
-                    cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+                    im_shape = frame.shape
+                    im_scale_x = float(resize_w) / float(im_shape[1])
+                    im_scale_y = float(resize_h) / float(im_shape[0])
+                    im = cv2.resize(
+                        frame,
+                        None,
+                        None,
+                        fx=im_scale_x,
+                        fy=im_scale_y,
+                        interpolation=cv2.INTER_LINEAR)
+                    image = im.astype('float32')
+                    im_info = ('resize', im_shape[0:2])
+                    pred = model.predict(image, test_transforms)
+                    score_map = pred['score_map']
+                    cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
                     cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
                     score_map = 255 * score_map[:, :, 1]
                     optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
@@ -248,8 +259,21 @@ def infer(args):
             while cap_video.isOpened():
                 ret, frame = cap_video.read()
                 if ret:
-                    score_map, im_info = predict(frame, model, test_transforms)
-                    cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+                    im_shape = frame.shape
+                    im_scale_x = float(resize_w) / float(im_shape[1])
+                    im_scale_y = float(resize_h) / float(im_shape[0])
+                    im = cv2.resize(
+                        frame,
+                        None,
+                        None,
+                        fx=im_scale_x,
+                        fy=im_scale_y,
+                        interpolation=cv2.INTER_LINEAR)
+                    image = im.astype('float32')
+                    im_info = ('resize', im_shape[0:2])
+                    pred = model.predict(image, test_transforms)
+                    score_map = pred['score_map']
+                    cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
                     cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
                     score_map = 255 * score_map[:, :, 1]
                     optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \

+ 3 - 3
examples/human_segmentation/video_infer.py

@@ -70,8 +70,8 @@ def video_infer(args):
     resize_h = args.image_shape[1]
     resize_w = args.image_shape[0]
 
-    test_transforms = transforms.Compose([transforms.Normalize()])
     model = pdx.load_model(args.model_dir)
+    test_transforms = transforms.Compose([transforms.Normalize()])
     if not args.video_path:
         cap = cv2.VideoCapture(0)
     else:
@@ -115,7 +115,7 @@ def video_infer(args):
                     interpolation=cv2.INTER_LINEAR)
                 image = im.astype('float32')
                 im_info = ('resize', im_shape[0:2])
-                pred = model.predict(image)
+                pred = model.predict(image, test_transforms)
                 score_map = pred['score_map']
                 cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
                 score_map = 255 * score_map[:, :, 1]
@@ -155,7 +155,7 @@ def video_infer(args):
                     interpolation=cv2.INTER_LINEAR)
                 image = im.astype('float32')
                 im_info = ('resize', im_shape[0:2])
-                pred = model.predict(image)
+                pred = model.predict(image, test_transforms)
                 score_map = pred['score_map']
                 cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
                 cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))