load_model.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import numpy as np
  16. import yaml
  17. import paddle
  18. import paddleslim
  19. import paddlex
  20. import paddlex.utils.logging as logging
  21. from paddlex.cv.transforms import build_transforms
  22. def load_rcnn_inference_model(model_dir):
  23. paddle.enable_static()
  24. exe = paddle.static.Executor(paddle.CPUPlace())
  25. path_prefix = osp.join(model_dir, "model")
  26. prog, _, _ = paddle.static.load_inference_model(path_prefix, exe)
  27. paddle.disable_static()
  28. extra_var_info = paddle.load(osp.join(model_dir, "model.pdiparams.info"))
  29. net_state_dict = dict()
  30. static_state_dict = dict()
  31. for name, var in prog.state_dict().items():
  32. static_state_dict[name] = np.array(var)
  33. for var_name in static_state_dict:
  34. if var_name not in extra_var_info:
  35. continue
  36. structured_name = extra_var_info[var_name].get('structured_name', None)
  37. if structured_name is None:
  38. continue
  39. net_state_dict[structured_name] = static_state_dict[var_name]
  40. return net_state_dict
  41. def load_model(model_dir, **params):
  42. """
  43. Load saved model from a given directory.
  44. Args:
  45. model_dir(str): The directory where the model is saved.
  46. Returns:
  47. The model loaded from the directory.
  48. """
  49. if not osp.exists(model_dir):
  50. logging.error("model_dir '{}' does not exists!".format(model_dir))
  51. if not osp.exists(osp.join(model_dir, "model.yml")):
  52. raise Exception("There's no model.yml in {}".format(model_dir))
  53. with open(osp.join(model_dir, "model.yml")) as f:
  54. model_info = yaml.load(f.read(), Loader=yaml.Loader)
  55. f.close()
  56. version = model_info['version']
  57. if int(version.split('.')[0]) < 2:
  58. raise Exception(
  59. 'Current version is {}, a model trained by PaddleX={} cannot be load.'.
  60. format(paddlex.__version__, version))
  61. status = model_info['status']
  62. with_net = params.get('with_net', True)
  63. if not with_net:
  64. assert status == 'Infer', \
  65. "Only exported inference models can be deployed, current model status is {}".format(status)
  66. if not hasattr(paddlex.cv.models, model_info['Model']):
  67. raise Exception("There's no attribute {} in paddlex.cv.models".format(
  68. model_info['Model']))
  69. if 'model_name' in model_info['_init_params']:
  70. del model_info['_init_params']['model_name']
  71. model_info['_init_params'].update({'with_net': with_net})
  72. with paddle.utils.unique_name.guard():
  73. model = getattr(paddlex.cv.models, model_info['Model'])(
  74. **model_info['_init_params'])
  75. if with_net:
  76. if status == 'Pruned' or osp.exists(
  77. osp.join(model_dir, "prune.yml")):
  78. with open(osp.join(model_dir, "prune.yml")) as f:
  79. pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
  80. inputs = pruning_info['pruner_inputs']
  81. if model.model_type == 'detector':
  82. inputs = [{
  83. k: paddle.to_tensor(v)
  84. for k, v in inputs.items()
  85. }]
  86. model.net.eval()
  87. model.pruner = getattr(paddleslim, pruning_info['pruner'])(
  88. model.net, inputs=inputs)
  89. model.pruning_ratios = pruning_info['pruning_ratios']
  90. model.pruner.prune_vars(
  91. ratios=model.pruning_ratios,
  92. axis=paddleslim.dygraph.prune.filter_pruner.FILTER_DIM)
  93. if status == 'Quantized' or osp.exists(
  94. osp.join(model_dir, "quant.yml")):
  95. with open(osp.join(model_dir, "quant.yml")) as f:
  96. quant_info = yaml.load(f.read(), Loader=yaml.Loader)
  97. model.quant_config = quant_info['quant_config']
  98. model.quantizer = paddleslim.QAT(model.quant_config)
  99. model.quantizer.quantize(model.net)
  100. if status == 'Infer':
  101. if osp.exists(osp.join(model_dir, "quant.yml")):
  102. logging.error(
  103. "Exported quantized model can not be loaded, only deployment is supported.",
  104. exit=True)
  105. model.net = model._build_inference_net()
  106. if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
  107. net_state_dict = load_rcnn_inference_model(model_dir)
  108. else:
  109. net_state_dict = paddle.load(osp.join(model_dir, 'model'))
  110. if model.model_type in ['classifier', 'segmenter'
  111. ] and 'rc' in version:
  112. # For PaddleX>=2.0.0, when exporting a classifier and segmenter,
  113. # InferNet is defined to append softmax and argmax operators to the model,
  114. # so parameter name starts with 'net.'
  115. new_net_state_dict = {}
  116. for k, v in net_state_dict.items():
  117. new_net_state_dict['net.' + k] = v
  118. net_state_dict = new_net_state_dict
  119. else:
  120. net_state_dict = paddle.load(
  121. osp.join(model_dir, 'model.pdparams'))
  122. model.net.set_state_dict(net_state_dict)
  123. if 'Transforms' in model_info:
  124. model.test_transforms = build_transforms(model_info['Transforms'])
  125. if '_Attributes' in model_info:
  126. for k, v in model_info['_Attributes'].items():
  127. if k in model.__dict__:
  128. model.__dict__[k] = v
  129. logging.info("Model[{}] loaded.".format(model_info['Model']))
  130. model.status = status
  131. return model