progbar.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import time
  17. import numpy as np
  18. class Progbar(object):
  19. """
  20. Displays a progress bar.
  21. It refers to https://github.com/keras-team/keras/blob/keras-2/keras/utils/generic_utils.py
  22. Args:
  23. target (int): Total number of steps expected, None if unknown.
  24. width (int): Progress bar width on screen.
  25. verbose (int): Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
  26. stateful_metrics (list|tuple): Iterable of string names of metrics that should *not* be
  27. averaged over time. Metrics in this list will be displayed as-is. All
  28. others will be averaged by the progbar before display.
  29. interval (float): Minimum visual progress update interval (in seconds).
  30. unit_name (str): Display name for step counts (usually "step" or "sample").
  31. """
  32. def __init__(self,
  33. target,
  34. width=30,
  35. verbose=1,
  36. interval=0.05,
  37. stateful_metrics=None,
  38. unit_name='step'):
  39. self.target = target
  40. self.width = width
  41. self.verbose = verbose
  42. self.interval = interval
  43. self.unit_name = unit_name
  44. if stateful_metrics:
  45. self.stateful_metrics = set(stateful_metrics)
  46. else:
  47. self.stateful_metrics = set()
  48. self._dynamic_display = (
  49. (hasattr(sys.stderr, 'isatty') and
  50. sys.stderr.isatty()) or 'ipykernel' in sys.modules or
  51. 'posix' in sys.modules or 'PYCHARM_HOSTED' in os.environ)
  52. self._total_width = 0
  53. self._seen_so_far = 0
  54. # We use a dict + list to avoid garbage collection
  55. # issues found in OrderedDict
  56. self._values = {}
  57. self._values_order = []
  58. self._start = time.time()
  59. self._last_update = 0
  60. def update(self, current, values=None, finalize=None):
  61. """
  62. Updates the progress bar.
  63. Args:
  64. current (int): Index of current step.
  65. values (list): List of tuples: `(name, value_for_last_step)`. If `name` is in
  66. `stateful_metrics`, `value_for_last_step` will be displayed as-is.
  67. Else, an average of the metric over time will be displayed.
  68. finalize (bool): Whether this is the last update for the progress bar. If
  69. `None`, defaults to `current >= self.target`.
  70. """
  71. if finalize is None:
  72. if self.target is None:
  73. finalize = False
  74. else:
  75. finalize = current >= self.target
  76. values = values or []
  77. for k, v in values:
  78. if k not in self._values_order:
  79. self._values_order.append(k)
  80. if k not in self.stateful_metrics:
  81. # In the case that progress bar doesn't have a target value in the first
  82. # epoch, both on_batch_end and on_epoch_end will be called, which will
  83. # cause 'current' and 'self._seen_so_far' to have the same value. Force
  84. # the minimal value to 1 here, otherwise stateful_metric will be 0s.
  85. value_base = max(current - self._seen_so_far, 1)
  86. if k not in self._values:
  87. self._values[k] = [v * value_base, value_base]
  88. else:
  89. self._values[k][0] += v * value_base
  90. self._values[k][1] += value_base
  91. else:
  92. # Stateful metrics output a numeric value. This representation
  93. # means "take an average from a single value" but keeps the
  94. # numeric formatting.
  95. self._values[k] = [v, 1]
  96. self._seen_so_far = current
  97. now = time.time()
  98. info = ' - %.0fs' % (now - self._start)
  99. if self.verbose == 1:
  100. if now - self._last_update < self.interval and not finalize:
  101. return
  102. prev_total_width = self._total_width
  103. if self._dynamic_display:
  104. sys.stderr.write('\b' * prev_total_width)
  105. sys.stderr.write('\r')
  106. else:
  107. sys.stderr.write('\n')
  108. if self.target is not None:
  109. numdigits = int(np.log10(self.target)) + 1
  110. bar = ('%' + str(numdigits) + 'd/%d [') % (current,
  111. self.target)
  112. prog = float(current) / self.target
  113. prog_width = int(self.width * prog)
  114. if prog_width > 0:
  115. bar += ('=' * (prog_width - 1))
  116. if current < self.target:
  117. bar += '>'
  118. else:
  119. bar += '='
  120. bar += ('.' * (self.width - prog_width))
  121. bar += ']'
  122. else:
  123. bar = '%7d/Unknown' % current
  124. self._total_width = len(bar)
  125. sys.stderr.write(bar)
  126. if current:
  127. time_per_unit = (now - self._start) / current
  128. else:
  129. time_per_unit = 0
  130. if self.target is None or finalize:
  131. if time_per_unit >= 1 or time_per_unit == 0:
  132. info += ' %.0fs/%s' % (time_per_unit, self.unit_name)
  133. elif time_per_unit >= 1e-3:
  134. info += ' %.0fms/%s' % (time_per_unit * 1e3,
  135. self.unit_name)
  136. else:
  137. info += ' %.0fus/%s' % (time_per_unit * 1e6,
  138. self.unit_name)
  139. else:
  140. eta = time_per_unit * (self.target - current)
  141. if eta > 3600:
  142. eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) //
  143. 60, eta % 60)
  144. elif eta > 60:
  145. eta_format = '%d:%02d' % (eta // 60, eta % 60)
  146. else:
  147. eta_format = '%ds' % eta
  148. info = ' - ETA: %s' % eta_format
  149. for k in self._values_order:
  150. info += ' - %s:' % k
  151. if isinstance(self._values[k], list):
  152. avg = np.mean(self._values[k][0] /
  153. max(1, self._values[k][1]))
  154. if abs(avg) > 1e-3:
  155. info += ' %.4f' % avg
  156. else:
  157. info += ' %.4e' % avg
  158. else:
  159. info += ' %s' % self._values[k]
  160. self._total_width += len(info)
  161. if prev_total_width > self._total_width:
  162. info += (' ' * (prev_total_width - self._total_width))
  163. if finalize:
  164. info += '\n'
  165. sys.stderr.write(info)
  166. sys.stderr.flush()
  167. elif self.verbose == 2:
  168. if finalize:
  169. numdigits = int(np.log10(self.target)) + 1
  170. count = ('%' + str(numdigits) + 'd/%d') % (current,
  171. self.target)
  172. info = count + info
  173. for k in self._values_order:
  174. info += ' - %s:' % k
  175. avg = np.mean(self._values[k][0] /
  176. max(1, self._values[k][1]))
  177. if avg > 1e-3:
  178. info += ' %.4f' % avg
  179. else:
  180. info += ' %.4e' % avg
  181. info += '\n'
  182. sys.stderr.write(info)
  183. sys.stderr.flush()
  184. self._last_update = now
  185. def add(self, n, values=None):
  186. self.update(self._seen_so_far + n, values)