video_infer.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import os
  17. import os.path as osp
  18. import cv2
  19. import numpy as np
  20. from utils.humanseg_postprocess import postprocess, threshold_mask
  21. import paddlex as pdx
  22. import paddlex.utils.logging as logging
  23. from paddlex.seg import transforms
  24. def parse_args():
  25. parser = argparse.ArgumentParser(
  26. description='HumanSeg inference for video')
  27. parser.add_argument(
  28. '--model_dir',
  29. dest='model_dir',
  30. help='Model path for inference',
  31. type=str)
  32. parser.add_argument(
  33. '--video_path',
  34. dest='video_path',
  35. help='Video path for inference, camera will be used if the path not existing',
  36. type=str,
  37. default=None)
  38. parser.add_argument(
  39. '--save_dir',
  40. dest='save_dir',
  41. help='The directory for saving the inference results',
  42. type=str,
  43. default='./output')
  44. parser.add_argument(
  45. "--image_shape",
  46. dest="image_shape",
  47. help="The image shape for net inputs.",
  48. nargs=2,
  49. default=[192, 192],
  50. type=int)
  51. return parser.parse_args()
  52. def predict(img, model, test_transforms):
  53. model.arrange_transforms(transforms=test_transforms, mode='test')
  54. img, im_info = test_transforms(img.astype('float32'))
  55. img = np.expand_dims(img, axis=0)
  56. result = model.exe.run(model.test_prog,
  57. feed={'image': img},
  58. fetch_list=list(model.test_outputs.values()))
  59. score_map = result[1]
  60. score_map = np.squeeze(score_map, axis=0)
  61. score_map = np.transpose(score_map, (1, 2, 0))
  62. return score_map, im_info
  63. def recover(img, im_info):
  64. for info in im_info[::-1]:
  65. if info[0] == 'resize':
  66. w, h = info[1][1], info[1][0]
  67. img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
  68. elif info[0] == 'padding':
  69. w, h = info[1][0], info[1][0]
  70. img = img[0:h, 0:w, :]
  71. return img
  72. def video_infer(args):
  73. resize_h = args.image_shape[1]
  74. resize_w = args.image_shape[0]
  75. test_transforms = transforms.Compose(
  76. [transforms.Resize((resize_w, resize_h)), transforms.Normalize()])
  77. model = pdx.load_model(args.model_dir)
  78. if not args.video_path:
  79. cap = cv2.VideoCapture(0)
  80. else:
  81. cap = cv2.VideoCapture(args.video_path)
  82. if not cap.isOpened():
  83. raise IOError("Error opening video stream or file, "
  84. "--video_path whether existing: {}"
  85. " or camera whether working".format(args.video_path))
  86. return
  87. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  88. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  89. disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
  90. prev_gray = np.zeros((resize_h, resize_w), np.uint8)
  91. prev_cfd = np.zeros((resize_h, resize_w), np.float32)
  92. is_init = True
  93. fps = cap.get(cv2.CAP_PROP_FPS)
  94. if args.video_path:
  95. logging.info("Please wait. It is computing......")
  96. # 用于保存预测结果视频
  97. if not osp.exists(args.save_dir):
  98. os.makedirs(args.save_dir)
  99. out = cv2.VideoWriter(
  100. osp.join(args.save_dir, 'result.avi'),
  101. cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (width, height))
  102. # 开始获取视频帧
  103. while cap.isOpened():
  104. ret, frame = cap.read()
  105. if ret:
  106. score_map, im_info = predict(frame, model, test_transforms)
  107. cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  108. cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
  109. score_map = 255 * score_map[:, :, 1]
  110. optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
  111. disflow, is_init)
  112. prev_gray = cur_gray.copy()
  113. prev_cfd = optflow_map.copy()
  114. is_init = False
  115. optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
  116. optflow_map = threshold_mask(
  117. optflow_map, thresh_bg=0.2, thresh_fg=0.8)
  118. img_matting = np.repeat(
  119. optflow_map[:, :, np.newaxis], 3, axis=2)
  120. img_matting = recover(img_matting, im_info)
  121. bg_im = np.ones_like(img_matting) * 255
  122. comb = (img_matting * frame +
  123. (1 - img_matting) * bg_im).astype(np.uint8)
  124. out.write(comb)
  125. else:
  126. break
  127. cap.release()
  128. out.release()
  129. else:
  130. while cap.isOpened():
  131. ret, frame = cap.read()
  132. if ret:
  133. score_map, im_info = predict(frame, model, test_transforms)
  134. cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  135. cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
  136. score_map = 255 * score_map[:, :, 1]
  137. optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
  138. disflow, is_init)
  139. prev_gray = cur_gray.copy()
  140. prev_cfd = optflow_map.copy()
  141. is_init = False
  142. optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
  143. optflow_map = threshold_mask(
  144. optflow_map, thresh_bg=0.2, thresh_fg=0.8)
  145. img_matting = np.repeat(
  146. optflow_map[:, :, np.newaxis], 3, axis=2)
  147. img_matting = recover(img_matting, im_info)
  148. bg_im = np.ones_like(img_matting) * 255
  149. comb = (img_matting * frame +
  150. (1 - img_matting) * bg_im).astype(np.uint8)
  151. cv2.imshow('HumanSegmentation', comb)
  152. if cv2.waitKey(1) & 0xFF == ord('q'):
  153. break
  154. else:
  155. break
  156. cap.release()
  157. if __name__ == "__main__":
  158. args = parse_args()
  159. video_infer(args)