Browse Source

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleX into param_tuning

will-jl944 4 years ago
parent
commit
6a94817fa4
29 changed files with 161 additions and 86 deletions
  1. 6 4
      dygraph/deploy/cpp/CMakeLists.txt
  2. 2 2
      dygraph/deploy/cpp/demo/batch_infer.cpp
  3. 2 3
      dygraph/deploy/cpp/demo/model_infer.cpp
  4. 1 2
      dygraph/deploy/cpp/demo/onnx_tensorrt/model_infer.cpp
  5. 2 2
      dygraph/deploy/cpp/demo/onnx_triton/model_infer.cpp
  6. 2 2
      dygraph/deploy/cpp/demo/tensorrt_infer.cpp
  7. 3 3
      dygraph/deploy/cpp/docs/apis/model.md
  8. 3 2
      dygraph/deploy/cpp/model_deploy/common/include/base_model.h
  9. 4 4
      dygraph/deploy/cpp/model_deploy/common/include/model_factory.h
  10. 2 3
      dygraph/deploy/cpp/model_deploy/common/include/multi_gpu_model.h
  11. 1 1
      dygraph/deploy/cpp/model_deploy/common/include/paddle_deploy.h
  12. 23 11
      dygraph/deploy/cpp/model_deploy/common/include/transforms.h
  13. 1 0
      dygraph/deploy/cpp/model_deploy/common/src/base_preprocess.cpp
  14. 1 1
      dygraph/deploy/cpp/model_deploy/common/src/model_factory.cpp
  15. 1 7
      dygraph/deploy/cpp/model_deploy/common/src/transforms.cpp
  16. 7 0
      dygraph/deploy/cpp/model_deploy/engine/src/tensorrt_engine.cpp
  17. 4 0
      dygraph/deploy/cpp/model_deploy/ppclas/src/clas_postprocess.cpp
  18. 1 1
      dygraph/deploy/cpp/model_deploy/ppdet/src/det_postprocess.cpp
  19. 4 0
      dygraph/deploy/cpp/model_deploy/ppseg/src/seg_postprocess.cpp
  20. 4 4
      dygraph/paddlex/cv/datasets/voc.py
  21. 7 4
      dygraph/paddlex/cv/models/base.py
  22. 22 13
      dygraph/paddlex/cv/models/classifier.py
  23. 35 2
      dygraph/paddlex/cv/models/detector.py
  24. 8 2
      dygraph/paddlex/cv/models/segmenter.py
  25. 3 9
      dygraph/paddlex/cv/transforms/functions.py
  26. 5 2
      dygraph/paddlex/cv/transforms/operators.py
  27. 5 0
      dygraph/paddlex/utils/env.py
  28. 0 2
      dygraph/requirements.txt
  29. 2 0
      dygraph/submodules.txt

+ 6 - 4
dygraph/deploy/cpp/CMakeLists.txt

@@ -189,7 +189,7 @@ if (WIN32)
         add_definitions(-DSTATIC_LIB)
     endif()
 else()
-    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -o2 -fopenmp -std=c++11")
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -o3 -fopenmp -std=c++11")
     set(CMAKE_STATIC_LIBRARY_PREFIX "")
     set(EXTERNAL_LIB "-ldl -lrt -lgomp -lz -lm -lpthread")
     set(DEPS ${DEPS} ${EXTERNAL_LIB})
@@ -233,10 +233,12 @@ if(WIN32)
     COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_DIR}/third_party/install/mklml/lib/libiomp5md.dll ${CMAKE_BINARY_DIR}/paddle_deploy
     COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_DIR}/third_party/install/mkldnn/lib/mkldnn.dll  ${CMAKE_BINARY_DIR}/paddle_deploy
     COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_DIR}/paddle/lib/paddle_inference.dll ${CMAKE_BINARY_DIR}/paddle_deploy
