|
|
@@ -12,15 +12,13 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
-import os
|
|
|
import os.path as osp
|
|
|
import numpy as np
|
|
|
-import yaml
|
|
|
from paddle.inference import Config
|
|
|
from paddle.inference import create_predictor
|
|
|
from paddle.inference import PrecisionType
|
|
|
-from paddlex.cv.transforms import build_transforms
|
|
|
-from paddlex.utils import logging
|
|
|
+from paddlex.cv.models import load_model
|
|
|
+from paddlex.utils import logging, Timer
|
|
|
|
|
|
|
|
|
class Predictor(object):
|
|
|
@@ -51,10 +49,8 @@ class Predictor(object):
|
|
|
max_trt_batch_size: 在使用TensorRT时配置的最大batch size,默认1
|
|
|
trt_precision_mode:在使用TensorRT时采用的精度,默认float32
|
|
|
"""
|
|
|
- if not osp.isdir(model_dir):
|
|
|
- logging.error(
|
|
|
- "{} is not a valid model directory.".format(model_dir),
|
|
|
- exit=True)
|
|
|
+ self.model_dir = model_dir
|
|
|
+ self._model = load_model(model_dir, with_net=False)
|
|
|
|
|
|
if trt_precision_mode == 'float32':
|
|
|
trt_precision_mode = PrecisionType.Float32
|
|
|
@@ -66,6 +62,19 @@ class Predictor(object):
|
|
|
.format(trt_precision_mode),
|
|
|
exit=True)
|
|
|
|
|
|
+ self.predictor = self.create_predictor(
|
|
|
+ use_gpu=use_gpu,
|
|
|
+ gpu_id=gpu_id,
|
|
|
+ cpu_thread_num=cpu_thread_num,
|
|
|
+ use_mkl=use_mkl,
|
|
|
+ mkl_thread_num=mkl_thread_num,
|
|
|
+ use_trt=use_trt,
|
|
|
+ use_glog=use_glog,
|
|
|
+ memory_optimize=memory_optimize,
|
|
|
+ max_trt_batch_size=max_trt_batch_size,
|
|
|
+ trt_precision_mode=trt_precision_mode)
|
|
|
+ self.timer = Timer()
|
|
|
+
|
|
|
def create_predictor(self,
|
|
|
use_gpu=True,
|
|
|
gpu_id=0,
|
|
|
@@ -78,8 +87,8 @@ class Predictor(object):
|
|
|
max_trt_batch_size=1,
|
|
|
trt_precision_mode=PrecisionType.Float32):
|
|
|
config = Config(
|
|
|
- prog_file=osp.join(self.model_dir, 'model.pdmodel'),
|
|
|
- params_file=osp.join(self.model_dir, 'model.pdiparams'))
|
|
|
+ osp.join(self.model_dir, 'model.pdmodel'),
|
|
|
+ osp.join(self.model_dir, 'model.pdiparams'))
|
|
|
|
|
|
if use_gpu:
|
|
|
# 设置GPU初始显存(单位M)和Device ID
|
|
|
@@ -117,3 +126,59 @@ class Predictor(object):
|
|
|
config.switch_use_feed_fetch_ops(False)
|
|
|
predictor = create_predictor(config)
|
|
|
return predictor
|
|
|
+
|
|
|
+ def preprocess(self, images, transforms):
|
|
|
+ preprocessed_samples = self._model._preprocess(
|
|
|
+ images, transforms, to_tensor=False)
|
|
|
+ if self._model.model_type == 'classifier':
|
|
|
+ batch_samples = {'image': preprocessed_samples[0]}
|
|
|
+ elif self._model.model_type == 'segmenter':
|
|
|
+ batch_samples = {
|
|
|
+ 'image': preprocessed_samples[0],
|
|
|
+ 'ori_shape': preprocessed_samples[1]
|
|
|
+ }
|
|
|
+ elif self._model.model_type == 'detector':
|
|
|
+ batch_samples = preprocessed_samples
|
|
|
+ else:
|
|
|
+ logging.error(
|
|
|
+ "Invalid model type {}".format(self._model.model_type),
|
|
|
+ exit=True)
|
|
|
+ return batch_samples
|
|
|
+
|
|
|
+ def raw_predict(self, inputs):
|
|
|
+ """ 接受预处理过后的数据进行预测
|
|
|
+
|
|
|
+ Args:
|
|
|
+ inputs(dict): 预处理过后的数据
|
|
|
+ """
|
|
|
+
|
|
|
+ def predict(self, img_file, topk=1, transforms=None):
|
|
|
+ """ 图片预测
|
|
|
+
|
|
|
+ Args:
|
|
|
+ img_file(List[np.ndarray or str], str or np.ndarray):
|
|
|
+ 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
|
|
|
+ topk(int): 分类预测时使用,表示预测前topk的结果。
|
|
|
+ transforms (paddlex.transforms): 数据预处理操作。
|
|
|
+ """
|
|
|
+ if transforms is None and not hasattr(self, 'test_transforms'):
|
|
|
+ raise Exception("Transforms need to be defined, now is None.")
|
|
|
+ if transforms is None:
|
|
|
+ transforms = self._model.test_transforms
|
|
|
+ if isinstance(img_file, (str, np.ndarray)):
|
|
|
+ images = [img_file]
|
|
|
+ else:
|
|
|
+ images = img_file
|
|
|
+
|
|
|
+ self.timer.preprocess_time_s.start()
|
|
|
+ batch_samples = self.preprocess(images, transforms)
|
|
|
+ self.timer.preprocess_time_s.end()
|
|
|
+
|
|
|
+ input_names = self.predictor.get_input_names()
|
|
|
+ for name in input_names:
|
|
|
+ input_tensor = self.predictor.get_input_handle(name)
|
|
|
+ input_tensor.copy_from_cpu(batch_samples[name])
|
|
|
+
|
|
|
+ self.timer.inference_time_s.start()
|
|
|
+ self.predictor.run()
|
|
|
+ output_names = self.predictor.get_output_names()
|