deploy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import numpy as np
  16. from paddle.inference import Config
  17. from paddle.inference import create_predictor
  18. from paddle.inference import PrecisionType
  19. from paddlex.cv.models import load_model
  20. from paddlex.utils import logging, Timer
  21. class Predictor(object):
  22. def __init__(self,
  23. model_dir,
  24. use_gpu=False,
  25. gpu_id=0,
  26. cpu_thread_num=1,
  27. use_mkl=True,
  28. mkl_thread_num=4,
  29. use_trt=False,
  30. use_glog=False,
  31. memory_optimize=True,
  32. max_trt_batch_size=1,
  33. trt_precision_mode='float32'):
  34. """ 创建Paddle Predictor
  35. Args:
  36. model_dir: 模型路径(必须是导出的部署或量化模型)
  37. use_gpu: 是否使用gpu,默认False
  38. gpu_id: 使用gpu的id,默认0
  39. cpu_thread_num:使用cpu进行预测时的线程数,默认为1
  40. use_mkl: 是否使用mkldnn计算库,CPU情况下使用,默认False
  41. mkl_thread_num: mkldnn计算线程数,默认为4
  42. use_trt: 是否使用TensorRT,默认False
  43. use_glog: 是否启用glog日志, 默认False
  44. memory_optimize: 是否启动内存优化,默认True
  45. max_trt_batch_size: 在使用TensorRT时配置的最大batch size,默认1
  46. trt_precision_mode:在使用TensorRT时采用的精度,可选值['float32', 'float16']。默认'float32',
  47. """
  48. self.model_dir = model_dir
  49. self._model = load_model(model_dir, with_net=False)
  50. if trt_precision_mode.lower() == 'float32':
  51. trt_precision_mode = PrecisionType.Float32
  52. elif trt_precision_mode.lower() == 'float16':
  53. trt_precision_mode = PrecisionType.Float16
  54. else:
  55. logging.error(
  56. "TensorRT precision mode {} is invalid. Supported modes are float32 and float16."
  57. .format(trt_precision_mode),
  58. exit=True)
  59. self.predictor = self.create_predictor(
  60. use_gpu=use_gpu,
  61. gpu_id=gpu_id,
  62. cpu_thread_num=cpu_thread_num,
  63. use_mkl=use_mkl,
  64. mkl_thread_num=mkl_thread_num,
  65. use_trt=use_trt,
  66. use_glog=use_glog,
  67. memory_optimize=memory_optimize,
  68. max_trt_batch_size=max_trt_batch_size,
  69. trt_precision_mode=trt_precision_mode)
  70. self.timer = Timer()
  71. def create_predictor(self,
  72. use_gpu=True,
  73. gpu_id=0,
  74. cpu_thread_num=1,
  75. use_mkl=True,
  76. mkl_thread_num=4,
  77. use_trt=False,
  78. use_glog=False,
  79. memory_optimize=True,
  80. max_trt_batch_size=1,
  81. trt_precision_mode=PrecisionType.Float32):
  82. config = Config(
  83. osp.join(self.model_dir, 'model.pdmodel'),
  84. osp.join(self.model_dir, 'model.pdiparams'))
  85. if use_gpu:
  86. # 设置GPU初始显存(单位M)和Device ID
  87. config.enable_use_gpu(200, gpu_id)
  88. config.switch_ir_optim(True)
  89. if use_trt:
  90. if self._model.model_type == 'segmenter':
  91. logging.warning(
  92. "Semantic segmentation models do not support TensorRT acceleration, "
  93. "TensorRT is forcibly disabled.")
  94. elif 'RCNN' in self._model.__class__.__name__:
  95. logging.warning(
  96. "RCNN models do not support TensorRT acceleration, "
  97. "TensorRT is forcibly disabled.")
  98. else:
  99. config.enable_tensorrt_engine(
  100. workspace_size=1 << 10,
  101. max_batch_size=max_trt_batch_size,
  102. min_subgraph_size=3,
  103. precision_mode=trt_precision_mode,
  104. use_static=False,
  105. use_calib_mode=False)
  106. else:
  107. config.disable_gpu()
  108. config.set_cpu_math_library_num_threads(cpu_thread_num)
  109. if use_mkl:
  110. if self._model.__class__.__name__ == 'MaskRCNN':
  111. logging.warning(
  112. "MaskRCNN does not support MKL-DNN, MKL-DNN is forcibly disabled"
  113. )
  114. else:
  115. try:
  116. # cache 10 different shapes for mkldnn to avoid memory leak
  117. config.set_mkldnn_cache_capacity(10)
  118. config.enable_mkldnn()
  119. config.set_cpu_math_library_num_threads(mkl_thread_num)
  120. except Exception as e:
  121. logging.warning(
  122. "The current environment does not support MKL-DNN, MKL-DNN is disabled."
  123. )
  124. pass
  125. if not use_glog:
  126. config.disable_glog_info()
  127. if memory_optimize:
  128. config.enable_memory_optim()
  129. config.switch_use_feed_fetch_ops(False)
  130. predictor = create_predictor(config)
  131. return predictor
  132. def preprocess(self, images, transforms):
  133. preprocessed_samples = self._model._preprocess(
  134. images, transforms, to_tensor=False)
  135. if self._model.model_type == 'classifier':
  136. preprocessed_samples = {'image': preprocessed_samples[0]}
  137. elif self._model.model_type == 'segmenter':
  138. preprocessed_samples = {
  139. 'image': preprocessed_samples[0],
  140. 'ori_shape': preprocessed_samples[1]
  141. }
  142. elif self._model.model_type == 'detector':
  143. pass
  144. else:
  145. logging.error(
  146. "Invalid model type {}".format(self._model.model_type),
  147. exit=True)
  148. return preprocessed_samples
  149. def postprocess(self, net_outputs, topk=1, ori_shape=None,
  150. transforms=None):
  151. if self._model.model_type == 'classifier':
  152. true_topk = min(self._model.num_classes, topk)
  153. preds = self._model._postprocess(net_outputs[0], true_topk)
  154. elif self._model.model_type == 'segmenter':
  155. label_map, score_map = self._model._postprocess(
  156. net_outputs,
  157. batch_origin_shape=ori_shape,
  158. transforms=transforms.transforms)
  159. preds = [{
  160. 'label_map': l,
  161. 'score_map': s
  162. } for l, s in zip(label_map, score_map)]
  163. elif self._model.model_type == 'detector':
  164. net_outputs = {
  165. k: v
  166. for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
  167. }
  168. preds = self._model._postprocess(net_outputs)
  169. else:
  170. logging.error(
  171. "Invalid model type {}.".format(self._model.model_type),
  172. exit=True)
  173. return preds
  174. def raw_predict(self, inputs):
  175. """ 接受预处理过后的数据进行预测
  176. Args:
  177. inputs(dict): 预处理过后的数据
  178. """
  179. input_names = self.predictor.get_input_names()
  180. for name in input_names:
  181. input_tensor = self.predictor.get_input_handle(name)
  182. input_tensor.copy_from_cpu(inputs[name])
  183. self.predictor.run()
  184. output_names = self.predictor.get_output_names()
  185. net_outputs = list()
  186. for name in output_names:
  187. output_tensor = self.predictor.get_output_handle(name)
  188. net_outputs.append(output_tensor.copy_to_cpu())
  189. return net_outputs
  190. def _run(self, images, topk=1, transforms=None):
  191. self.timer.preprocess_time_s.start()
  192. preprocessed_input = self.preprocess(images, transforms)
  193. self.timer.preprocess_time_s.end(iter_num=len(images))
  194. self.timer.inference_time_s.start()
  195. net_outputs = self.raw_predict(preprocessed_input)
  196. self.timer.inference_time_s.end(iter_num=1)
  197. self.timer.postprocess_time_s.start()
  198. results = self.postprocess(
  199. net_outputs,
  200. topk,
  201. ori_shape=preprocessed_input.get('ori_shape', None),
  202. transforms=transforms)
  203. self.timer.postprocess_time_s.end(iter_num=len(images))
  204. return results
  205. def predict(self,
  206. img_file,
  207. topk=1,
  208. transforms=None,
  209. warmup_iters=0,
  210. repeats=1):
  211. """ 图片预测
  212. Args:
  213. img_file(List[np.ndarray or str], str or np.ndarray):
  214. 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
  215. topk(int): 分类预测时使用,表示预测前topk的结果。默认值为1。
  216. transforms (paddlex.transforms): 数据预处理操作。默认值为None, 即使用`model.yml`中保存的数据预处理操作。
  217. warmup_iters (int): 预热轮数,用于评估模型推理以及前后处理速度。若大于1,会预先重复预测warmup_iters,而后才开始正式的预测及其速度评估。默认为0。
  218. repeats (int): 重复次数,用于评估模型推理以及前后处理速度。若大于1,会预测repeats次取时间平均值。默认值为1。
  219. """
  220. if repeats < 1:
  221. logging.error("`repeats` must be greater than 1.", exit=True)
  222. if transforms is None and not hasattr(self._model, 'test_transforms'):
  223. raise Exception("Transforms need to be defined, now is None.")
  224. if transforms is None:
  225. transforms = self._model.test_transforms
  226. if isinstance(img_file, (str, np.ndarray)):
  227. images = [img_file]
  228. else:
  229. images = img_file
  230. for _ in range(warmup_iters):
  231. self._run(images=images, topk=topk, transforms=transforms)
  232. self.timer.reset()
  233. for _ in range(repeats):
  234. results = self._run(
  235. images=images, topk=topk, transforms=transforms)
  236. self.timer.repeats = repeats
  237. self.timer.img_num = len(images)
  238. self.timer.info(average=True)
  239. if isinstance(img_file, (str, np.ndarray)):
  240. results = results[0]
  241. return results
  242. def batch_predict(self, image_list, **params):
  243. return self.predict(img_file=image_list, **params)