callbacks.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import datetime
  20. import six
  21. import numpy as np
  22. import paddle
  23. import paddle.distributed as dist
  24. from paddlex.ppdet.utils.checkpoint import save_model
  25. from paddlex.ppdet.utils.logger import setup_logger
  26. logger = setup_logger('ppdet.engine')
  27. __all__ = ['Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer']
  28. class Callback(object):
  29. def __init__(self, model):
  30. self.model = model
  31. def on_step_begin(self, status):
  32. pass
  33. def on_step_end(self, status):
  34. pass
  35. def on_epoch_begin(self, status):
  36. pass
  37. def on_epoch_end(self, status):
  38. pass
  39. class ComposeCallback(object):
  40. def __init__(self, callbacks):
  41. callbacks = [c for c in list(callbacks) if c is not None]
  42. for c in callbacks:
  43. assert isinstance(
  44. c, Callback), "callback should be subclass of Callback"
  45. self._callbacks = callbacks
  46. def on_step_begin(self, status):
  47. for c in self._callbacks:
  48. c.on_step_begin(status)
  49. def on_step_end(self, status):
  50. for c in self._callbacks:
  51. c.on_step_end(status)
  52. def on_epoch_begin(self, status):
  53. for c in self._callbacks:
  54. c.on_epoch_begin(status)
  55. def on_epoch_end(self, status):
  56. for c in self._callbacks:
  57. c.on_epoch_end(status)
  58. class LogPrinter(Callback):
  59. def __init__(self, model):
  60. super(LogPrinter, self).__init__(model)
  61. def on_step_end(self, status):
  62. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  63. mode = status['mode']
  64. if mode == 'train':
  65. epoch_id = status['epoch_id']
  66. step_id = status['step_id']
  67. steps_per_epoch = status['steps_per_epoch']
  68. training_staus = status['training_staus']
  69. batch_time = status['batch_time']
  70. data_time = status['data_time']
  71. epoches = self.model.cfg.epoch
  72. batch_size = self.model.cfg['{}Reader'.format(mode.capitalize(
  73. ))]['batch_size']
  74. logs = training_staus.log()
  75. space_fmt = ':' + str(len(str(steps_per_epoch))) + 'd'
  76. if step_id % self.model.cfg.log_iter == 0:
  77. eta_steps = (epoches - epoch_id
  78. ) * steps_per_epoch - step_id
  79. eta_sec = eta_steps * batch_time.global_avg
  80. eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
  81. ips = float(batch_size) / batch_time.avg
  82. fmt = ' '.join([
  83. 'Epoch: [{}]',
  84. '[{' + space_fmt + '}/{}]',
  85. 'learning_rate: {lr:.6f}',
  86. '{meters}',
  87. 'eta: {eta}',
  88. 'batch_cost: {btime}',
  89. 'data_cost: {dtime}',
  90. 'ips: {ips:.4f} images/s',
  91. ])
  92. fmt = fmt.format(
  93. epoch_id,
  94. step_id,
  95. steps_per_epoch,
  96. lr=status['learning_rate'],
  97. meters=logs,
  98. eta=eta_str,
  99. btime=str(batch_time),
  100. dtime=str(data_time),
  101. ips=ips)
  102. logger.info(fmt)
  103. if mode == 'eval':
  104. step_id = status['step_id']
  105. if step_id % 100 == 0:
  106. logger.info("Eval iter: {}".format(step_id))
  107. def on_epoch_end(self, status):
  108. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  109. mode = status['mode']
  110. if mode == 'eval':
  111. sample_num = status['sample_num']
  112. cost_time = status['cost_time']
  113. logger.info('Total sample number: {}, averge FPS: {}'.format(
  114. sample_num, sample_num / cost_time))
  115. class Checkpointer(Callback):
  116. def __init__(self, model):
  117. super(Checkpointer, self).__init__(model)
  118. cfg = self.model.cfg
  119. self.best_ap = 0.
  120. self.save_dir = os.path.join(self.model.cfg.save_dir,
  121. self.model.cfg.filename)
  122. if hasattr(self.model.model, 'student_model'):
  123. self.weight = self.model.model.student_model
  124. else:
  125. self.weight = self.model.model
  126. def on_epoch_end(self, status):
  127. # Checkpointer only performed during training
  128. mode = status['mode']
  129. epoch_id = status['epoch_id']
  130. weight = None
  131. save_name = None
  132. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  133. if mode == 'train':
  134. end_epoch = self.model.cfg.epoch
  135. if (
  136. epoch_id + 1
  137. ) % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
  138. save_name = str(
  139. epoch_id
  140. ) if epoch_id != end_epoch - 1 else "model_final"
  141. weight = self.weight
  142. elif mode == 'eval':
  143. if 'save_best_model' in status and status['save_best_model']:
  144. for metric in self.model._metrics:
  145. map_res = metric.get_results()
  146. if 'bbox' in map_res:
  147. key = 'bbox'
  148. elif 'keypoint' in map_res:
  149. key = 'keypoint'
  150. else:
  151. key = 'mask'
  152. if key not in map_res:
  153. logger.warn("Evaluation results empty, this may be due to " \
  154. "training iterations being too few or not " \
  155. "loading the correct weights.")
  156. return
  157. if map_res[key][0] > self.best_ap:
  158. self.best_ap = map_res[key][0]
  159. save_name = 'best_model'
  160. weight = self.weight
  161. logger.info("Best test {} ap is {:0.3f}.".format(
  162. key, self.best_ap))
  163. if weight:
  164. save_model(weight, self.model.optimizer, self.save_dir,
  165. save_name, epoch_id + 1)
  166. class WiferFaceEval(Callback):
  167. def __init__(self, model):
  168. super(WiferFaceEval, self).__init__(model)
  169. def on_epoch_begin(self, status):
  170. assert self.model.mode == 'eval', \
  171. "WiferFaceEval can only be set during evaluation"
  172. for metric in self.model._metrics:
  173. metric.update(self.model.model)
  174. sys.exit()
  175. class VisualDLWriter(Callback):
  176. """
  177. Use VisualDL to log data or image
  178. """
  179. def __init__(self, model):
  180. super(VisualDLWriter, self).__init__(model)
  181. assert six.PY3, "VisualDL requires Python >= 3.5"
  182. try:
  183. from visualdl import LogWriter
  184. except Exception as e:
  185. logger.error('visualdl not found, plaese install visualdl. '
  186. 'for example: `pip install visualdl`.')
  187. raise e
  188. self.vdl_writer = LogWriter(
  189. model.cfg.get('vdl_log_dir', 'vdl_log_dir/scalar'))
  190. self.vdl_loss_step = 0
  191. self.vdl_mAP_step = 0
  192. self.vdl_image_step = 0
  193. self.vdl_image_frame = 0
  194. def on_step_end(self, status):
  195. mode = status['mode']
  196. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  197. if mode == 'train':
  198. training_staus = status['training_staus']
  199. for loss_name, loss_value in training_staus.get().items():
  200. self.vdl_writer.add_scalar(loss_name, loss_value,
  201. self.vdl_loss_step)
  202. self.vdl_loss_step += 1
  203. elif mode == 'test':
  204. ori_image = status['original_image']
  205. result_image = status['result_image']
  206. self.vdl_writer.add_image(
  207. "original/frame_{}".format(self.vdl_image_frame),
  208. ori_image, self.vdl_image_step)
  209. self.vdl_writer.add_image(
  210. "result/frame_{}".format(self.vdl_image_frame),
  211. result_image, self.vdl_image_step)
  212. self.vdl_image_step += 1
  213. # each frame can display ten pictures at most.
  214. if self.vdl_image_step % 10 == 0:
  215. self.vdl_image_step = 0
  216. self.vdl_image_frame += 1
  217. def on_epoch_end(self, status):
  218. mode = status['mode']
  219. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  220. if mode == 'eval':
  221. for metric in self.model._metrics:
  222. for key, map_value in metric.get_results().items():
  223. self.vdl_writer.add_scalar("{}-mAP".format(key),
  224. map_value[0],
  225. self.vdl_mAP_step)
  226. self.vdl_mAP_step += 1