|
|
@@ -35,7 +35,6 @@ class Predictor(object):
|
|
|
max_trt_batch_size=1,
|
|
|
trt_precision_mode='float32'):
|
|
|
""" 创建Paddle Predictor
|
|
|
-
|
|
|
Args:
|
|
|
model_dir: 模型路径(必须是导出的部署或量化模型)
|
|
|
use_gpu: 是否使用gpu,默认True
|
|
|
@@ -183,7 +182,6 @@ class Predictor(object):
|
|
|
|
|
|
def raw_predict(self, inputs):
|
|
|
""" 接受预处理过后的数据进行预测
|
|
|
-
|
|
|
Args:
|
|
|
inputs(dict): 预处理过后的数据
|
|
|
"""
|
|
|
@@ -204,7 +202,6 @@ class Predictor(object):
|
|
|
|
|
|
def predict(self, img_file, topk=1, transforms=None):
|
|
|
""" 图片预测
|
|
|
-
|
|
|
Args:
|
|
|
img_file(List[np.ndarray or str], str or np.ndarray):
|
|
|
图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
|