소스 검색

preprocess optimize (#840)

* code optimization

* code optimization2

* fix bug
heliqi 4 년 전
부모
커밋
6fd365740c

+ 6 - 4
dygraph/deploy/cpp/CMakeLists.txt

@@ -189,7 +189,7 @@ if (WIN32)
         add_definitions(-DSTATIC_LIB)
     endif()
 else()
-    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -o2 -fopenmp -std=c++11")
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -o3 -fopenmp -std=c++11")
     set(CMAKE_STATIC_LIBRARY_PREFIX "")
     set(EXTERNAL_LIB "-ldl -lrt -lgomp -lz -lm -lpthread")
     set(DEPS ${DEPS} ${EXTERNAL_LIB})
@@ -233,10 +233,12 @@ if(WIN32)
     COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_DIR}/third_party/install/mklml/lib/libiomp5md.dll ${CMAKE_BINARY_DIR}/paddle_deploy
     COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_DIR}/third_party/install/mkldnn/lib/mkldnn.dll  ${CMAKE_BINARY_DIR}/paddle_deploy
     COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_DIR}/paddle/lib/paddle_inference.dll ${CMAKE_BINARY_DIR}/paddle_deploy
-    if (WITH_TENSORRT)
+  )
+  if (WITH_TENSORRT)
+    add_custom_command(TARGET model_infer POST_BUILD
       COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_DIR}/lib/nvinfer.dll ${CMAKE_BINARY_DIR}/paddle_deploy
       COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_DIR}/lib/nvinfer_plugin.dll ${CMAKE_BINARY_DIR}/paddle_deploy
       COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_DIR}/lib/myelin64_1.dll ${CMAKE_BINARY_DIR}/paddle_deploy
-    endif()
-  )
+    )
+  endif()
 endif()

+ 2 - 2
dygraph/deploy/cpp/demo/batch_infer.cpp

@@ -35,8 +35,7 @@ int main(int argc, char** argv) {
   google::ParseCommandLineFlags(&argc, &argv, true);
 
   // create model
-  std::shared_ptr<PaddleDeploy::Model> model =
-        PaddleDeploy::CreateModel(FLAGS_model_type);
+  PaddleDeploy::Model* model = PaddleDeploy::CreateModel(FLAGS_model_type);
 
   // model init
   model->Init(FLAGS_cfg_file);
@@ -88,5 +87,6 @@ int main(int argc, char** argv) {
     }
   }
 
+  delete model;
   return 0;
 }

+ 2 - 3
dygraph/deploy/cpp/demo/model_infer.cpp

