Browse Source

Merge pull request #1082 from will-jl944/deploy_jf

Refine cpp seg mode deployment postprocess
will-jl944 4 years ago
parent
commit
15f17498e9

+ 1 - 0
deploy/cpp/CMakeLists.txt

@@ -43,6 +43,7 @@ include_directories("${PROJECT_SOURCE_DIR}")
 
 # common
 aux_source_directory(${PROJECT_SOURCE_DIR}/model_deploy/common/src SRC)
+aux_source_directory(${PROJECT_SOURCE_DIR}/model_deploy/utils/src SRC)
 
 # det seg clas pdx src
 aux_source_directory(${PROJECT_SOURCE_DIR}/model_deploy/ppdet/src DETECTOR_SRC)

+ 10 - 6
deploy/cpp/model_deploy/ppseg/include/seg_postprocess.h

@@ -39,14 +39,18 @@ class SegPostprocess : public BasePostprocess {
                    const std::vector<ShapeInfo>& shape_infos,
                    std::vector<Result>* results, int thread_num = 1);
 
-  void RestoreSegMap(const ShapeInfo& shape_info,
-                     cv::Mat* label_mat,
-                     cv::Mat* score_mat,
-                     SegResult* result);
+  void RestoreSegMap(const ShapeInfo& shape_info, cv::Mat* label_mat,
+                     cv::Mat* score_mat, SegResult* result);
 
-  bool RunV2(const DataBlob& outputs,
-             const std::vector<ShapeInfo>& shape_infos,
+  bool RunV2(const DataBlob& outputs, const std::vector<ShapeInfo>& shape_infos,
              std::vector<Result>* results, int thread_num);
+
+  bool RunXV2(const std::vector<DataBlob>& outputs,
+              const std::vector<ShapeInfo>& shape_infos,
+              std::vector<Result>* results, int thread_num);
+
+ private:
+  std::string version_;
 };
 
 }  // namespace PaddleDeploy

+ 105 - 33
deploy/cpp/model_deploy/ppseg/src/seg_postprocess.cpp

@@ -12,24 +12,30 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include <time.h>
-
 #include "model_deploy/ppseg/include/seg_postprocess.h"
 
