utils.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import os
  16. import math
  17. import chardet
  18. import json
  19. import numpy as np
  20. import paddlex
  21. from . import logging
  22. import platform
  23. def seconds_to_hms(seconds):
  24. h = math.floor(seconds / 3600)
  25. m = math.floor((seconds - h * 3600) / 60)
  26. s = int(seconds - h * 3600 - m * 60)
  27. hms_str = "{}:{}:{}".format(h, m, s)
  28. return hms_str
  29. def get_encoding(path):
  30. f = open(path, 'rb')
  31. data = f.read()
  32. file_encoding = chardet.detect(data).get('encoding')
  33. f.close()
  34. return file_encoding
  35. def get_single_card_bs(batch_size):
  36. card_num = paddlex.env_info['num']
  37. place = paddlex.env_info['place']
  38. if batch_size % card_num == 0:
  39. return int(batch_size // card_num)
  40. elif batch_size == 1:
  41. # Evaluation of detection task only supports single card with batch size 1
  42. return batch_size
  43. else:
  44. raise Exception("Please support correct batch_size, \
  45. which can be divided by available cards({}) in {}"
  46. .format(card_num, place))
  47. def dict2str(dict_input):
  48. out = ''
  49. for k, v in dict_input.items():
  50. try:
  51. v = '{:8.6f}'.format(float(v))
  52. except:
  53. pass
  54. out = out + '{}={}, '.format(k, v)
  55. return out.strip(', ')
  56. def path_normalization(path):
  57. win_sep = "\\"
  58. other_sep = "/"
  59. if platform.system() == "Windows":
  60. path = win_sep.join(path.split(other_sep))
  61. else:
  62. path = other_sep.join(path.split(win_sep))
  63. return path
  64. def is_pic(img_name):
  65. valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
  66. suffix = img_name.split('.')[-1]
  67. if suffix not in valid_suffix:
  68. return False
  69. return True
  70. class MyEncoder(json.JSONEncoder):
  71. def default(self, obj):
  72. if isinstance(obj, np.integer):
  73. return int(obj)
  74. elif isinstance(obj, np.floating):
  75. return float(obj)
  76. elif isinstance(obj, np.ndarray):
  77. return obj.tolist()
  78. else:
  79. return super(MyEncoder, self).default(obj)
  80. class EarlyStop:
  81. def __init__(self, patience, thresh):
  82. self.patience = patience
  83. self.counter = 0
  84. self.score = None
  85. self.max = 0
  86. self.thresh = thresh
  87. if patience < 1:
  88. raise Exception("Argument patience should be a positive integer.")
  89. def __call__(self, current_score):
  90. if self.score is None:
  91. self.score = current_score
  92. return False
  93. elif current_score > self.max:
  94. self.counter = 0
  95. self.score = current_score
  96. self.max = current_score
  97. return False
  98. else:
  99. if (abs(self.score - current_score) < self.thresh or
  100. current_score < self.score):
  101. self.counter += 1
  102. self.score = current_score
  103. logging.debug("EarlyStopping: %i / %i" %
  104. (self.counter, self.patience))
  105. if self.counter >= self.patience:
  106. logging.info("EarlyStopping: Stop training")
  107. return True
  108. return False
  109. else:
  110. self.counter = 0
  111. self.score = current_score
  112. return False
  113. class DisablePrint(object):
  114. def __enter__(self):
  115. self._original_stdout = sys.stdout
  116. sys.stdout = open(os.devnull, 'w')
  117. def __exit__(self, exc_type, exc_val, exc_tb):
  118. sys.stdout.close()
  119. sys.stdout = self._original_stdout