浏览代码

add seg batch predict

jack 5 年之前
父节点
当前提交
26cefa3dc2
共有 4 个文件被更改,包括 220 次插入26 次删除
  1. 19 14
      deploy/cpp/demo/classifier.cpp
  2. 54 12
      deploy/cpp/demo/segmenter.cpp
  3. 4 0
      deploy/cpp/include/paddlex/paddlex.h
  4. 143 0
      deploy/cpp/src/paddlex.cpp

+ 19 - 14
deploy/cpp/demo/classifier.cpp

@@ -32,7 +32,7 @@ DEFINE_int32(gpu_id, 0, "GPU card id");
 DEFINE_string(key, "", "key of encryption");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
-DEFINE_int32(batch_size, 1, "Batch size when infering");
+DEFINE_int32(batch_size, 1, "Batch size of infering");
 
 int main(int argc, char** argv) {
   // Parsing command-line
@@ -53,8 +53,8 @@ int main(int argc, char** argv) {
 
   // 进行预测
   double total_running_time_s = 0.0;
-  double total_imreaad_time_s = 0.0;
-
+  double total_imread_time_s = 0.0;
+  int imgs = 1;
   if (FLAGS_image_list != "") {
     std::ifstream inf(FLAGS_image_list);
     if (!inf) {
@@ -63,31 +63,32 @@ int main(int argc, char** argv) {
     }
     // 多batch预测
     std::string image_path;
-    std::vector<std::string> image_path_vec;
+    std::vector<std::string> image_paths;
     while (getline(inf, image_path)) {
-      image_path_vec.push_back(image_path);
+      image_paths.push_back(image_path);
     }
-    for(int i = 0; i < image_path_vec.size(); i += FLAGS_batch_size) {
+    imgs = image_paths.size();
+    for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
       auto start = system_clock::now();
         // 读图像
-      int im_vec_size = std::min((int)image_path_vec.size(), i + FLAGS_batch_size);      
+      int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size);      
       std::vector<cv::Mat> im_vec(im_vec_size - i);
       std::vector<PaddleX::ClsResult> results(im_vec_size - i, PaddleX::ClsResult());
       #pragma omp parallel for num_threads(im_vec_size - i)
       for(int j = i; j < im_vec_size; ++j){
-        im_vec[j - i] = std::move(cv::imread(image_path_vec[j], 1));
+        im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
       }
       auto imread_end = system_clock::now();
       model.predict(im_vec, results);
 
       auto imread_duration = duration_cast<microseconds>(imread_end - start);
-      total_imreaad_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
+      total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
 
       auto end = system_clock::now();
       auto duration = duration_cast<microseconds>(end - start);
       total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
       for(int j = i; j < im_vec_size; ++j) {
-            std::cout << "Path:" << image_path_vec[j]
+            std::cout << "Path:" << image_paths[j]
                       << ", predict label: " << results[j - i].category
                       << ", label_id:" << results[j - i].category_id
                       << ", score: " << results[j - i].score << std::endl;
@@ -105,11 +106,15 @@ int main(int argc, char** argv) {
               << ", label_id:" << result.category_id
               << ", score: " << result.score << std::endl;
   }
-  std::cout << "Total average running time: " 
+  std::cout << "Total running time: " 
 	    << total_running_time_s
-	    << " s, total average read img time: " 
-	    << total_imreaad_time_s
-	    << " s, batch_size = " 
+      << " s, average running time: "
+      << total_running_time_s / imgs 
+	    << " s/img, total read img time: " 
+	    << total_imread_time_s
+      << " s, average read time: "
+      << total_imread_time_s / imgs  
+	    << " s/img, batch_size = " 
 	    << FLAGS_batch_size 
 	    << std::endl;
   return 0;

+ 54 - 12
deploy/cpp/demo/segmenter.cpp

@@ -14,14 +14,18 @@
 
 #include <glog/logging.h>
 
+#include <algorithm>
+#include <chrono>
 #include <fstream>
 #include <iostream>
 #include <string>
 #include <vector>
-
+#include <utility>
 #include "include/paddlex/paddlex.h"
 #include "include/paddlex/visualize.h"
 
+using namespace std::chrono;
+
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
 DEFINE_bool(use_trt, false, "Infering with TensorRT");
@@ -30,6 +34,7 @@ DEFINE_string(key, "", "key of encryption");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
 DEFINE_string(save_dir, "output", "Path to save visualized image");
+DEFINE_int32(batch_size, 1, "Batch size of infering");
 
 int main(int argc, char** argv) {
   // 解析命令行参数
@@ -46,8 +51,11 @@ int main(int argc, char** argv) {
 
   // 加载模型
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key);
+  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size);
 
+  double total_running_time_s = 0.0;
+  double total_imread_time_s = 0.0;
+  int imgs = 1;
   auto colormap = PaddleX::GenerateColorMap(model.labels.size());
   // 进行预测
   if (FLAGS_image_list != "") {
@@ -57,23 +65,46 @@ int main(int argc, char** argv) {
       return -1;
     }
     std::string image_path;
+    std::vector<std::string> image_paths;
     while (getline(inf, image_path)) {
-      PaddleX::SegResult result;
-      cv::Mat im = cv::imread(image_path, 1);
-      model.predict(im, &result);
+      image_paths.push_back(image_path);
+    }
+    imgs = image_paths.size();
+    for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size){
+      auto start = system_clock::now();
+      int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size);
+      std::vector<cv::Mat> im_vec(im_vec_size - i);
+      std::vector<PaddleX::SegResult> results(im_vec_size - i, PaddleX::SegResult());
+      #pragma omp parallel for num_threads(im_vec_size - i)
+      for(int j = i; j < im_vec_size; ++j){
+        im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
+      }
+      auto imread_end = system_clock::now();
+      model.predict(im_vec, results);
+      auto imread_duration = duration_cast<microseconds>(imread_end - start);
+      total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
+
+      auto end = system_clock::now();
+      auto duration = duration_cast<microseconds>(end - start);
+      total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
       // 可视化
-      cv::Mat vis_img =
-          PaddleX::Visualize(im, result, model.labels, colormap);
-      std::string save_path =
-          PaddleX::generate_save_path(FLAGS_save_dir, image_path);
-      cv::imwrite(save_path, vis_img);
-      result.clear();
-      std::cout << "Visualized output saved as " << save_path << std::endl;
+      for(int j = 0; j < im_vec_size - i; ++j) {
+        cv::Mat vis_img =
+            PaddleX::Visualize(im_vec[j], results[j], model.labels, colormap);
+        std::string save_path =
+            PaddleX::generate_save_path(FLAGS_save_dir, image_paths[i + j]);
+        cv::imwrite(save_path, vis_img);
+        std::cout << "Visualized output saved as " << save_path << std::endl;
+      }
     }
   } else {
+    auto start = system_clock::now();
     PaddleX::SegResult result;
     cv::Mat im = cv::imread(FLAGS_image, 1);
     model.predict(im, &result);
+    auto end = system_clock::now();
+    auto duration = duration_cast<microseconds>(end - start);
+    total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
     // 可视化
     cv::Mat vis_img = PaddleX::Visualize(im, result, model.labels, colormap);
     std::string save_path =
@@ -82,6 +113,17 @@ int main(int argc, char** argv) {
     result.clear();
     std::cout << "Visualized output saved as " << save_path << std::endl;
   }
+  std::cout << "Total running time: " 
+	    << total_running_time_s
+      << " s, average running time: "
+      << total_running_time_s / imgs
+	    << " s/img, total read img time: " 
+	    << total_imread_time_s
+      << " s, average read img time: "
+      << total_imread_time_s / imgs
+	    << " s, batch_size = " 
+	    << FLAGS_batch_size 
+	    << std::endl;
 
   return 0;
 }

+ 4 - 0
deploy/cpp/include/paddlex/paddlex.h

@@ -69,8 +69,12 @@ class Model {
 
   bool predict(const cv::Mat& im, DetResult* result);
 
+  bool predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult> &result);
+  
   bool predict(const cv::Mat& im, SegResult* result);
 
+  bool predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult> &result);
+  
   bool postprocess(SegResult* result);
 
   bool postprocess(DetResult* result);

+ 143 - 0
deploy/cpp/src/paddlex.cpp

@@ -161,6 +161,7 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
   result->category_id = std::distance(std::begin(outputs_), ptr);
   result->score = *ptr;
   result->category = labels[result->category_id];
+  return true;
 }
 
 bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult> &results) {
@@ -322,6 +323,7 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
                          static_cast<int>(box->coordinate[3])};
     }
   }