+#include <time.h>
+
 namespace PaddleDeploy {
 
 bool SegPostprocess::Init(const YAML::Node& yaml_config) {
+  if (yaml_config["version"].IsDefined() &&
+      yaml_config["toolkit"].as<std::string>() == "PaddleX") {
+    version_ = yaml_config["version"].as<std::string>();
+  } else {
+    version_ = "0.0.0";
+  }
   return true;
 }
 
 void SegPostprocess::RestoreSegMap(const ShapeInfo& shape_info,
-                                   cv::Mat* label_mat,
-                                   cv::Mat*  score_mat,
+                                   cv::Mat* label_mat, cv::Mat* score_mat,
                                    SegResult* result) {
   int ori_h = shape_info.shapes[0][1];
   int ori_w = shape_info.shapes[0][0];
+  int score_c = score_mat->channels();
   result->label_map.Resize({ori_h, ori_w});
-  result->score_map.Resize({ori_h, ori_w});
+  result->score_map.Resize({ori_h, ori_w, score_c});
 
   for (int j = shape_info.transforms.size() - 1; j > 0; --j) {
     std::vector<int> last_shape = shape_info.shapes[j - 1];
@@ -38,13 +44,13 @@ void SegPostprocess::RestoreSegMap(const ShapeInfo& shape_info,
         shape_info.transforms[j] == "ResizeByShort" ||
         shape_info.transforms[j] == "ResizeByLong") {
       if (last_shape[0] != label_mat->cols ||
-            last_shape[1] != label_mat->rows) {
+          last_shape[1] != label_mat->rows) {
         cv::resize(*label_mat, *label_mat,
-                cv::Size(last_shape[0], last_shape[1]),
-                0, 0, cv::INTER_NEAREST);
+                   cv::Size(last_shape[0], last_shape[1]), 0, 0,
+                   cv::INTER_NEAREST);
         cv::resize(*score_mat, *score_mat,
-                cv::Size(last_shape[0], last_shape[1]),
-                0, 0, cv::INTER_LINEAR);
+                   cv::Size(last_shape[0], last_shape[1]), 0, 0,
+                   cv::INTER_LINEAR);
       }
     } else if (shape_info.transforms[j] == "Padding") {
       if (last_shape[0] < label_mat->cols || last_shape[1] < label_mat->rows) {
@@ -53,10 +59,32 @@ void SegPostprocess::RestoreSegMap(const ShapeInfo& shape_info,
       }
     }
   }
-  result->label_map.data.assign(
-    label_mat->begin<uint8_t>(), label_mat->end<uint8_t>());
-  result->score_map.data.assign(
-    score_mat->begin<float>(), score_mat->end<float>());
+  if (label_mat->isContinuous()) {
+    result->label_map.data.assign(
+        reinterpret_cast<const uint8_t*>(label_mat->data),
+        reinterpret_cast<const uint8_t*>(label_mat->data) +
+            label_mat->total() * (label_mat->channels()));
+  } else {
+    for (int i = 0; i < label_mat->rows; ++i) {
+      result->label_map.data.insert(
+          result->label_map.data.end(), label_mat->ptr<uint8_t>(i),
+          label_mat->ptr<uint8_t>(i) +
+              label_mat->cols * (label_mat->channels()));
+    }
+  }
+
+  if (score_mat->isContinuous()) {
+    result->score_map.data.assign(
+        reinterpret_cast<const float*>(score_mat->data),
+        reinterpret_cast<const float*>(score_mat->data) +
+            score_mat->total() * (score_mat->channels()));
+  } else {
+    for (int i = 0; i < score_mat->rows; ++i) {
+      result->score_map.data.insert(
+          result->score_map.data.end(), score_mat->ptr<float>(i),
+          score_mat->ptr<float>(i) + score_mat->cols * (score_mat->channels()));
+    }
+  }
 }
 
 // ppseg version >= 2.1  shape = [b, w, h]
