load_model.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import yaml
  15. import os.path as osp
  16. import six
  17. import copy
  18. from collections import OrderedDict
  19. import paddle.fluid as fluid
  20. from paddle.fluid.framework import Parameter
  21. import paddlex
  22. import paddlex.utils.logging as logging
  23. def load_model(model_dir, fixed_input_shape=None):
  24. if not osp.exists(osp.join(model_dir, "model.yml")):
  25. raise Exception("There's not model.yml in {}".format(model_dir))
  26. with open(osp.join(model_dir, "model.yml")) as f:
  27. info = yaml.load(f.read(), Loader=yaml.Loader)
  28. if 'status' in info:
  29. status = info['status']
  30. elif 'save_method' in info:
  31. # 兼容老版本PaddleX
  32. status = info['save_method']
  33. if not hasattr(paddlex.cv.models, info['Model']):
  34. raise Exception("There's no attribute {} in paddlex.cv.models".format(
  35. info['Model']))
  36. info['_init_params']['fixed_input_shape'] = fixed_input_shape
  37. if info['_Attributes']['model_type'] == 'classifier':
  38. model = paddlex.cv.models.BaseClassifier(**info['_init_params'])
  39. else:
  40. model = getattr(paddlex.cv.models,
  41. info['Model'])(**info['_init_params'])
  42. if status == "Normal" or \
  43. status == "Prune" or status == "fluid.save":
  44. startup_prog = fluid.Program()
  45. model.test_prog = fluid.Program()
  46. with fluid.program_guard(model.test_prog, startup_prog):
  47. with fluid.unique_name.guard():
  48. model.test_inputs, model.test_outputs = model.build_net(
  49. mode='test')
  50. model.test_prog = model.test_prog.clone(for_test=True)
  51. model.exe.run(startup_prog)
  52. if status == "Prune":
  53. from .slim.prune import update_program
  54. model.test_prog = update_program(model.test_prog, model_dir,
  55. model.places[0])
  56. import pickle
  57. with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
  58. load_dict = pickle.load(f)
  59. fluid.io.set_program_state(model.test_prog, load_dict)
  60. elif status == "Infer" or \
  61. status == "Quant" or status == "fluid.save_inference_model":
  62. [prog, input_names, outputs] = fluid.io.load_inference_model(
  63. model_dir, model.exe, params_filename='__params__')
  64. model.test_prog = prog
  65. test_outputs_info = info['_ModelInputsOutputs']['test_outputs']
  66. model.test_inputs = OrderedDict()
  67. model.test_outputs = OrderedDict()
  68. for name in input_names:
  69. model.test_inputs[name] = model.test_prog.global_block().var(name)
  70. for i, out in enumerate(outputs):
  71. var_desc = test_outputs_info[i]
  72. model.test_outputs[var_desc[0]] = out
  73. if 'Transforms' in info:
  74. transforms_mode = info.get('TransformsMode', 'RGB')
  75. if transforms_mode == 'RGB':
  76. to_rgb = True
  77. else:
  78. to_rgb = False
  79. if 'BatchTransforms' in info:
  80. # 兼容老版本PaddleX模型
  81. model.test_transforms = build_transforms_v1(
  82. model.model_type, info['Transforms'], info['BatchTransforms'])
  83. model.eval_transforms = copy.deepcopy(model.test_transforms)
  84. else:
  85. model.test_transforms = build_transforms(
  86. model.model_type, info['Transforms'], to_rgb)
  87. model.eval_transforms = copy.deepcopy(model.test_transforms)
  88. if '_Attributes' in info:
  89. for k, v in info['_Attributes'].items():
  90. if k in model.__dict__:
  91. model.__dict__[k] = v
  92. logging.info("Model[{}] loaded.".format(info['Model']))
  93. model.trainable = False
  94. return model
  95. def build_transforms(model_type, transforms_info, to_rgb=True):
  96. if model_type == "classifier":
  97. import paddlex.cv.transforms.cls_transforms as T
  98. elif model_type == "detector":
  99. import paddlex.cv.transforms.det_transforms as T
  100. elif model_type == "segmenter":
  101. import paddlex.cv.transforms.seg_transforms as T
  102. transforms = list()
  103. for op_info in transforms_info:
  104. op_name = list(op_info.keys())[0]
  105. op_attr = op_info[op_name]
  106. if not hasattr(T, op_name):
  107. raise Exception(
  108. "There's no operator named '{}' in transforms of {}".format(
  109. op_name, model_type))
  110. transforms.append(getattr(T, op_name)(**op_attr))
  111. eval_transforms = T.Compose(transforms)
  112. eval_transforms.to_rgb = to_rgb
  113. return eval_transforms
  114. def build_transforms_v1(model_type, transforms_info, batch_transforms_info):
  115. """ 老版本模型加载,仅支持PaddleX前端导出的模型
  116. """
  117. logging.debug("Use build_transforms_v1 to reconstruct transforms")
  118. if model_type == "classifier":
  119. import paddlex.cv.transforms.cls_transforms as T
  120. elif model_type == "detector":
  121. import paddlex.cv.transforms.det_transforms as T
  122. elif model_type == "segmenter":
  123. import paddlex.cv.transforms.seg_transforms as T
  124. transforms = list()
  125. for op_info in transforms_info:
  126. op_name = op_info[0]
  127. op_attr = op_info[1]
  128. if op_name == 'DecodeImage':
  129. continue
  130. if op_name == 'Permute':
  131. continue
  132. if op_name == 'ResizeByShort':
  133. op_attr_new = dict()
  134. if 'short_size' in op_attr:
  135. op_attr_new['short_size'] = op_attr['short_size']
  136. else:
  137. op_attr_new['short_size'] = op_attr['target_size']
  138. op_attr_new['max_size'] = op_attr.get('max_size', -1)
  139. op_attr = op_attr_new
  140. if op_name.startswith('Arrange'):
  141. continue
  142. if not hasattr(T, op_name):
  143. raise Exception(
  144. "There's no operator named '{}' in transforms of {}".format(
  145. op_name, model_type))
  146. transforms.append(getattr(T, op_name)(**op_attr))
  147. if model_type == "detector" and len(batch_transforms_info) > 0:
  148. op_name = batch_transforms_info[0][0]
  149. op_attr = batch_transforms_info[0][1]
  150. assert op_name == "PaddingMiniBatch", "Only PaddingMiniBatch transform is supported for batch transform"
  151. padding = T.Padding(coarsest_stride=op_attr['coarsest_stride'])
  152. transforms.append(padding)
  153. eval_transforms = T.Compose(transforms)
  154. return eval_transforms