+  return true;
 }
 
 bool Model::predict(const cv::Mat& im, SegResult* result) {
@@ -430,6 +432,147 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
   result->score_map.data.assign(mask_score.begin<float>(),
                                 mask_score.end<float>());
   result->score_map.shape = {mask_score.rows, mask_score.cols};
+  return true;
+}
+
+bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult> &result) {
+  for(auto &inputs: inputs_batch_) {
+    inputs.clear();
+  }
+  if (type == "classifier") {
+    std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
+                 "to function predict()!"
+              << std::endl;
+    return false;
+  } else if (type == "detector") {
+    std::cerr << "Loading model is a 'detector', DetResult should be passed to "
+                 "function predict()!"
+              << std::endl;
+    return false;
+  }
+
+  // 处理输入图像
+  if (!preprocess(im_batch, inputs_batch_)) {
+    std::cerr << "Preprocess failed!" << std::endl;
+    return false;
+  }
+
+  int batch_size = im_batch.size();
+  result.clear();
+  result.resize(batch_size);
+  int h = inputs_batch_[0].new_im_size_[0];
+  int w = inputs_batch_[0].new_im_size_[1];
+  auto im_tensor = predictor_->GetInputTensor("image");
+  im_tensor->Reshape({batch_size, 3, h, w});
+  std::vector<float> inputs_data(batch_size * 3 * h * w);
+  for(int i = 0; i <inputs_batch_.size(); ++i) {
+    std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w);
+  }
+  im_tensor->copy_from_cpu(inputs_data.data());
+  //im_tensor->copy_from_cpu(inputs_.im_data_.data());
+
+  // 使用加载的模型进行预测
+  predictor_->ZeroCopyRun();
+
+  // 获取预测置信度,经过argmax后的labelmap
+  auto output_names = predictor_->GetOutputNames();
+  auto output_label_tensor = predictor_->GetOutputTensor(output_names[0]);
+  std::vector<int> output_label_shape = output_label_tensor->shape();
+  int size = 1;
+  for (const auto& i : output_label_shape) {
+    size *= i;
+  }
+
+  std::vector<int64_t> output_labels(size, 0);
+  output_label_tensor->copy_to_cpu(output_labels.data());
+  auto output_labels_iter = output_labels.begin();
+
+  int single_batch_size = size / batch_size;
+  for(int i = 0; i < batch_size; ++i) {
+    result[i].label_map.data.resize(single_batch_size);
+    result[i].label_map.shape.push_back(1);
+    for(int j = 1; j < output_label_shape.size(); ++j) {
+      result[i].label_map.shape.push_back(output_label_shape[j]);
+    }
+    std::copy(output_labels_iter + i * single_batch_size, output_labels_iter + (i + 1) * single_batch_size, result[i].label_map.data.data());
+  }
+
+  // 获取预测置信度scoremap
+  auto output_score_tensor = predictor_->GetOutputTensor(output_names[1]);
+  std::vector<int> output_score_shape = output_score_tensor->shape();
+  size = 1;
+  for (const auto& i : output_score_shape) {
+    size *= i;
+  }
+
+  std::vector<float> output_scores(size, 0);
+  output_score_tensor->copy_to_cpu(output_scores.data());
+  auto output_scores_iter = output_scores.begin();
+
+  int single_batch_score_size = size / batch_size;
+  for(int i = 0; i < batch_size; ++i) {
+    result[i].score_map.data.resize(single_batch_score_size);
+    result[i].score_map.shape.push_back(1);
+    for(int j = 1; j < output_score_shape.size(); ++j) {
+      result[i].score_map.shape.push_back(output_score_shape[j]);
+    }
+    std::copy(output_scores_iter + i * single_batch_score_size, output_scores_iter + (i + 1) * single_batch_score_size, result[i].score_map.data.data());
+  }
+
+  // 解析输出结果到原图大小
+  for(int i = 0; i < batch_size; ++i) {
+    std::vector<uint8_t> label_map(result[i].label_map.data.begin(),
+                                   result[i].label_map.data.end());
+    cv::Mat mask_label(result[i].label_map.shape[1],
+                       result[i].label_map.shape[2],
+                       CV_8UC1,
+                       label_map.data());
+  
+    cv::Mat mask_score(result[i].score_map.shape[2],
+                       result[i].score_map.shape[3],
+                       CV_32FC1,
+                       result[i].score_map.data.data());
+    int idx = 1;
+    int len_postprocess = inputs_batch_[i].im_size_before_resize_.size();
+    for (std::vector<std::string>::reverse_iterator iter =
+             inputs_batch_[i].reshape_order_.rbegin();
+         iter != inputs_batch_[i].reshape_order_.rend();
+         ++iter) {
+      if (*iter == "padding") {
+        auto before_shape = inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
+        inputs_batch_[i].im_size_before_resize_.pop_back();
+        auto padding_w = before_shape[0];
+        auto padding_h = before_shape[1];
+        mask_label = mask_label(cv::Rect(0, 0, padding_h, padding_w));
+        mask_score = mask_score(cv::Rect(0, 0, padding_h, padding_w));
+      } else if (*iter == "resize") {
+        auto before_shape = inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
+        inputs_batch_[i].im_size_before_resize_.pop_back();
+        auto resize_w = before_shape[0];
+        auto resize_h = before_shape[1];
+        cv::resize(mask_label,
+                   mask_label,
+                   cv::Size(resize_h, resize_w),
+                   0,
+                   0,
+                   cv::INTER_NEAREST);
+        cv::resize(mask_score,
+                   mask_score,
+                   cv::Size(resize_h, resize_w),
+                   0,
+                   0,
+                   cv::INTER_LINEAR); 
+      }
+      ++idx;
+    }
+    result[i].label_map.data.assign(mask_label.begin<uint8_t>(),
+                                  mask_label.end<uint8_t>());
+    result[i].label_map.shape = {mask_label.rows, mask_label.cols};
+    result[i].score_map.data.assign(mask_score.begin<float>(),
+                                  mask_score.end<float>());
+    result[i].score_map.shape = {mask_score.rows, mask_score.cols};
+  }
+  return true;
 }
 
 }  // namespce of PaddleX