|
|
@@ -16,6 +16,7 @@ import os.path as osp
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
import yaml
|
|
|
+import multiprocessing as mp
|
|
|
import paddlex
|
|
|
import paddle.fluid as fluid
|
|
|
from paddlex.cv.transforms import build_transforms
|
|
|
@@ -79,6 +80,15 @@ class Predictor:
|
|
|
self.predictor = self.create_predictor(use_gpu, gpu_id, use_mkl,
|
|
|
mkl_thread_num, use_trt,
|
|
|
use_glog, memory_optimize)
|
|
|
+ # 线程池,在模型在预测时用于对输入数据以图片为单位进行并行处理
|
|
|
+ # 主要用于batch_predict接口
|
|
|
+ thread_num = mp.cpu_count() if mp.cpu_count() < 8 else 8
|
|
|
+ self.thread_pool = mp.pool.ThreadPool(thread_num)
|
|
|
+
|
|
|
+ def reset_thread_pool(self, thread_num):
|
|
|
+ self.thread_pool.close()
|
|
|
+ self.thread_pool.join()
|
|
|
+ self.thread_pool = mp.pool.ThreadPool(thread_num)
|
|
|
|
|
|
def create_predictor(self,
|
|
|
use_gpu=True,
|
|
|
@@ -114,7 +124,7 @@ class Predictor:
|
|
|
predictor = fluid.core.create_paddle_predictor(config)
|
|
|
return predictor
|
|
|
|
|
|
- def preprocess(self, image, thread_num=1):
|
|
|
+ def preprocess(self, image, thread_pool=None):
|
|
|
""" 对图像做预处理
|
|
|
|
|
|
Args:
|
|
|
@@ -128,7 +138,7 @@ class Predictor:
|
|
|
self.transforms,
|
|
|
self.model_type,
|
|
|
self.model_name,
|
|
|
- thread_num=thread_num)
|
|
|
+ thread_pool=thread_pool)
|
|
|
res['image'] = im
|
|
|
elif self.model_type == "detector":
|
|
|
if self.model_name in ["PPYOLO", "YOLOv3"]:
|
|
|
@@ -137,7 +147,7 @@ class Predictor:
|
|
|
self.transforms,
|
|
|
self.model_type,
|
|
|
self.model_name,
|
|
|
- thread_num=thread_num)
|
|
|
+ thread_pool=thread_pool)
|
|
|
res['image'] = im
|
|
|
res['im_size'] = im_size
|
|
|
if self.model_name.count('RCNN') > 0:
|
|
|
@@ -146,7 +156,7 @@ class Predictor:
|
|
|
self.transforms,
|
|
|
self.model_type,
|
|
|
self.model_name,
|
|
|
- thread_num=thread_num)
|
|
|
+ thread_pool=thread_pool)
|
|
|
res['image'] = im
|
|
|
res['im_info'] = im_resize_info
|
|
|
res['im_shape'] = im_shape
|
|
|
@@ -156,7 +166,7 @@ class Predictor:
|
|
|
self.transforms,
|
|
|
self.model_type,
|
|
|
self.model_name,
|
|
|
- thread_num=thread_num)
|
|
|
+ thread_pool=thread_pool)
|
|
|
res['image'] = im
|
|
|
res['im_info'] = im_info
|
|
|
return res
|
|
|
@@ -253,17 +263,16 @@ class Predictor:
|
|
|
|
|
|
return results[0]
|
|
|
|
|
|
- def batch_predict(self, image_list, topk=1, thread_num=2):
|
|
|
+ def batch_predict(self, image_list, topk=1):
|
|
|
""" 图片预测
|
|
|
|
|
|
Args:
|
|
|
image_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
|
|
|
也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
|
|
|
- thread_num (int): 并发执行各图像预处理时的线程数。
|
|
|
|
|
|
topk(int): 分类预测时使用,表示预测前topk的结果
|
|
|
"""
|
|
|
- preprocessed_input = self.preprocess(image_list)
|
|
|
+ preprocessed_input = self.preprocess(image_list, self.thread_pool)
|
|
|
model_pred = self.raw_predict(preprocessed_input)
|
|
|
im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
|
|
|
'im_shape']
|