Răsfoiți Sursa

fix seg postprocess int64 bug (#858)

* fix seg postprocess int64 bug

* use enum
heliqi 4 ani în urmă
părinte
comite
02a0726779

+ 20 - 9
dygraph/deploy/cpp/model_deploy/ppseg/src/seg_postprocess.cpp

@@ -59,29 +59,40 @@ void SegPostprocess::RestoreSegMap(const ShapeInfo& shape_info,
     score_mat->begin<float>(), score_mat->end<float>());
 }
 
+// ppseg version >= 2.1  shape = [b, w, h]
 bool SegPostprocess::RunV2(const DataBlob& output,
                            const std::vector<ShapeInfo>& shape_infos,
                            std::vector<Result>* results, int thread_num) {
   int batch_size = shape_infos.size();
-  std::vector<int> score_map_shape = output.shape;
-  int score_map_size = std::accumulate(output.shape.begin() + 1,
-                                       output.shape.end(), 1,
-                                       std::multiplies<int>());
-  const uint8_t* score_map_data =
-          reinterpret_cast<const uint8_t*>(output.data.data());
-  int num_map_pixels = output.shape[1] * output.shape[2];
+  int label_map_size = output.shape[1] * output.shape[2];
+  const uint8_t* label_data;
+  std::vector<uint8_t> label_vector;
+  if (output.dtype == INT64) {  // int64
+    const int64_t* output_data =
+          reinterpret_cast<const int64_t*>(output.data.data());
+    std::transform(output_data, output_data + label_map_size * batch_size,
+                   std::back_inserter(label_vector),
+                   [](int64_t x) { return (uint8_t)x;});
+    label_data = reinterpret_cast<const uint8_t*>(label_vector.data());
+  } else if (output.dtype == INT8) {  // uint8
+    label_data = reinterpret_cast<const uint8_t*>(output.data.data());
+  } else {
+    std::cerr << "Output dtype is not support on seg posrtprocess "
+              << output.dtype << std::endl;
+    return false;
+  }
 
   for (int i = 0; i < batch_size; ++i) {
     (*results)[i].model_type = "seg";
     (*results)[i].seg_result = new SegResult();
-    const uint8_t* current_start_ptr = score_map_data + i * score_map_size;
+    const uint8_t* current_start_ptr = label_data + i * label_map_size;
     cv::Mat score_mat(output.shape[1], output.shape[2],
                       CV_32FC1, cv::Scalar(1.0));
     cv::Mat label_mat(output.shape[1], output.shape[2],
                       CV_8UC1, const_cast<uint8_t*>(current_start_ptr));
 
     RestoreSegMap(shape_infos[i], &label_mat,
-                &score_mat, (*results)[i].seg_result);
+                 &score_mat, (*results)[i].seg_result);
   }
   return true;
 }