Ver código fonte

add python deploy

will-jl944 4 anos atrás
pai
commit
683a44be2b

+ 1 - 0
paddlex/__init__.py

@@ -22,6 +22,7 @@ from . import seg
 from . import cls
 from . import det
 from . import tools
+from . import deploy
 
 from .cv.models.utils.visualize import visualize_detection as visualize_det
 from .cv.models.utils.visualize import visualize_segmentation as visualize_seg

+ 6 - 6
paddlex/cv/models/classifier.py

@@ -18,7 +18,6 @@ import os.path as osp
 from collections import OrderedDict
 import numpy as np
 import paddle
-from paddle import to_tensor
 import paddle.nn.functional as F
 from paddle.static import InputSpec
 from paddlex.utils import logging, TrainingStats, DisablePrint
@@ -411,7 +410,7 @@ class BaseClassifier(BaseModel):
         """
         Do inference.
         Args:
-            img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
+            img_file(List[np.ndarray or str], str or np.ndarray):
                 Image path or decoded image data in a BGR format, which also could constitute a list,
                 meaning all images to be predicted as a mini-batch.
             transforms(paddlex.transforms.Compose or None, optional):
@@ -436,7 +435,7 @@ class BaseClassifier(BaseModel):
             images = [img_file]
         else:
             images = img_file
-        im = self._preprocess(images, transforms, self.model_type)
+        im = self._preprocess(images, transforms)
         self.net.eval()
         with paddle.no_grad():
             outputs = self.run(self.net, im, mode='test')
@@ -447,15 +446,16 @@ class BaseClassifier(BaseModel):
 
         return prediction
 
-    def _preprocess(self, images, transforms, model_type):
+    def _preprocess(self, images, transforms, to_tensor=True):
         arrange_transforms(
-            model_type=model_type, transforms=transforms, mode='test')
+            model_type=self.model_type, transforms=transforms, mode='test')
         batch_im = list()
         for im in images:
             sample = {'image': im}
             batch_im.append(transforms(sample))
 
-        batch_im = to_tensor(batch_im)
+        if to_tensor:
+            batch_im = paddle.to_tensor(batch_im)
 
         return batch_im,
 

+ 5 - 4
paddlex/cv/models/detector.py

@@ -470,7 +470,7 @@ class BaseDetector(BaseModel):
         """
         Do inference.
         Args:
-            img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
+            img_file(List[np.ndarray or str], str or np.ndarray):
                 Image path or decoded image data in a BGR format, which also could constitute a list,
                 meaning all images to be predicted as a mini-batch.
             transforms(paddlex.transforms.Compose or None, optional):
@@ -505,7 +505,7 @@ class BaseDetector(BaseModel):
             prediction = prediction[0]
         return prediction
 
-    def _preprocess(self, images, transforms):
+    def _preprocess(self, images, transforms, to_tensor=True):
         arrange_transforms(
             model_type=self.model_type, transforms=transforms, mode='test')
         batch_samples = list()
@@ -514,8 +514,9 @@ class BaseDetector(BaseModel):
             batch_samples.append(transforms(sample))
         batch_transforms = self._compose_batch_transform(transforms, 'test')
         batch_samples = batch_transforms(batch_samples)
-        for k, v in batch_samples.items():
-            batch_samples[k] = paddle.to_tensor(v)
+        if to_tensor:
+            for k, v in batch_samples.items():
+                batch_samples[k] = paddle.to_tensor(v)
         return batch_samples
 
     def _postprocess(self, batch_pred):

+ 4 - 8
paddlex/cv/models/load_model.py

@@ -82,11 +82,10 @@ def load_model(model_dir, **params):
 
     model_info['_init_params'].update({'with_net': with_net})
 
-    if with_net:
-        with paddle.utils.unique_name.guard():
-            model = getattr(paddlex.cv.models, model_info['Model'])(
-                **model_info['_init_params'])
-
+    with paddle.utils.unique_name.guard():
+        model = getattr(paddlex.cv.models, model_info['Model'])(
+            **model_info['_init_params'])
+        if with_net:
             if status == 'Pruned' or osp.exists(
                     osp.join(model_dir, "prune.yml")):
                 with open(osp.join(model_dir, "prune.yml")) as f:
@@ -121,9 +120,6 @@ def load_model(model_dir, **params):
                 net_state_dict = paddle.load(
                     osp.join(model_dir, 'model.pdparams'))
             model.net.set_state_dict(net_state_dict)
-    else:
-        model = getattr(paddlex.cv.models, model_info['Model'])(
-            **model_info['_init_params'])
 
     if 'Transforms' in model_info:
         model.test_transforms = build_transforms(model_info['Transforms'])

