save_load.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import errno
  18. import os
  19. import re
  20. import shutil
  21. import tempfile
  22. import paddle
  23. from paddle.static import load_program_state
  24. from paddlex.ppcls.utils import logger
  25. __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
  26. def _mkdir_if_not_exist(path):
  27. """
  28. mkdir if not exists, ignore the exception when multiprocess mkdir together
  29. """
  30. if not os.path.exists(path):
  31. try:
  32. os.makedirs(path)
  33. except OSError as e:
  34. if e.errno == errno.EEXIST and os.path.isdir(path):
  35. logger.warning(
  36. 'be happy if some process has already created {}'.format(
  37. path))
  38. else:
  39. raise OSError('Failed to mkdir {}'.format(path))
  40. def load_dygraph_pretrain(model, path=None, load_static_weights=False):
  41. if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
  42. raise ValueError("Model pretrain path {} does not "
  43. "exists.".format(path))
  44. if load_static_weights:
  45. pre_state_dict = load_program_state(path)
  46. param_state_dict = {}
  47. model_dict = model.state_dict()
  48. for key in model_dict.keys():
  49. weight_name = model_dict[key].name
  50. if weight_name in pre_state_dict.keys():
  51. logger.info('Load weight: {}, shape: {}'.format(
  52. weight_name, pre_state_dict[weight_name].shape))
  53. param_state_dict[key] = pre_state_dict[weight_name]
  54. else:
  55. param_state_dict[key] = model_dict[key]
  56. model.set_dict(param_state_dict)
  57. return
  58. param_state_dict = paddle.load(path + ".pdparams")
  59. model.set_dict(param_state_dict)
  60. return
  61. def load_distillation_model(model, pretrained_model, load_static_weights):
  62. logger.info("In distillation mode, teacher model will be "
  63. "loaded firstly before student model.")
  64. if not isinstance(pretrained_model, list):
  65. pretrained_model = [pretrained_model]
  66. if not isinstance(load_static_weights, list):
  67. load_static_weights = [load_static_weights] * len(pretrained_model)
  68. teacher = model.teacher if hasattr(model,
  69. "teacher") else model._layers.teacher
  70. student = model.student if hasattr(model,
  71. "student") else model._layers.student
  72. load_dygraph_pretrain(
  73. teacher,
  74. path=pretrained_model[0],
  75. load_static_weights=load_static_weights[0])
  76. logger.info("Finish initing teacher model from {}".format(
  77. pretrained_model))
  78. # load student model
  79. if len(pretrained_model) >= 2:
  80. load_dygraph_pretrain(
  81. student,
  82. path=pretrained_model[1],
  83. load_static_weights=load_static_weights[1])
  84. logger.info("Finish initing student model from {}".format(
  85. pretrained_model))
  86. def init_model(config, net, optimizer=None):
  87. """
  88. load model from checkpoint or pretrained_model
  89. """
  90. checkpoints = config.get('checkpoints')
  91. if checkpoints and optimizer is not None:
  92. assert os.path.exists(checkpoints + ".pdparams"), \
  93. "Given dir {}.pdparams not exist.".format(checkpoints)
  94. assert os.path.exists(checkpoints + ".pdopt"), \
  95. "Given dir {}.pdopt not exist.".format(checkpoints)
  96. para_dict = paddle.load(checkpoints + ".pdparams")
  97. opti_dict = paddle.load(checkpoints + ".pdopt")
  98. net.set_dict(para_dict)
  99. optimizer.set_state_dict(opti_dict)
  100. logger.info("Finish load checkpoints from {}".format(checkpoints))
  101. return
  102. pretrained_model = config.get('pretrained_model')
  103. load_static_weights = config.get('load_static_weights', False)
  104. use_distillation = config.get('use_distillation', False)
  105. if pretrained_model:
  106. if use_distillation:
  107. load_distillation_model(net, pretrained_model, load_static_weights)
  108. else: # common load
  109. load_dygraph_pretrain(
  110. net,
  111. path=pretrained_model,
  112. load_static_weights=load_static_weights)
  113. logger.info(
  114. logger.coloring("Finish load pretrained model from {}".format(
  115. pretrained_model), "HEADER"))
  116. def _save_student_model(net, model_prefix):
  117. """
  118. save student model if the net is the network contains student
  119. """
  120. student_model_prefix = model_prefix + "_student.pdparams"
  121. if hasattr(net, "_layers"):
  122. net = net._layers
  123. if hasattr(net, "student"):
  124. paddle.save(net.student.state_dict(), student_model_prefix)
  125. logger.info("Already save student model in {}".format(
  126. student_model_prefix))
  127. def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
  128. """
  129. save model to the target path
  130. """
  131. if paddle.distributed.get_rank() != 0:
  132. return
  133. model_path = os.path.join(model_path, str(epoch_id))
  134. _mkdir_if_not_exist(model_path)
  135. model_prefix = os.path.join(model_path, prefix)
  136. _save_student_model(net, model_prefix)
  137. paddle.save(net.state_dict(), model_prefix + ".pdparams")
  138. paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
  139. logger.info("Already save model in {}".format(model_path))