deploy.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import numpy as np
  16. from paddle.inference import Config
  17. from paddle.inference import create_predictor
  18. from paddle.inference import PrecisionType
  19. from paddlex.cv.models import load_model
  20. from paddlex.utils import logging, Timer
  21. class Predictor(object):
  22. def __init__(self,
  23. model_dir,
  24. use_gpu=True,
  25. gpu_id=0,
  26. cpu_thread_num=1,
  27. use_mkl=True,
  28. mkl_thread_num=4,
  29. use_trt=False,
  30. use_glog=False,
  31. memory_optimize=True,
  32. max_trt_batch_size=1,
  33. trt_precision_mode='float32'):
  34. """ 创建Paddle Predictor
  35. Args:
  36. model_dir: 模型路径(必须是导出的部署或量化模型)
  37. use_gpu: 是否使用gpu,默认True
  38. gpu_id: 使用gpu的id,默认0
  39. cpu_thread_num=1:使用cpu进行预测时的线程数,默认为1
  40. use_mkl: 是否使用mkldnn计算库,CPU情况下使用,默认False
  41. mkl_thread_num: mkldnn计算线程数,默认为4
  42. use_trt: 是否使用TensorRT,默认False
  43. use_glog: 是否启用glog日志, 默认False
  44. memory_optimize: 是否启动内存优化,默认True
  45. max_trt_batch_size: 在使用TensorRT时配置的最大batch size,默认1
  46. trt_precision_mode:在使用TensorRT时采用的精度,默认float32
  47. """
  48. self.model_dir = model_dir
  49. self._model = load_model(model_dir, with_net=False)
  50. if trt_precision_mode == 'float32':
  51. trt_precision_mode = PrecisionType.Float32
  52. elif trt_precision_mode == 'float16':
  53. trt_precision_mode = PrecisionType.Float16
  54. else:
  55. logging.error(
  56. "TensorRT precision mode {} is invalid. Supported modes are float32 and float16."
  57. .format(trt_precision_mode),
  58. exit=True)
  59. self.predictor = self.create_predictor(
  60. use_gpu=use_gpu,
  61. gpu_id=gpu_id,
  62. cpu_thread_num=cpu_thread_num,
  63. use_mkl=use_mkl,
  64. mkl_thread_num=mkl_thread_num,
  65. use_trt=use_trt,
  66. use_glog=use_glog,
  67. memory_optimize=memory_optimize,
  68. max_trt_batch_size=max_trt_batch_size,
  69. trt_precision_mode=trt_precision_mode)
  70. self.timer = Timer()
  71. def create_predictor(self,
  72. use_gpu=True,
  73. gpu_id=0,
  74. cpu_thread_num=1,
  75. use_mkl=True,
  76. mkl_thread_num=4,
  77. use_trt=False,
  78. use_glog=False,
  79. memory_optimize=True,
  80. max_trt_batch_size=1,
  81. trt_precision_mode=PrecisionType.Float32):
  82. config = Config(
  83. osp.join(self.model_dir, 'model.pdmodel'),
  84. osp.join(self.model_dir, 'model.pdiparams'))
  85. if use_gpu:
  86. # 设置GPU初始显存(单位M)和Device ID
  87. config.enable_use_gpu(100, gpu_id)
  88. config.switch_ir_optim(True)
  89. if use_trt:
  90. config.enable_tensorrt_engine(
  91. workspace_size=1 << 10,
  92. max_batch_size=max_trt_batch_size,
  93. min_subgraph_size=3,
  94. precision_mode=trt_precision_mode,
  95. use_static=False,
  96. use_calib_mode=False)
  97. else:
  98. config.disable_gpu()
  99. config.set_cpu_math_library_num_threads(cpu_thread_num)
  100. if use_mkl:
  101. try:
  102. # cache 10 different shapes for mkldnn to avoid memory leak
  103. config.set_mkldnn_cache_capacity(10)
  104. config.enable_mkldnn()
  105. config.set_cpu_math_library_num_threads(mkl_thread_num)
  106. except Exception as e:
  107. logging.warning(
  108. "The current environment does not support `mkldnn`, so disable mkldnn."
  109. )
  110. pass
  111. if use_glog:
  112. config.enable_glog_info()
  113. else:
  114. config.disable_glog_info()
  115. if memory_optimize:
  116. config.enable_memory_optim()
  117. config.switch_use_feed_fetch_ops(False)
  118. predictor = create_predictor(config)
  119. return predictor
  120. def preprocess(self, images, transforms):
  121. preprocessed_samples = self._model._preprocess(
  122. images, transforms, to_tensor=False)
  123. if self._model.model_type == 'classifier':
  124. batch_samples = {'image': preprocessed_samples[0]}
  125. elif self._model.model_type == 'segmenter':
  126. batch_samples = {
  127. 'image': preprocessed_samples[0],
  128. 'ori_shape': preprocessed_samples[1]
  129. }
  130. elif self._model.model_type == 'detector':
  131. batch_samples = preprocessed_samples
  132. else:
  133. logging.error(
  134. "Invalid model type {}".format(self._model.model_type),
  135. exit=True)
  136. return batch_samples
  137. def raw_predict(self, inputs):
  138. """ 接受预处理过后的数据进行预测
  139. Args:
  140. inputs(dict): 预处理过后的数据
  141. """
  142. def predict(self, img_file, topk=1, transforms=None):
  143. """ 图片预测
  144. Args:
  145. img_file(List[np.ndarray or str], str or np.ndarray):
  146. 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
  147. topk(int): 分类预测时使用,表示预测前topk的结果。
  148. transforms (paddlex.transforms): 数据预处理操作。
  149. """
  150. if transforms is None and not hasattr(self, 'test_transforms'):
  151. raise Exception("Transforms need to be defined, now is None.")
  152. if transforms is None:
  153. transforms = self._model.test_transforms
  154. if isinstance(img_file, (str, np.ndarray)):
  155. images = [img_file]
  156. else:
  157. images = img_file
  158. self.timer.preprocess_time_s.start()
  159. batch_samples = self.preprocess(images, transforms)
  160. self.timer.preprocess_time_s.end()
  161. input_names = self.predictor.get_input_names()
  162. for name in input_names:
  163. input_tensor = self.predictor.get_input_handle(name)
  164. input_tensor.copy_from_cpu(batch_samples[name])
  165. self.timer.inference_time_s.start()
  166. self.predictor.run()
  167. output_names = self.predictor.get_output_names()