@@ -69,17 +97,17 @@ bool SegPostprocess::RunV2(const DataBlob& output,
   std::vector<uint8_t> label_vector;
   if (output.dtype == INT64) {  // int64
     const int64_t* output_data =
-          reinterpret_cast<const int64_t*>(output.data.data());
+        reinterpret_cast<const int64_t*>(output.data.data());
     std::transform(output_data, output_data + label_map_size * batch_size,
                    std::back_inserter(label_vector),
-                   [](int64_t x) { return (uint8_t)x;});
+                   [](int64_t x) { return (uint8_t)x; });
     label_data = reinterpret_cast<const uint8_t*>(label_vector.data());
   } else if (output.dtype == INT32) {  // int32
     const int32_t* output_data =
-          reinterpret_cast<const int32_t*>(output.data.data());
+        reinterpret_cast<const int32_t*>(output.data.data());
     std::transform(output_data, output_data + label_map_size * batch_size,
                    std::back_inserter(label_vector),
-                   [](int32_t x) { return (uint8_t)x;});
+                   [](int32_t x) { return (uint8_t)x; });
     label_data = reinterpret_cast<const uint8_t*>(label_vector.data());
   } else if (output.dtype == INT8) {  // uint8
     label_data = reinterpret_cast<const uint8_t*>(output.data.data());
@@ -93,13 +121,52 @@ bool SegPostprocess::RunV2(const DataBlob& output,
     (*results)[i].model_type = "seg";
     (*results)[i].seg_result = new SegResult();
     const uint8_t* current_start_ptr = label_data + i * label_map_size;
-    cv::Mat score_mat(output.shape[1], output.shape[2],
-                      CV_32FC1, cv::Scalar(1.0));
-    cv::Mat label_mat(output.shape[1], output.shape[2],
-                      CV_8UC1, const_cast<uint8_t*>(current_start_ptr));
+    cv::Mat score_mat(output.shape[1], output.shape[2], CV_32FC1,
+                      cv::Scalar(1.0));
+    cv::Mat label_mat(output.shape[1], output.shape[2], CV_8UC1,
+                      const_cast<uint8_t*>(current_start_ptr));
 
-    RestoreSegMap(shape_infos[i], &label_mat,
-                 &score_mat, (*results)[i].seg_result);
+    RestoreSegMap(shape_infos[i], &label_mat, &score_mat,
+                  (*results)[i].seg_result);
+  }
+  return true;
+}
+
+// paddlex version >= 2.0.0 shape = [b, h, w, c]
+bool SegPostprocess::RunXV2(const std::vector<DataBlob>& outputs,
+                            const std::vector<ShapeInfo>& shape_infos,
+                            std::vector<Result>* results, int thread_num) {
+  int batch_size = shape_infos.size();
+  int label_map_size = outputs[0].shape[1] * outputs[0].shape[2];
+  std::vector<int> score_map_shape = outputs[1].shape;
+  int score_map_size =
+      std::accumulate(score_map_shape.begin() + 1, score_map_shape.end(), 1,
+                      std::multiplies<int>());
+  const uint8_t* label_map_data;
+  std::vector<uint8_t> label_map_vector;
+  if (outputs[0].dtype == INT32) {
+    const int32_t* output_data =
+        reinterpret_cast<const int32_t*>(outputs[0].data.data());
+    std::transform(output_data, output_data + label_map_size * batch_size,
+                   std::back_inserter(label_map_vector),
+                   [](int32_t x) { return (uint8_t)x; });
+    label_map_data = reinterpret_cast<const uint8_t*>(label_map_vector.data());
+  }
+  const float* score_map_data =
+      reinterpret_cast<const float*>(outputs[1].data.data());
+  for (int i = 0; i < batch_size; ++i) {
+    (*results)[i].model_type = "seg";
+    (*results)[i].seg_result = new SegResult();
+    const uint8_t* current_label_start_ptr =
+        label_map_data + i * label_map_size;
+    const float* current_score_start_ptr = score_map_data + i * score_map_size;
+    cv::Mat label_mat(outputs[0].shape[1], outputs[0].shape[2], CV_8UC1,
+                      const_cast<uint8_t*>(current_label_start_ptr));
+    cv::Mat score_mat(score_map_shape[1], score_map_shape[2],
+                      CV_32FC(score_map_shape[3]),
+                      const_cast<float*>(current_score_start_ptr));
+    RestoreSegMap(shape_infos[i], &label_mat, &score_mat,
+                  (*results)[i].seg_result);
   }
   return true;
 }
