utils.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import time
  16. import os
  17. import os.path as osp
  18. import numpy as np
  19. import six
  20. import yaml
  21. import math
  22. from . import logging
  23. def seconds_to_hms(seconds):
  24. h = math.floor(seconds / 3600)
  25. m = math.floor((seconds - h * 3600) / 60)
  26. s = int(seconds - h * 3600 - m * 60)
  27. hms_str = "{}:{}:{}".format(h, m, s)
  28. return hms_str
  29. def get_environ_info():
  30. import paddle.fluid as fluid
  31. info = dict()
  32. info['place'] = 'cpu'
  33. info['num'] = int(os.environ.get('CPU_NUM', 1))
  34. if os.environ.get('CUDA_VISIBLE_DEVICES', None) != "":
  35. if hasattr(fluid.core, 'get_cuda_device_count'):
  36. gpu_num = 0
  37. try:
  38. gpu_num = fluid.core.get_cuda_device_count()
  39. except:
  40. os.environ['CUDA_VISIBLE_DEVICES'] = ''
  41. pass
  42. if gpu_num > 0:
  43. info['place'] = 'cuda'
  44. info['num'] = fluid.core.get_cuda_device_count()
  45. return info
  46. def parse_param_file(param_file, return_shape=True):
  47. from paddle.fluid.proto.framework_pb2 import VarType
  48. f = open(param_file, 'rb')
  49. version = np.fromstring(f.read(4), dtype='int32')
  50. lod_level = np.fromstring(f.read(8), dtype='int64')
  51. for i in range(int(lod_level)):
  52. _size = np.fromstring(f.read(8), dtype='int64')
  53. _ = f.read(_size)
  54. version = np.fromstring(f.read(4), dtype='int32')
  55. tensor_desc = VarType.TensorDesc()
  56. tensor_desc_size = np.fromstring(f.read(4), dtype='int32')
  57. tensor_desc.ParseFromString(f.read(int(tensor_desc_size)))
  58. tensor_shape = tuple(tensor_desc.dims)
  59. if return_shape:
  60. f.close()
  61. return tuple(tensor_desc.dims)
  62. if tensor_desc.data_type != 5:
  63. raise Exception(
  64. "Unexpected data type while parse {}".format(param_file))
  65. data_size = 4
  66. for i in range(len(tensor_shape)):
  67. data_size *= tensor_shape[i]
  68. weight = np.fromstring(f.read(data_size), dtype='float32')
  69. f.close()
  70. return np.reshape(weight, tensor_shape)
  71. def fuse_bn_weights(exe, main_prog, weights_dir):
  72. import paddle.fluid as fluid
  73. logging.info("Try to fuse weights of batch_norm...")
  74. bn_vars = list()
  75. for block in main_prog.blocks:
  76. ops = list(block.ops)
  77. for op in ops:
  78. if op.type == 'affine_channel':
  79. scale_name = op.input('Scale')[0]
  80. bias_name = op.input('Bias')[0]
  81. prefix = scale_name[:-5]
  82. mean_name = prefix + 'mean'
  83. variance_name = prefix + 'variance'
  84. if not osp.exists(osp.join(
  85. weights_dir, mean_name)) or not osp.exists(
  86. osp.join(weights_dir, variance_name)):
  87. logging.info(
  88. "There's no batch_norm weight found to fuse, skip fuse_bn."
  89. )
  90. return
  91. bias = block.var(bias_name)
  92. pretrained_shape = parse_param_file(
  93. osp.join(weights_dir, bias_name))
  94. actual_shape = tuple(bias.shape)
  95. if pretrained_shape != actual_shape:
  96. continue
  97. bn_vars.append(
  98. [scale_name, bias_name, mean_name, variance_name])
  99. eps = 1e-5
  100. for names in bn_vars:
  101. scale_name, bias_name, mean_name, variance_name = names
  102. scale = parse_param_file(
  103. osp.join(weights_dir, scale_name), return_shape=False)
  104. bias = parse_param_file(
  105. osp.join(weights_dir, bias_name), return_shape=False)
  106. mean = parse_param_file(
  107. osp.join(weights_dir, mean_name), return_shape=False)
  108. variance = parse_param_file(
  109. osp.join(weights_dir, variance_name), return_shape=False)
  110. bn_std = np.sqrt(np.add(variance, eps))
  111. new_scale = np.float32(np.divide(scale, bn_std))
  112. new_bias = bias - mean * new_scale
  113. scale_tensor = fluid.global_scope().find_var(scale_name).get_tensor()
  114. bias_tensor = fluid.global_scope().find_var(bias_name).get_tensor()
  115. scale_tensor.set(new_scale, exe.place)
  116. bias_tensor.set(new_bias, exe.place)
  117. if len(bn_vars) == 0:
  118. logging.info(
  119. "There's no batch_norm weight found to fuse, skip fuse_bn.")
  120. else:
  121. logging.info("There's {} batch_norm ops been fused.".format(
  122. len(bn_vars)))
  123. def load_pdparams(exe, main_prog, model_dir):
  124. import paddle.fluid as fluid
  125. from paddle.fluid.proto.framework_pb2 import VarType
  126. from paddle.fluid.framework import Program
  127. vars_to_load = list()
  128. import pickle
  129. with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
  130. params_dict = pickle.load(f) if six.PY2 else pickle.load(
  131. f, encoding='latin1')
  132. unused_vars = list()
  133. for var in main_prog.list_vars():
  134. if not isinstance(var, fluid.framework.Parameter):
  135. continue
  136. if var.name not in params_dict:
  137. raise Exception("{} is not in saved paddlex model".format(
  138. var.name))
  139. if var.shape != params_dict[var.name].shape:
  140. unused_vars.append(var.name)
  141. logging.warning(
  142. "[SKIP] Shape of pretrained weight {} doesn't match.(Pretrained: {}, Actual: {})"
  143. .format(var.name, params_dict[var.name].shape, var.shape))
  144. continue
  145. vars_to_load.append(var)
  146. logging.debug("Weight {} will be load".format(var.name))
  147. for var_name in unused_vars:
  148. del params_dict[var_name]
  149. fluid.io.set_program_state(main_prog, params_dict)
  150. if len(vars_to_load) == 0:
  151. logging.warning(
  152. "There is no pretrain weights loaded, maybe you should check you pretrain model!"
  153. )
  154. else:
  155. logging.info("There are {} varaibles in {} are loaded.".format(
  156. len(vars_to_load), model_dir))
  157. def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
  158. if not osp.exists(weights_dir):
  159. raise Exception("Path {} not exists.".format(weights_dir))
  160. if osp.exists(osp.join(weights_dir, "model.pdparams")):
  161. return load_pdparams(exe, main_prog, weights_dir)
  162. import paddle.fluid as fluid
  163. vars_to_load = list()
  164. for var in main_prog.list_vars():
  165. if not isinstance(var, fluid.framework.Parameter):
  166. continue
  167. if not osp.exists(osp.join(weights_dir, var.name)):
  168. logging.debug(
  169. "[SKIP] Pretrained weight {}/{} doesn't exist".format(
  170. weights_dir, var.name))
  171. continue
  172. pretrained_shape = parse_param_file(osp.join(weights_dir, var.name))
  173. actual_shape = tuple(var.shape)
  174. if pretrained_shape != actual_shape:
  175. logging.warning(
  176. "[SKIP] Shape of pretrained weight {}/{} doesn't match.(Pretrained: {}, Actual: {})"
  177. .format(weights_dir, var.name, pretrained_shape, actual_shape))
  178. continue
  179. vars_to_load.append(var)
  180. logging.debug("Weight {} will be load".format(var.name))
  181. fluid.io.load_vars(
  182. executor=exe,
  183. dirname=weights_dir,
  184. main_program=main_prog,
  185. vars=vars_to_load)
  186. if len(vars_to_load) == 0:
  187. logging.warning(
  188. "There is no pretrain weights loaded, maybe you should check you pretrain model!"
  189. )
  190. else:
  191. logging.info("There are {} varaibles in {} are loaded.".format(
  192. len(vars_to_load), weights_dir))
  193. if fuse_bn:
  194. fuse_bn_weights(exe, main_prog, weights_dir)
  195. class EarlyStop:
  196. def __init__(self, patience, thresh):
  197. self.patience = patience
  198. self.counter = 0
  199. self.score = None
  200. self.max = 0
  201. self.thresh = thresh
  202. if patience < 1:
  203. raise Exception("Argument patience should be a positive integer.")
  204. def __call__(self, current_score):
  205. if self.score is None:
  206. self.score = current_score
  207. return False
  208. elif current_score > self.max:
  209. self.counter = 0
  210. self.score = current_score
  211. self.max = current_score
  212. return False
  213. else:
  214. if (abs(self.score - current_score) < self.thresh
  215. or current_score < self.score):
  216. self.counter += 1
  217. self.score = current_score
  218. logging.debug(
  219. "EarlyStopping: %i / %i" % (self.counter, self.patience))
  220. if self.counter >= self.patience:
  221. logging.info("EarlyStopping: Stop training")
  222. return True
  223. return False
  224. else:
  225. self.counter = 0
  226. self.score = current_score
  227. return False