浏览代码

fix mkldnn

syyxsxx 5 年之前
父节点
当前提交
067ba31f16

+ 2 - 2
deploy/cpp/demo/classifier.cpp

@@ -61,9 +61,9 @@ int main(int argc, char** argv) {
              FLAGS_use_gpu,
              FLAGS_use_trt,
              FLAGS_use_mkl,
+             FLAGS_mkl_thread_num,
              FLAGS_gpu_id,
-             FLAGS_key,
-             FLAGS_mkl_thread_num);
+             FLAGS_key);
 
   // Predict
   int imgs = 1;

+ 2 - 2
deploy/cpp/demo/detector.cpp

@@ -66,9 +66,9 @@ int main(int argc, char** argv) {
              FLAGS_use_gpu,
              FLAGS_use_trt,
              FLAGS_use_mkl,
+             FLAGS_mkl_thread_num,
              FLAGS_gpu_id,
-             FLAGS_key,
-             FLAGS_mkl_thread_num);
+             FLAGS_key);
   int imgs = 1;
   std::string save_dir = "output";
   // Predict

+ 2 - 2
deploy/cpp/demo/segmenter.cpp

@@ -63,9 +63,9 @@ int main(int argc, char** argv) {
              FLAGS_use_gpu,
              FLAGS_use_trt,
              FLAGS_use_mkl,
+             FLAGS_mkl_thread_num,
              FLAGS_gpu_id,
-             FLAGS_key,
-             FLAGS_mkl_thread_num);
+             FLAGS_key);
   int imgs = 1;
   // Predict
   if (FLAGS_image_list != "") {

+ 2 - 2
deploy/cpp/demo/video_classifier.cpp

@@ -67,9 +67,9 @@ int main(int argc, char** argv) {
              FLAGS_use_gpu,
              FLAGS_use_trt,
              FLAGS_use_mkl,
+             FLAGS_mkl_thread_num,
              FLAGS_gpu_id,
-             FLAGS_key,
-             FLAGS_mkl_thread_num);
+             FLAGS_key);
 
   // Open video
   cv::VideoCapture capture;

+ 2 - 2
deploy/cpp/demo/video_detector.cpp

@@ -69,9 +69,9 @@ int main(int argc, char** argv) {
              FLAGS_use_gpu,
              FLAGS_use_trt,
              FLAGS_use_mkl,
+             FLAGS_mkl_thread_num,
              FLAGS_gpu_id,
-             FLAGS_key,
-             FLAGS_mkl_thread_num);
+             FLAGS_key);
   // Open video
   cv::VideoCapture capture;
   if (FLAGS_use_camera) {

+ 2 - 2
deploy/cpp/demo/video_segmenter.cpp

@@ -67,9 +67,9 @@ int main(int argc, char** argv) {
              FLAGS_use_gpu,
              FLAGS_use_trt,
              FLAGS_use_mkl,
+             FLAGS_mkl_thread_num,
              FLAGS_gpu_id,
-             FLAGS_key,
-             FLAGS_mkl_thread_num);
+             FLAGS_key);
   // Open video
   cv::VideoCapture capture;
   if (FLAGS_use_camera) {

+ 5 - 4
deploy/cpp/include/paddlex/paddlex.h

@@ -70,6 +70,8 @@ class Model {
    * @param model_dir: the directory which contains model.yml
    * @param use_gpu: use gpu or not when infering
    * @param use_trt: use Tensor RT or not when infering
+   * @param use_trt: use mkl or not when infering
+   * @param mkl_thread_num: the threads of mkl when infering
    * @param gpu_id: the id of gpu when infering with using gpu
    * @param key: the key of encryption when using encrypted model
    * @param use_ir_optim: use ir optimization when infering
@@ -78,28 +80,27 @@ class Model {
             bool use_gpu = false,
             bool use_trt = false,
             bool use_mkl = true,
+            int mkl_thread_num = 4,
             int gpu_id = 0,
             std::string key = "",
-            int mkl_thread_num = 4,
             bool use_ir_optim = true) {
     create_predictor(
                      model_dir,
                      use_gpu,
                      use_trt,
                      use_mkl,
+                     mkl_thread_num,
                      gpu_id,
                      key,
-                     mkl_thread_num,
                      use_ir_optim);
   }
-
   void create_predictor(const std::string& model_dir,
                         bool use_gpu = false,
                         bool use_trt = false,
                         bool use_mkl = true,
+                        int mkl_thread_num = 4,
                         int gpu_id = 0,
                         std::string key = "",
-                        int mkl_thread_num = 4,
                         bool use_ir_optim = true);
 
   /*

+ 1 - 1
deploy/cpp/src/paddlex.cpp

@@ -29,9 +29,9 @@ void Model::create_predictor(const std::string& model_dir,
                              bool use_gpu,
                              bool use_trt,
                              bool use_mkl,
+                             int mkl_thread_num,
                              int gpu_id,
                              std::string key,
-                             int mkl_thread_num,
                              bool use_ir_optim) {
   paddle::AnalysisConfig config;
   std::string model_file = model_dir + OS_PATH_SEP + "__model__";