| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271 |
- # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os.path as osp
- import numpy as np
- from paddle.inference import Config
- from paddle.inference import create_predictor
- from paddle.inference import PrecisionType
- from paddlex.cv.models import load_model
- from paddlex.utils import logging, Timer
- class Predictor(object):
- def __init__(self,
- model_dir,
- use_gpu=False,
- gpu_id=0,
- cpu_thread_num=1,
- use_mkl=True,
- mkl_thread_num=4,
- use_trt=False,
- use_glog=False,
- memory_optimize=True,
- max_trt_batch_size=1,
- trt_precision_mode='float32'):
- """ 创建Paddle Predictor
- Args:
- model_dir: 模型路径(必须是导出的部署或量化模型)
- use_gpu: 是否使用gpu,默认False
- gpu_id: 使用gpu的id,默认0
- cpu_thread_num=1:使用cpu进行预测时的线程数,默认为1
- use_mkl: 是否使用mkldnn计算库,CPU情况下使用,默认False
- mkl_thread_num: mkldnn计算线程数,默认为4
- use_trt: 是否使用TensorRT,默认False
- use_glog: 是否启用glog日志, 默认False
- memory_optimize: 是否启动内存优化,默认True
- max_trt_batch_size: 在使用TensorRT时配置的最大batch size,默认1
- trt_precision_mode:在使用TensorRT时采用的精度,默认float32
- """
- self.model_dir = model_dir
- self._model = load_model(model_dir, with_net=False)
- if trt_precision_mode.lower() == 'float32':
- trt_precision_mode = PrecisionType.Float32
- elif trt_precision_mode.lower() == 'float16':
- trt_precision_mode = PrecisionType.Float16
- else:
- logging.error(
- "TensorRT precision mode {} is invalid. Supported modes are float32 and float16."
- .format(trt_precision_mode),
- exit=True)
- self.predictor = self.create_predictor(
- use_gpu=use_gpu,
- gpu_id=gpu_id,
- cpu_thread_num=cpu_thread_num,
- use_mkl=use_mkl,
- mkl_thread_num=mkl_thread_num,
- use_trt=use_trt,
- use_glog=use_glog,
- memory_optimize=memory_optimize,
- max_trt_batch_size=max_trt_batch_size,
- trt_precision_mode=trt_precision_mode)
- self.timer = Timer()
- def create_predictor(self,
- use_gpu=True,
- gpu_id=0,
- cpu_thread_num=1,
- use_mkl=True,
- mkl_thread_num=4,
- use_trt=False,
- use_glog=False,
- memory_optimize=True,
- max_trt_batch_size=1,
- trt_precision_mode=PrecisionType.Float32):
- config = Config(
- osp.join(self.model_dir, 'model.pdmodel'),
- osp.join(self.model_dir, 'model.pdiparams'))
- if use_gpu:
- # 设置GPU初始显存(单位M)和Device ID
- config.enable_use_gpu(100, gpu_id)
- config.switch_ir_optim(True)
- if use_trt:
- config.enable_tensorrt_engine(
- workspace_size=1 << 10,
- max_batch_size=max_trt_batch_size,
- min_subgraph_size=3,
- precision_mode=trt_precision_mode,
- use_static=False,
- use_calib_mode=False)
- else:
- config.disable_gpu()
- config.set_cpu_math_library_num_threads(cpu_thread_num)
- if use_mkl:
- try:
- # cache 10 different shapes for mkldnn to avoid memory leak
- config.set_mkldnn_cache_capacity(10)
- config.enable_mkldnn()
- config.set_cpu_math_library_num_threads(mkl_thread_num)
- except Exception as e:
- logging.warning(
- "The current environment does not support `mkldnn`, so disable mkldnn."
- )
- pass
- if use_glog:
- config.enable_glog_info()
- else:
- config.disable_glog_info()
- if memory_optimize:
- config.enable_memory_optim()
- config.switch_use_feed_fetch_ops(False)
- predictor = create_predictor(config)
- return predictor
- def preprocess(self, images, transforms):
- preprocessed_samples = self._model._preprocess(
- images, transforms, to_tensor=False)
- if self._model.model_type == 'classifier':
- preprocessed_samples = {'image': preprocessed_samples[0]}
- elif self._model.model_type == 'segmenter':
- preprocessed_samples = {
- 'image': preprocessed_samples[0],
- 'ori_shape': preprocessed_samples[1]
- }
- elif self._model.model_type == 'detector':
- pass
- else:
- logging.error(
- "Invalid model type {}".format(self._model.model_type),
- exit=True)
- return preprocessed_samples
- def postprocess(self, net_outputs, topk=1, ori_shape=None,
- transforms=None):
- if self._model.model_type == 'classifier':
- true_topk = min(self._model.num_classes, topk)
- preds = self._model._postprocess(net_outputs[0], true_topk)
- if len(preds) == 1:
- preds = preds[0]
- elif self._model.model_type == 'segmenter':
- score_map, label_map = self._model._postprocess(
- net_outputs,
- batch_origin_shape=ori_shape,
- transforms=transforms.transforms)
- score_map = np.squeeze(score_map)
- label_map = np.squeeze(label_map)
- if score_map.ndim == 3:
- preds = {'label_map': label_map, 'score_map': score_map}
- else:
- preds = [{
- 'label_map': l,
- 'score_map': s
- } for l, s in zip(label_map, score_map)]
- elif self._model.model_type == 'detector':
- if 'RCNN' in self._model.__class__.__name__:
- net_outputs = [{
- k: v
- for k, v in zip(['bbox', 'bbox_num', 'mask'], res)
- } for res in net_outputs]
- else:
- net_outputs = {
- k: v
- for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
- }
- preds = self._model._postprocess(net_outputs)
- if len(preds) == 1:
- preds = preds[0]
- else:
- logging.error(
- "Invalid model type {}.".format(self._model.model_type),
- exit=True)
- return preds
- def raw_predict(self, inputs):
- """ 接受预处理过后的数据进行预测
- Args:
- inputs(dict): 预处理过后的数据
- """
- input_names = self.predictor.get_input_names()
- for name in input_names:
- input_tensor = self.predictor.get_input_handle(name)
- input_tensor.copy_from_cpu(inputs[name])
- self.predictor.run()
- output_names = self.predictor.get_output_names()
- net_outputs = list()
- for name in output_names:
- output_tensor = self.predictor.get_output_handle(name)
- net_outputs.append(output_tensor.copy_to_cpu())
- return net_outputs
- def _run(self, images, topk=1, transforms=None, repeats=1, verbose=False):
- self.timer.preprocess_time_s.start()
- preprocessed_input = self.preprocess(images, transforms)
- self.timer.preprocess_time_s.end()
- self.timer.inference_time_s.start()
- if 'RCNN' in self._model.__class__.__name__:
- if len(preprocessed_input) > 1:
- logging.warning(
- "{} only supports inference with batch size equal to 1."
- .format(self._model.__class__.__name__))
- for step in range(repeats):
- net_outputs = [
- self.raw_predict(sample) for sample in preprocessed_input
- ]
- self.timer.inference_time_s.end(repeats=len(preprocessed_input) *
- repeats)
- ori_shape = None
- else:
- for step in range(repeats):
- net_outputs = self.raw_predict(preprocessed_input)
- self.timer.inference_time_s.end(repeats=repeats)
- ori_shape = preprocessed_input.get('ori_shape', None)
- self.timer.postprocess_time_s.start()
- results = self.postprocess(
- net_outputs, topk, ori_shape=ori_shape, transforms=transforms)
- self.timer.postprocess_time_s.end()
- self.timer.img_num = len(images)
- if verbose:
- self.timer.info(average=True)
- return results
- def predict(self,
- img_file,
- topk=1,
- transforms=None,
- warmup_iters=0,
- repeats=1):
- """ 图片预测
- Args:
- img_file(List[np.ndarray or str], str or np.ndarray):
- 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
- topk(int): 分类预测时使用,表示预测前topk的结果。
- transforms (paddlex.transforms): 数据预处理操作。
- """
- if transforms is None and not hasattr(self._model, 'test_transforms'):
- raise Exception("Transforms need to be defined, now is None.")
- if transforms is None:
- transforms = self._model.test_transforms
- if isinstance(img_file, (str, np.ndarray)):
- images = [img_file]
- else:
- images = img_file
- for step in range(warmup_iters):
- self._run(
- images=images, topk=topk, transforms=transforms, verbose=False)
- self.timer.reset()
|