+ 5 - 4
paddlex/cv/models/segmenter.py

@@ -453,7 +453,7 @@ class BaseSegmenter(BaseModel):
         Do inference.
         Args:
             Args:
-            img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
+            img_file(List[np.ndarray or str], str or np.ndarray):
                 Image path or decoded image data in a BGR format, which also could constitute a list,
                 meaning all images to be predicted as a mini-batch.
             transforms(paddlex.transforms.Compose or None, optional):
@@ -495,9 +495,9 @@ class BaseSegmenter(BaseModel):
             prediction = {'label_map': label_map, 'score_map': score_map}
         return prediction
 
-    def _preprocess(self, images, transforms, model_type):
+    def _preprocess(self, images, transforms, to_tensor=True):
         arrange_transforms(
-            model_type=model_type, transforms=transforms, mode='test')
+            model_type=self.model_type, transforms=transforms, mode='test')
         batch_im = list()
         batch_ori_shape = list()
         for im in images:
@@ -508,7 +508,8 @@ class BaseSegmenter(BaseModel):
             im = transforms(sample)[0]
             batch_im.append(im)
             batch_ori_shape.append(ori_shape)
-        batch_im = paddle.to_tensor(batch_im)
+        if to_tensor:
+            batch_im = paddle.to_tensor(batch_im)
 
         return batch_im, batch_ori_shape
 

+ 75 - 10
paddlex/deploy.py

@@ -12,15 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
 import os.path as osp
 import numpy as np
-import yaml
 from paddle.inference import Config
 from paddle.inference import create_predictor
 from paddle.inference import PrecisionType
-from paddlex.cv.transforms import build_transforms
-from paddlex.utils import logging
+from paddlex.cv.models import load_model
+from paddlex.utils import logging, Timer
 
 
 class Predictor(object):
