bg_replace.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import os
  17. import os.path as osp
  18. import cv2
  19. import numpy as np
  20. from postprocess import postprocess, threshold_mask
  21. import paddlex as pdx
  22. import paddlex.utils.logging as logging
  23. from paddlex.seg import transforms
  24. def parse_args():
  25. parser = argparse.ArgumentParser(
  26. description='HumanSeg inference for video')
  27. parser.add_argument(
  28. '--model_dir',
  29. dest='model_dir',
  30. help='Model path for inference',
  31. type=str)
  32. parser.add_argument(
  33. '--image_path',
  34. dest='image_path',
  35. help='Image including human',
  36. type=str,
  37. default=None)
  38. parser.add_argument(
  39. '--background_image_path',
  40. dest='background_image_path',
  41. help='Background image for replacing',
  42. type=str,
  43. default=None)
  44. parser.add_argument(
  45. '--video_path',
  46. dest='video_path',
  47. help='Video path for inference',
  48. type=str,
  49. default=None)
  50. parser.add_argument(
  51. '--background_video_path',
  52. dest='background_video_path',
  53. help='Background video path for replacing',
  54. type=str,
  55. default=None)
  56. parser.add_argument(
  57. '--save_dir',
  58. dest='save_dir',
  59. help='The directory for saving the inference results',
  60. type=str,
  61. default='./output')
  62. parser.add_argument(
  63. "--image_shape",
  64. dest="image_shape",
  65. help="The image shape for net inputs.",
  66. nargs=2,
  67. default=[192, 192],
  68. type=int)
  69. return parser.parse_args()
  70. def bg_replace(label_map, img, bg):
  71. h, w, _ = img.shape
  72. bg = cv2.resize(bg, (w, h))
  73. label_map = np.repeat(label_map[:, :, np.newaxis], 3, axis=2)
  74. comb = (label_map * img + (1 - label_map) * bg).astype(np.uint8)
  75. return comb
  76. def recover(img, im_info):
  77. if im_info[0] == 'resize':
  78. w, h = im_info[1][1], im_info[1][0]
  79. img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
  80. elif im_info[0] == 'padding':
  81. w, h = im_info[1][0], im_info[1][0]
  82. img = img[0:h, 0:w, :]
  83. return img
  84. def infer(args):
  85. resize_h = args.image_shape[1]
  86. resize_w = args.image_shape[0]
  87. test_transforms = transforms.Compose([transforms.Normalize()])
  88. model = pdx.load_model(args.model_dir)
  89. if not osp.exists(args.save_dir):
  90. os.makedirs(args.save_dir)
  91. # 图像背景替换
  92. if args.image_path is not None:
  93. if not osp.exists(args.image_path):
  94. raise Exception('The --image_path is not existed: {}'.format(
  95. args.image_path))
  96. if args.background_image_path is None:
  97. raise Exception(
  98. 'The --background_image_path is not set. Please set it')
  99. else:
  100. if not osp.exists(args.background_image_path):
  101. raise Exception(
  102. 'The --background_image_path is not existed: {}'.format(
  103. args.background_image_path))
  104. img = cv2.imread(args.image_path)
  105. im_shape = img.shape
  106. im_scale_x = float(resize_w) / float(im_shape[1])
  107. im_scale_y = float(resize_h) / float(im_shape[0])
  108. im = cv2.resize(
  109. img,
  110. None,
  111. None,
  112. fx=im_scale_x,
  113. fy=im_scale_y,
  114. interpolation=cv2.INTER_LINEAR)
  115. image = im.astype('float32')
  116. im_info = ('resize', im_shape[0:2])
  117. pred = model.predict(image, test_transforms)
  118. label_map = pred['label_map']
  119. label_map = recover(label_map, im_info)
  120. bg = cv2.imread(args.background_image_path)
  121. save_name = osp.basename(args.image_path)
  122. save_path = osp.join(args.save_dir, save_name)
  123. result = bg_replace(label_map, img, bg)
  124. cv2.imwrite(save_path, result)
  125. # 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片
  126. else:
  127. is_video_bg = False
  128. if args.background_video_path is not None:
  129. if not osp.exists(args.background_video_path):
  130. raise Exception(
  131. 'The --background_video_path is not existed: {}'.format(
  132. args.background_video_path))
  133. is_video_bg = True
  134. elif args.background_image_path is not None:
  135. if not osp.exists(args.background_image_path):
  136. raise Exception(
  137. 'The --background_image_path is not existed: {}'.format(
  138. args.background_image_path))
  139. else:
  140. raise Exception(
  141. 'Please offer backgound image or video. You should set --backbground_iamge_paht or --background_video_path'
  142. )
  143. disflow = cv2.DISOpticalFlow_create(
  144. cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
  145. prev_gray = np.zeros((resize_h, resize_w), np.uint8)
  146. prev_cfd = np.zeros((resize_h, resize_w), np.float32)
  147. is_init = True
  148. if args.video_path is not None:
  149. logging.info('Please wait. It is computing......')
  150. if not osp.exists(args.video_path):
  151. raise Exception('The --video_path is not existed: {}'.format(
  152. args.video_path))
  153. cap_video = cv2.VideoCapture(args.video_path)
  154. fps = cap_video.get(cv2.CAP_PROP_FPS)
  155. width = int(cap_video.get(cv2.CAP_PROP_FRAME_WIDTH))
  156. height = int(cap_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
  157. save_name = osp.basename(args.video_path)
  158. save_name = save_name.split('.')[0]
  159. save_path = osp.join(args.save_dir, save_name + '.avi')
  160. cap_out = cv2.VideoWriter(
  161. save_path,
  162. cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
  163. (width, height))
  164. if is_video_bg:
  165. cap_bg = cv2.VideoCapture(args.background_video_path)
  166. frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
  167. current_frame_bg = 1
  168. else:
  169. img_bg = cv2.imread(args.background_image_path)
  170. while cap_video.isOpened():
  171. ret, frame = cap_video.read()
  172. if ret:
  173. im_shape = frame.shape
  174. im_scale_x = float(resize_w) / float(im_shape[1])
  175. im_scale_y = float(resize_h) / float(im_shape[0])
  176. im = cv2.resize(
  177. frame,
  178. None,
  179. None,
  180. fx=im_scale_x,
  181. fy=im_scale_y,
  182. interpolation=cv2.INTER_LINEAR)
  183. image = im.astype('float32')
  184. im_info = ('resize', im_shape[0:2])
  185. pred = model.predict(image, test_transforms)
  186. score_map = pred['score_map']
  187. cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
  188. cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
  189. score_map = 255 * score_map[:, :, 1]
  190. optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
  191. disflow, is_init)
  192. prev_gray = cur_gray.copy()
  193. prev_cfd = optflow_map.copy()
  194. is_init = False
  195. optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
  196. optflow_map = threshold_mask(
  197. optflow_map, thresh_bg=0.2, thresh_fg=0.8)
  198. score_map = recover(optflow_map, im_info)
  199. #循环读取背景帧
  200. if is_video_bg:
  201. ret_bg, frame_bg = cap_bg.read()
  202. if ret_bg:
  203. if current_frame_bg == frames_bg:
  204. current_frame_bg = 1
  205. cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
  206. else:
  207. break
  208. current_frame_bg += 1
  209. comb = bg_replace(score_map, frame, frame_bg)
  210. else:
  211. comb = bg_replace(score_map, frame, img_bg)
  212. cap_out.write(comb)
  213. else:
  214. break
  215. if is_video_bg:
  216. cap_bg.release()
  217. cap_video.release()
  218. cap_out.release()
  219. # 当没有输入预测图像和视频的时候,则打开摄像头
  220. else:
  221. cap_video = cv2.VideoCapture(0)
  222. if not cap_video.isOpened():
  223. raise IOError("Error opening video stream or file, "
  224. "--video_path whether existing: {}"
  225. " or camera whether working".format(
  226. args.video_path))
  227. return
  228. if is_video_bg:
  229. cap_bg = cv2.VideoCapture(args.background_video_path)
  230. frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
  231. current_frame_bg = 1
  232. else:
  233. img_bg = cv2.imread(args.background_image_path)
  234. while cap_video.isOpened():
  235. ret, frame = cap_video.read()
  236. if ret:
  237. im_shape = frame.shape
  238. im_scale_x = float(resize_w) / float(im_shape[1])
  239. im_scale_y = float(resize_h) / float(im_shape[0])
  240. im = cv2.resize(
  241. frame,
  242. None,
  243. None,
  244. fx=im_scale_x,
  245. fy=im_scale_y,
  246. interpolation=cv2.INTER_LINEAR)
  247. image = im.astype('float32')
  248. im_info = ('resize', im_shape[0:2])
  249. pred = model.predict(image, test_transforms)
  250. score_map = pred['score_map']
  251. cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
  252. cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
  253. score_map = 255 * score_map[:, :, 1]
  254. optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
  255. disflow, is_init)
  256. prev_gray = cur_gray.copy()
  257. prev_cfd = optflow_map.copy()
  258. is_init = False
  259. optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
  260. optflow_map = threshold_mask(
  261. optflow_map, thresh_bg=0.2, thresh_fg=0.8)
  262. score_map = recover(optflow_map, im_info)
  263. #循环读取背景帧
  264. if is_video_bg:
  265. ret_bg, frame_bg = cap_bg.read()
  266. if ret_bg:
  267. if current_frame_bg == frames_bg:
  268. current_frame_bg = 1
  269. cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
  270. else:
  271. break
  272. current_frame_bg += 1
  273. comb = bg_replace(score_map, frame, frame_bg)
  274. else:
  275. comb = bg_replace(score_map, frame, img_bg)
  276. cv2.imshow('HumanSegmentation', comb)
  277. if cv2.waitKey(1) & 0xFF == ord('q'):
  278. break
  279. else:
  280. break
  281. if is_video_bg:
  282. cap_bg.release()
  283. cap_video.release()
  284. if __name__ == "__main__":
  285. args = parse_args()
  286. infer(args)