@@ -115,21 +182,26 @@ bool SegPostprocess::Run(const std::vector<DataBlob>& outputs,
   int batch_size = shape_infos.size();
   results->resize(batch_size);
 
-  // tricks for PaddleX, which segmentation model has two outputs
+  // tricks for PaddleX, of which segmentation model has two outputs
   int index = 0;
   if (outputs.size() == 2) {
     index = 1;
   }
   std::vector<int> score_map_shape = outputs[index].shape;
-  // ppseg version >= 2.1  shape = [b, w, h]
+  // paddlex version >= 2.0.0 shape[b, h, w, c]
+  if (version_ >= "2.0.0") {
+    return RunXV2(outputs, shape_infos, results, thread_num);
+  }
+  // ppseg version >= 2.1  shape = [b, h, w]
   if (score_map_shape.size() == 3) {
     return RunV2(outputs[index], shape_infos, results, thread_num);
   }
 
-  int score_map_size = std::accumulate(score_map_shape.begin() + 1,
-                    score_map_shape.end(), 1, std::multiplies<int>());
+  int score_map_size =
+      std::accumulate(score_map_shape.begin() + 1, score_map_shape.end(), 1,
+                      std::multiplies<int>());
   const float* score_map_data =
-        reinterpret_cast<const float*>(outputs[index].data.data());
+      reinterpret_cast<const float*>(outputs[index].data.data());
   int num_map_pixels = score_map_shape[2] * score_map_shape[3];
 
   for (int i = 0; i < batch_size; ++i) {
@@ -137,8 +209,8 @@ bool SegPostprocess::Run(const std::vector<DataBlob>& outputs,
     (*results)[i].seg_result = new SegResult();
     const float* current_start_ptr = score_map_data + i * score_map_size;
     cv::Mat ori_score_mat(score_map_shape[1],
-            score_map_shape[2] * score_map_shape[3],
-            CV_32FC1, const_cast<float*>(current_start_ptr));
+                          score_map_shape[2] * score_map_shape[3], CV_32FC1,
+                          const_cast<float*>(current_start_ptr));
     ori_score_mat = ori_score_mat.t();
     cv::Mat score_mat(score_map_shape[2], score_map_shape[3], CV_32FC1);
     cv::Mat label_mat(score_map_shape[2], score_map_shape[3], CV_8UC1);
@@ -149,8 +221,8 @@ bool SegPostprocess::Run(const std::vector<DataBlob>& outputs,
       score_mat.at<float>(j) = max_value;
       label_mat.at<uchar>(j) = max_id.x;
     }
-    RestoreSegMap(shape_infos[i], &label_mat,
-                &score_mat, (*results)[i].seg_result);
+    RestoreSegMap(shape_infos[i], &label_mat, &score_mat,
+                  (*results)[i].seg_result);
   }
   return true;
 }

+ 1 - 0
deploy/cpp/model_deploy/utils/src/visualize.cpp

@@ -124,6 +124,7 @@ bool Visualize(const cv::Mat& img,
       vis_img->at<cv::Vec3b>(i, j)[2] = color_map[3 * category_id + 2];
     }
   }
+  cv::addWeighted(img, .5, *vis_img, .5, 0, *vis_img);
   return true;
 }
 

+ 61 - 52
paddlex/cv/models/segmenter.py

@@ -105,28 +105,35 @@ class BaseSegmenter(BaseModel):
         if mode == 'test':
             origin_shape = inputs[1]
             if self.status == 'Infer':