-    if (WITH_TENSORRT)
+  )
+  if (WITH_TENSORRT)
+    add_custom_command(TARGET model_infer POST_BUILD
       COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_DIR}/lib/nvinfer.dll ${CMAKE_BINARY_DIR}/paddle_deploy
       COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_DIR}/lib/nvinfer_plugin.dll ${CMAKE_BINARY_DIR}/paddle_deploy
       COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_DIR}/lib/myelin64_1.dll ${CMAKE_BINARY_DIR}/paddle_deploy
-    endif()
-  )
+    )
+  endif()
 endif()

+ 2 - 2
dygraph/deploy/cpp/demo/batch_infer.cpp

@@ -35,8 +35,7 @@ int main(int argc, char** argv) {
   google::ParseCommandLineFlags(&argc, &argv, true);
 
   // create model
-  std::shared_ptr<PaddleDeploy::Model> model =
-        PaddleDeploy::CreateModel(FLAGS_model_type);
+  PaddleDeploy::Model* model = PaddleDeploy::CreateModel(FLAGS_model_type);
 
   // model init
   model->Init(FLAGS_cfg_file);
@@ -88,5 +87,6 @@ int main(int argc, char** argv) {
     }
   }
 
+  delete model;
   return 0;
 }

+ 2 - 3
dygraph/deploy/cpp/demo/model_infer.cpp

