Browse Source

Merge pull request #881 from FlyingQianMM/devellop_fix

export pipeline info
FlyingQianMM 4 năm trước cách đây
mục cha
commit
1f90f463e7

+ 3 - 0
dygraph/deploy/cpp/model_deploy/common/include/base_preprocess.h

@@ -46,6 +46,8 @@ class BasePreprocess {
                    std::vector<ShapeInfo>* shape_info,
                    int thread_num = 1) = 0;
 
+  virtual std::string GetModelName() { return model_name_; }
+
  protected:
   bool BuildTransform(const YAML::Node& yaml_config);
   std::vector<std::shared_ptr<Transform>> transforms_;
@@ -54,6 +56,7 @@ class BasePreprocess {
   std::shared_ptr<Transform> CreateTransform(const std::string& name);
   Padding batch_padding_;
   Permute permute_;
+  std::string model_name_;
 };
 
 }  // namespace PaddleDeploy

+ 33 - 9
dygraph/deploy/cpp/model_deploy/common/include/output_struct.h

@@ -243,8 +243,24 @@ struct Result {
     return stream;
   }
 
+  void Clear() {
+    if ("det" == model_type) {
+      delete det_result;
+      det_result = NULL;
+    } else if ("seg" == model_type) {
+      delete seg_result;
+      seg_result = NULL;
+    } else if ("clas" == model_type) {
+      delete clas_result;
+      clas_result = NULL;
+    } else if ("ocr" == model_type) {
+      delete ocr_result;
+      ocr_result = NULL;
+    }
+  }
 
   Result(const Result& result) {
+    Clear();
     model_type = result.model_type;
     if ("det" == model_type) {
       det_result = new DetResult();
@@ -261,20 +277,28 @@ struct Result {
     }
   }
 
-  ~Result() {
+  Result& operator=(const Result& result) {
+    Clear();
+    model_type = result.model_type;
     if ("det" == model_type) {
-      delete det_result;
-      det_result = NULL;
+      det_result = new DetResult();
+      *det_result = *(result.det_result);
     } else if ("seg" == model_type) {
-      delete seg_result;
-      seg_result = NULL;
+      seg_result = new SegResult();
+      *seg_result = *(result.seg_result);
     } else if ("clas" == model_type) {
-      delete clas_result;
-      clas_result = NULL;
+      clas_result = new ClasResult();
+      *clas_result = *(result.clas_result);
     } else if ("ocr" == model_type) {
-      delete ocr_result;
-      ocr_result = NULL;
+      ocr_result = new OcrResult();
+      *ocr_result = *(result.ocr_result);
     }
+    return *this;
+  }
+
+
+  ~Result() {
+    Clear();
   }
 };
 

+ 1 - 0
dygraph/deploy/cpp/model_deploy/paddlex/include/x_preprocess.h

@@ -33,6 +33,7 @@ class XPreprocess : public BasePreprocess {
                    std::vector<DataBlob>* inputs,
                    std::vector<ShapeInfo>* shape_info,
                    int thread_num = 1);
+  virtual std::string GetModelName() { return model_name_; }
 
  private:
   std::string model_type_;

+ 28 - 0
dygraph/deploy/cpp/model_deploy/utils/include/bbox_utils.h

@@ -0,0 +1,28 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <iostream>
+#include <vector>
+
+#include "model_deploy/common/include/output_struct.h"
+
+namespace PaddleDeploy {
+
+bool FilterBbox(const std::vector<Result> &results,
+                const float &score_thresh,
+                std::vector<Result>* filter_results);
+
+}  // namespace PaddleDeploy

+ 57 - 0
dygraph/deploy/cpp/model_deploy/utils/src/bbox_utils.cpp

@@ -0,0 +1,57 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "model_deploy/utils/include/bbox_utils.h"
+
+namespace PaddleDeploy {
+
+bool FilterBbox(const std::vector<Result> &results,
+                const float &score_thresh,
+                std::vector<Result>* filter_results) {
+  for (auto i = 0; i < results.size(); ++i) {
+    if ("det" != results[i].model_type) {
+       std::cerr << "FilterBbox can be only done on results from a det model, "
+                 << "but the received results are from a "
+                 << results[i].model_type << " model." << std::endl;
+       return false;
+    }
+  }
+
+  for (auto i = 0; i < results.size(); ++i) {
+    Result result;
+    result.model_type = "det";
+    result.det_result = new DetResult();
+    std::vector<Box> boxes = results[i].det_result->boxes;
+    for (auto j = 0; j < boxes.size(); ++j) {
+      if (boxes[j].score >= score_thresh) {
+        Box box;
+        box.category_id = boxes[j].category_id;
+        box.category = boxes[j].category;
+        box.score = boxes[j].score;
+        box.coordinate.assign(boxes[j].coordinate.begin(),
+                              boxes[j].coordinate.end());
+        box.mask.data.assign(boxes[j].mask.data.begin(),
+                             boxes[j].mask.data.end());
+        box.mask.shape.assign(boxes[j].mask.shape.begin(),
+                              boxes[j].mask.shape.end());
+        result.det_result->boxes.push_back(std::move(box));
+      }
+    }
+    result.det_result->mask_resolution = results[i].det_result->mask_resolution;
+    filter_results->push_back(std::move(result));
+  }
+  return true;
+}
+
+}  // namespace PaddleDeploy

+ 130 - 0
dygraph/deploy/cpp/model_deploy/utils/src/visualize.cpp

@@ -0,0 +1,130 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "model_deploy/utils/include/visualize.h"
+
+namespace PaddleDeploy {
+
+bool GenerateColorMap(const int &num_classes,
+                      std::vector<int> *color_map) {
+  *color_map = std::vector<int>(3 * num_classes, 0);
+  for (int i = 0; i < num_classes; ++i) {
+    int j = 0;
+    int lab = i;
+    while (lab) {
+      (*color_map)[i * 3] |= (((lab >> 0) & 1) << (7 - j));
+      (*color_map)[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j));
+      (*color_map)[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j));
+      ++j;
+      lab >>= 3;
+    }
+  }
+  return true;
+}
+
+bool Visualize(const cv::Mat& img,
+               const DetResult& result,
+               cv::Mat* vis_img,
+               const int& num_classes) {
+  std::vector<int> color_map;
+  GenerateColorMap(num_classes + 2, &color_map);
+  *vis_img = img.clone();
+  std::vector<Box> boxes = result.boxes;
+  for (auto i = 0; i < boxes.size(); ++i) {
+    cv::Rect roi = cv::Rect(boxes[i].coordinate[0],
+                            boxes[i].coordinate[1],
+                            boxes[i].coordinate[2],
+                            boxes[i].coordinate[3]);
+    // draw box and title
+    std::string text = boxes[i].category;
+    int category_id = boxes[i].category_id + 2;
+    int c1 = color_map[3 * category_id + 0];
+    int c2 = color_map[3 * category_id + 1];
+    int c3 = color_map[3 * category_id + 2];
+    cv::Scalar roi_color = cv::Scalar(c1, c2, c3);
+    text += std::to_string(static_cast<int>(boxes[i].score * 100)) + "%";
+    int font_face = cv::FONT_HERSHEY_SIMPLEX;
+    double font_scale = 0.5f;
+    float thickness = 0.5;
+    cv::Size text_size =
+        cv::getTextSize(text, font_face, font_scale, thickness, nullptr);
+    cv::Point origin;
+    origin.x = roi.x;
+    origin.y = roi.y;
+
+    // background
+    cv::Rect text_back = cv::Rect(boxes[i].coordinate[0],
+                                  boxes[i].coordinate[1] - text_size.height,
+                                  text_size.width,
+                                  text_size.height);
+
+    // draw
+    cv::rectangle(*vis_img, roi, roi_color, 2);
+    cv::rectangle(*vis_img, text_back, roi_color, -1);
+    cv::putText(*vis_img,
+                text,
+                origin,
+                font_face,
+                font_scale,
+                cv::Scalar(255, 255, 255),
+                thickness);
+
+    // mask
+    if (boxes[i].mask.data.size() == 0) {
+      continue;
+    }
+
+    cv::Mat full_mask(boxes[i].mask.shape[0],
+                      boxes[i].mask.shape[1],
+                      CV_8UC1,
+                      boxes[i].mask.data.data());
+    cv::Mat mask_ch[3];
+    mask_ch[0] = full_mask * c1;
+    mask_ch[1] = full_mask * c2;
+    mask_ch[2] = full_mask * c3;
+    cv::Mat mask;
+    cv::merge(mask_ch, 3, mask);
+    cv::addWeighted(*vis_img, 1, mask, 0.5, 0, *vis_img);
+  }
+  return true;
+}
+
+bool Visualize(const cv::Mat& img,
+               const SegResult& result,
+               cv::Mat* vis_img,
+               const int& num_classes) {
+  std::vector<int> color_map;
+  GenerateColorMap(num_classes, &color_map);
+  std::vector<uint8_t> label_map(result.label_map.data.begin(),
+                                 result.label_map.data.end());
+  cv::Mat mask(result.label_map.shape[0],
+               result.label_map.shape[1],
+               CV_8UC1,
+               label_map.data());
+  *vis_img = cv::Mat::zeros(
+      result.label_map.shape[0], result.label_map.shape[1], CV_8UC3);
+  int rows = img.rows;
+  int cols = img.cols;
+  for (int i = 0; i < rows; i++) {
+    for (int j = 0; j < cols; j++) {
+      int category_id = static_cast<int>(mask.at<uchar>(i, j));
+      vis_img->at<cv::Vec3b>(i, j)[0] = color_map[3 * category_id + 0];
+      vis_img->at<cv::Vec3b>(i, j)[1] = color_map[3 * category_id + 1];
+      vis_img->at<cv::Vec3b>(i, j)[2] = color_map[3 * category_id + 2];
+    }
+  }
+  return true;
+}
+
+}  // namespace PaddleDeploy

