utils.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import os
  16. import time
  17. import math
  18. import chardet
  19. import json
  20. import numpy as np
  21. import paddlex
  22. from . import logging
  23. import platform
  24. def seconds_to_hms(seconds):
  25. h = math.floor(seconds / 3600)
  26. m = math.floor((seconds - h * 3600) / 60)
  27. s = int(seconds - h * 3600 - m * 60)
  28. hms_str = "{}:{}:{}".format(h, m, s)
  29. return hms_str
  30. def get_encoding(path):
  31. f = open(path, 'rb')
  32. data = f.read()
  33. file_encoding = chardet.detect(data).get('encoding')
  34. f.close()
  35. return file_encoding
  36. def get_single_card_bs(batch_size):
  37. card_num = paddlex.env_info['num']
  38. place = paddlex.env_info['place']
  39. if batch_size % card_num == 0:
  40. return int(batch_size // card_num)
  41. elif batch_size == 1:
  42. # Evaluation of detection task only supports single card with batch size 1
  43. return batch_size
  44. else:
  45. raise Exception("Please support correct batch_size, \
  46. which can be divided by available cards({}) in {}"
  47. .format(card_num, place))
  48. def dict2str(dict_input):
  49. out = ''
  50. for k, v in dict_input.items():
  51. try:
  52. v = '{:8.6f}'.format(float(v))
  53. except:
  54. pass
  55. out = out + '{}={}, '.format(k, v)
  56. return out.strip(', ')
  57. def path_normalization(path):
  58. win_sep = "\\"
  59. other_sep = "/"
  60. if platform.system() == "Windows":
  61. path = win_sep.join(path.split(other_sep))
  62. else:
  63. path = other_sep.join(path.split(win_sep))
  64. return path
  65. def is_pic(img_name):
  66. valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
  67. suffix = img_name.split('.')[-1]
  68. if suffix not in valid_suffix:
  69. return False
  70. return True
  71. class MyEncoder(json.JSONEncoder):
  72. def default(self, obj):
  73. if isinstance(obj, np.integer):
  74. return int(obj)
  75. elif isinstance(obj, np.floating):
  76. return float(obj)
  77. elif isinstance(obj, np.ndarray):
  78. return obj.tolist()
  79. else:
  80. return super(MyEncoder, self).default(obj)
  81. class EarlyStop:
  82. def __init__(self, patience, thresh):
  83. self.patience = patience
  84. self.counter = 0
  85. self.score = None
  86. self.max = 0
  87. self.thresh = thresh
  88. if patience < 1:
  89. raise Exception("Argument patience should be a positive integer.")
  90. def __call__(self, current_score):
  91. if self.score is None:
  92. self.score = current_score
  93. return False
  94. elif current_score > self.max:
  95. self.counter = 0
  96. self.score = current_score
  97. self.max = current_score
  98. return False
  99. else:
  100. if (abs(self.score - current_score) < self.thresh or
  101. current_score < self.score):
  102. self.counter += 1
  103. self.score = current_score
  104. logging.debug("EarlyStopping: %i / %i" %
  105. (self.counter, self.patience))
  106. if self.counter >= self.patience:
  107. logging.info("EarlyStopping: Stop training")
  108. return True
  109. return False
  110. else:
  111. self.counter = 0
  112. self.score = current_score
  113. return False
  114. class DisablePrint(object):
  115. def __enter__(self):
  116. self._original_stdout = sys.stdout
  117. sys.stdout = open(os.devnull, 'w')
  118. def __exit__(self, exc_type, exc_val, exc_tb):
  119. sys.stdout.close()
  120. sys.stdout = self._original_stdout
  121. class Times(object):
  122. def __init__(self):
  123. self.time = 0.
  124. # start time
  125. self.st = 0.
  126. # end time
  127. self.et = 0.
  128. def start(self):
  129. self.st = time.time()
  130. def end(self, iter_num=1, accumulative=True):
  131. self.et = time.time()
  132. if accumulative:
  133. self.time += (self.et - self.st) / iter_num
  134. else:
  135. self.time = (self.et - self.st) / iter_num
  136. def reset(self):
  137. self.time = 0.
  138. self.st = 0.
  139. self.et = 0.
  140. def value(self):
  141. return round(self.time, 4)
  142. class Timer(Times):
  143. def __init__(self):
  144. super(Timer, self).__init__()
  145. self.preprocess_time_s = Times()
  146. self.inference_time_s = Times()
  147. self.postprocess_time_s = Times()
  148. self.img_num = 0
  149. self.repeats = 0
  150. def info(self, average=False):
  151. total_time = self.preprocess_time_s.value(
  152. ) * self.img_num + self.inference_time_s.value(
  153. ) + self.postprocess_time_s.value() * self.img_num
  154. total_time = round(total_time, 4)
  155. logging.info("------------------ Inference Time Info ----------------------")
  156. logging.info("total_time(ms): {}, img_num: {}, batch_size: {}".format(
  157. total_time * 1000, self.img_num, self.img_num))
  158. preprocess_time = round(
  159. self.preprocess_time_s.value() / self.repeats,
  160. 4) if average else self.preprocess_time_s.value()
  161. postprocess_time = round(
  162. self.postprocess_time_s.value() / self.repeats,
  163. 4) if average else self.postprocess_time_s.value()
  164. inference_time = round(self.inference_time_s.value() / self.repeats,
  165. 4) if average else self.inference_time_s.value()
  166. average_latency = total_time / self.repeats
  167. logging.info("average latency time(ms): {:.2f}, QPS: {:2f}".format(
  168. average_latency * 1000, 1 / average_latency))
  169. logging.info("preprocess_time_per_im(ms): {:.2f}, "
  170. "inference_time_per_batch(ms): {:.2f}, "
  171. "postprocess_time_per_im(ms): {:.2f}".format(
  172. preprocess_time * 1000, inference_time * 1000,
  173. postprocess_time * 1000))
  174. def report(self, average=False):
  175. dic = {}
  176. dic['preprocess_time_s'] = round(
  177. self.preprocess_time_s.value() / self.repeats,
  178. 4) if average else self.preprocess_time_s.value()
  179. dic['postprocess_time_s'] = round(
  180. self.postprocess_time_s.value() / self.repeats,
  181. 4) if average else self.postprocess_time_s.value()
  182. dic['inference_time_s'] = round(
  183. self.inference_time_s.value() / self.repeats,
  184. 4) if average else self.inference_time_s.value()
  185. dic['img_num'] = self.img_num
  186. total_time = self.preprocess_time_s.value(
  187. ) + self.inference_time_s.value() + self.postprocess_time_s.value()
  188. dic['total_time_s'] = round(total_time, 4)
  189. dic['batch_size'] = self.img_num / self.repeats
  190. return dic
  191. def reset(self):
  192. self.preprocess_time_s.reset()
  193. self.inference_time_s.reset()
  194. self.postprocess_time_s.reset()
  195. self.img_num = 0
  196. self.repeats = 0