bg_replace.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import os
  17. import os.path as osp
  18. import cv2
  19. import numpy as np
  20. from utils.humanseg_postprocess import postprocess, threshold_mask
  21. import paddlex as pdx
  22. import paddlex.utils.logging as logging
  23. from paddlex.seg import transforms
  24. def parse_args():
  25. parser = argparse.ArgumentParser(
  26. description='HumanSeg inference for video')
  27. parser.add_argument(
  28. '--model_dir',
  29. dest='model_dir',
  30. help='Model path for inference',
  31. type=str)
  32. parser.add_argument(
  33. '--image_path',
  34. dest='image_path',
  35. help='Image including human',
  36. type=str,
  37. default=None)
  38. parser.add_argument(
  39. '--background_image_path',
  40. dest='background_image_path',
  41. help='Background image for replacing',
  42. type=str,
  43. default=None)
  44. parser.add_argument(
  45. '--video_path',
  46. dest='video_path',
  47. help='Video path for inference',
  48. type=str,
  49. default=None)
  50. parser.add_argument(
  51. '--background_video_path',
  52. dest='background_video_path',
  53. help='Background video path for replacing',
  54. type=str,
  55. default=None)
  56. parser.add_argument(
  57. '--save_dir',
  58. dest='save_dir',
  59. help='The directory for saving the inference results',
  60. type=str,
  61. default='./output')
  62. parser.add_argument(
  63. "--image_shape",
  64. dest="image_shape",
  65. help="The image shape for net inputs.",
  66. nargs=2,
  67. default=[192, 192],
  68. type=int)
  69. return parser.parse_args()
  70. def predict(img, model, test_transforms):
  71. model.arrange_transforms(transforms=test_transforms, mode='test')
  72. img, im_info = test_transforms(img.astype('float32'))
  73. img = np.expand_dims(img, axis=0)
  74. result = model.exe.run(model.test_prog,
  75. feed={'image': img},
  76. fetch_list=list(model.test_outputs.values()))
  77. score_map = result[1]
  78. score_map = np.squeeze(score_map, axis=0)
  79. score_map = np.transpose(score_map, (1, 2, 0))
  80. return score_map, im_info
  81. def recover(img, im_info):
  82. for info in im_info[::-1]:
  83. if info[0] == 'resize':
  84. w, h = info[1][1], info[1][0]
  85. img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
  86. elif info[0] == 'padding':
  87. w, h = info[1][0], info[1][0]
  88. img = img[0:h, 0:w, :]
  89. return img
  90. def bg_replace(score_map, img, bg):
  91. h, w, _ = img.shape
  92. bg = cv2.resize(bg, (w, h))
  93. score_map = np.repeat(score_map[:, :, np.newaxis], 3, axis=2)
  94. comb = (score_map * img + (1 - score_map) * bg).astype(np.uint8)
  95. return comb
  96. def infer(args):
  97. resize_h = args.image_shape[1]
  98. resize_w = args.image_shape[0]
  99. test_transforms = transforms.Compose(
  100. [transforms.Resize((resize_w, resize_h)), transforms.Normalize()])
  101. model = pdx.load_model(args.model_dir)
  102. if not osp.exists(args.save_dir):
  103. os.makedirs(args.save_dir)
  104. # 图像背景替换
  105. if args.image_path is not None:
  106. if not osp.exists(args.image_path):
  107. raise Exception('The --image_path is not existed: {}'.format(
  108. args.image_path))
  109. if args.background_image_path is None:
  110. raise Exception(
  111. 'The --background_image_path is not set. Please set it')
  112. else:
  113. if not osp.exists(args.background_image_path):
  114. raise Exception(
  115. 'The --background_image_path is not existed: {}'.format(
  116. args.background_image_path))
  117. img = cv2.imread(args.image_path)
  118. score_map, im_info = predict(img, model, test_transforms)
  119. score_map = score_map[:, :, 1]
  120. score_map = recover(score_map, im_info)
  121. bg = cv2.imread(args.background_image_path)
  122. save_name = osp.basename(args.image_path)
  123. save_path = osp.join(args.save_dir, save_name)
  124. result = bg_replace(score_map, img, bg)
  125. cv2.imwrite(save_path, result)
  126. # 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片
  127. else:
  128. is_video_bg = False
  129. if args.background_video_path is not None:
  130. if not osp.exists(args.background_video_path):
  131. raise Exception(
  132. 'The --background_video_path is not existed: {}'.format(
  133. args.background_video_path))
  134. is_video_bg = True
  135. elif args.background_image_path is not None:
  136. if not osp.exists(args.background_image_path):
  137. raise Exception(
  138. 'The --background_image_path is not existed: {}'.format(
  139. args.background_image_path))
  140. else:
  141. raise Exception(
  142. 'Please offer backgound image or video. You should set --backbground_iamge_paht or --background_video_path'
  143. )
  144. disflow = cv2.DISOpticalFlow_create(
  145. cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
  146. prev_gray = np.zeros((resize_h, resize_w), np.uint8)
  147. prev_cfd = np.zeros((resize_h, resize_w), np.float32)
  148. is_init = True
  149. if args.video_path is not None:
  150. logging.info('Please wait. It is computing......')
  151. if not osp.exists(args.video_path):
  152. raise Exception('The --video_path is not existed: {}'.format(
  153. args.video_path))
  154. cap_video = cv2.VideoCapture(args.video_path)
  155. fps = cap_video.get(cv2.CAP_PROP_FPS)
  156. width = int(cap_video.get(cv2.CAP_PROP_FRAME_WIDTH))
  157. height = int(cap_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
  158. save_name = osp.basename(args.video_path)
  159. save_name = save_name.split('.')[0]
  160. save_path = osp.join(args.save_dir, save_name + '.avi')
  161. cap_out = cv2.VideoWriter(
  162. save_path,
  163. cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
  164. (width, height))
  165. if is_video_bg:
  166. cap_bg = cv2.VideoCapture(args.background_video_path)
  167. frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
  168. current_frame_bg = 1
  169. else:
  170. img_bg = cv2.imread(args.background_image_path)
  171. while cap_video.isOpened():
  172. ret, frame = cap_video.read()
  173. if ret:
  174. score_map, im_info = predict(frame, model, test_transforms)
  175. cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  176. cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
  177. score_map = 255 * score_map[:, :, 1]
  178. optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
  179. disflow, is_init)
  180. prev_gray = cur_gray.copy()
  181. prev_cfd = optflow_map.copy()
  182. is_init = False
  183. optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
  184. optflow_map = threshold_mask(
  185. optflow_map, thresh_bg=0.2, thresh_fg=0.8)
  186. score_map = recover(optflow_map, im_info)
  187. #循环读取背景帧
  188. if is_video_bg:
  189. ret_bg, frame_bg = cap_bg.read()
  190. if ret_bg:
  191. if current_frame_bg == frames_bg:
  192. current_frame_bg = 1
  193. cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
  194. else:
  195. break
  196. current_frame_bg += 1
  197. comb = bg_replace(score_map, frame, frame_bg)
  198. else:
  199. comb = bg_replace(score_map, frame, img_bg)
  200. cap_out.write(comb)
  201. else:
  202. break
  203. if is_video_bg:
  204. cap_bg.release()
  205. cap_video.release()
  206. cap_out.release()
  207. # 当没有输入预测图像和视频的时候,则打开摄像头
  208. else:
  209. cap_video = cv2.VideoCapture(0)
  210. if not cap_video.isOpened():
  211. raise IOError("Error opening video stream or file, "
  212. "--video_path whether existing: {}"
  213. " or camera whether working".format(
  214. args.video_path))
  215. return
  216. if is_video_bg:
  217. cap_bg = cv2.VideoCapture(args.background_video_path)
  218. frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
  219. current_frame_bg = 1
  220. else:
  221. img_bg = cv2.imread(args.background_image_path)
  222. while cap_video.isOpened():
  223. ret, frame = cap_video.read()
  224. if ret:
  225. score_map, im_info = predict(frame, model, test_transforms)
  226. cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  227. cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
  228. score_map = 255 * score_map[:, :, 1]
  229. optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
  230. disflow, is_init)
  231. prev_gray = cur_gray.copy()
  232. prev_cfd = optflow_map.copy()
  233. is_init = False
  234. optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
  235. optflow_map = threshold_mask(
  236. optflow_map, thresh_bg=0.2, thresh_fg=0.8)
  237. score_map = recover(optflow_map, im_info)
  238. #循环读取背景帧
  239. if is_video_bg:
  240. ret_bg, frame_bg = cap_bg.read()
  241. if ret_bg:
  242. if current_frame_bg == frames_bg:
  243. current_frame_bg = 1
  244. cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
  245. else:
  246. break
  247. current_frame_bg += 1
  248. comb = bg_replace(score_map, frame, frame_bg)
  249. else:
  250. comb = bg_replace(score_map, frame, img_bg)
  251. cv2.imshow('HumanSegmentation', comb)
  252. if cv2.waitKey(1) & 0xFF == ord('q'):
  253. break
  254. else:
  255. break
  256. if is_video_bg:
  257. cap_bg.release()
  258. cap_video.release()
  259. if __name__ == "__main__":
  260. args = parse_args()
  261. infer(args)