save_load.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import errno
  18. import os
  19. import re
  20. import shutil
  21. import tempfile
  22. import paddle
  23. from paddlex.ppcls.utils import logger
  24. from .download import get_weights_path_from_url
  25. __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
  26. def _mkdir_if_not_exist(path):
  27. """
  28. mkdir if not exists, ignore the exception when multiprocess mkdir together
  29. """
  30. if not os.path.exists(path):
  31. try:
  32. os.makedirs(path)
  33. except OSError as e:
  34. if e.errno == errno.EEXIST and os.path.isdir(path):
  35. logger.warning(
  36. 'be happy if some process has already created {}'.format(
  37. path))
  38. else:
  39. raise OSError('Failed to mkdir {}'.format(path))
  40. def load_dygraph_pretrain(model, path=None):
  41. if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
  42. raise ValueError("Model pretrain path {} does not "
  43. "exists.".format(path))
  44. param_state_dict = paddle.load(path + ".pdparams")
  45. model.set_dict(param_state_dict)
  46. return
  47. def load_dygraph_pretrain_from_url(model, pretrained_url, use_ssld=False):
  48. if use_ssld:
  49. pretrained_url = pretrained_url.replace("_pretrained",
  50. "_ssld_pretrained")
  51. local_weight_path = get_weights_path_from_url(pretrained_url).replace(
  52. ".pdparams", "")
  53. load_dygraph_pretrain(model, path=local_weight_path)
  54. return
  55. def load_distillation_model(model, pretrained_model):
  56. logger.info("In distillation mode, teacher model will be "
  57. "loaded firstly before student model.")
  58. if not isinstance(pretrained_model, list):
  59. pretrained_model = [pretrained_model]
  60. teacher = model.teacher if hasattr(model,
  61. "teacher") else model._layers.teacher
  62. student = model.student if hasattr(model,
  63. "student") else model._layers.student
  64. load_dygraph_pretrain(teacher, path=pretrained_model[0])
  65. logger.info("Finish initing teacher model from {}".format(
  66. pretrained_model))
  67. # load student model
  68. if len(pretrained_model) >= 2:
  69. load_dygraph_pretrain(student, path=pretrained_model[1])
  70. logger.info("Finish initing student model from {}".format(
  71. pretrained_model))
  72. def init_model(config, net, optimizer=None):
  73. """
  74. load model from checkpoint or pretrained_model
  75. """
  76. checkpoints = config.get('checkpoints')
  77. if checkpoints and optimizer is not None:
  78. assert os.path.exists(checkpoints + ".pdparams"), \
  79. "Given dir {}.pdparams not exist.".format(checkpoints)
  80. assert os.path.exists(checkpoints + ".pdopt"), \
  81. "Given dir {}.pdopt not exist.".format(checkpoints)
  82. para_dict = paddle.load(checkpoints + ".pdparams")
  83. opti_dict = paddle.load(checkpoints + ".pdopt")
  84. metric_dict = paddle.load(checkpoints + ".pdstates")
  85. net.set_dict(para_dict)
  86. optimizer.set_state_dict(opti_dict)
  87. logger.info("Finish load checkpoints from {}".format(checkpoints))
  88. return metric_dict
  89. pretrained_model = config.get('pretrained_model')
  90. use_distillation = config.get('use_distillation', False)
  91. if pretrained_model:
  92. if use_distillation:
  93. load_distillation_model(net, pretrained_model)
  94. else: # common load
  95. load_dygraph_pretrain(net, path=pretrained_model)
  96. logger.info(
  97. logger.coloring("Finish load pretrained model from {}".format(
  98. pretrained_model), "HEADER"))
  99. def save_model(net,
  100. optimizer,
  101. metric_info,
  102. model_path,
  103. model_name="",
  104. prefix='ppcls'):
  105. """
  106. save model to the target path
  107. """
  108. if paddle.distributed.get_rank() != 0:
  109. return
  110. model_path = os.path.join(model_path, model_name)
  111. _mkdir_if_not_exist(model_path)
  112. model_path = os.path.join(model_path, prefix)
  113. paddle.save(net.state_dict(), model_path + ".pdparams")
  114. paddle.save(optimizer.state_dict(), model_path + ".pdopt")
  115. paddle.save(metric_info, model_path + ".pdstates")
  116. logger.info("Already save model in {}".format(model_path))