Procházet zdrojové kódy

add yolov3 batch predict

jack před 5 roky
rodič
revize
eeec4b151b

+ 64 - 21
deploy/cpp/demo/detector.cpp

@@ -14,14 +14,19 @@
 
 #include <glog/logging.h>
 
+#include <algorithm>
+#include <chrono>
 #include <fstream>
 #include <iostream>
 #include <string>
 #include <vector>
+#include <utility>
 
 #include "include/paddlex/paddlex.h"
 #include "include/paddlex/visualize.h"
 
+using namespace std::chrono;
+
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
 DEFINE_bool(use_trt, false, "Infering with TensorRT");
@@ -30,6 +35,7 @@ DEFINE_string(key, "", "key of encryption");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
 DEFINE_string(save_dir, "output", "Path to save visualized image");
+DEFINE_int32(batch_size, 1, "");
 
 int main(int argc, char** argv) {
   // 解析命令行参数
@@ -46,8 +52,11 @@ int main(int argc, char** argv) {
 
   // 加载模型
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key);
+  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size);
 
+  double total_running_time_s = 0.0;
+  double total_imread_time_s = 0.0;
+  int imgs = 1;
   auto colormap = PaddleX::GenerateColorMap(model.labels.size());
   std::string save_dir = "output";
   // 进行预测
@@ -58,35 +67,57 @@ int main(int argc, char** argv) {
       return -1;
     }
     std::string image_path;
+    std::vector<std::string> image_paths;
     while (getline(inf, image_path)) {
-      PaddleX::DetResult result;
-      cv::Mat im = cv::imread(image_path, 1);
-      model.predict(im, &result);
-      for (int i = 0; i < result.boxes.size(); ++i) {
-        std::cout << "image file: " << image_path
-                  << ", predict label: " << result.boxes[i].category
-                  << ", label_id:" << result.boxes[i].category_id
-                  << ", score: " << result.boxes[i].score << ", box(xmin, ymin, w, h):("
-                  << result.boxes[i].coordinate[0] << ", "
-                  << result.boxes[i].coordinate[1] << ", "
-                  << result.boxes[i].coordinate[2] << ", "
-                  << result.boxes[i].coordinate[3] << ")" << std::endl;
+      image_paths.push_back(image_path);
+    }
+    imgs = image_paths.size();
+    for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
+      auto start = system_clock::now();
+      int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size);
+      std::vector<cv::Mat> im_vec(im_vec_size - i);
+      std::vector<PaddleX::DetResult> results(im_vec_size - i, PaddleX::DetResult());
+      #pragma omp parallel for num_threads(im_vec_size - i)
+      for(int j = i; j < im_vec_size; ++j){
+        im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
+      }
+      auto imread_end = system_clock::now();
+      model.predict(im_vec, results);
+      auto imread_duration = duration_cast<microseconds>(imread_end - start);
+      total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
+      auto end = system_clock::now();
+      auto duration = duration_cast<microseconds>(end - start);
+      total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
+      //输出结果目标框
+      for(int j = 0; j < im_vec_size - i; ++j) {
+        std::cout << "image file: " << image_paths[i + j] << std::endl;          
+        for(int k = 0; k < results[j].boxes.size(); ++k) {
+          std::cout << "predict label: " << results[j].boxes[k].category
+                    << ", label_id:" << results[j].boxes[k].category_id
+                    << ", score: " << results[j].boxes[k].score << ", box(xmin, ymin, w, h):("
+                    << results[j].boxes[k].coordinate[0] << ", "
+                    << results[j].boxes[k].coordinate[1] << ", "
+                    << results[j].boxes[k].coordinate[2] << ", "
+                    << results[j].boxes[k].coordinate[3] << ")" << std::endl;
+          
+        }
       }
-
       // 可视化
