load_model.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import yaml
  15. import os.path as osp
  16. import six
  17. import copy
  18. from collections import OrderedDict
  19. import paddle.fluid as fluid
  20. from paddle.fluid.framework import Parameter
  21. import paddlex
  22. import paddlex.utils.logging as logging
  23. from paddlex.cv.transforms import build_transforms, build_transforms_v1
  24. def load_model(model_dir, fixed_input_shape=None):
  25. model_scope = fluid.Scope()
  26. if not osp.exists(model_dir):
  27. logging.error("model_dir '{}' is not exists!".format(model_dir))
  28. if not osp.exists(osp.join(model_dir, "model.yml")):
  29. raise Exception("There's not model.yml in {}".format(model_dir))
  30. with open(osp.join(model_dir, "model.yml")) as f:
  31. info = yaml.load(f.read(), Loader=yaml.Loader)
  32. if 'status' in info:
  33. status = info['status']
  34. elif 'save_method' in info:
  35. # 兼容老版本PaddleX
  36. status = info['save_method']
  37. if not hasattr(paddlex.cv.models, info['Model']):
  38. raise Exception("There's no attribute {} in paddlex.cv.models".format(
  39. info['Model']))
  40. if 'model_name' in info['_init_params']:
  41. del info['_init_params']['model_name']
  42. model = getattr(paddlex.cv.models, info['Model'])(**info['_init_params'])
  43. model.fixed_input_shape = fixed_input_shape
  44. if '_Attributes' in info:
  45. if 'fixed_input_shape' in info['_Attributes']:
  46. fixed_input_shape = info['_Attributes']['fixed_input_shape']
  47. if fixed_input_shape is not None:
  48. logging.info("Model already has fixed_input_shape with {}".
  49. format(fixed_input_shape))
  50. model.fixed_input_shape = fixed_input_shape
  51. else:
  52. info['_Attributes']['fixed_input_shape'] = model.fixed_input_shape
  53. if info['Model'].count('RCNN') > 0:
  54. if info['_init_params']['with_fpn']:
  55. if model.fixed_input_shape is not None:
  56. if model.fixed_input_shape[0] % 32 > 0:
  57. raise Exception(
  58. "The first value in fixed_input_shape must be a multiple of 32, but recieved {}.".
  59. format(model.fixed_input_shape[0]))
  60. if model.fixed_input_shape[1] % 32 > 0:
  61. raise Exception(
  62. "The second value in fixed_input_shape must be a multiple of 32, but recieved {}.".
  63. format(model.fixed_input_shape[1]))
  64. with fluid.scope_guard(model_scope):
  65. if status == "Normal" or \
  66. status == "Prune" or status == "fluid.save":
  67. startup_prog = fluid.Program()
  68. model.test_prog = fluid.Program()
  69. with fluid.program_guard(model.test_prog, startup_prog):
  70. with fluid.unique_name.guard():
  71. model.test_inputs, model.test_outputs = model.build_net(
  72. mode='test')
  73. model.test_prog = model.test_prog.clone(for_test=True)
  74. model.exe.run(startup_prog)
  75. if status == "Prune":
  76. from .slim.prune import update_program
  77. model.test_prog = update_program(
  78. model.test_prog,
  79. model_dir,
  80. model.places[0],
  81. scope=model_scope)
  82. import pickle
  83. with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
  84. load_dict = pickle.load(f)
  85. fluid.io.set_program_state(model.test_prog, load_dict)
  86. elif status == "Infer" or \
  87. status == "Quant" or status == "fluid.save_inference_model":
  88. [prog, input_names, outputs] = fluid.io.load_inference_model(
  89. model_dir, model.exe, params_filename='__params__')
  90. model.test_prog = prog
  91. test_outputs_info = info['_ModelInputsOutputs']['test_outputs']
  92. model.test_inputs = OrderedDict()
  93. model.test_outputs = OrderedDict()
  94. for name in input_names:
  95. model.test_inputs[name] = model.test_prog.global_block().var(
  96. name)
  97. for i, out in enumerate(outputs):
  98. var_desc = test_outputs_info[i]
  99. model.test_outputs[var_desc[0]] = out
  100. if 'Transforms' in info:
  101. transforms_mode = info.get('TransformsMode', 'RGB')
  102. # 固定模型的输入shape
  103. fix_input_shape(info, fixed_input_shape=model.fixed_input_shape)
  104. if transforms_mode == 'RGB':
  105. to_rgb = True
  106. else:
  107. to_rgb = False
  108. if 'BatchTransforms' in info:
  109. # 兼容老版本PaddleX模型
  110. model.test_transforms = build_transforms_v1(
  111. model.model_type, info['Transforms'], info['BatchTransforms'])
  112. model.eval_transforms = copy.deepcopy(model.test_transforms)
  113. else:
  114. model.test_transforms = build_transforms(
  115. model.model_type, info['Transforms'], to_rgb)
  116. model.eval_transforms = copy.deepcopy(model.test_transforms)
  117. if '_Attributes' in info:
  118. for k, v in info['_Attributes'].items():
  119. if k in model.__dict__:
  120. model.__dict__[k] = v
  121. logging.info("Model[{}] loaded.".format(info['Model']))
  122. model.scope = model_scope
  123. model.trainable = False
  124. model.status = status
  125. return model
  126. def fix_input_shape(info, fixed_input_shape=None):
  127. if fixed_input_shape is not None:
  128. input_channel = 3
  129. if 'input_channel' in info['_init_params']:
  130. input_channel = info['_init_params']['input_channel']
  131. resize = {'ResizeByShort': {}}
  132. padding = {'Padding': {}}
  133. if info['_Attributes']['model_type'] == 'classifier':
  134. pass
  135. elif info['Model'].count('YOLO') > 0:
  136. resize_op_index = None
  137. for i in range(len(info['Transforms'])):
  138. if list(info['Transforms'][i].keys())[0] == 'Resize':
  139. resize_op_index = i
  140. if resize_op_index is not None:
  141. info['Transforms'][resize_op_index]['Resize'][
  142. 'target_size'] = fixed_input_shape[0]
  143. elif info['Model'].count('RCNN') > 0:
  144. resize_op_index = None
  145. for i in range(len(info['Transforms'])):
  146. if list(info['Transforms'][i].keys())[0] == 'ResizeByShort':
  147. resize_op_index = i
  148. if resize_op_index is not None:
  149. info['Transforms'][resize_op_index]['ResizeByShort'][
  150. 'short_size'] = min(fixed_input_shape)
  151. info['Transforms'][resize_op_index]['ResizeByShort'][
  152. 'max_size'] = max(fixed_input_shape)
  153. else:
  154. resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
  155. resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
  156. info['Transforms'].append(resize)
  157. padding_op_index = None
  158. for i in range(len(info['Transforms'])):
  159. if list(info['Transforms'][i].keys())[0] == 'Padding':
  160. padding_op_index = i
  161. if padding_op_index is not None:
  162. info['Transforms'][padding_op_index]['Padding'][
  163. 'target_size'] = list(fixed_input_shape)
  164. else:
  165. padding['Padding']['target_size'] = list(fixed_input_shape)
  166. info['Transforms'].append(padding)
  167. elif info['_Attributes']['model_type'] == 'segmenter':
  168. resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
  169. resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
  170. padding['Padding']['target_size'] = list(fixed_input_shape)
  171. padding['Padding']['im_padding_value'] = [0.] * input_channel
  172. info['Transforms'].append(resize)
  173. info['Transforms'].append(padding)