deploy.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import numpy as np
  16. import paddle.nn.functional as F
  17. from paddle.inference import Config
  18. from paddle.inference import create_predictor
  19. from paddle.inference import PrecisionType
  20. from paddlex.cv.models import load_model
  21. from paddlex.utils import logging, Timer
  22. class Predictor(object):
  23. def __init__(self,
  24. model_dir,
  25. use_gpu=True,
  26. gpu_id=0,
  27. cpu_thread_num=1,
  28. use_mkl=True,
  29. mkl_thread_num=4,
  30. use_trt=False,
  31. use_glog=False,
  32. memory_optimize=True,
  33. max_trt_batch_size=1,
  34. trt_precision_mode='float32'):
  35. """ 创建Paddle Predictor
  36. Args:
  37. model_dir: 模型路径(必须是导出的部署或量化模型)
  38. use_gpu: 是否使用gpu,默认True
  39. gpu_id: 使用gpu的id,默认0
  40. cpu_thread_num=1:使用cpu进行预测时的线程数,默认为1
  41. use_mkl: 是否使用mkldnn计算库,CPU情况下使用,默认False
  42. mkl_thread_num: mkldnn计算线程数,默认为4
  43. use_trt: 是否使用TensorRT,默认False
  44. use_glog: 是否启用glog日志, 默认False
  45. memory_optimize: 是否启动内存优化,默认True
  46. max_trt_batch_size: 在使用TensorRT时配置的最大batch size,默认1
  47. trt_precision_mode:在使用TensorRT时采用的精度,默认float32
  48. """
  49. self.model_dir = model_dir
  50. self._model = load_model(model_dir, with_net=False)
  51. if trt_precision_mode == 'float32':
  52. trt_precision_mode = PrecisionType.Float32
  53. elif trt_precision_mode == 'float16':
  54. trt_precision_mode = PrecisionType.Float16
  55. else:
  56. logging.error(
  57. "TensorRT precision mode {} is invalid. Supported modes are float32 and float16."
  58. .format(trt_precision_mode),
  59. exit=True)
  60. self.predictor = self.create_predictor(
  61. use_gpu=use_gpu,
  62. gpu_id=gpu_id,
  63. cpu_thread_num=cpu_thread_num,
  64. use_mkl=use_mkl,
  65. mkl_thread_num=mkl_thread_num,
  66. use_trt=use_trt,
  67. use_glog=use_glog,
  68. memory_optimize=memory_optimize,
  69. max_trt_batch_size=max_trt_batch_size,
  70. trt_precision_mode=trt_precision_mode)
  71. self.timer = Timer()
  72. def create_predictor(self,
  73. use_gpu=True,
  74. gpu_id=0,
  75. cpu_thread_num=1,
  76. use_mkl=True,
  77. mkl_thread_num=4,
  78. use_trt=False,
  79. use_glog=False,
  80. memory_optimize=True,
  81. max_trt_batch_size=1,
  82. trt_precision_mode=PrecisionType.Float32):
  83. config = Config(
  84. osp.join(self.model_dir, 'model.pdmodel'),
  85. osp.join(self.model_dir, 'model.pdiparams'))
  86. if use_gpu:
  87. # 设置GPU初始显存(单位M)和Device ID
  88. config.enable_use_gpu(100, gpu_id)
  89. config.switch_ir_optim(True)
  90. if use_trt:
  91. config.enable_tensorrt_engine(
  92. workspace_size=1 << 10,
  93. max_batch_size=max_trt_batch_size,
  94. min_subgraph_size=3,
  95. precision_mode=trt_precision_mode,
  96. use_static=False,
  97. use_calib_mode=False)
  98. else:
  99. config.disable_gpu()
  100. config.set_cpu_math_library_num_threads(cpu_thread_num)
  101. if use_mkl:
  102. try:
  103. # cache 10 different shapes for mkldnn to avoid memory leak
  104. config.set_mkldnn_cache_capacity(10)
  105. config.enable_mkldnn()
  106. config.set_cpu_math_library_num_threads(mkl_thread_num)
  107. except Exception as e:
  108. logging.warning(
  109. "The current environment does not support `mkldnn`, so disable mkldnn."
  110. )
  111. pass
  112. if use_glog:
  113. config.enable_glog_info()
  114. else:
  115. config.disable_glog_info()
  116. if memory_optimize:
  117. config.enable_memory_optim()
  118. config.switch_use_feed_fetch_ops(False)
  119. predictor = create_predictor(config)
  120. return predictor
  121. def preprocess(self, images, transforms):
  122. preprocessed_samples = self._model._preprocess(
  123. images, transforms, to_tensor=False)
  124. if self._model.model_type == 'classifier':
  125. preprocessed_samples = {'image': preprocessed_samples[0]}
  126. elif self._model.model_type == 'segmenter':
  127. preprocessed_samples = {
  128. 'image': preprocessed_samples[0],
  129. 'ori_shape': preprocessed_samples[1]
  130. }
  131. elif self._model.model_type == 'detector':
  132. pass
  133. else:
  134. logging.error(
  135. "Invalid model type {}".format(self._model.model_type),
  136. exit=True)
  137. return preprocessed_samples
  138. def postprocess(self, net_outputs, topk=1, ori_shape=None,
  139. transforms=None):
  140. if self._model.model_type == 'classifier':
  141. true_topk = min(self._model.num_classes, topk)
  142. preds = self._model._postprocess(net_outputs[0], true_topk)
  143. elif self._model.model_type == 'segmenter':
  144. score_map, label_map = net_outputs
  145. combo = np.concatenate([score_map, label_map], axis=-1)
  146. combo = self._model._postprocess(
  147. combo,
  148. batch_origin_shape=ori_shape,
  149. transforms=transforms.transforms)
  150. score_map = np.squeeze(combo[..., :-1])
  151. label_map = np.squeeze(combo[..., -1])
  152. if len(score_map.shape) == 3:
  153. preds = {'label_map': label_map, 'score_map': score_map}
  154. else:
  155. preds = [{
  156. 'label_map': l,
  157. 'score_map': s
  158. } for l, s in zip(label_map, score_map)]
  159. elif self._model.model_type == 'detector':
  160. net_outputs = {
  161. k: v
  162. for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
  163. }
  164. preds = self._model._postprocess(net_outputs)
  165. else:
  166. logging.error(
  167. "Invalid model type {}.".format(self._model.model_type),
  168. exit=True)
  169. return preds
  170. def raw_predict(self, inputs):
  171. """ 接受预处理过后的数据进行预测
  172. Args:
  173. inputs(dict): 预处理过后的数据
  174. """
  175. input_names = self.predictor.get_input_names()
  176. for name in input_names:
  177. input_tensor = self.predictor.get_input_handle(name)
  178. input_tensor.copy_from_cpu(inputs[name])
  179. self.timer.inference_time_s.start()
  180. self.predictor.run()
  181. output_names = self.predictor.get_output_names()
  182. net_outputs = list()
  183. for name in output_names:
  184. output_tensor = self.predictor.get_output_handle(name)
  185. net_outputs.append(output_tensor.copy_to_cpu())
  186. return net_outputs
  187. def predict(self, img_file, topk=1, transforms=None):
  188. """ 图片预测
  189. Args:
  190. img_file(List[np.ndarray or str], str or np.ndarray):
  191. 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
  192. topk(int): 分类预测时使用,表示预测前topk的结果。
  193. transforms (paddlex.transforms): 数据预处理操作。
  194. """
  195. if transforms is None and not hasattr(self._model, 'test_transforms'):
  196. raise Exception("Transforms need to be defined, now is None.")
  197. if transforms is None:
  198. transforms = self._model.test_transforms
  199. if isinstance(img_file, (str, np.ndarray)):
  200. images = [img_file]
  201. else:
  202. images = img_file
  203. self.timer.preprocess_time_s.start()
  204. preprocessed_input = self.preprocess(images, transforms)
  205. self.timer.preprocess_time_s.end()
  206. self.timer.inference_time_s.start()
  207. net_outputs = self.raw_predict(preprocessed_input)
  208. self.timer.inference_time_s.end()
  209. self.timer.postprocess_time_s.start()
  210. results = self.postprocess(
  211. net_outputs,
  212. topk,
  213. ori_shape=preprocessed_input.get('ori_shape', None),
  214. transforms=transforms.transforms)
  215. self.timer.postprocess_time_s.end()
  216. return results