-      cv::Mat vis_img =
-          PaddleX::Visualize(im, result, model.labels, colormap, 0.5);
-      std::string save_path =
-          PaddleX::generate_save_path(FLAGS_save_dir, image_path);
-      cv::imwrite(save_path, vis_img);
-      result.clear();
-      std::cout << "Visualized output saved as " << save_path << std::endl;
+      for(int j = 0; j < im_vec_size - i; ++j) {
+        cv::Mat vis_img =
+            PaddleX::Visualize(im_vec[j], results[j], model.labels, colormap, 0.5);
+        std::string save_path =
+            PaddleX::generate_save_path(FLAGS_save_dir, image_paths[i + j]);
+        cv::imwrite(save_path, vis_img);
+        std::cout << "Visualized output saved as " << save_path << std::endl;
+      }
     }
   } else {
     PaddleX::DetResult result;
     cv::Mat im = cv::imread(FLAGS_image, 1);
     model.predict(im, &result);
     for (int i = 0; i < result.boxes.size(); ++i) {
+      std::cout << "image file: " << FLAGS_image << std::endl;          
       std::cout << ", predict label: " << result.boxes[i].category
                 << ", label_id:" << result.boxes[i].category_id
                 << ", score: " << result.boxes[i].score << ", box(xmin, ymin, w, h):("
@@ -105,6 +136,18 @@ int main(int argc, char** argv) {
     result.clear();
     std::cout << "Visualized output saved as " << save_path << std::endl;
   }
+  
+  std::cout << "Total running time: " 
+	    << total_running_time_s
+      << " s, average running time: "
+      << total_running_time_s / imgs
+	    << " s/img, total read img time: " 
+	    << total_imread_time_s
+      << " s, average read img time: "
+      << total_imread_time_s / imgs
+	    << " s, batch_size = " 
+	    << FLAGS_batch_size 
+	    << std::endl;
 
   return 0;
 }

+ 0 - 1
deploy/cpp/demo/segmenter.cpp

@@ -83,7 +83,6 @@ int main(int argc, char** argv) {
       model.predict(im_vec, results);
       auto imread_duration = duration_cast<microseconds>(imread_end - start);
       total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
-
       auto end = system_clock::now();
       auto duration = duration_cast<microseconds>(end - start);
       total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;

+ 7 - 2
deploy/cpp/include/paddlex/transforms.h

@@ -58,6 +58,7 @@ class Transform {
  public:
   virtual void Init(const YAML::Node& item) = 0;
   virtual bool Run(cv::Mat* im, ImageBlob* data) = 0;
+  virtual void SetPaddingSize(int max_h, int max_w) {}
 };
 
 class Normalize : public Transform {
@@ -169,11 +170,13 @@ class Padding : public Transform {
     }
   }
   virtual bool Run(cv::Mat* im, ImageBlob* data);
-
+  virtual void SetPaddingSize(int max_h, int max_w);
  private:
   int coarsest_stride_ = -1;
   int width_ = 0;
   int height_ = 0;
+  int max_height_ = 0;
+  int max_width_ = 0;
 };
 
 class Transforms {
@@ -181,10 +184,12 @@ class Transforms {
   void Init(const YAML::Node& node, bool to_rgb = true);
   std::shared_ptr<Transform> CreateTransform(const std::string& name);
   bool Run(cv::Mat* im, ImageBlob* data);
-
+  void SetPaddingSize(int max_h, int max_w);
  private:
   std::vector<std::shared_ptr<Transform>> transforms_;
   bool to_rgb_ = true;
+  int max_h_ = 0;
+  int max_w_ = 0;
 };
 
 }  // namespace PaddleX

+ 155 - 6
deploy/cpp/src/paddlex.cpp

@@ -14,6 +14,7 @@
 #include <algorithm>
 #include <omp.h>
 #include "include/paddlex/paddlex.h"
+#include <fstream>
 namespace PaddleX {
 
 void Model::create_predictor(const std::string& model_dir,
@@ -100,6 +101,9 @@ bool Model::load_config(const std::string& model_dir) {
 
 bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) {
   cv::Mat im = input_im.clone();
+  int max_h = im.rows;
+  int max_w = im.cols;
+  transforms_.SetPaddingSize(max_h, max_w);
   if (!transforms_.Run(&im, blob)) {
     return false;
   }
@@ -110,7 +114,13 @@ bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) {
 bool Model::preprocess(const std::vector<cv::Mat> &input_im_batch, std::vector<ImageBlob> &blob_batch) {
   int batch_size = inputs_batch_.size();
   bool success = true;
-  //int i;
+  int max_h = -1;
+  int max_w = -1;
+  for(int i = 0; i < input_im_batch.size(); ++i) {
+    max_h = std::max(max_h, input_im_batch[i].rows);
+    max_w = std::max(max_w, input_im_batch[i].cols);
+  }
+  transforms_.SetPaddingSize(max_h, max_w);
   #pragma omp parallel for num_threads(batch_size)
   for(int i = 0; i < input_im_batch.size(); ++i) {
     cv::Mat im = input_im_batch[i].clone();
@@ -126,10 +136,6 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
   if (type == "detector") {
     std::cerr << "Loading model is a 'detector', DetResult should be passed to "
                  "function predict()!"
-              << std::endl;
-    return false;
-  } else if (type == "segmenter") {
-    std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
                  "to function predict()!"
               << std::endl;
     return false;
@@ -224,7 +230,6 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult>
 
 bool Model::predict(const cv::Mat& im, DetResult* result) {
   result->clear();
-  inputs_.clear();
   if (type == "classifier") {
     std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
                  "to function predict()!"
@@ -248,9 +253,15 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
   auto im_tensor = predictor_->GetInputTensor("image");
   im_tensor->Reshape({1, 3, h, w});
   im_tensor->copy_from_cpu(inputs_.im_data_.data());
+
+  std::ofstream fout("test_single.dat", std::ios::out);
   if (name == "YOLOv3") {
     auto im_size_tensor = predictor_->GetInputTensor("im_size");
     im_size_tensor->Reshape({1, 2});
+    for(int i = 0; i < inputs_.ori_im_size_.size(); ++i) {
+      fout << inputs_.ori_im_size_[i] << " ";
+    }
+    fout << std::endl;
     im_size_tensor->copy_from_cpu(inputs_.ori_im_size_.data());
   } else if (name == "FasterRCNN" || name == "MaskRCNN") {
     auto im_info_tensor = predictor_->GetInputTensor("im_info");
@@ -283,6 +294,9 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
     std::cerr << "[WARNING] There's no object detected." << std::endl;
     return true;
   }
+  for(int i = 0; i < output_box.size(); ++i) {
+    fout << output_box[i] << " ";
+  }
   int num_boxes = size / 6;
   // 解析预测框box
   for (int i = 0; i < num_boxes; ++i) {
@@ -326,6 +340,141 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
   return true;
 }
 
+bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult> &result) {
+  if (type == "classifier") {
+    std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
+                 "to function predict()!"
+              << std::endl;
+    return false;
+  } else if (type == "segmenter") {
+    std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
+                 "to function predict()!"
+              << std::endl;
+    return false;
+  }
+
+  // 处理输入图像
+  if (!preprocess(im_batch, inputs_batch_)) {
+    std::cerr << "Preprocess failed!" << std::endl;
+    return false;
+  }
+  int batch_size = im_batch.size();
+  int h = inputs_batch_[0].new_im_size_[0];
+  int w = inputs_batch_[0].new_im_size_[1];
+  auto im_tensor = predictor_->GetInputTensor("image");
+  im_tensor->Reshape({batch_size, 3, h, w});
+  std::vector<float> inputs_data(batch_size * 3 * h * w);
+  for(int i = 0; i < inputs_batch_.size(); ++i) {
+    std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w);
+  }
+  im_tensor->copy_from_cpu(inputs_data.data());
+  std::ofstream fout("test_batch.dat", std::ios::out);
+  if (name == "YOLOv3") {
+    auto im_size_tensor = predictor_->GetInputTensor("im_size");
+    im_size_tensor->Reshape({batch_size, 2});
+    std::vector<int> inputs_data_size(batch_size  * 2);
+    for(int i = 0; i < inputs_batch_.size(); ++i){
+      std::copy(inputs_batch_[i].ori_im_size_.begin(), inputs_batch_[i].ori_im_size_.end(), inputs_data_size.begin() + 2 * i);
+    }
+    for(int i = 0; i < inputs_data_size.size(); ++i) {
+      fout << inputs_data_size[i] << " ";
+    }
+    fout << std::endl;
+    im_size_tensor->copy_from_cpu(inputs_data_size.data());
+  } else if (name == "FasterRCNN" || name == "MaskRCNN") {
+    auto im_info_tensor = predictor_->GetInputTensor("im_info");
+    auto im_shape_tensor = predictor_->GetInputTensor("im_shape");
+    im_info_tensor->Reshape({batch_size, 3});
+    im_shape_tensor->Reshape({batch_size, 3});
+    
+    std::vector<float> im_info(3 * batch_size);
+    std::vector<float> im_shape(3 * batch_size);
+    for(int i = 0; i < inputs_batch_.size(); ++i) {
+      float ori_h = static_cast<float>(inputs_batch_[i].ori_im_size_[0]);
+      float ori_w = static_cast<float>(inputs_batch_[i].ori_im_size_[1]);
+      float new_h = static_cast<float>(inputs_batch_[i].new_im_size_[0]);
+      float new_w = static_cast<float>(inputs_batch_[i].new_im_size_[1]);
+      im_info[i * 3] = new_h;
+      im_info[i * 3 + 1] = new_w;
+      im_info[i * 3 + 2] = inputs_batch_[i].scale;
+      im_shape[i * 3] = ori_h;
+      im_shape[i * 3 + 1] = ori_w;
+      im_shape[i * 3 + 2] = 1.0;
+    }
+    im_info_tensor->copy_from_cpu(im_info.data());
+    im_shape_tensor->copy_from_cpu(im_shape.data());
+  }
+  // 使用加载的模型进行预测
+  predictor_->ZeroCopyRun();
+
+  // 读取所有box
+  std::vector<float> output_box;
+  auto output_names = predictor_->GetOutputNames();
+  auto output_box_tensor = predictor_->GetOutputTensor(output_names[0]);
+  std::vector<int> output_box_shape = output_box_tensor->shape();
+  int size = 1;
+  for (const auto& i : output_box_shape) {
+    size *= i;
+  }
+  output_box.resize(size);
+  output_box_tensor->copy_to_cpu(output_box.data());
+  if (size < 6) {
+    std::cerr << "[WARNING] There's no object detected." << std::endl;
+    return true;
+  }
+  for(int i = 0; i < output_box.size(); ++i) {
+    fout << output_box[i] << " ";
+  }
+  auto lod_vector = output_box_tensor->lod();
+  int num_boxes = size / 6;
+  // 解析预测框box
+  for (int i = 0; i < lod_vector[0].size() - 1; ++i) {
+    for(int j = lod_vector[0][i]; j < lod_vector[0][i + 1]; ++j) {
+      Box box;
+      box.category_id = static_cast<int> (round(output_box[j * 6]));
+      box.category = labels[box.category_id];
+      box.score = output_box[j * 6 + 1];
+      float xmin = output_box[j * 6 + 2];
+      float ymin = output_box[j * 6 + 3];
+      float xmax = output_box[j * 6 + 4];
+      float ymax = output_box[j * 6 + 5];
+      float w = xmax - xmin + 1;
+      float h = ymax - ymin + 1;
+      box.coordinate = {xmin, ymin, w, h};
+      result[i].boxes.push_back(std::move(box));
+    }
+  }
+
+  // 实例分割需解析mask
+  if (name == "MaskRCNN") {
+    std::vector<float> output_mask;
+    auto output_mask_tensor = predictor_->GetOutputTensor(output_names[1]);
+    std::vector<int> output_mask_shape = output_mask_tensor->shape();
+    int masks_size = 1;
+    for (const auto& i : output_mask_shape) {
+      masks_size *= i;
+    }
+    int mask_pixels = output_mask_shape[2] * output_mask_shape[3];
+    int classes = output_mask_shape[1];
+    output_mask.resize(masks_size);
+    output_mask_tensor->copy_to_cpu(output_mask.data());
+    int mask_idx = 0;
+    for(int i = 0; i < lod_vector[0].size() - 1; ++i) {
+      result[i].mask_resolution = output_mask_shape[2];
+      for(int j = 0; j < result[i].boxes.size(); ++j) {
+        Box* box = &result[i].boxes[j];
+        auto begin_mask = output_mask.begin() + (mask_idx * classes + box->category_id) * mask_pixels;
+        auto end_mask = begin_mask + mask_pixels;
+        box->mask.data.assign(begin_mask, end_mask);
+        box->mask.shape = {static_cast<int>(box->coordinate[2]),
+                           static_cast<int>(box->coordinate[3])};
+        mask_idx++;
+      }
+    }
+  }
+  return true; 
+}
+
 bool Model::predict(const cv::Mat& im, SegResult* result) {
   result->clear();
   inputs_.clear();

+ 15 - 3
deploy/cpp/src/transforms.cpp

@@ -95,11 +95,11 @@ bool Padding::Run(cv::Mat* im, ImageBlob* data) {
   if (width_ > 1 & height_ > 1) {
     padding_w = width_ - im->cols;
     padding_h = height_ - im->rows;
-  } else if (coarsest_stride_ > 1) {
+  } else if (coarsest_stride_ >= 1) {
     padding_h =
-        ceil(im->rows * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
+        ceil(max_height_ * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
     padding_w =
-        ceil(im->cols * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
+        ceil(max_width_ * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
   }
 
   if (padding_h < 0 || padding_w < 0) {
@@ -115,6 +115,11 @@ bool Padding::Run(cv::Mat* im, ImageBlob* data) {
   return true;
 }
 
+void Padding::SetPaddingSize(int max_h, int max_w) {
+  max_height_ = max_h;
+  max_width_ = max_w;
+}
+
 bool ResizeByLong::Run(cv::Mat* im, ImageBlob* data) {
   if (long_size_ <= 0) {
     std::cerr << "[ResizeByLong] long_size should be greater than 0"
@@ -201,6 +206,7 @@ bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
   data->new_im_size_[0] = im->rows;
   data->new_im_size_[1] = im->cols;
   for (int i = 0; i < transforms_.size(); ++i) {
+    transforms_[i]->SetPaddingSize(max_h_, max_w_);
     if (!transforms_[i]->Run(im, data)) {
       std::cerr << "Apply transforms to image failed!" << std::endl;
       return false;
@@ -219,4 +225,10 @@ bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
   }
   return true;
 }
+
+void Transforms::SetPaddingSize(int max_h, int max_w) {
+  max_h_ = max_h;
+  max_w_ = max_w;
+}
+
 }  // namespace PaddleX