@@ -31,8 +31,7 @@ int main(int argc, char** argv) {
   google::ParseCommandLineFlags(&argc, &argv, true);
 
   // create model
-  std::shared_ptr<PaddleDeploy::Model> model =
-        PaddleDeploy::CreateModel(FLAGS_model_type);
+  PaddleDeploy::Model* model = PaddleDeploy::CreateModel(FLAGS_model_type);
 
   // model init
   model->Init(FLAGS_cfg_file);
@@ -54,6 +53,6 @@ int main(int argc, char** argv) {
   model->Predict(imgs, &results, 1);
 
   std::cout << results[0] << std::endl;
-
+  delete model;
   return 0;
 }

+ 1 - 2
dygraph/deploy/cpp/demo/onnx_tensorrt/model_infer.cpp

@@ -37,8 +37,7 @@ int main(int argc, char** argv) {
             << FLAGS_model_file << std::endl;
 
   // create model
-  std::shared_ptr<PaddleDeploy::Model> model =
-        PaddleDeploy::CreateModel(FLAGS_model_type);
+  PaddleDeploy::Model* model = PaddleDeploy::CreateModel(FLAGS_model_type);
   if (!model) {
     std::cout << "no model_type: " << FLAGS_model_type
               << "  model=" << model << std::endl;

+ 2 - 2
dygraph/deploy/cpp/demo/onnx_triton/model_infer.cpp

@@ -36,8 +36,7 @@ int main(int argc, char** argv) {
             << FLAGS_model_name << std::endl;
 
   // create model
-  std::shared_ptr<PaddleDeploy::Model> model =
-        PaddleDeploy::CreateModel(FLAGS_model_type);
+  PaddleDeploy::Model* model = PaddleDeploy::CreateModel(FLAGS_model_type);
   if (!model) {
     std::cout << "no model_type: " << FLAGS_model_type
               << "  model=" << model << std::endl;
@@ -97,5 +96,6 @@ int main(int argc, char** argv) {
       std::cout << results[j] << std::endl;
     }
   }
+  delete model;
   return 0;
 }

+ 2 - 2
dygraph/deploy/cpp/demo/tensorrt_infer.cpp

@@ -30,8 +30,7 @@ int main(int argc, char** argv) {
   google::ParseCommandLineFlags(&argc, &argv, true);
 
   // create model
-  std::shared_ptr<PaddleDeploy::Model> model =
-        PaddleDeploy::CreateModel(FLAGS_model_type);
+  PaddleDeploy::Model* model = PaddleDeploy::CreateModel(FLAGS_model_type);
 
   // model init
   model->Init(FLAGS_cfg_file);
@@ -74,5 +73,6 @@ int main(int argc, char** argv) {
 
   std::cout << results[0] << std::endl;
 
+  delete model;
   return 0;
 }

+ 3 - 3
dygraph/deploy/cpp/docs/apis/model.md

@@ -15,10 +15,10 @@
 ## 1. 创建模型对象
 
 ```c++
-std::shared_ptr<PaddleDeploy::Model>  PaddleDeploy::ModelFactory::CreateObject(const std::string  &model_type)
+PaddleDeploy::Model*  PaddleDeploy::ModelFactory::CreateObject(const std::string  &model_type)
 ```
 
-> 根据模型来源的套件类型,创建相应的套件对象并返回基类智能指针。所有推理相关的操作,包括预处理、推理和后处理都在该对象中。
+> 根据模型来源的套件类型,创建相应的套件对象并返回基类指针。所有推理相关的操作,包括预处理、推理和后处理都在该对象中。
 
 **参数**
 
@@ -31,7 +31,7 @@ std::shared_ptr<PaddleDeploy::Model>  PaddleDeploy::ModelFactory::CreateObject(c
 **代码示例**
 
 > ```c++
-> std::shared_ptr<PaddleDeploy::Model> model = PaddleDeploy::ModelFactory::CreateObject("det")
+> PaddleDeploy::Model* model = PaddleDeploy::ModelFactory::CreateObject("det")
 > ```
 
 

+ 3 - 2
dygraph/deploy/cpp/model_deploy/common/include/base_model.h

@@ -54,8 +54,9 @@ class Model {
   }
 
   virtual bool YamlConfigInit(const std::string& cfg_file) {
-    YAML::Node yaml_config_ = YAML::LoadFile(cfg_file);
-    return true;
+    // YAML::Node yaml_config_ = YAML::LoadFile(cfg_file);
+    std::cerr << "Error! The Base Model was incorrectly entered" << std::endl;
+    return false;
   }
 
   virtual bool PreprocessInit() {

+ 4 - 4
dygraph/deploy/cpp/model_deploy/common/include/model_factory.h

@@ -26,14 +26,14 @@
 
 namespace PaddleDeploy {
 
-typedef std::shared_ptr<Model> (*NewInstance)();
+typedef Model* (*NewInstance)();
 
 class ModelFactory {
  private:
   static std::map<std::string, NewInstance> model_map_;
 
  public:
-  static std::shared_ptr<Model> CreateObject(const std::string &name);
+  static Model* CreateObject(const std::string &name);
 
   static void Register(const std::string &name, NewInstance model);
 };
@@ -48,9 +48,9 @@ class Register {
 #define REGISTER_CLASS(model_type, class_name)                    \
   class class_name##Register {                                    \
    public:                                                        \
-    static std::shared_ptr<Model> newInstance() {                 \
+    static Model* newInstance() {                                 \
       std::cerr << "REGISTER_CLASS:" << #model_type << std::endl; \
-      return std::make_shared<class_name>(#model_type);           \
+      return new class_name(#model_type);           \
     }                                                             \
                                                                   \
    private:                                                       \

+ 2 - 3
dygraph/deploy/cpp/model_deploy/common/include/multi_gpu_model.h

@@ -30,8 +30,7 @@ class MultiGPUModel {
             const std::string& cfg_file, size_t gpu_num = 1) {
     models_.clear();
     for (auto i = 0; i < gpu_num; ++i) {
-      std::shared_ptr<Model> model =
-          PaddleDeploy::ModelFactory::CreateObject(model_type);
+      Model* model = PaddleDeploy::ModelFactory::CreateObject(model_type);
 
       if (!model) {
         std::cerr << "no model_type: " << model_type << std::endl;
@@ -45,7 +44,7 @@ class MultiGPUModel {
         return false;
       }
 
-      models_.push_back(model);
+      models_.push_back(std::shared_ptr<Model>(model));
     }
     return true;
   }

+ 1 - 1
dygraph/deploy/cpp/model_deploy/common/include/paddle_deploy.h

@@ -21,7 +21,7 @@
 #include "model_deploy/engine/include/engine.h"
 
 namespace PaddleDeploy {
-inline std::shared_ptr<Model> CreateModel(const std::string &name) {
+inline Model* CreateModel(const std::string &name) {
   return PaddleDeploy::ModelFactory::CreateObject(name);
 }
 }  // namespace PaddleDeploy

+ 23 - 11
dygraph/deploy/cpp/model_deploy/common/include/transforms.h

@@ -42,22 +42,37 @@ class Transform {
 class Normalize : public Transform {
  public:
   virtual void Init(const YAML::Node& item) {
-    mean_ = item["mean"].as<std::vector<float>>();
-    std_ = item["std"].as<std::vector<float>>();
+    std::vector<double> mean_ = item["mean"].as<std::vector<double>>();
+    std::vector<double> std_ = item["std"].as<std::vector<double>>();
+    bool is_scale_;
+    std::vector<double> min_val_;
+    std::vector<double> max_val_;
     if (item["is_scale"].IsDefined()) {
       is_scale_ = item["is_scale"];
     } else {
       is_scale_ = true;
     }
     if (item["min_val"].IsDefined()) {
-      min_val_ = item["min_val"].as<std::vector<float>>();
+      min_val_ = item["min_val"].as<std::vector<double>>();
     } else {
-      min_val_ = std::vector<float>(mean_.size(), 0.);
+      min_val_ = std::vector<double>(mean_.size(), 0.);
     }
     if (item["max_val"].IsDefined()) {
-      max_val_ = item["max_val"].as<std::vector<float>>();
+      max_val_ = item["max_val"].as<std::vector<double>>();
     } else {
-      max_val_ = std::vector<float>(mean_.size(), 255.);
+      max_val_ = std::vector<double>(mean_.size(), 255.);
+    }
+
+    for (auto c = 0; c < mean_.size(); c++) {
+      double alpha = 1.0;
+      if (is_scale_) {
+        alpha /= (max_val_[c] - min_val_[c]);
+      }
+      alpha /= std_[c];
+      double beta = -1.0 * mean_[c] / std_[c];
+
+      alpha_.push_back(alpha);
+      beta_.push_back(beta);
     }
   }
   virtual bool Run(cv::Mat* im);
@@ -67,11 +82,8 @@ class Normalize : public Transform {
 
 
  private:
-  bool is_scale_;
-  std::vector<float> mean_;
-  std::vector<float> std_;
-  std::vector<float> min_val_;
-  std::vector<float> max_val_;
+  std::vector<float> alpha_;
+  std::vector<float> beta_;
 };
 
 class ResizeByShort : public Transform {

+ 1 - 0
dygraph/deploy/cpp/model_deploy/common/src/base_preprocess.cpp

@@ -26,6 +26,7 @@ bool BasePreprocess::BuildTransform(const YAML::Node& yaml_config) {
     std::string name = it->first.as<std::string>();
     std::shared_ptr<Transform> transform = CreateTransform(name);
     if (!transform) {
+      std::cerr << "Failed to create " << name << " on Preprocess" << std::endl;
       return false;
     }
     transform->Init(it->second);

+ 1 - 1
dygraph/deploy/cpp/model_deploy/common/src/model_factory.cpp

@@ -23,7 +23,7 @@ REGISTER_CLASS(seg, SegModel);
 REGISTER_CLASS(clas, ClasModel);
 REGISTER_CLASS(paddlex, PaddleXModel);
 
-std::shared_ptr<Model> ModelFactory::CreateObject(const std::string &name) {
+Model* ModelFactory::CreateObject(const std::string &name) {
   std::map<std::string, NewInstance>::const_iterator it;
   it = model_map_.find(name);
   if (it == model_map_.end())

+ 1 - 7
dygraph/deploy/cpp/model_deploy/common/src/transforms.cpp

@@ -26,13 +26,7 @@ bool Normalize::Run(cv::Mat *im) {
   std::vector<cv::Mat> split_im;
   cv::split(*im, split_im);
   for (int c = 0; c < im->channels(); c++) {
-    cv::subtract(split_im[c], cv::Scalar(min_val_[c]), split_im[c]);
-    if (is_scale_) {
-      float range_val = max_val_[c] - min_val_[c];
-      cv::divide(split_im[c], cv::Scalar(range_val), split_im[c]);
-    }
-    cv::subtract(split_im[c], cv::Scalar(mean_[c]), split_im[c]);
-    cv::divide(split_im[c], cv::Scalar(std_[c]), split_im[c]);
+    split_im[c].convertTo(split_im[c], CV_32FC1, alpha_[c], beta_[c]);
   }
   cv::merge(split_im, *im);
   return true;

+ 7 - 0
dygraph/deploy/cpp/model_deploy/engine/src/tensorrt_engine.cpp

@@ -70,6 +70,7 @@ bool TensorRTInferenceEngine::Init(const InferenceConfig& engine_config) {
   auto builder = InferUniquePtr<nvinfer1::IBuilder>(
                      nvinfer1::createInferBuilder(logger_));
   if (!builder) {
+    std::cerr << "TensorRT init builder error" << std::endl;
     return false;
   }
 
@@ -78,22 +79,26 @@ bool TensorRTInferenceEngine::Init(const InferenceConfig& engine_config) {
   auto network = InferUniquePtr<nvinfer1::INetworkDefinition>(
                      builder->createNetworkV2(explicitBatch));
   if (!network) {
+    std::cerr << "TensorRT init network error" << std::endl;
     return false;
   }
 
   auto parser = InferUniquePtr<nvonnxparser::IParser>(
                     nvonnxparser::createParser(*network, logger_));
   if (!parser) {
+    std::cerr << "TensorRT init parser error" << std::endl;
     return false;
   }
   if (!parser->parseFromFile(tensorrt_config.model_file_.c_str(),
                              static_cast<int>(logger_.mReportableSeverity))) {
+    std::cerr << "TensorRT init model_file error" << std::endl;
     return false;
   }
 
   auto config = InferUniquePtr<nvinfer1::IBuilderConfig>(
                      builder->createBuilderConfig());
   if (!config) {
+    std::cerr << "TensorRT init config error" << std::endl;
     return false;
   }
 
@@ -130,6 +135,7 @@ bool TensorRTInferenceEngine::Init(const InferenceConfig& engine_config) {
                     engine_->createExecutionContext(),
                     InferDeleter());
   if (!context_) {
+    std::cerr << "TensorRT init context error" << std::endl;
     return false;
   }
 
@@ -229,6 +235,7 @@ bool TensorRTInferenceEngine::Infer(const std::vector<DataBlob>& input_blobs,
   buffers.copyInputToDevice();
   bool status = context_->executeV2(buffers.getDeviceBindings().data());
   if (!status) {
+    std::cerr << "TensorRT create execute error" << std::endl;
     return false;
   }
   buffers.copyOutputToHost();

+ 4 - 0
dygraph/deploy/cpp/model_deploy/ppclas/src/clas_postprocess.cpp

@@ -30,6 +30,10 @@ bool ClasPostprocess::Init(const YAML::Node& yaml_config) {
 bool ClasPostprocess::Run(const std::vector<DataBlob>& outputs,
                          const std::vector<ShapeInfo>& shape_infos,
                          std::vector<Result>* results, int thread_num) {
+  if (outputs.size() == 0) {
+    std::cerr << "empty output on ClasPostprocess" << std::endl;
+    return false;
+  }
   results->clear();
   int batch_size = shape_infos.size();
   results->resize(batch_size);

+ 1 - 1
dygraph/deploy/cpp/model_deploy/ppdet/src/det_postprocess.cpp

@@ -197,7 +197,7 @@ bool DetPostprocess::Run(const std::vector<DataBlob>& outputs,
                          std::vector<Result>* results, int thread_num) {
   results->clear();
   if (outputs.size() == 0) {
-    std::cerr << "empty input image on DetPostprocess" << std::endl;
+    std::cerr << "empty output on DetPostprocess" << std::endl;
     return true;
   }
   results->resize(shape_infos.size());

+ 4 - 0
dygraph/deploy/cpp/model_deploy/ppseg/src/seg_postprocess.cpp

@@ -89,6 +89,10 @@ bool SegPostprocess::RunV2(const DataBlob& output,
 bool SegPostprocess::Run(const std::vector<DataBlob>& outputs,
                          const std::vector<ShapeInfo>& shape_infos,
                          std::vector<Result>* results, int thread_num) {
+  if (outputs.size() == 0) {
+    std::cerr << "empty output on SegPostprocess" << std::endl;
+    return true;
+  }
   results->clear();
   int batch_size = shape_infos.size();
   results->resize(batch_size);

+ 4 - 4
dygraph/paddlex/cv/datasets/voc.py

@@ -89,7 +89,7 @@ class VOCDetection(Dataset):
         for k, v in cname2cid.items():
             annotations['categories'].append({
                 'supercategory': 'component',
-                'id': v,
+                'id': v + 1,
                 'name': k
             })
         ct = 0
@@ -113,11 +113,11 @@ class VOCDetection(Dataset):
                 if not osp.isfile(xml_file):
                     continue
                 if not osp.exists(img_file):
-                    logging.warning('The image file {} is not exist!'.format(
+                    logging.warning('The image file {} does not exist!'.format(
                         img_file))
                     continue
                 if not osp.exists(xml_file):
-                    logging.warning('The annotation file {} is not exist!'.
+                    logging.warning('The annotation file {} does not exist!'.
                                     format(xml_file))
                     continue
                 tree = ET.parse(xml_file)
@@ -219,7 +219,7 @@ class VOCDetection(Dataset):
                         'image_id': int(im_id[0]),
                         'bbox': [x1, y1, x2 - x1 + 1, y2 - y1 + 1],
                         'area': float((x2 - x1 + 1) * (y2 - y1 + 1)),
-                        'category_id': cname2cid[cname],
+                        'category_id': cname2cid[cname] + 1,
                         'id': ann_ct,
                         'difficult': _difficult
                     })

+ 7 - 4
dygraph/paddlex/cv/models/base.py

@@ -191,11 +191,14 @@ class BaseModel:
             shuffle=dataset.shuffle,
             drop_last=mode == 'train')
 
-        shm_size = _get_shared_memory_size_in_M()
-        if shm_size is None or shm_size < 1024.:
-            use_shared_memory = False
+        if dataset.num_workers > 0:
+            shm_size = _get_shared_memory_size_in_M()
+            if shm_size is None or shm_size < 1024.:
+                use_shared_memory = False
+            else:
+                use_shared_memory = True
         else:
-            use_shared_memory = True
+            use_shared_memory = False
 
         loader = DataLoader(
             dataset,

+ 22 - 13
dygraph/paddlex/cv/models/classifier.py

@@ -243,7 +243,7 @@ class BaseClassifier(BaseModel):
                     "If don't want to use pretrain weights, "
                     "set pretrain_weights to be None.")
                 pretrain_weights = 'IMAGENET'
-        elif osp.exists(pretrain_weights):
+        elif pretrain_weights is not None and osp.exists(pretrain_weights):
             if osp.splitext(pretrain_weights)[-1] != '.pdparams':
                 logging.error(
                     "Invalid pretrain weights. Please specify a '.pdparams' file.",
@@ -538,10 +538,13 @@ class AlexNet(BaseClassifier):
                 image_shape = [None, 3] + image_shape
         else:
             image_shape = [None, 3, 224, 224]
-            logging.info('When exporting inference model for {},'.format(
-                self.__class__.__name__
-            ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
-                         )
+            logging.warning(
+                '[Important!!!] When exporting inference model for {},'.format(
+                    self.__class__.__name__) +
+                ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
+                +
+                'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
+                + 'should be specified manually.')
         self._fix_transforms_shape(image_shape[-2:])
 
         input_spec = [
@@ -743,10 +746,13 @@ class ShuffleNetV2(BaseClassifier):
                 image_shape = [None, 3] + image_shape
         else:
             image_shape = [None, 3, 224, 224]
-            logging.info('When exporting inference model for {},'.format(
-                self.__class__.__name__
-            ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
-                         )
+            logging.warning(
+                '[Important!!!] When exporting inference model for {},'.format(
+                    self.__class__.__name__) +
+                ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
+                +
+                'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
+                + 'should be specified manually.')
         self._fix_transforms_shape(image_shape[-2:])
         input_spec = [
             InputSpec(
@@ -766,10 +772,13 @@ class ShuffleNetV2_swish(BaseClassifier):
                 image_shape = [None, 3] + image_shape
         else:
             image_shape = [None, 3, 224, 224]
-            logging.info('When exporting inference model for {},'.format(
-                self.__class__.__name__
-            ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
-                         )
+            logging.warning(
+                '[Important!!!] When exporting inference model for {},'.format(
+                    self.__class__.__name__) +
+                ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
+                +
+                'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
+                + 'should be specified manually.')
         self._fix_transforms_shape(image_shape[-2:])
         input_spec = [
             InputSpec(

+ 35 - 2
dygraph/paddlex/cv/models/detector.py

@@ -240,7 +240,7 @@ class BaseDetector(BaseModel):
                                 "If you don't want to use pretrain weights, "
                                 "set pretrain_weights to be None.".format(
                                     pretrain_weights))
-        elif osp.exists(pretrain_weights):
+        elif pretrain_weights is not None and osp.exists(pretrain_weights):
             if osp.splitext(pretrain_weights)[-1] != '.pdparams':
                 logging.error(
                     "Invalid pretrain weights. Please specify a '.pdparams' file.",
@@ -512,7 +512,7 @@ class BaseDetector(BaseModel):
                     h = ymax - ymin
                     bbox = [xmin, ymin, w, h]
                     dt_res = {
-                        'category_id': int(num_id),
+                        'category_id': int(num_id) + 1,
                         'category': category,
                         'bbox': bbox,
                         'score': score
@@ -544,6 +544,7 @@ class BaseDetector(BaseModel):
                         if 'counts' in rle:
                             rle['counts'] = rle['counts'].decode("utf8")
                     sg_res = {
+                        'category_id': int(label) + 1,
                         'category': category,
                         'segmentation': rle,
                         'score': score
@@ -1377,6 +1378,38 @@ class PPYOLOv2(YOLOv3):
         self.downsample_ratios = downsample_ratios
         self.model_name = 'PPYOLOv2'
 
+    def _get_test_inputs(self, image_shape):
+        if image_shape is not None:
+            if len(image_shape) == 2:
+                image_shape = [None, 3] + image_shape
+            if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
+                raise Exception(
+                    "Height and width in fixed_input_shape must be a multiple of 32, but recieved is {}.".
+                    format(image_shape[-2:]))
+            self._fix_transforms_shape(image_shape[-2:])
+        else:
+            logging.warning(
+                '[Important!!!] When exporting inference model for {},'.format(
+                    self.__class__.__name__) +
+                ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 608, 608]. '
+                +
+                'Please check image shape after transforms is [3, 608, 608], if not, fixed_input_shape '
+                + 'should be specified manually.')
+            image_shape = [None, 3, 608, 608]
+
+        input_spec = [{
+            "image": InputSpec(
+                shape=image_shape, name='image', dtype='float32'),
+            "im_shape": InputSpec(
+                shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
+            "scale_factor": InputSpec(
+                shape=[image_shape[0], 2],
+                name='scale_factor',
+                dtype='float32')
+        }]
+
+        return input_spec
+
 
 class MaskRCNN(BaseDetector):
     def __init__(self,

+ 8 - 2
dygraph/paddlex/cv/models/segmenter.py

@@ -241,7 +241,7 @@ class BaseSegmenter(BaseModel):
                                         0]))
                 pretrain_weights = seg_pretrain_weights_dict[self.model_name][
                     0]
-        elif osp.exists(pretrain_weights):
+        elif pretrain_weights is not None and osp.exists(pretrain_weights):
             if osp.splitext(pretrain_weights)[-1] != '.pdparams':
                 logging.error(
                     "Invalid pretrain weights. Please specify a '.pdparams' file.",
@@ -463,7 +463,13 @@ class BaseSegmenter(BaseModel):
         label_map = label_map.numpy().astype('uint8')
         score_map = outputs['score_map']
         score_map = score_map.numpy().astype('float32')
-        return {'label_map': label_map, 'score_map': score_map}
+        prediction = [{
+            'label_map': l,
+            'score_map': s
+        } for l, s in zip(label_map, score_map)]
+        if isinstance(img_file, (str, np.ndarray)):
+            prediction = prediction[0]
+        return prediction
 
     def _preprocess(self, images, transforms, model_type):
         arrange_transforms(

+ 3 - 9
dygraph/paddlex/cv/transforms/functions.py

@@ -47,23 +47,17 @@ def center_crop(im, crop_size=224):
     h_start = (height - crop_size) // 2
     w_end = w_start + crop_size
     h_end = h_start + crop_size
-    im = im[h_start:h_end, w_start:w_end, :]
+    im = im[h_start:h_end, w_start:w_end, ...]
     return im
 
 
 def horizontal_flip(im):
-    if len(im.shape) == 3:
-        im = im[:, ::-1, :]
-    elif len(im.shape) == 2:
-        im = im[:, ::-1]
+    im = im[:, ::-1, ...]
     return im
 
 
 def vertical_flip(im):
-    if len(im.shape) == 3:
-        im = im[::-1, :, :]
-    elif len(im.shape) == 2:
-        im = im[::-1, :]
+    im = im[::-1, :, ...]
     return im
 
 

+ 5 - 2
dygraph/paddlex/cv/transforms/operators.py

@@ -814,7 +814,7 @@ class RandomCrop(Transform):
 
     def apply_mask(self, mask, crop):
         x1, y1, x2, y2 = crop
-        return mask[y1:y2, x1:x2, :]
+        return mask[y1:y2, x1:x2, ...]
 
     def apply(self, sample):
         crop_info = self._generate_crop_info(sample)
@@ -867,13 +867,16 @@ class RandomCrop(Transform):
 
 class RandomExpand(Transform):
     """
-    Randomly expand the input by padding to the lower right side of the image(s) in input.
+    Randomly expand the input by padding according to random offsets.
 
     Args:
         upper_ratio(float, optional): The maximum ratio to which the original image is expanded. Defaults to 4..
         prob(float, optional): The probability of apply expanding. Defaults to .5.
         im_padding_value(List[float] or Tuple[float], optional): RGB filling value for the image. Defaults to (127.5, 127.5, 127.5).
         label_padding_value(int, optional): Filling value for the mask. Defaults to 255.
+
+    See Also:
+        paddlex.transforms.Padding
     """
 
     def __init__(self,

+ 5 - 0
dygraph/paddlex/utils/env.py

@@ -16,6 +16,7 @@ import sys
 import glob
 import os
 import os.path as osp
+import platform
 import random
 import numpy as np
 import multiprocessing as mp
@@ -47,6 +48,10 @@ def get_environ_info():
 
 
 def get_num_workers(num_workers):
+    if not platform.system() == 'Linux':
+        # Dataloader with multi-process model is not supported
+        # on MacOS and Windows currently.
+        return 0
     if num_workers == 'auto':
         num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 2 else 2
     return num_workers

+ 0 - 2
dygraph/requirements.txt

@@ -1,8 +1,6 @@
 -r ./PaddleClas/requirements.txt
 -r ./PaddleSeg/requirements.txt
-./PaddleSeg
 -r ./PaddleDetection/requirements.txt
-./PaddleDetection
 tqdm
 scipy
 colorama

+ 2 - 0
dygraph/submodules.txt

@@ -0,0 +1,2 @@
+./PaddleSeg
+./PaddleDetection