load_model.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import yaml
  16. import paddle
  17. import paddleslim
  18. import paddlex
  19. import paddlex.utils.logging as logging
  20. from paddlex.cv.transforms import build_transforms
  21. def load_model(model_dir):
  22. """
  23. Load saved model from a given directory.
  24. Args:
  25. model_dir(str): The directory where the model is saved.
  26. Returns:
  27. The model loaded from the directory.
  28. """
  29. if not osp.exists(model_dir):
  30. logging.error("model_dir '{}' does not exists!".format(model_dir))
  31. if not osp.exists(osp.join(model_dir, "model.yml")):
  32. raise Exception("There's no model.yml in {}".format(model_dir))
  33. with open(osp.join(model_dir, "model.yml")) as f:
  34. model_info = yaml.load(f.read(), Loader=yaml.Loader)
  35. f.close()
  36. version = model_info['version']
  37. if int(version.split('.')[0]) < 2:
  38. raise Exception(
  39. 'Current version is {}, a model trained by PaddleX={} cannot be load.'.
  40. format(paddlex.__version__, version))
  41. status = model_info['status']
  42. if not hasattr(paddlex.cv.models, model_info['Model']):
  43. raise Exception("There's no attribute {} in paddlex.cv.models".format(
  44. model_info['Model']))
  45. if 'model_name' in model_info['_init_params']:
  46. del model_info['_init_params']['model_name']
  47. with paddle.utils.unique_name.guard():
  48. model = getattr(paddlex.cv.models, model_info['Model'])(
  49. **model_info['_init_params'])
  50. if 'Transforms' in model_info:
  51. model.test_transforms = build_transforms(model_info['Transforms'])
  52. if '_Attributes' in model_info:
  53. for k, v in model_info['_Attributes'].items():
  54. if k in model.__dict__:
  55. model.__dict__[k] = v
  56. if status == 'Pruned' or osp.exists(osp.join(model_dir, "prune.yml")):
  57. with open(osp.join(model_dir, "prune.yml")) as f:
  58. pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
  59. inputs = pruning_info['pruner_inputs']
  60. model.pruner = getattr(paddleslim, pruning_info['pruner'])(
  61. model.net, inputs=inputs)
  62. model.pruning_ratios = pruning_info['pruning_ratios']
  63. model.pruner.prune_vars(
  64. ratios=model.pruning_ratios,
  65. axis=paddleslim.dygraph.prune.filter_pruner.FILTER_DIM)
  66. if status == 'Infer':
  67. if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
  68. #net_state_dict = paddle.load(
  69. # model_dir,
  70. # params_filename='model.pdiparams',
  71. # model_filename='model.pdmodel')
  72. net = paddle.jit.load(osp.join(model_dir, 'model'))
  73. #load_param_dict = paddle.load(osp.join(model_dir, 'model.pdiparams'))
  74. #print(load_param_dict)
  75. import pickle
  76. var_info_path = osp.join(model_dir, 'model.pdiparams.info')
  77. with open(var_info_path, 'rb') as f:
  78. extra_var_info = pickle.load(f)
  79. net_state_dict = dict()
  80. static_state_dict = dict()
  81. for name, var in net.state_dict().items():
  82. print(name, var.name)
  83. static_state_dict[var.name] = var.numpy()
  84. exit()
  85. for var_name in static_state_dict:
  86. if var_name not in extra_var_info:
  87. print(var_name)
  88. continue
  89. structured_name = extra_var_info[var_name].get(
  90. 'structured_name', None)
  91. if structured_name is None:
  92. continue
  93. net_state_dict[structured_name] = static_state_dict[
  94. var_name]
  95. #model.net = paddle.jit.load(
  96. # model_dir,
  97. # params_filename='model.pdiparams',
  98. # model_filename='model.pdmodel')
  99. #net_state_dict = paddle.load(osp.join(model_dir, 'model'))
  100. else:
  101. net_state_dict = paddle.load(osp.join(model_dir, 'model'))
  102. else:
  103. net_state_dict = paddle.load(osp.join(model_dir, 'model.pdparams'))
  104. model.net.set_state_dict(net_state_dict)
  105. logging.info("Model[{}] loaded.".format(model_info['Model']))
  106. model.status = status
  107. return model