浏览代码

refine cpp seg deploy postprocess

will-jl944 4 年之前
父节点
当前提交
b1eca3d289
共有 1 个文件被更改,包括 27 次插入9 次删除
  1. 27 9
      deploy/cpp/model_deploy/ppseg/src/seg_postprocess.cpp

+ 27 - 9
deploy/cpp/model_deploy/ppseg/src/seg_postprocess.cpp

@@ -35,11 +35,7 @@ void SegPostprocess::RestoreSegMap(const ShapeInfo& shape_info,
   int ori_w = shape_info.shapes[0][0];
   int score_c = score_mat->channels();
   result->label_map.Resize({ori_h, ori_w});
-  if (score_c == 1) {
-    result->score_map.Resize({ori_h, ori_w});
-  } else {
-    result->score_map.Resize({ori_h, ori_w, score_c});
-  }
+  result->score_map.Resize({ori_h, ori_w, score_c});
 
   for (int j = shape_info.transforms.size() - 1; j > 0; --j) {
     std::vector<int> last_shape = shape_info.shapes[j - 1];
@@ -63,10 +59,32 @@ void SegPostprocess::RestoreSegMap(const ShapeInfo& shape_info,
       }
     }
   }
-  result->label_map.data.assign(label_mat->begin<uint8_t>(),
-                                label_mat->end<uint8_t>());
-  result->score_map.data.assign(score_mat->begin<float>(),
-                                score_mat->end<float>());
+  if (label_mat->isContinuous()) {
+    result->label_map.data.assign(
+        reinterpret_cast<const uint8_t*>(label_mat->data),
+        reinterpret_cast<const uint8_t*>(label_mat->data) +
+            label_mat->total() * (label_mat->channels()));
+  } else {
+    for (int i = 0; i < label_mat->rows; ++i) {
+      result->label_map.data.insert(
+          result->label_map.data.end(), label_mat->ptr<uint8_t>(i),
+          label_mat->ptr<uint8_t>(i) +
+              label_mat->cols * (label_mat->channels()));
+    }
+  }
+
+  if (score_mat->isContinuous()) {
+    result->score_map.data.assign(
+        reinterpret_cast<const float*>(score_mat->data),
+        reinterpret_cast<const float*>(score_mat->data) +
+            score_mat->total() * (score_mat->channels()));
+  } else {
+    for (int i = 0; i < score_mat->rows; ++i) {
+      result->score_map.data.insert(
+          result->score_map.data.end(), score_mat->ptr<float>(i),
+          score_mat->ptr<float>(i) + score_mat->cols * (score_mat->channels()));
+    }
+  }
 }
 
 // ppseg version >= 2.1  shape = [b, w, h]