@@ -51,10 +49,8 @@ class Predictor(object):
                 max_trt_batch_size: 在使用TensorRT时配置的最大batch size,默认1
                 trt_precision_mode:在使用TensorRT时采用的精度,默认float32
         """
-        if not osp.isdir(model_dir):
-            logging.error(
-                "{} is not a valid model directory.".format(model_dir),
-                exit=True)
+        self.model_dir = model_dir
+        self._model = load_model(model_dir, with_net=False)
 
         if trt_precision_mode == 'float32':
             trt_precision_mode = PrecisionType.Float32
@@ -66,6 +62,19 @@ class Predictor(object):
                 .format(trt_precision_mode),
                 exit=True)
 
+        self.predictor = self.create_predictor(
+            use_gpu=use_gpu,
+            gpu_id=gpu_id,
+            cpu_thread_num=cpu_thread_num,
+            use_mkl=use_mkl,
+            mkl_thread_num=mkl_thread_num,
+            use_trt=use_trt,
+            use_glog=use_glog,
+            memory_optimize=memory_optimize,
+            max_trt_batch_size=max_trt_batch_size,
+            trt_precision_mode=trt_precision_mode)
+        self.timer = Timer()
+
     def create_predictor(self,
                          use_gpu=True,
                          gpu_id=0,
@@ -78,8 +87,8 @@ class Predictor(object):
                          max_trt_batch_size=1,
                          trt_precision_mode=PrecisionType.Float32):
         config = Config(
-            prog_file=osp.join(self.model_dir, 'model.pdmodel'),
-            params_file=osp.join(self.model_dir, 'model.pdiparams'))
+            osp.join(self.model_dir, 'model.pdmodel'),
+            osp.join(self.model_dir, 'model.pdiparams'))
 
         if use_gpu:
             # 设置GPU初始显存(单位M)和Device ID
@@ -117,3 +126,59 @@ class Predictor(object):
         config.switch_use_feed_fetch_ops(False)
         predictor = create_predictor(config)
         return predictor
+
+    def preprocess(self, images, transforms):
+        preprocessed_samples = self._model._preprocess(
+            images, transforms, to_tensor=False)
+        if self._model.model_type == 'classifier':
+            batch_samples = {'image': preprocessed_samples[0]}
+        elif self._model.model_type == 'segmenter':
+            batch_samples = {
+                'image': preprocessed_samples[0],
+                'ori_shape': preprocessed_samples[1]
+            }
+        elif self._model.model_type == 'detector':
+            batch_samples = preprocessed_samples
+        else:
+            logging.error(
+                "Invalid model type {}".format(self._model.model_type),
+                exit=True)
+        return batch_samples
+
+    def raw_predict(self, inputs):
+        """ 接受预处理过后的数据进行预测
+
+            Args:
+                inputs(dict): 预处理过后的数据
+        """
+
+    def predict(self, img_file, topk=1, transforms=None):
+        """ 图片预测
+
+            Args:
+                img_file(List[np.ndarray or str], str or np.ndarray):
+                    图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
+                topk(int): 分类预测时使用,表示预测前topk的结果。
+                transforms (paddlex.transforms): 数据预处理操作。
+        """
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise Exception("Transforms need to be defined, now is None.")
+        if transforms is None:
+            transforms = self._model.test_transforms
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
+        else:
+            images = img_file
+
+        self.timer.preprocess_time_s.start()
+        batch_samples = self.preprocess(images, transforms)
+        self.timer.preprocess_time_s.end()
+
+        input_names = self.predictor.get_input_names()
+        for name in input_names:
+            input_tensor = self.predictor.get_input_handle(name)
+            input_tensor.copy_from_cpu(batch_samples[name])
+
+        self.timer.inference_time_s.start()
+        self.predictor.run()
+        output_names = self.predictor.get_output_names()

+ 1 - 1
paddlex/utils/__init__.py

@@ -16,7 +16,7 @@ from . import logging
 from . import utils
 from .utils import (seconds_to_hms, get_encoding, get_single_card_bs, dict2str,
                     EarlyStop, path_normalization, is_pic, MyEncoder,
-                    DisablePrint)
+                    DisablePrint, Timer)
 from .checkpoint import get_pretrain_weights, load_pretrain_weights, load_checkpoint
 from .env import get_environ_info, get_num_workers, init_parallel_env
 from .download import download_and_decompress, decompress

+ 78 - 0
paddlex/utils/utils.py

@@ -14,6 +14,7 @@
 
 import sys
 import os
+import time
 import math
 import chardet
 import json
@@ -138,3 +139,80 @@ class DisablePrint(object):
     def __exit__(self, exc_type, exc_val, exc_tb):
         sys.stdout.close()
         sys.stdout = self._original_stdout
+
+
+class Times(object):
+    def __init__(self):
+        self.time = 0.
+        # start time
+        self.st = 0.
+        # end time
+        self.et = 0.
+
+    def start(self):
+        self.st = time.time()
+
+    def end(self, repeats=1, accumulative=True):
+        self.et = time.time()
+        if accumulative:
+            self.time += (self.et - self.st) / repeats
+        else:
+            self.time = (self.et - self.st) / repeats
+
+    def reset(self):
+        self.time = 0.
+        self.st = 0.
+        self.et = 0.
+
+    def value(self):
+        return round(self.time, 4)
+
+
+class Timer(Times):
+    def __init__(self):
+        super(Timer, self).__init__()
+        self.preprocess_time_s = Times()
+        self.inference_time_s = Times()
+        self.postprocess_time_s = Times()
+        self.img_num = 0
+
+    def info(self, average=False):
+        total_time = self.preprocess_time_s.value(
+        ) + self.inference_time_s.value() + self.postprocess_time_s.value()
+        total_time = round(total_time, 4)
+        print("------------------ Inference Time Info ----------------------")
+        print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
+                                                       self.img_num))
+        preprocess_time = round(
+            self.preprocess_time_s.value() / self.img_num,
+            4) if average else self.preprocess_time_s.value()
+        postprocess_time = round(
+            self.postprocess_time_s.value() / self.img_num,
+            4) if average else self.postprocess_time_s.value()
+        inference_time = round(self.inference_time_s.value() / self.img_num,
+                               4) if average else self.inference_time_s.value()
+
+        average_latency = total_time / self.img_num
+        print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
+            average_latency * 1000, 1 / average_latency))
+        print(
+            "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
+            format(preprocess_time * 1000, inference_time * 1000,
+                   postprocess_time * 1000))
+
+    def report(self, average=False):
+        dic = {}
+        dic['preprocess_time_s'] = round(
+            self.preprocess_time_s.value() / self.img_num,
+            4) if average else self.preprocess_time_s.value()
+        dic['postprocess_time_s'] = round(
+            self.postprocess_time_s.value() / self.img_num,
+            4) if average else self.postprocess_time_s.value()
+        dic['inference_time_s'] = round(
+            self.inference_time_s.value() / self.img_num,
+            4) if average else self.inference_time_s.value()
+        dic['img_num'] = self.img_num
+        total_time = self.preprocess_time_s.value(
+        ) + self.inference_time_s.value() + self.postprocess_time_s.value()
+        dic['total_time_s'] = round(total_time, 4)
+        return dic