+ 39 - 0
dygraph/paddlex/cv/models/base.py

@@ -527,6 +527,39 @@ class BaseModel:
                 .format(self.quant_config),
                 exit=True)
 
+    def _get_pipeline_info(self, save_dir):
+        pipeline_info = {}
+        pipeline_info["pipeline_name"] = self.model_type
+        nodes = [{
+            "src0": {
+                "type": "Source",
+                "next": "decode0"
+            }
+        }, {
+            "decode0": {
+                "type": "Decode",
+                "next": "predict0"
+            }
+        }, {
+            "predict0": {
+                "type": "Predict",
+                "init_params": {
+                    "use_gpu": False,
+                    "gpu_id": 0,
+                    "use_trt": False,
+                    "model_dir": save_dir,
+                },
+                "next": "sink0"
+            }
+        }, {
+            "sink0": {
+                "type": "Sink"
+            }
+        }]
+        pipeline_info["pipeline_nodes"] = nodes
+        pipeline_info["version"] = "1.0.0"
+        return pipeline_info
+
     def _export_inference_model(self, save_dir, image_shape=None):
         save_dir = osp.join(save_dir, 'inference_model')
         self.net.eval()
@@ -560,6 +593,12 @@ class BaseModel:
                 mode='w') as f:
             yaml.dump(model_info, f)
 
+        pipeline_info = self._get_pipeline_info(save_dir)
+        with open(
+                osp.join(save_dir, 'pipeline.yml'), encoding='utf-8',
+                mode='w') as f:
+            yaml.dump(pipeline_info, f)
+
         # 模型保存成功的标志
         open(osp.join(save_dir, '.success'), 'w').close()
         logging.info("The model for the inference deployment is saved in {}.".