load_model.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import numpy as np
  16. import yaml
  17. import paddle
  18. import paddleslim
  19. import paddlex
  20. import paddlex.utils.logging as logging
  21. from paddlex.cv.transforms import build_transforms
  22. def load_rcnn_inference_model(model_dir):
  23. paddle.enable_static()
  24. exe = paddle.static.Executor(paddle.CPUPlace())
  25. path_prefix = osp.join(model_dir, "model")
  26. prog, _, _ = paddle.static.load_inference_model(path_prefix, exe)
  27. paddle.disable_static()
  28. extra_var_info = paddle.load(osp.join(model_dir, "model.pdiparams.info"))
  29. net_state_dict = dict()
  30. static_state_dict = dict()
  31. for name, var in prog.state_dict().items():
  32. static_state_dict[name] = np.array(var)
  33. for var_name in static_state_dict:
  34. if var_name not in extra_var_info:
  35. continue
  36. structured_name = extra_var_info[var_name].get('structured_name', None)
  37. if structured_name is None:
  38. continue
  39. net_state_dict[structured_name] = static_state_dict[var_name]
  40. return net_state_dict
  41. def load_model(model_dir, **params):
  42. """
  43. Load saved model from a given directory.
  44. Args:
  45. model_dir(str): The directory where the model is saved.
  46. Returns:
  47. The model loaded from the directory.
  48. """
  49. if not osp.exists(model_dir):
  50. logging.error("model_dir '{}' does not exists!".format(model_dir))
  51. if not osp.exists(osp.join(model_dir, "model.yml")):
  52. raise Exception("There's no model.yml in {}".format(model_dir))
  53. with open(osp.join(model_dir, "model.yml")) as f:
  54. model_info = yaml.load(f.read(), Loader=yaml.Loader)
  55. f.close()
  56. version = model_info['version']
  57. if int(version.split('.')[0]) < 2:
  58. raise Exception(
  59. 'Current version is {}, a model trained by PaddleX={} cannot be load.'.
  60. format(paddlex.__version__, version))
  61. status = model_info['status']
  62. with_net = params.get('with_net', True)
  63. if not with_net:
  64. assert status == 'Infer', \
  65. "Only exported inference models can be deployed, current model status is {}".format(status)
  66. if not hasattr(paddlex.cv.models, model_info['Model']):
  67. raise Exception("There's no attribute {} in paddlex.cv.models".format(
  68. model_info['Model']))
  69. if 'model_name' in model_info['_init_params']:
  70. del model_info['_init_params']['model_name']
  71. model_info['_init_params'].update({'with_net': with_net})
  72. with paddle.utils.unique_name.guard():
  73. model = getattr(paddlex.cv.models, model_info['Model'])(
  74. **model_info['_init_params'])
  75. if with_net:
  76. if status == 'Pruned' or osp.exists(
  77. osp.join(model_dir, "prune.yml")):
  78. with open(osp.join(model_dir, "prune.yml")) as f:
  79. pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
  80. inputs = pruning_info['pruner_inputs']
  81. if model.model_type == 'detector':
  82. inputs = [{
  83. k: paddle.to_tensor(v)
  84. for k, v in inputs.items()
  85. }]
  86. model.net.eval()
  87. model.pruner = getattr(paddleslim, pruning_info['pruner'])(
  88. model.net, inputs=inputs)
  89. model.pruning_ratios = pruning_info['pruning_ratios']
  90. model.pruner.prune_vars(
  91. ratios=model.pruning_ratios,
  92. axis=paddleslim.dygraph.prune.filter_pruner.FILTER_DIM)
  93. if status == 'Quantized':
  94. with open(osp.join(model_dir, "quant.yml")) as f:
  95. quant_info = yaml.load(f.read(), Loader=yaml.Loader)
  96. model.quant_config = quant_info['quant_config']
  97. model.quantizer = paddleslim.QAT(model.quant_config)
  98. model.quantizer.quantize(model.net)
  99. if status == 'Infer':
  100. if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
  101. net_state_dict = load_rcnn_inference_model(model_dir)
  102. else:
  103. net_state_dict = paddle.load(osp.join(model_dir, 'model'))
  104. else:
  105. net_state_dict = paddle.load(
  106. osp.join(model_dir, 'model.pdparams'))
  107. model.net.set_state_dict(net_state_dict)
  108. if 'Transforms' in model_info:
  109. model.test_transforms = build_transforms(model_info['Transforms'])
  110. if '_Attributes' in model_info:
  111. for k, v in model_info['_Attributes'].items():
  112. if k in model.__dict__:
  113. model.__dict__[k] = v
  114. logging.info("Model[{}] loaded.".format(model_info['Model']))
  115. model.status = status
  116. return model