浏览代码

support loading exported infer mdoel

will-jl944 4 年之前
父节点
当前提交
8df2fcd3ec
共有 4 个文件被更改,包括 99 次插入37 次删除
  1. 5 1
      paddlex/cv/models/classifier.py
  2. 9 1
      paddlex/cv/models/load_model.py
  3. 74 29
      paddlex/cv/models/segmenter.py
  4. 11 6
      paddlex/deploy.py

+ 5 - 1
paddlex/cv/models/classifier.py

@@ -98,7 +98,7 @@ class BaseClassifier(BaseModel):
 
     def run(self, net, inputs, mode):
         net_out = net(inputs[0])
-        softmax_out = F.softmax(net_out)
+        softmax_out = net_out if self.status == 'Infer' else F.softmax(net_out)
         if mode == 'test':
             outputs = OrderedDict([('prediction', softmax_out)])
 
@@ -228,6 +228,10 @@ class BaseClassifier(BaseModel):
                 `pretrain_weights` can be set simultaneously. Defaults to None.
 
         """
+        if self.status == 'Infer':
+            logging.error(
+                "Exported inference model does not support training.",
+                exit=True)
         if pretrain_weights is not None and resume_checkpoint is not None:
             logging.error(
                 "pretrain_weights and resume_checkpoint cannot be set simultaneously.",

+ 9 - 1
paddlex/cv/models/load_model.py

@@ -20,6 +20,7 @@ import paddleslim
 import paddlex
 import paddlex.utils.logging as logging
 from paddlex.cv.transforms import build_transforms
+from .utils.infer_nets import InferNet
 
 
 def load_rcnn_inference_model(model_dir):
@@ -104,7 +105,8 @@ def load_model(model_dir, **params):
                         ratios=model.pruning_ratios,
                         axis=paddleslim.dygraph.prune.filter_pruner.FILTER_DIM)
 
-            if status == 'Quantized':
+            if status == 'Quantized' or osp.exists(
+                    osp.join(model_dir, "quant.yml")):
                 with open(osp.join(model_dir, "quant.yml")) as f:
                     quant_info = yaml.load(f.read(), Loader=yaml.Loader)
                     model.quant_config = quant_info['quant_config']
@@ -112,6 +114,12 @@ def load_model(model_dir, **params):
                     model.quantizer.quantize(model.net)
 
             if status == 'Infer':
+                if osp.exists(osp.join(model_dir, "quant.yml")):
+                    logging.error(
+                        "Exported quantized model can not be loaded, only deployment is supported.",
+                        exit=True)
+                model.net = InferNet(
+                    net=model.net, model_type=model.model_type)
                 if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
                     net_state_dict = load_rcnn_inference_model(model_dir)
                 else:

+ 74 - 29
paddlex/cv/models/segmenter.py

@@ -104,22 +104,28 @@ class BaseSegmenter(BaseModel):
         outputs = OrderedDict()
         if mode == 'test':
             origin_shape = inputs[1]
-            score_map = self._postprocess(
-                logit, origin_shape, transforms=inputs[2])
-            label_map = paddle.argmax(
-                score_map, axis=1, keepdim=True, dtype='int32')
-            score_map = paddle.transpose(
-                paddle.nn.functional.softmax(
-                    score_map, axis=1),
-                perm=[0, 2, 3, 1])
-            score_map = paddle.squeeze(score_map)
-            label_map = paddle.squeeze(label_map)
-            outputs = {'label_map': label_map, 'score_map': score_map}
+            if self.status == 'Infer':
+                score_map, label_map = self._postprocess(
+                    net_out, origin_shape, transforms=inputs[2])
+            else:
+                logit = self._postprocess(
+                    logit, origin_shape, transforms=inputs[2])
+                score_map = paddle.transpose(
+                    F.softmax(
+                        logit, axis=1), perm=[0, 2, 3, 1])
+                label_map = paddle.argmax(
+                    score_map, axis=-1, keepdim=True, dtype='int32')
+            outputs['label_map'] = paddle.squeeze(label_map)
+            outputs['score_map'] = paddle.squeeze(score_map)
+
         if mode == 'eval':
-            pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
+            if self.status == 'Infer':
+                pred = paddle.transpose(net_out[1], perm=[0, 3, 1, 2])
+            else:
+                pred = paddle.argmax(
+                    logit, axis=1, keepdim=True, dtype='int32')
             label = inputs[1]
             origin_shape = [label.shape[-2:]]
-            # TODO: 替换cv2后postprocess移出run
             pred = self._postprocess(pred, origin_shape, transforms=inputs[2])
             intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
                 pred, label, self.num_classes)
@@ -570,33 +576,72 @@ class BaseSegmenter(BaseModel):
     def _postprocess(self, batch_pred, batch_origin_shape, transforms):
         batch_restore_list = BaseSegmenter.get_transforms_shape_info(
             batch_origin_shape, transforms)
-        results = list()
+        if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
+            return self._infer_postprocess(
+                batch_score_map=batch_pred[0],
+                batch_label_map=batch_pred[1],
+                batch_restore_list=batch_restore_list)
+        results = []
         for pred, restore_list in zip(batch_pred, batch_restore_list):
-            if not isinstance(pred, np.ndarray):
-                pred = paddle.unsqueeze(pred, axis=0)
+            pred = paddle.unsqueeze(pred, axis=0)
             for item in restore_list[::-1]:
-                # TODO: 替换成cv2的interpolate(部署阶段无法使用paddle op)
                 h, w = item[1][0], item[1][1]
                 if item[0] == 'resize':
-                    if not isinstance(pred, np.ndarray):
-                        pred = F.interpolate(pred, (h, w), mode='nearest')
+                    pred = F.interpolate(
+                        pred, (h, w), mode='nearest', data_format='NCHW')
+                elif item[0] == 'padding':
+                    x, y = item[2]
+                    pred = pred[:, :, y:y + h, x:x + w]
+                else:
+                    pass
+            results.append(pred)
+        batch_pred = paddle.concat(results, axis=0)
+        return batch_pred
+
+    def _infer_postprocess(self, batch_score_map, batch_label_map,
+                           batch_restore_list):
+        score_maps = []
+        label_maps = []
+        for score_map, label_map, restore_list in zip(
+                batch_score_map, batch_label_map, batch_restore_list):
+            if not isinstance(score_map, np.ndarray):
+                score_map = paddle.unsqueeze(score_map, axis=0)
+                label_map = paddle.unsqueeze(label_map, axis=0)
+            for item in restore_list[::-1]:
+                h, w = item[1][0], item[1][1]
+                if item[0] == 'resize':
+                    if isinstance(score_map, np.ndarray):
+                        score_map = cv2.resize(
+                            score_map, (h, w), interpolation=cv2.INTER_LINEAR)
+                        label_map = cv2.resize(
+                            label_map, (h, w), interpolation=cv2.INTER_NEAREST)
                     else:
-                        pred = cv2.resize(
-                            pred, (h, w), interpolation=cv2.INTER_NEAREST)
+                        score_map = F.interpolate(
+                            score_map, (h, w),
+                            mode='bilinear',
+                            data_format='NHWC')
+                        label_map = F.interpolate(
+                            label_map, (h, w),
+                            mode='nearest',
+                            data_format='NHWC')
                 elif item[0] == 'padding':
                     x, y = item[2]
-                    if not isinstance(pred, np.ndarray):
-                        pred = pred[:, :, y:y + h, x:x + w]
+                    if isinstance(score_map, np.ndarray):
+                        score_map = score_map[..., y:y + h, x:x + w]
+                        label_map = label_map[..., y:y + h, x:x + w]
                     else:
-                        pred = pred[..., y:y + h, x:x + w]
+                        score_map = score_map[:, :, y:y + h, x:x + w]
+                        label_map = label_map[:, :, y:y + h, x:x + w]
                 else:
                     pass
-            results.append(pred)
-        if not isinstance(pred, np.ndarray):
-            batch_pred = paddle.concat(results, axis=0)
+            score_maps.append(score_map)
+            label_maps.append(label_map)
+        if isinstance(score_maps[0], np.ndarray):
+            return np.stack(score_maps, axis=0), np.stack(label_maps, axis=0)
         else:
-            batch_pred = np.stack(results, axis=0)
-        return batch_pred
+            return paddle.concat(
+                score_maps, axis=0), paddle.concat(
+                    label_maps, axis=0)
 
 
 class UNet(BaseSegmenter):

+ 11 - 6
paddlex/deploy.py

@@ -150,15 +150,15 @@ class Predictor(object):
         if self._model.model_type == 'classifier':
             true_topk = min(self._model.num_classes, topk)
             preds = self._model._postprocess(net_outputs[0], true_topk)
+            if len(preds) == 1:
+                preds = preds[0]
         elif self._model.model_type == 'segmenter':
-            score_map, label_map = net_outputs
-            combo = np.concatenate([score_map, label_map], axis=-1)
-            combo = self._model._postprocess(
-                combo,
+            score_map, label_map = self._model._postprocess(
+                net_outputs,
                 batch_origin_shape=ori_shape,
                 transforms=transforms.transforms)
-            score_map = np.squeeze(combo[..., :-1])
-            label_map = np.squeeze(combo[..., -1])
+            score_map = np.squeeze(score_map)
+            label_map = np.squeeze(label_map)
             if len(score_map.shape) == 3:
                 preds = {'label_map': label_map, 'score_map': score_map}
             else:
@@ -172,6 +172,8 @@ class Predictor(object):
                 for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
             }
             preds = self._model._postprocess(net_outputs)
+            if len(preds) == 1:
+                preds = preds[0]
         else:
             logging.error(
                 "Invalid model type {}.".format(self._model.model_type),
@@ -234,4 +236,7 @@ class Predictor(object):
             transforms=transforms)
         self.timer.postprocess_time_s.end()
 
+        self.timer.img_num = len(images)
+        self.timer.info(average=True)
+
         return results