@@ -31,8 +31,7 @@ int main(int argc, char** argv) {
   google::ParseCommandLineFlags(&argc, &argv, true);
 
   // create model
-  std::shared_ptr<PaddleDeploy::Model> model =
-        PaddleDeploy::CreateModel(FLAGS_model_type);
+  PaddleDeploy::Model* model = PaddleDeploy::CreateModel(FLAGS_model_type);
 
   // model init
   model->Init(FLAGS_cfg_file);
@@ -54,6 +53,6 @@ int main(int argc, char** argv) {
   model->Predict(imgs, &results, 1);
 
   std::cout << results[0] << std::endl;
-
+  delete model;
   return 0;
 }

+ 1 - 2
dygraph/deploy/cpp/demo/onnx_tensorrt/model_infer.cpp

@@ -37,8 +37,7 @@ int main(int argc, char** argv) {
             << FLAGS_model_file << std::endl;
 
   // create model
-  std::shared_ptr<PaddleDeploy::Model> model =
-        PaddleDeploy::CreateModel(FLAGS_model_type);
+  PaddleDeploy::Model* model = PaddleDeploy::CreateModel(FLAGS_model_type);
   if (!model) {
     std::cout << "no model_type: " << FLAGS_model_type
               << "  model=" << model << std::endl;

+ 2 - 2
dygraph/deploy/cpp/demo/onnx_triton/model_infer.cpp

@@ -36,8 +36,7 @@ int main(int argc, char** argv) {
             << FLAGS_model_name << std::endl;
 
   // create model
-  std::shared_ptr<PaddleDeploy::Model> model =
-        PaddleDeploy::CreateModel(FLAGS_model_type);
+  PaddleDeploy::Model* model = PaddleDeploy::CreateModel(FLAGS_model_type);
   if (!model) {
     std::cout << "no model_type: " << FLAGS_model_type
               << "  model=" << model << std::endl;
@@ -97,5 +96,6 @@ int main(int argc, char** argv) {
       std::cout << results[j] << std::endl;
     }
   }
+  delete model;
   return 0;
 }

+ 2 - 2
dygraph/deploy/cpp/demo/tensorrt_infer.cpp

@@ -30,8 +30,7 @@ int main(int argc, char** argv) {
   google::ParseCommandLineFlags(&argc, &argv, true);
 
   // create model
-  std::shared_ptr<PaddleDeploy::Model> model =
-        PaddleDeploy::CreateModel(FLAGS_model_type);
+  PaddleDeploy::Model* model = PaddleDeploy::CreateModel(FLAGS_model_type);
 
   // model init
   model->Init(FLAGS_cfg_file);
@@ -74,5 +73,6 @@ int main(int argc, char** argv) {
 
   std::cout << results[0] << std::endl;
 
+  delete model;
   return 0;
 }

+ 3 - 2
dygraph/deploy/cpp/model_deploy/common/include/base_model.h

@@ -54,8 +54,9 @@ class Model {
   }
 
   virtual bool YamlConfigInit(const std::string& cfg_file) {
-    YAML::Node yaml_config_ = YAML::LoadFile(cfg_file);
-    return true;
+    // YAML::Node yaml_config_ = YAML::LoadFile(cfg_file);
+    std::cerr << "Error! The Base Model was incorrectly entered" << std::endl;
+    return false;
   }
 
   virtual bool PreprocessInit() {

+ 4 - 4
dygraph/deploy/cpp/model_deploy/common/include/model_factory.h

@@ -26,14 +26,14 @@
 
 namespace PaddleDeploy {
 
-typedef std::shared_ptr<Model> (*NewInstance)();
+typedef Model* (*NewInstance)();
 
 class ModelFactory {
  private:
   static std::map<std::string, NewInstance> model_map_;
 
  public:
-  static std::shared_ptr<Model> CreateObject(const std::string &name);
+  static Model* CreateObject(const std::string &name);
 
   static void Register(const std::string &name, NewInstance model);
 };
@@ -48,9 +48,9 @@ class Register {
 #define REGISTER_CLASS(model_type, class_name)                    \
   class class_name##Register {                                    \
    public:                                                        \
-    static std::shared_ptr<Model> newInstance() {                 \
+    static Model* newInstance() {                                 \
       std::cerr << "REGISTER_CLASS:" << #model_type << std::endl; \
-      return std::make_shared<class_name>(#model_type);           \
+      return new class_name(#model_type);           \
     }                                                             \
                                                                   \
    private:                                                       \

+ 2 - 3
dygraph/deploy/cpp/model_deploy/common/include/multi_gpu_model.h

@@ -30,8 +30,7 @@ class MultiGPUModel {
             const std::string& cfg_file, size_t gpu_num = 1) {
     models_.clear();
     for (auto i = 0; i < gpu_num; ++i) {
-      std::shared_ptr<Model> model =
-          PaddleDeploy::ModelFactory::CreateObject(model_type);
+      Model* model = PaddleDeploy::ModelFactory::CreateObject(model_type);
 
       if (!model) {
         std::cerr << "no model_type: " << model_type << std::endl;
@@ -45,7 +44,7 @@ class MultiGPUModel {
         return false;
       }
 
-      models_.push_back(model);
+      models_.push_back(std::shared_ptr<Model>(model));
     }
     return true;
   }

+ 1 - 1
dygraph/deploy/cpp/model_deploy/common/include/paddle_deploy.h

@@ -21,7 +21,7 @@
 #include "model_deploy/engine/include/engine.h"
 
 namespace PaddleDeploy {
-inline std::shared_ptr<Model> CreateModel(const std::string &name) {
+inline Model* CreateModel(const std::string &name) {
   return PaddleDeploy::ModelFactory::CreateObject(name);
 }
 }  // namespace PaddleDeploy

+ 23 - 11
dygraph/deploy/cpp/model_deploy/common/include/transforms.h

@@ -42,22 +42,37 @@ class Transform {
 class Normalize : public Transform {
  public:
   virtual void Init(const YAML::Node& item) {
-    mean_ = item["mean"].as<std::vector<float>>();
-    std_ = item["std"].as<std::vector<float>>();
+    std::vector<double> mean_ = item["mean"].as<std::vector<double>>();
+    std::vector<double> std_ = item["std"].as<std::vector<double>>();
+    bool is_scale_;
+    std::vector<double> min_val_;
+    std::vector<double> max_val_;
     if (item["is_scale"].IsDefined()) {
       is_scale_ = item["is_scale"];
     } else {
       is_scale_ = true;
     }
     if (item["min_val"].IsDefined()) {
-      min_val_ = item["min_val"].as<std::vector<float>>();
+      min_val_ = item["min_val"].as<std::vector<double>>();
     } else {
-      min_val_ = std::vector<float>(mean_.size(), 0.);
+      min_val_ = std::vector<double>(mean_.size(), 0.);
     }
     if (item["max_val"].IsDefined()) {
-      max_val_ = item["max_val"].as<std::vector<float>>();
+      max_val_ = item["max_val"].as<std::vector<double>>();
     } else {
-      max_val_ = std::vector<float>(mean_.size(), 255.);
+      max_val_ = std::vector<double>(mean_.size(), 255.);
+    }
+
+    for (auto c = 0; c < mean_.size(); c++) {
+      double alpha = 1.0;
+      if (is_scale_) {
+        alpha /= (max_val_[c] - min_val_[c]);
+      }
+      alpha /= std_[c];
+      double beta = -1.0 * mean_[c] / std_[c];
+
+      alpha_.push_back(alpha);
+      beta_.push_back(beta);
     }
   }
   virtual bool Run(cv::Mat* im);
@@ -67,11 +82,8 @@ class Normalize : public Transform {
 
 
  private:
-  bool is_scale_;
-  std::vector<float> mean_;
-  std::vector<float> std_;
-  std::vector<float> min_val_;
-  std::vector<float> max_val_;
+  std::vector<float> alpha_;
+  std::vector<float> beta_;
 };
 
 class ResizeByShort : public Transform {

+ 1 - 0
dygraph/deploy/cpp/model_deploy/common/src/base_preprocess.cpp

@@ -26,6 +26,7 @@ bool BasePreprocess::BuildTransform(const YAML::Node& yaml_config) {
     std::string name = it->first.as<std::string>();
     std::shared_ptr<Transform> transform = CreateTransform(name);
     if (!transform) {
+      std::cerr << "Failed to create " << name << " on Preprocess" << std::endl;
       return false;
     }
     transform->Init(it->second);

+ 1 - 1
dygraph/deploy/cpp/model_deploy/common/src/model_factory.cpp

@@ -23,7 +23,7 @@ REGISTER_CLASS(seg, SegModel);
 REGISTER_CLASS(clas, ClasModel);
 REGISTER_CLASS(paddlex, PaddleXModel);
 
-std::shared_ptr<Model> ModelFactory::CreateObject(const std::string &name) {
+Model* ModelFactory::CreateObject(const std::string &name) {
   std::map<std::string, NewInstance>::const_iterator it;
   it = model_map_.find(name);
   if (it == model_map_.end())

+ 1 - 7
dygraph/deploy/cpp/model_deploy/common/src/transforms.cpp

@@ -26,13 +26,7 @@ bool Normalize::Run(cv::Mat *im) {
   std::vector<cv::Mat> split_im;
   cv::split(*im, split_im);
   for (int c = 0; c < im->channels(); c++) {
-    cv::subtract(split_im[c], cv::Scalar(min_val_[c]), split_im[c]);
-    if (is_scale_) {
-      float range_val = max_val_[c] - min_val_[c];
-      cv::divide(split_im[c], cv::Scalar(range_val), split_im[c]);
-    }
-    cv::subtract(split_im[c], cv::Scalar(mean_[c]), split_im[c]);
-    cv::divide(split_im[c], cv::Scalar(std_[c]), split_im[c]);
+    split_im[c].convertTo(split_im[c], CV_32FC1, alpha_[c], beta_[c]);
   }
   cv::merge(split_im, *im);
   return true;

+ 7 - 0
dygraph/deploy/cpp/model_deploy/engine/src/tensorrt_engine.cpp

@@ -70,6 +70,7 @@ bool TensorRTInferenceEngine::Init(const InferenceConfig& engine_config) {
   auto builder = InferUniquePtr<nvinfer1::IBuilder>(
                      nvinfer1::createInferBuilder(logger_));
   if (!builder) {
+    std::cerr << "TensorRT init builder error" << std::endl;
     return false;
   }
 
@@ -78,22 +79,26 @@ bool TensorRTInferenceEngine::Init(const InferenceConfig& engine_config) {
   auto network = InferUniquePtr<nvinfer1::INetworkDefinition>(
                      builder->createNetworkV2(explicitBatch));
   if (!network) {
+    std::cerr << "TensorRT init network error" << std::endl;
     return false;
   }
 
   auto parser = InferUniquePtr<nvonnxparser::IParser>(
                     nvonnxparser::createParser(*network, logger_));
   if (!parser) {
+    std::cerr << "TensorRT init parser error" << std::endl;
     return false;
   }
   if (!parser->parseFromFile(tensorrt_config.model_file_.c_str(),
                              static_cast<int>(logger_.mReportableSeverity))) {
+    std::cerr << "TensorRT init model_file error" << std::endl;
     return false;
   }
 
   auto config = InferUniquePtr<nvinfer1::IBuilderConfig>(
                      builder->createBuilderConfig());
   if (!config) {
+    std::cerr << "TensorRT init config error" << std::endl;
     return false;
   }
 
@@ -130,6 +135,7 @@ bool TensorRTInferenceEngine::Init(const InferenceConfig& engine_config) {
                     engine_->createExecutionContext(),
                     InferDeleter());
   if (!context_) {
+    std::cerr << "TensorRT init context error" << std::endl;
     return false;
   }
 
@@ -229,6 +235,7 @@ bool TensorRTInferenceEngine::Infer(const std::vector<DataBlob>& input_blobs,
   buffers.copyInputToDevice();
   bool status = context_->executeV2(buffers.getDeviceBindings().data());
   if (!status) {
+    std::cerr << "TensorRT create execute error" << std::endl;
     return false;
   }
   buffers.copyOutputToHost();

+ 4 - 0
dygraph/deploy/cpp/model_deploy/ppclas/src/clas_postprocess.cpp

@@ -30,6 +30,10 @@ bool ClasPostprocess::Init(const YAML::Node& yaml_config) {
 bool ClasPostprocess::Run(const std::vector<DataBlob>& outputs,
                          const std::vector<ShapeInfo>& shape_infos,
                          std::vector<Result>* results, int thread_num) {
+  if (outputs.size() == 0) {
+    std::cerr << "empty output on ClasPostprocess" << std::endl;
+    return false;
+  }
   results->clear();
   int batch_size = shape_infos.size();
   results->resize(batch_size);

+ 1 - 1
dygraph/deploy/cpp/model_deploy/ppdet/src/det_postprocess.cpp

@@ -197,7 +197,7 @@ bool DetPostprocess::Run(const std::vector<DataBlob>& outputs,
                          std::vector<Result>* results, int thread_num) {
   results->clear();
   if (outputs.size() == 0) {
-    std::cerr << "empty input image on DetPostprocess" << std::endl;
+    std::cerr << "empty output on DetPostprocess" << std::endl;
     return true;
   }
   results->resize(shape_infos.size());

+ 4 - 0
dygraph/deploy/cpp/model_deploy/ppseg/src/seg_postprocess.cpp

@@ -89,6 +89,10 @@ bool SegPostprocess::RunV2(const DataBlob& output,
 bool SegPostprocess::Run(const std::vector<DataBlob>& outputs,
                          const std::vector<ShapeInfo>& shape_infos,
                          std::vector<Result>* results, int thread_num) {
+  if (outputs.size() == 0) {
+    std::cerr << "empty output on SegPostprocess" << std::endl;
+    return true;
+  }
   results->clear();
   int batch_size = shape_infos.size();
   results->resize(batch_size);