video_infer.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import os
  17. import os.path as osp
  18. import cv2
  19. import numpy as np
  20. from postprocess import postprocess, threshold_mask
  21. import paddlex as pdx
  22. import paddlex.utils.logging as logging
  23. from paddlex.seg import transforms
  24. def parse_args():
  25. parser = argparse.ArgumentParser(
  26. description='HumanSeg inference for video')
  27. parser.add_argument(
  28. '--model_dir',
  29. dest='model_dir',
  30. help='Model path for inference',
  31. type=str)
  32. parser.add_argument(
  33. '--video_path',
  34. dest='video_path',
  35. help='Video path for inference, camera will be used if the path not existing',
  36. type=str,
  37. default=None)
  38. parser.add_argument(
  39. '--save_dir',
  40. dest='save_dir',
  41. help='The directory for saving the inference results',
  42. type=str,
  43. default='./output')
  44. parser.add_argument(
  45. "--image_shape",
  46. dest="image_shape",
  47. help="The image shape for net inputs.",
  48. nargs=2,
  49. default=[192, 192],
  50. type=int)
  51. return parser.parse_args()
  52. def recover(img, im_info):
  53. if im_info[0] == 'resize':
  54. w, h = im_info[1][1], im_info[1][0]
  55. img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
  56. elif im_info[0] == 'padding':
  57. w, h = im_info[1][0], im_info[1][0]
  58. img = img[0:h, 0:w, :]
  59. return img
  60. def video_infer(args):
  61. resize_h = args.image_shape[1]
  62. resize_w = args.image_shape[0]
  63. model = pdx.load_model(args.model_dir)
  64. test_transforms = transforms.Compose([transforms.Normalize()])
  65. if not args.video_path:
  66. cap = cv2.VideoCapture(0)
  67. else:
  68. cap = cv2.VideoCapture(args.video_path)
  69. if not cap.isOpened():
  70. raise IOError("Error opening video stream or file, "
  71. "--video_path whether existing: {}"
  72. " or camera whether working".format(args.video_path))
  73. return
  74. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  75. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  76. disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
  77. prev_gray = np.zeros((resize_h, resize_w), np.uint8)
  78. prev_cfd = np.zeros((resize_h, resize_w), np.float32)
  79. is_init = True
  80. fps = cap.get(cv2.CAP_PROP_FPS)
  81. if args.video_path:
  82. logging.info("Please wait. It is computing......")
  83. # 用于保存预测结果视频
  84. if not osp.exists(args.save_dir):
  85. os.makedirs(args.save_dir)
  86. out = cv2.VideoWriter(
  87. osp.join(args.save_dir, 'result.avi'),
  88. cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (width, height))
  89. # 开始获取视频帧
  90. while cap.isOpened():
  91. ret, frame = cap.read()
  92. if ret:
  93. im_shape = frame.shape
  94. im_scale_x = float(resize_w) / float(im_shape[1])
  95. im_scale_y = float(resize_h) / float(im_shape[0])
  96. im = cv2.resize(
  97. frame,
  98. None,
  99. None,
  100. fx=im_scale_x,
  101. fy=im_scale_y,
  102. interpolation=cv2.INTER_LINEAR)
  103. image = im.astype('float32')
  104. im_info = ('resize', im_shape[0:2])
  105. pred = model.predict(image, test_transforms)
  106. score_map = pred['score_map']
  107. cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
  108. score_map = 255 * score_map[:, :, 1]
  109. optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
  110. disflow, is_init)
  111. prev_gray = cur_gray.copy()
  112. prev_cfd = optflow_map.copy()
  113. is_init = False
  114. optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
  115. optflow_map = threshold_mask(
  116. optflow_map, thresh_bg=0.2, thresh_fg=0.8)
  117. img_matting = np.repeat(
  118. optflow_map[:, :, np.newaxis], 3, axis=2)
  119. img_matting = recover(img_matting, im_info)
  120. bg_im = np.ones_like(img_matting) * 255
  121. comb = (img_matting * frame +
  122. (1 - img_matting) * bg_im).astype(np.uint8)
  123. out.write(comb)
  124. else:
  125. break
  126. cap.release()
  127. out.release()
  128. else:
  129. while cap.isOpened():
  130. ret, frame = cap.read()
  131. if ret:
  132. im_shape = frame.shape
  133. im_scale_x = float(resize_w) / float(im_shape[1])
  134. im_scale_y = float(resize_h) / float(im_shape[0])
  135. im = cv2.resize(
  136. frame,
  137. None,
  138. None,
  139. fx=im_scale_x,
  140. fy=im_scale_y,
  141. interpolation=cv2.INTER_LINEAR)
  142. image = im.astype('float32')
  143. im_info = ('resize', im_shape[0:2])
  144. pred = model.predict(image, test_transforms)
  145. score_map = pred['score_map']
  146. cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
  147. cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
  148. score_map = 255 * score_map[:, :, 1]
  149. optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
  150. disflow, is_init)
  151. prev_gray = cur_gray.copy()
  152. prev_cfd = optflow_map.copy()
  153. is_init = False
  154. optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
  155. optflow_map = threshold_mask(
  156. optflow_map, thresh_bg=0.2, thresh_fg=0.8)
  157. img_matting = np.repeat(
  158. optflow_map[:, :, np.newaxis], 3, axis=2)
  159. img_matting = recover(img_matting, im_info)
  160. bg_im = np.ones_like(img_matting) * 255
  161. comb = (img_matting * frame +
  162. (1 - img_matting) * bg_im).astype(np.uint8)
  163. cv2.imshow('HumanSegmentation', comb)
  164. if cv2.waitKey(1) & 0xFF == ord('q'):
  165. break
  166. else:
  167. break
  168. cap.release()
  169. if __name__ == "__main__":
  170. args = parse_args()
  171. video_infer(args)