-                score_map, label_map = self._postprocess(
+                label_map_list, score_map_list = self._postprocess(
                     net_out, origin_shape, transforms=inputs[2])
             else:
-                logit = self._postprocess(
+                logit_list = self._postprocess(
                     logit, origin_shape, transforms=inputs[2])
-                score_map = paddle.transpose(
-                    F.softmax(
-                        logit, axis=1), perm=[0, 2, 3, 1])
-                label_map = paddle.argmax(
-                    score_map, axis=-1, keepdim=True, dtype='int32')
-            outputs['label_map'] = paddle.squeeze(label_map)
-            outputs['score_map'] = paddle.squeeze(score_map)
+                label_map_list = []
+                score_map_list = []
+                for logit in logit_list:
+                    logit = paddle.transpose(logit, perm=[0, 2, 3, 1])  # NHWC
+                    label_map_list.append((paddle.argmax(
+                        logit, axis=-1, keepdim=False, dtype='int32')).squeeze(
+                        ).numpy().astype('int32'))
+                    score_map_list.append(
+                        F.softmax(
+                            logit, axis=-1).squeeze().numpy().astype(
+                                'float32'))
+            outputs['label_map'] = label_map_list
+            outputs['score_map'] = score_map_list
 
         if mode == 'eval':
             if self.status == 'Infer':
-                pred = paddle.transpose(net_out[1], perm=[0, 3, 1, 2])
+                pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
                 pred = paddle.argmax(
                     logit, axis=1, keepdim=True, dtype='int32')
             label = inputs[1]
             origin_shape = [label.shape[-2:]]
-            pred = self._postprocess(pred, origin_shape, transforms=inputs[2])
+            pred = self._postprocess(
+                pred, origin_shape, transforms=inputs[2])[0]  # NCHW
             intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
                 pred, label, self.num_classes)
             outputs['intersect_area'] = intersect_area
@@ -477,8 +484,8 @@ class BaseSegmenter(BaseModel):
             If img_file is a string or np.array, the result is a dict with key-value pairs:
             {"label map": `label map`, "score_map": `score map`}.
             If img_file is a list, the result is a list composed of dicts with the corresponding fields:
-            label_map(np.ndarray): the predicted label map
-            score_map(np.ndarray): the prediction score map (NHWC)
+            label_map(np.ndarray): the predicted label map (HW)
+            score_map(np.ndarray): the prediction score map (HWC)
 
         """
         if transforms is None and not hasattr(self, 'test_transforms'):
@@ -494,19 +501,18 @@ class BaseSegmenter(BaseModel):
         self.net.eval()
         data = (batch_im, batch_origin_shape, transforms.transforms)
         outputs = self.run(self.net, data, 'test')
-        label_map = outputs['label_map']
-        label_map = label_map.numpy().astype('uint8')
-        score_map = outputs['score_map']
-        score_map = score_map.numpy().astype('float32')
-        if isinstance(img_file, list) and len(img_file) > 1:
+        label_map_list = outputs['label_map']
+        score_map_list = outputs['score_map']
+        if isinstance(img_file, list):
             prediction = [{
                 'label_map': l,
                 'score_map': s
-            } for l, s in zip(label_map, score_map)]
-        elif isinstance(img_file, list):
-            prediction = [{'label_map': label_map, 'score_map': score_map}]
+            } for l, s in zip(label_map_list, score_map_list)]
         else:
-            prediction = {'label_map': label_map, 'score_map': score_map}
+            prediction = {
+                'label_map': label_map_list[0],
+                'score_map': score_map_list[0]
+            }
         return prediction
 
     def _preprocess(self, images, transforms, to_tensor=True):
@@ -582,70 +588,73 @@ class BaseSegmenter(BaseModel):
             batch_origin_shape, transforms)
         if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
             return self._infer_postprocess(
-                batch_score_map=batch_pred[0],
-                batch_label_map=batch_pred[1],
+                batch_label_map=batch_pred[0],
+                batch_score_map=batch_pred[1],
                 batch_restore_list=batch_restore_list)
         results = []
+        if batch_pred.dtype == paddle.float32:
+            mode = 'bilinear'
+        else:
+            mode = 'nearest'
         for pred, restore_list in zip(batch_pred, batch_restore_list):
             pred = paddle.unsqueeze(pred, axis=0)
             for item in restore_list[::-1]:
                 h, w = item[1][0], item[1][1]
                 if item[0] == 'resize':
                     pred = F.interpolate(
-                        pred, (h, w), mode='nearest', data_format='NCHW')
+                        pred, (h, w), mode=mode, data_format='NCHW')
                 elif item[0] == 'padding':
                     x, y = item[2]
                     pred = pred[:, :, y:y + h, x:x + w]
                 else:
                     pass
             results.append(pred)
-        batch_pred = paddle.concat(results, axis=0)
-        return batch_pred
+        return results
 
-    def _infer_postprocess(self, batch_score_map, batch_label_map,
+    def _infer_postprocess(self, batch_label_map, batch_score_map,
                            batch_restore_list):
-        score_maps = []
         label_maps = []
-        for score_map, label_map, restore_list in zip(
-                batch_score_map, batch_label_map, batch_restore_list):
-            if not isinstance(score_map, np.ndarray):
+        score_maps = []
+        for label_map, score_map, restore_list in zip(
+                batch_label_map, batch_score_map, batch_restore_list):
+            if not isinstance(label_map, np.ndarray):
+                label_map = paddle.unsqueeze(label_map, axis=[0, 3])
                 score_map = paddle.unsqueeze(score_map, axis=0)
-                label_map = paddle.unsqueeze(label_map, axis=0)
             for item in restore_list[::-1]:
                 h, w = item[1][0], item[1][1]
                 if item[0] == 'resize':
-                    if isinstance(score_map, np.ndarray):
-                        score_map = cv2.resize(
-                            score_map, (h, w), interpolation=cv2.INTER_LINEAR)
+                    if isinstance(label_map, np.ndarray):
                         label_map = cv2.resize(
-                            label_map, (h, w), interpolation=cv2.INTER_NEAREST)
+                            label_map, (w, h), interpolation=cv2.INTER_NEAREST)
+                        score_map = cv2.resize(
+                            score_map, (w, h), interpolation=cv2.INTER_LINEAR)
                     else:
-                        score_map = F.interpolate(
-                            score_map, (h, w),
-                            mode='bilinear',
-                            data_format='NHWC')
                         label_map = F.interpolate(
                             label_map, (h, w),
                             mode='nearest',
                             data_format='NHWC')
+                        score_map = F.interpolate(
+                            score_map, (h, w),
+                            mode='bilinear',
+                            data_format='NHWC')
                 elif item[0] == 'padding':
                     x, y = item[2]
-                    if isinstance(score_map, np.ndarray):
-                        score_map = score_map[..., y:y + h, x:x + w]
+                    if isinstance(label_map, np.ndarray):
                         label_map = label_map[..., y:y + h, x:x + w]
+                        score_map = score_map[..., y:y + h, x:x + w]
                     else:
-                        score_map = score_map[:, :, y:y + h, x:x + w]
                         label_map = label_map[:, :, y:y + h, x:x + w]
+                        score_map = score_map[:, :, y:y + h, x:x + w]
                 else:
                     pass
-            score_maps.append(score_map)
-            label_maps.append(label_map)
-        if isinstance(score_maps[0], np.ndarray):
-            return np.stack(score_maps, axis=0), np.stack(label_maps, axis=0)
-        else:
-            return paddle.concat(
-                score_maps, axis=0), paddle.concat(
-                    label_maps, axis=0)
+            label_map = label_map.squeeze()
+            score_map = score_map.squeeze()
+            if not isinstance(label_map, np.ndarray):
+                label_map = label_map.numpy()
+                score_map = score_map.numpy()
+            label_maps.append(label_map.squeeze())
+            score_maps.append(score_map.squeeze())
+        return label_maps, score_maps
 
 
 class UNet(BaseSegmenter):

+ 3 - 4
paddlex/cv/models/utils/infer_nets.py

@@ -24,11 +24,10 @@ class PostProcessor(paddle.nn.Layer):
         if self.model_type == 'classifier':
             outputs = paddle.nn.functional.softmax(net_outputs, axis=1)
         else:
-            # score_map, label_map
+            # label_map [NHW], score_map [NHWC]
             logit = net_outputs[0]
-            outputs = paddle.transpose(paddle.nn.functional.softmax(logit, axis=1), perm=[0, 2, 3, 1]), \
-                      paddle.transpose(paddle.argmax(logit, axis=1, keepdim=True, dtype='int32'),
-                                       perm=[0, 2, 3, 1])
+            outputs = paddle.argmax(logit, axis=1, keepdim=False, dtype='int32'), \
+                      paddle.transpose(paddle.nn.functional.softmax(logit, axis=1), perm=[0, 2, 3, 1])
 
         return outputs
 

+ 2 - 2
paddlex/deploy.py

@@ -152,12 +152,12 @@ class Predictor(object):
             if len(preds) == 1:
                 preds = preds[0]
         elif self._model.model_type == 'segmenter':
-            score_map, label_map = self._model._postprocess(
+            label_map, score_map = self._model._postprocess(
                 net_outputs,
                 batch_origin_shape=ori_shape,
                 transforms=transforms.transforms)
-            score_map = np.squeeze(score_map)
             label_map = np.squeeze(label_map)
+            score_map = np.squeeze(score_map)
             if score_map.ndim == 3:
                 preds = {'label_map': label_map, 'score_map': score_map}
             else: