Browse Source

Merge pull request #1 from PaddlePaddle/develop

0
SunAhong1993 5 years ago
parent
commit
b7f908353e
53 changed files with 2081 additions and 383 deletions
  1. 3 1
      deploy/README.md
  2. 8 8
      deploy/cpp/CMakeLists.txt
  3. 3 1
      deploy/cpp/include/paddlex/paddlex.h
  4. 4 6
      deploy/cpp/include/paddlex/transforms.h
  5. 11 1
      deploy/cpp/scripts/build.sh
  6. 2 1
      deploy/cpp/src/classifier.cpp
  7. 4 3
      deploy/cpp/src/detector.cpp
  8. 21 6
      deploy/cpp/src/paddlex.cpp
  9. 3 1
      deploy/cpp/src/segmenter.cpp
  10. 7 10
      deploy/cpp/src/transforms.cpp
  11. 1 1
      docs/apis/load_model.md
  12. 17 11
      docs/apis/models.md
  13. 3 1
      docs/apis/transforms/cls_transforms.md
  14. 9 3
      docs/apis/transforms/det_transforms.md
  15. 3 1
      docs/apis/transforms/seg_transforms.md
  16. 14 5
      docs/deploy/deploy.md
  17. 34 14
      docs/deploy/deploy_cpp_linux.md
  18. 18 8
      docs/deploy/deploy_cpp_win_vs2019.md
  19. 1 1
      docs/index.rst
  20. 19 7
      paddlex/__init__.py
  21. 20 2
      paddlex/command.py
  22. 7 3
      paddlex/cv/datasets/coco.py
  23. 8 0
      paddlex/cv/datasets/dataset.py
  24. 1 1
      paddlex/cv/datasets/easydata_det.py
  25. 1 2
      paddlex/cv/datasets/voc.py
  26. 59 23
      paddlex/cv/models/base.py
  27. 15 5
      paddlex/cv/models/classifier.py
  28. 16 11
      paddlex/cv/models/deeplabv3p.py
  29. 13 7
      paddlex/cv/models/faster_rcnn.py
  30. 31 1
      paddlex/cv/models/load_model.py
  31. 12 6
      paddlex/cv/models/mask_rcnn.py
  32. 20 16
      paddlex/cv/models/slim/prune_config.py
  33. 12 9
      paddlex/cv/models/unet.py
  34. 44 31
      paddlex/cv/models/utils/pretrain_weights.py
  35. 14 4
      paddlex/cv/models/utils/visualize.py
  36. 8 3
      paddlex/cv/models/yolo_v3.py
  37. 13 3
      paddlex/cv/nets/detection/faster_rcnn.py
  38. 13 3
      paddlex/cv/nets/detection/mask_rcnn.py
  39. 12 3
      paddlex/cv/nets/detection/yolo_v3.py
  40. 14 3
      paddlex/cv/nets/segmentation/deeplabv3p.py
  41. 14 3
      paddlex/cv/nets/segmentation/unet.py
  42. 50 18
      paddlex/cv/transforms/cls_transforms.py
  43. 139 72
      paddlex/cv/transforms/det_transforms.py
  44. 131 0
      paddlex/cv/transforms/imgaug_support.py
  45. 198 47
      paddlex/cv/transforms/seg_transforms.py
  46. 24 0
      paddlex/tools/__init__.py
  47. 43 0
      paddlex/tools/base.py
  48. 257 0
      paddlex/tools/x2coco.py
  49. 58 0
      paddlex/tools/x2imagenet.py
  50. 332 0
      paddlex/tools/x2seg.py
  51. 199 0
      paddlex/tools/x2voc.py
  52. 115 13
      paddlex/utils/utils.py
  53. 3 4
      setup.py

+ 3 - 1
deploy/README.md

@@ -1,3 +1,5 @@
 # 模型部署
 # 模型部署
 
 
-本目录为PaddleX模型部署代码。
+本目录为PaddleX模型部署代码, 编译和使用的教程参考:
+
+- [C++部署文档](../docs/deploy/deploy.md#C部署)

+ 8 - 8
deploy/cpp/CMakeLists.txt

@@ -3,9 +3,10 @@ project(PaddleX CXX C)
 
 
 option(WITH_MKL        "Compile demo with MKL/OpenBlas support,defaultuseMKL."          ON)
 option(WITH_MKL        "Compile demo with MKL/OpenBlas support,defaultuseMKL."          ON)
 option(WITH_GPU        "Compile demo with GPU/CPU, default use CPU."                    ON)
 option(WITH_GPU        "Compile demo with GPU/CPU, default use CPU."                    ON)
-option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static."   ON)
+option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static."   OFF)
 option(WITH_TENSORRT "Compile demo with TensorRT."   OFF)
 option(WITH_TENSORRT "Compile demo with TensorRT."   OFF)
 
 
+SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
 SET(PADDLE_DIR "" CACHE PATH "Location of libraries")
 SET(PADDLE_DIR "" CACHE PATH "Location of libraries")
 SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
 SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
 SET(CUDA_LIB "" CACHE PATH "Location of libraries")
 SET(CUDA_LIB "" CACHE PATH "Location of libraries")
@@ -111,8 +112,8 @@ endif()
 
 
 if (NOT WIN32)
 if (NOT WIN32)
   if (WITH_TENSORRT AND WITH_GPU)
   if (WITH_TENSORRT AND WITH_GPU)
-      include_directories("${PADDLE_DIR}/third_party/install/tensorrt/include")
-      link_directories("${PADDLE_DIR}/third_party/install/tensorrt/lib")
+      include_directories("${TENSORRT_DIR}/include")
+      link_directories("${TENSORRT_DIR}/lib")
   endif()
   endif()
 endif(NOT WIN32)
 endif(NOT WIN32)
 
 
@@ -169,7 +170,7 @@ endif()
 
 
 if (NOT WIN32)
 if (NOT WIN32)
     set(DEPS ${DEPS}
     set(DEPS ${DEPS}
-        ${MATH_LIB} ${MKLDNN_LIB} 
+        ${MATH_LIB} ${MKLDNN_LIB}
         glog gflags protobuf z xxhash yaml-cpp
         glog gflags protobuf z xxhash yaml-cpp
         )
         )
     if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
     if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
@@ -194,8 +195,8 @@ endif(NOT WIN32)
 if(WITH_GPU)
 if(WITH_GPU)
   if(NOT WIN32)
   if(NOT WIN32)
     if (WITH_TENSORRT)
     if (WITH_TENSORRT)
-      set(DEPS ${DEPS} ${PADDLE_DIR}/third_party/install/tensorrt/lib/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX})
-      set(DEPS ${DEPS} ${PADDLE_DIR}/third_party/install/tensorrt/lib/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX})
+      set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX})
+      set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX})
     endif()
     endif()
     set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
     set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
     set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX})
     set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX})
@@ -211,7 +212,7 @@ if (NOT WIN32)
     set(DEPS ${DEPS} ${EXTERNAL_LIB})
     set(DEPS ${DEPS} ${EXTERNAL_LIB})
 endif()
 endif()
 
 
-set(DEPS ${DEPS} ${OpenCV_LIBS}) 
+set(DEPS ${DEPS} ${OpenCV_LIBS})
 add_executable(classifier src/classifier.cpp src/transforms.cpp src/paddlex.cpp)
 add_executable(classifier src/classifier.cpp src/transforms.cpp src/paddlex.cpp)
 ADD_DEPENDENCIES(classifier ext-yaml-cpp)
 ADD_DEPENDENCIES(classifier ext-yaml-cpp)
 target_link_libraries(classifier ${DEPS})
 target_link_libraries(classifier ${DEPS})
@@ -251,4 +252,3 @@ if (WIN32 AND WITH_MKL)
     )
     )
 
 
 endif()
 endif()
-

+ 3 - 1
deploy/cpp/include/paddlex/paddlex.h

@@ -38,12 +38,14 @@ class Model {
  public:
  public:
   void Init(const std::string& model_dir,
   void Init(const std::string& model_dir,
             bool use_gpu = false,
             bool use_gpu = false,
+            bool use_trt = false,
             int gpu_id = 0) {
             int gpu_id = 0) {
-    create_predictor(model_dir, use_gpu, gpu_id);
+    create_predictor(model_dir, use_gpu, use_trt, gpu_id);
   }
   }
 
 
   void create_predictor(const std::string& model_dir,
   void create_predictor(const std::string& model_dir,
                         bool use_gpu = false,
                         bool use_gpu = false,
+                        bool use_trt = false,
                         int gpu_id = 0);
                         int gpu_id = 0);
 
 
   bool load_config(const std::string& model_dir);
   bool load_config(const std::string& model_dir);

+ 4 - 6
deploy/cpp/include/paddlex/transforms.h

@@ -35,10 +35,8 @@ class ImageBlob {
   std::vector<int> ori_im_size_ = std::vector<int>(2);
   std::vector<int> ori_im_size_ = std::vector<int>(2);
   // Newest image height and width after process
   // Newest image height and width after process
   std::vector<int> new_im_size_ = std::vector<int>(2);
   std::vector<int> new_im_size_ = std::vector<int>(2);
-  // Image height and width before padding
-  std::vector<int> im_size_before_padding_ = std::vector<int>(2);
   // Image height and width before resize
   // Image height and width before resize
-  std::vector<int> im_size_before_resize_ = std::vector<int>(2);
+  std::vector<std::vector<int>> im_size_before_resize_;
   // Reshape order
   // Reshape order
   std::vector<std::string> reshape_order_;
   std::vector<std::string> reshape_order_;
   // Resize scale
   // Resize scale
@@ -49,7 +47,6 @@ class ImageBlob {
   void clear() {
   void clear() {
     ori_im_size_.clear();
     ori_im_size_.clear();
     new_im_size_.clear();
     new_im_size_.clear();
-    im_size_before_padding_.clear();
     im_size_before_resize_.clear();
     im_size_before_resize_.clear();
     reshape_order_.clear();
     reshape_order_.clear();
     im_data_.clear();
     im_data_.clear();
@@ -155,12 +152,13 @@ class Padding : public Transform {
   virtual void Init(const YAML::Node& item) {
   virtual void Init(const YAML::Node& item) {
     if (item["coarsest_stride"].IsDefined()) {
     if (item["coarsest_stride"].IsDefined()) {
       coarsest_stride_ = item["coarsest_stride"].as<int>();
       coarsest_stride_ = item["coarsest_stride"].as<int>();
-      if (coarsest_stride_ <= 1) {
+      if (coarsest_stride_ < 1) {
         std::cerr << "[Padding] coarest_stride should greater than 0"
         std::cerr << "[Padding] coarest_stride should greater than 0"
                   << std::endl;
                   << std::endl;
         exit(-1);
         exit(-1);
       }
       }
-    } else {
+    }
+    if (item["target_size"].IsDefined()) {
       if (item["target_size"].IsScalar()) {
       if (item["target_size"].IsScalar()) {
         width_ = item["target_size"].as<int>();
         width_ = item["target_size"].as<int>();
         height_ = item["target_size"].as<int>();
         height_ = item["target_size"].as<int>();

+ 11 - 1
deploy/cpp/scripts/build.sh

@@ -1,9 +1,16 @@
 # 是否使用GPU(即是否使用 CUDA)
 # 是否使用GPU(即是否使用 CUDA)
-WITH_GPU=ON
+WITH_GPU=OFF
+# 使用MKL or openblas
+WITH_MKL=ON
 # 是否集成 TensorRT(仅WITH_GPU=ON 有效)
 # 是否集成 TensorRT(仅WITH_GPU=ON 有效)
 WITH_TENSORRT=OFF
 WITH_TENSORRT=OFF
+# TensorRT 的lib路径
+TENSORRT_DIR=/path/to/TensorRT/
 # Paddle 预测库路径
 # Paddle 预测库路径
 PADDLE_DIR=/path/to/fluid_inference/
 PADDLE_DIR=/path/to/fluid_inference/
+# Paddle 的预测库是否使用静态库来编译
+# 使用TensorRT时,Paddle的预测库通常为动态库
+WITH_STATIC_LIB=OFF
 # CUDA 的 lib 路径
 # CUDA 的 lib 路径
 CUDA_LIB=/path/to/cuda/lib/
 CUDA_LIB=/path/to/cuda/lib/
 # CUDNN 的 lib 路径
 # CUDNN 的 lib 路径
@@ -19,8 +26,11 @@ mkdir -p build
 cd build
 cd build
 cmake .. \
 cmake .. \
     -DWITH_GPU=${WITH_GPU} \
     -DWITH_GPU=${WITH_GPU} \
+    -DWITH_MKL=${WITH_MKL} \
     -DWITH_TENSORRT=${WITH_TENSORRT} \
     -DWITH_TENSORRT=${WITH_TENSORRT} \
+    -DTENSORRT_DIR=${TENSORRT_DIR} \
     -DPADDLE_DIR=${PADDLE_DIR} \
     -DPADDLE_DIR=${PADDLE_DIR} \
+    -DWITH_STATIC_LIB=${WITH_STATIC_LIB} \
     -DCUDA_LIB=${CUDA_LIB} \
     -DCUDA_LIB=${CUDA_LIB} \
     -DCUDNN_LIB=${CUDNN_LIB} \
     -DCUDNN_LIB=${CUDNN_LIB} \
     -DOPENCV_DIR=${OPENCV_DIR}
     -DOPENCV_DIR=${OPENCV_DIR}

+ 2 - 1
deploy/cpp/src/classifier.cpp

@@ -23,6 +23,7 @@
 
 
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
+DEFINE_bool(use_trt, false, "Infering with TensorRT");
 DEFINE_int32(gpu_id, 0, "GPU card id");
 DEFINE_int32(gpu_id, 0, "GPU card id");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
 DEFINE_string(image_list, "", "Path of test image list file");
@@ -42,7 +43,7 @@ int main(int argc, char** argv) {
 
 
   // 加载模型
   // 加载模型
   PaddleX::Model model;
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_gpu_id);
+  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id);
 
 
   // 进行预测
   // 进行预测
   if (FLAGS_image_list != "") {
   if (FLAGS_image_list != "") {

+ 4 - 3
deploy/cpp/src/detector.cpp

@@ -24,6 +24,7 @@
 
 
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
+DEFINE_bool(use_trt, false, "Infering with TensorRT");
 DEFINE_int32(gpu_id, 0, "GPU card id");
 DEFINE_int32(gpu_id, 0, "GPU card id");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
 DEFINE_string(image_list, "", "Path of test image list file");
@@ -44,7 +45,7 @@ int main(int argc, char** argv) {
 
 
   // 加载模型
   // 加载模型
   PaddleX::Model model;
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_gpu_id);
+  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id);
 
 
   auto colormap = PaddleX::GenerateColorMap(model.labels.size());
   auto colormap = PaddleX::GenerateColorMap(model.labels.size());
   std::string save_dir = "output";
   std::string save_dir = "output";
@@ -68,7 +69,7 @@ int main(int argc, char** argv) {
                   << result.boxes[i].coordinate[0] << ", "
                   << result.boxes[i].coordinate[0] << ", "
                   << result.boxes[i].coordinate[1] << ", "
                   << result.boxes[i].coordinate[1] << ", "
                   << result.boxes[i].coordinate[2] << ", "
                   << result.boxes[i].coordinate[2] << ", "
-                  << result.boxes[i].coordinate[3] << std::endl;
+                  << result.boxes[i].coordinate[3] << ")" << std::endl;
       }
       }
 
 
       // 可视化
       // 可视化
@@ -91,7 +92,7 @@ int main(int argc, char** argv) {
                 << result.boxes[i].coordinate[0] << ", "
                 << result.boxes[i].coordinate[0] << ", "
                 << result.boxes[i].coordinate[1] << ", "
                 << result.boxes[i].coordinate[1] << ", "
                 << result.boxes[i].coordinate[2] << ", "
                 << result.boxes[i].coordinate[2] << ", "
-                << result.boxes[i].coordinate[3] << std::endl;
+                << result.boxes[i].coordinate[3] << ")" << std::endl;
     }
     }
 
 
     // 可视化
     // 可视化

+ 21 - 6
deploy/cpp/src/paddlex.cpp

@@ -18,6 +18,7 @@ namespace PaddleX {
 
 
 void Model::create_predictor(const std::string& model_dir,
 void Model::create_predictor(const std::string& model_dir,
                              bool use_gpu,
                              bool use_gpu,
+                             bool use_trt,
                              int gpu_id) {
                              int gpu_id) {
   // 读取配置文件
   // 读取配置文件
   if (!load_config(model_dir)) {
   if (!load_config(model_dir)) {
@@ -37,6 +38,15 @@ void Model::create_predictor(const std::string& model_dir,
   config.SwitchSpecifyInputNames(true);
   config.SwitchSpecifyInputNames(true);
   // 开启内存优化
   // 开启内存优化
   config.EnableMemoryOptim();
   config.EnableMemoryOptim();
+  if (use_trt) {
+    config.EnableTensorRtEngine(
+        1 << 20 /* workspace_size*/,
+        32 /* max_batch_size*/,
+        20 /* min_subgraph_size*/,
+        paddle::AnalysisConfig::Precision::kFloat32 /* precision*/,
+        true /* use_static*/,
+        false /* use_calib_mode*/);
+  }
   predictor_ = std::move(CreatePaddlePredictor(config));
   predictor_ = std::move(CreatePaddlePredictor(config));
 }
 }
 
 
@@ -246,7 +256,6 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
   auto im_tensor = predictor_->GetInputTensor("image");
   auto im_tensor = predictor_->GetInputTensor("image");
   im_tensor->Reshape({1, 3, h, w});
   im_tensor->Reshape({1, 3, h, w});
   im_tensor->copy_from_cpu(inputs_.im_data_.data());
   im_tensor->copy_from_cpu(inputs_.im_data_.data());
-  std::cout << "input image: " << h << " " << w << std::endl;
 
 
   // 使用加载的模型进行预测
   // 使用加载的模型进行预测
   predictor_->ZeroCopyRun();
   predictor_->ZeroCopyRun();
@@ -286,19 +295,24 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
                      result->score_map.shape[3],
                      result->score_map.shape[3],
                      CV_32FC1,
                      CV_32FC1,
                      result->score_map.data.data());
                      result->score_map.data.data());
-
+  int idx = 1;
+  int len_postprocess = inputs_.im_size_before_resize_.size();
   for (std::vector<std::string>::reverse_iterator iter =
   for (std::vector<std::string>::reverse_iterator iter =
            inputs_.reshape_order_.rbegin();
            inputs_.reshape_order_.rbegin();
        iter != inputs_.reshape_order_.rend();
        iter != inputs_.reshape_order_.rend();
        ++iter) {
        ++iter) {
     if (*iter == "padding") {
     if (*iter == "padding") {
-      auto padding_w = inputs_.im_size_before_padding_[0];
-      auto padding_h = inputs_.im_size_before_padding_[1];
+      auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
+      inputs_.im_size_before_resize_.pop_back();
+      auto padding_w = before_shape[0];
+      auto padding_h = before_shape[1];
       mask_label = mask_label(cv::Rect(0, 0, padding_w, padding_h));
       mask_label = mask_label(cv::Rect(0, 0, padding_w, padding_h));
       mask_score = mask_score(cv::Rect(0, 0, padding_w, padding_h));
       mask_score = mask_score(cv::Rect(0, 0, padding_w, padding_h));
     } else if (*iter == "resize") {
     } else if (*iter == "resize") {
-      auto resize_w = inputs_.im_size_before_resize_[0];
-      auto resize_h = inputs_.im_size_before_resize_[1];
+      auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
+      inputs_.im_size_before_resize_.pop_back();
+      auto resize_w = before_shape[0];
+      auto resize_h = before_shape[1];
       cv::resize(mask_label,
       cv::resize(mask_label,
                  mask_label,
                  mask_label,
                  cv::Size(resize_h, resize_w),
                  cv::Size(resize_h, resize_w),
@@ -312,6 +326,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
                  0,
                  0,
                  cv::INTER_NEAREST);
                  cv::INTER_NEAREST);
     }
     }
+    ++idx;
   }
   }
   result->label_map.data.assign(mask_label.begin<uint8_t>(),
   result->label_map.data.assign(mask_label.begin<uint8_t>(),
                                 mask_label.end<uint8_t>());
                                 mask_label.end<uint8_t>());

+ 3 - 1
deploy/cpp/src/segmenter.cpp

@@ -24,6 +24,7 @@
 
 
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
+DEFINE_bool(use_trt, false, "Infering with TensorRT");
 DEFINE_int32(gpu_id, 0, "GPU card id");
 DEFINE_int32(gpu_id, 0, "GPU card id");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
 DEFINE_string(image_list, "", "Path of test image list file");
@@ -44,7 +45,8 @@ int main(int argc, char** argv) {
 
 
   // 加载模型
   // 加载模型
   PaddleX::Model model;
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_gpu_id);
+  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id);
+
   auto colormap = PaddleX::GenerateColorMap(model.labels.size());
   auto colormap = PaddleX::GenerateColorMap(model.labels.size());
   // 进行预测
   // 进行预测
   if (FLAGS_image_list != "") {
   if (FLAGS_image_list != "") {

+ 7 - 10
deploy/cpp/src/transforms.cpp

@@ -56,8 +56,7 @@ float ResizeByShort::GenerateScale(const cv::Mat& im) {
 }
 }
 
 
 bool ResizeByShort::Run(cv::Mat* im, ImageBlob* data) {
 bool ResizeByShort::Run(cv::Mat* im, ImageBlob* data) {
-  data->im_size_before_resize_[0] = im->rows;
-  data->im_size_before_resize_[1] = im->cols;
+  data->im_size_before_resize_.push_back({im->rows, im->cols});
   data->reshape_order_.push_back("resize");
   data->reshape_order_.push_back("resize");
 
 
   float scale = GenerateScale(*im);
   float scale = GenerateScale(*im);
@@ -88,21 +87,21 @@ bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
 }
 }
 
 
 bool Padding::Run(cv::Mat* im, ImageBlob* data) {
 bool Padding::Run(cv::Mat* im, ImageBlob* data) {
-  data->im_size_before_padding_[0] = im->rows;
-  data->im_size_before_padding_[1] = im->cols;
+  data->im_size_before_resize_.push_back({im->rows, im->cols});
   data->reshape_order_.push_back("padding");
   data->reshape_order_.push_back("padding");
 
 
   int padding_w = 0;
   int padding_w = 0;
   int padding_h = 0;
   int padding_h = 0;
-  if (width_ > 0 & height_ > 0) {
+  if (width_ > 1 & height_ > 1) {
     padding_w = width_ - im->cols;
     padding_w = width_ - im->cols;
     padding_h = height_ - im->rows;
     padding_h = height_ - im->rows;
-  } else if (coarsest_stride_ > 0) {
+  } else if (coarsest_stride_ > 1) {
     padding_h =
     padding_h =
         ceil(im->rows * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
         ceil(im->rows * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
     padding_w =
     padding_w =
         ceil(im->cols * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
         ceil(im->cols * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
   }
   }
+
   if (padding_h < 0 || padding_w < 0) {
   if (padding_h < 0 || padding_w < 0) {
     std::cerr << "[Padding] Computed padding_h=" << padding_h
     std::cerr << "[Padding] Computed padding_h=" << padding_h
               << ", padding_w=" << padding_w
               << ", padding_w=" << padding_w
@@ -122,8 +121,7 @@ bool ResizeByLong::Run(cv::Mat* im, ImageBlob* data) {
               << std::endl;
               << std::endl;
     return false;
     return false;
   }
   }
-  data->im_size_before_resize_[0] = im->rows;
-  data->im_size_before_resize_[1] = im->cols;
+  data->im_size_before_resize_.push_back({im->rows, im->cols});
   data->reshape_order_.push_back("resize");
   data->reshape_order_.push_back("resize");
   int origin_w = im->cols;
   int origin_w = im->cols;
   int origin_h = im->rows;
   int origin_h = im->rows;
@@ -149,8 +147,7 @@ bool Resize::Run(cv::Mat* im, ImageBlob* data) {
               << std::endl;
               << std::endl;
     return false;
     return false;
   }
   }
-  data->im_size_before_resize_[0] = im->rows;
-  data->im_size_before_resize_[1] = im->cols;
+  data->im_size_before_resize_.push_back({im->rows, im->cols});
   data->reshape_order_.push_back("resize");
   data->reshape_order_.push_back("resize");
 
 
   cv::resize(
   cv::resize(

+ 1 - 1
docs/apis/load_model.md

@@ -34,7 +34,7 @@ pred_result = model.predict('./xiaoduxiong_ins_det/JPEGImages/WechatIMG114.jpeg'
 
 
 # 在验证集上进行评估
 # 在验证集上进行评估
 eval_reader = pdx.cv.datasets.CocoDetection(data_dir=data_dir,
 eval_reader = pdx.cv.datasets.CocoDetection(data_dir=data_dir,
-                                            ann_file=ann_file
+                                            ann_file=ann_file,
                                             transforms=model.eval_transforms)
                                             transforms=model.eval_transforms)
 eval_result = model.evaluate(eval_reader, batch_size=1)
 eval_result = model.evaluate(eval_reader, batch_size=1)
 ```
 ```

+ 17 - 11
docs/apis/models.md

@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000)
 #### 分类器训练函数接口
 #### 分类器训练函数接口
 
 
 > ```python
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 > ```
 > ```
 >
 >
 > **参数:**
 > **参数:**
@@ -39,6 +39,7 @@ paddlex.cls.ResNet50(num_classes=1000)
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
 #### 分类器评估函数接口
 #### 分类器评估函数接口
 
 
@@ -89,7 +90,7 @@ paddlex.cls.ResNet50(num_classes=1000)
 paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_masks=None, ignore_threshold=0.7, nms_score_threshold=0.01, nms_topk=1000, nms_keep_topk=100, nms_iou_threshold=0.45, label_smooth=False, train_random_shapes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608])
 paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_masks=None, ignore_threshold=0.7, nms_score_threshold=0.01, nms_topk=1000, nms_keep_topk=100, nms_iou_threshold=0.45, label_smooth=False, train_random_shapes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608])
 ```
 ```
 
 
-构建YOLOv3检测器,并实现其训练、评估和预测。  
+构建YOLOv3检测器,并实现其训练、评估和预测。 **注意在YOLOv3,num_classes不需要包含背景类,如目标包括human、dog两种,则num_classes设为2即可,这里与FasterRCNN/MaskRCNN有差别**
 
 
 **参数:**
 **参数:**
 
 
@@ -111,7 +112,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
 #### YOLOv3训练函数接口
 #### YOLOv3训练函数接口
 
 
 > ```python
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 > ```
 > ```
 >
 >
 > **参数:**
 > **参数:**
@@ -136,6 +137,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
 #### YOLOv3评估函数接口
 #### YOLOv3评估函数接口
 
 
@@ -177,12 +179,12 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
 
 
 ```
 ```
 
 
-构建FasterRCNN检测器,并实现其训练、评估和预测。  
+构建FasterRCNN检测器,并实现其训练、评估和预测。 **注意在FasterRCNN中,num_classes需要设置为类别数+背景类,如目标包括human、dog两种,则num_classes需设为3,多的一种为背景background类别**
 
 
 **参数:**
 **参数:**
 
 
 > - **num_classes** (int): 包含了背景类的类别数。默认为81。
 > - **num_classes** (int): 包含了背景类的类别数。默认为81。
-> - **backbone** (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50vd', 'ResNet101', 'ResNet101vd']。默认为'ResNet50'。
+> - **backbone** (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。
 > - **with_fpn** (bool): 是否使用FPN结构。默认为True。
 > - **with_fpn** (bool): 是否使用FPN结构。默认为True。
 > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
 > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
 > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
 > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
@@ -190,7 +192,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
 #### FasterRCNN训练函数接口
 #### FasterRCNN训练函数接口
 
 
 > ```python
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 >
 >
 > ```
 > ```
 >
 >
@@ -214,6 +216,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
 #### FasterRCNN评估函数接口
 #### FasterRCNN评估函数接口
 
 
@@ -257,12 +260,12 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
 
 
 ```
 ```
 
 
-构建MaskRCNN检测器,并实现其训练、评估和预测。  
+构建MaskRCNN检测器,并实现其训练、评估和预测。**注意在MaskRCNN中,num_classes需要设置为类别数+背景类,如目标包括human、dog两种,则num_classes需设为3,多的一种为背景background类别**
 
 
 **参数:**
 **参数:**
 
 
 > - **num_classes** (int): 包含了背景类的类别数。默认为81。
 > - **num_classes** (int): 包含了背景类的类别数。默认为81。
-> - **backbone** (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50vd', 'ResNet101', 'ResNet101vd']。默认为'ResNet50'。
+> - **backbone** (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。
 > - **with_fpn** (bool): 是否使用FPN结构。默认为True。
 > - **with_fpn** (bool): 是否使用FPN结构。默认为True。
 > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
 > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
 > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
 > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
@@ -270,7 +273,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
 #### MaskRCNN训练函数接口
 #### MaskRCNN训练函数接口
 
 
 > ```python
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 >
 >
 > ```
 > ```
 >
 >
@@ -294,6 +297,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
 #### MaskRCNN评估函数接口
 #### MaskRCNN评估函数接口
 
 
@@ -358,7 +362,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
 #### DeepLabv3训练函数接口
 #### DeepLabv3训练函数接口
 
 
 > ```python
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
 >
 >
 > ```
 > ```
 >
 >
@@ -380,6 +384,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
 #### DeepLabv3评估函数接口
 #### DeepLabv3评估函数接口
 
 
@@ -437,7 +442,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
 #### Unet训练函数接口
 #### Unet训练函数接口
 
 
 > ```python
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
 > ```
 > ```
 >
 >
 > **参数:**
 > **参数:**
@@ -458,6 +463,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
 #### Unet评估函数接口
 #### Unet评估函数接口
 
 

+ 3 - 1
docs/apis/transforms/cls_transforms.md

@@ -109,7 +109,9 @@ paddlex.cls.transforms.RandomDistort(brightness_range=0.9, brightness_prob=0.5,
 
 
 以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。  
 以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。  
 1. 对变换的操作顺序进行随机化操作。
 1. 对变换的操作顺序进行随机化操作。
-2. 按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。
+2. 按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。  
+
+【注意】该数据增强必须在数据增强Normalize之前使用。
 
 
 ### 参数
 ### 参数
 * **brightness_range** (float): 明亮度因子的范围。默认为0.9。
 * **brightness_range** (float): 明亮度因子的范围。默认为0.9。

+ 9 - 3
docs/apis/transforms/det_transforms.md

@@ -85,7 +85,9 @@ paddlex.det.transforms.RandomDistort(brightness_range=0.5, brightness_prob=0.5,
 
 
 以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。  
 以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。  
 1. 对变换的操作顺序进行随机化操作。
 1. 对变换的操作顺序进行随机化操作。
-2. 按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。
+2. 按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。  
+
+【注意】该数据增强必须在数据增强Normalize之前使用。
 
 
 ### 参数
 ### 参数
 * **brightness_range** (float): 明亮度因子的范围。默认为0.5。
 * **brightness_range** (float): 明亮度因子的范围。默认为0.5。
@@ -135,7 +137,9 @@ paddlex.det.transforms.RandomExpand(ratio=4., prob=0.5, fill_value=[123.675, 116
 ### 参数
 ### 参数
 * **ratio** (float): 图像扩张的最大比例。默认为4.0。
 * **ratio** (float): 图像扩张的最大比例。默认为4.0。
 * **prob** (float): 随机扩张的概率。默认为0.5。
 * **prob** (float): 随机扩张的概率。默认为0.5。
-* **fill_value** (list): 扩张图像的初始填充值(0-255)。默认为[123.675, 116.28, 103.53]。
+* **fill_value** (list): 扩张图像的初始填充值(0-255)。默认为[123.675, 116.28, 103.53]。  
+
+【注意】该数据增强必须在数据增强Resize、ResizeByShort之前使用。
 
 
 ## RandomCrop类
 ## RandomCrop类
 ```python
 ```python
@@ -152,7 +156,9 @@ paddlex.det.transforms.RandomCrop(aspect_ratio=[.5, 2.], thresholds=[.0, .1, .3,
     (4) 如果cover_all_box为True且存在真实标注框的IoU小于thresh,则继续第3步。
     (4) 如果cover_all_box为True且存在真实标注框的IoU小于thresh,则继续第3步。
     (5) 筛选出位于候选裁剪区域内的真实标注框,若有效框的个数为0,则继续第3步,否则进行第4步。
     (5) 筛选出位于候选裁剪区域内的真实标注框,若有效框的个数为0,则继续第3步,否则进行第4步。
 4. 换算有效真值标注框相对候选裁剪区域的位置坐标。
 4. 换算有效真值标注框相对候选裁剪区域的位置坐标。
-5. 换算有效分割区域相对候选裁剪区域的位置坐标。
+5. 换算有效分割区域相对候选裁剪区域的位置坐标。  
+
+【注意】该数据增强必须在数据增强Resize、ResizeByShort之前使用。
 
 
 ### 参数
 ### 参数
 * **aspect_ratio** (list): 裁剪后短边缩放比例的取值范围,以[min, max]形式表示。默认值为[.5, 2.]。
 * **aspect_ratio** (list): 裁剪后短边缩放比例的取值范围,以[min, max]形式表示。默认值为[.5, 2.]。

+ 3 - 1
docs/apis/transforms/seg_transforms.md

@@ -153,7 +153,9 @@ paddlex.seg.transforms.RandomDistort(brightness_range=0.5, brightness_prob=0.5,
 以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
 以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
 
 
 1.对变换的操作顺序进行随机化操作。
 1.对变换的操作顺序进行随机化操作。
-2.按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。
+2.按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。  
+
+【注意】该数据增强必须在数据增强Normalize之前使用。
 
 
 ### 参数
 ### 参数
 * **brightness_range** (float): 明亮度因子的范围。默认为0.5。
 * **brightness_range** (float): 明亮度因子的范围。默认为0.5。

+ 14 - 5
docs/deploy/deploy.md

@@ -7,20 +7,29 @@
 ### 导出inference模型
 ### 导出inference模型
 
 
 在服务端部署的模型需要首先将模型导出为inference格式模型,导出的模型将包括`__model__`、`__params__`和`model.yml`三个文名,分别为模型的网络结构,模型权重和模型的配置文件(包括数据预处理参数等等)。在安装完PaddleX后,在命令行终端使用如下命令导出模型到当前目录`inferece_model`下。
 在服务端部署的模型需要首先将模型导出为inference格式模型,导出的模型将包括`__model__`、`__params__`和`model.yml`三个文名,分别为模型的网络结构,模型权重和模型的配置文件(包括数据预处理参数等等)。在安装完PaddleX后,在命令行终端使用如下命令导出模型到当前目录`inferece_model`下。
+> 可直接下载小度熊分拣模型测试本文档的流程[xiaoduxiong_epoch_12.tar.gz](https://bj.bcebos.com/paddlex/models/xiaoduxiong_epoch_12.tar.gz)
 
 
-> 可直接下载垃圾检测模型测试本文档的流程[garbage_epoch_12.tar.gz](https://bj.bcebos.com/paddlex/models/garbage_epoch_12.tar.gz)
+```
+paddlex --export_inference --model_dir=./xiaoduxiong_epoch_12 --save_dir=./inference_model
+```
+
+使用TensorRT预测时,需指定模型的图像输入shape:[w,h]。
+**注**:
+- 分类模型请保持于训练时输入的shape一致。
+- 指定[w,h]时,w和h中间逗号隔开,不允许存在空格等其他字符
 
 
 ```
 ```
-paddlex --export_inference --model_dir=./garbage_epoch_12 --save_dir=./inference_model
+paddlex --export_inference --model_dir=./xiaoduxiong_epoch_12 --save_dir=./inference_model --fixed_input_shape=[640,960]
 ```
 ```
 
 
 ### Python部署
 ### Python部署
 PaddleX已经集成了基于Python的高性能预测接口,在安装PaddleX后,可参照如下代码示例,进行预测。相关的接口文档可参考[paddlex.deploy](apis/deploy.md)
 PaddleX已经集成了基于Python的高性能预测接口,在安装PaddleX后,可参照如下代码示例,进行预测。相关的接口文档可参考[paddlex.deploy](apis/deploy.md)
-> 点击下载测试图片 [garbage.bmp](https://bj.bcebos.com/paddlex/datasets/garbage.bmp)
+> 点击下载测试图片 [xiaoduxiong_test_image.tar.gz](https://bj.bcebos.com/paddlex/datasets/xiaoduxiong_test_image.tar.gz)
+
 ```
 ```
 import paddlex as pdx
 import paddlex as pdx
-predictorpdx.deploy.create_predictor('./inference_model')
-result = predictor.predict(image='garbage.bmp')
+predictor = pdx.deploy.create_predictor('./inference_model')
+result = predictor.predict(image='xiaoduxiong_test_image/JPEGImages/WeChatIMG110.jpeg')
 ```
 ```
 
 
 ### C++部署
 ### C++部署

+ 34 - 14
docs/deploy/deploy_cpp_linux.md

@@ -19,8 +19,18 @@
 
 
 ### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference
 ### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference
 
 
-PaddlePaddle C++ 预测库针对不同的`CPU`,`CUDA`,以及是否支持TensorRT,提供了不同的预编译版本,请根据实际情况下载:  [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html#id1)
+PaddlePaddle C++ 预测库针对不同的`CPU`,`CUDA`,以及是否支持TensorRT,提供了不同的预编译版本,目前PaddleX依赖于Paddle1.7版本,以下提供了多个不同版本的Paddle预测库:
 
 
+|  版本说明   | 预测库(1.7.2版本)  |
+|  ----  | ----  |
+| ubuntu14.04_cpu_avx_mkl  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-cpu-avx-mkl/fluid_inference.tgz) |
+| ubuntu14.04_cpu_avx_openblas  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-cpu-avx-openblas/fluid_inference.tgz) |
+| ubuntu14.04_cpu_noavx_openblas  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-cpu-noavx-openblas/fluid_inference.tgz) |
+| ubuntu14.04_cuda9.0_cudnn7_avx_mkl  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-gpu-cuda9-cudnn7-avx-mkl/fluid_inference.tgz) |
+| ubuntu14.04_cuda10.0_cudnn7_avx_mkl  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-gpu-cuda10-cudnn7-avx-mkl/fluid_inference.tgz ) |
+| ubuntu14.04_cuda10.1_cudnn7.6_avx_mkl_trt6  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-gpu-cuda10.1-cudnn7.6-avx-mkl-trt6%2Ffluid_inference.tgz) |
+
+更多和更新的版本,请根据实际情况下载:  [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/windows_cpp_inference.html#id1)
 
 
 下载并解压后`/root/projects/fluid_inference`目录包含内容为:
 下载并解压后`/root/projects/fluid_inference`目录包含内容为:
 ```
 ```
@@ -40,17 +50,24 @@ fluid_inference
 编译`cmake`的命令在`scripts/build.sh`中,请根据实际情况修改主要参数,其主要内容说明如下:
 编译`cmake`的命令在`scripts/build.sh`中,请根据实际情况修改主要参数,其主要内容说明如下:
 ```
 ```
 # 是否使用GPU(即是否使用 CUDA)
 # 是否使用GPU(即是否使用 CUDA)
-WITH_GPU=ON
+WITH_GPU=OFF
+# 使用MKL or openblas
+WITH_MKL=ON
 # 是否集成 TensorRT(仅WITH_GPU=ON 有效)
 # 是否集成 TensorRT(仅WITH_GPU=ON 有效)
 WITH_TENSORRT=OFF
 WITH_TENSORRT=OFF
-# 上一步下载的 Paddle 预测库路径
-PADDLE_DIR=/root/projects/deps/fluid_inference/
+# TensorRT 的lib路径
+TENSORRT_DIR=/path/to/TensorRT/
+# Paddle 预测库路径
+PADDLE_DIR=/path/to/fluid_inference/
+# Paddle 的预测库是否使用静态库来编译
+# 使用TensorRT时,Paddle的预测库通常为动态库
+WITH_STATIC_LIB=ON
 # CUDA 的 lib 路径
 # CUDA 的 lib 路径
-CUDA_LIB=/usr/local/cuda/lib64/
+CUDA_LIB=/path/to/cuda/lib/
 # CUDNN 的 lib 路径
 # CUDNN 的 lib 路径
-CUDNN_LIB=/usr/local/cudnn/lib64/
+CUDNN_LIB=/path/to/cudnn/lib/
 
 
-# OPENCV 路径, 如果使用自带预编译版本可不设置
+# OPENCV 路径, 如果使用自带预编译版本可不修改
 OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/
 OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/
 sh $(pwd)/scripts/bootstrap.sh
 sh $(pwd)/scripts/bootstrap.sh
 
 
@@ -60,8 +77,11 @@ mkdir -p build
 cd build
 cd build
 cmake .. \
 cmake .. \
     -DWITH_GPU=${WITH_GPU} \
     -DWITH_GPU=${WITH_GPU} \
+    -DWITH_MKL=${WITH_MKL} \
     -DWITH_TENSORRT=${WITH_TENSORRT} \
     -DWITH_TENSORRT=${WITH_TENSORRT} \
+    -DTENSORRT_DIR=${TENSORRT_DIR} \
     -DPADDLE_DIR=${PADDLE_DIR} \
     -DPADDLE_DIR=${PADDLE_DIR} \
+    -DWITH_STATIC_LIB=${WITH_STATIC_LIB} \
     -DCUDA_LIB=${CUDA_LIB} \
     -DCUDA_LIB=${CUDA_LIB} \
     -DCUDNN_LIB=${CUDNN_LIB} \
     -DCUDNN_LIB=${CUDNN_LIB} \
     -DOPENCV_DIR=${OPENCV_DIR}
     -DOPENCV_DIR=${OPENCV_DIR}
@@ -83,19 +103,20 @@ make
 | image  | 要预测的图片文件路径 |
 | image  | 要预测的图片文件路径 |
 | image_list  | 按行存储图片路径的.txt文件 |
 | image_list  | 按行存储图片路径的.txt文件 |
 | use_gpu  | 是否使用 GPU 预测, 支持值为0或1(默认值为0) |
 | use_gpu  | 是否使用 GPU 预测, 支持值为0或1(默认值为0) |
+| use_trt  | 是否使用 TensorTr 预测, 支持值为0或1(默认值为0) |
 | gpu_id  | GPU 设备ID, 默认值为0 |
 | gpu_id  | GPU 设备ID, 默认值为0 |
 | save_dir | 保存可视化结果的路径, 默认值为"output",classfier无该参数 |
 | save_dir | 保存可视化结果的路径, 默认值为"output",classfier无该参数 |
 
 
 ## 样例
 ## 样例
 
 
-可使用[垃圾检测模型](deploy.md#导出inference模型)中生成的`inference_model`模型和测试图片进行预测。
+可使用[小度熊识别模型](deploy.md#导出inference模型)中导出的`inference_model`和测试图片进行预测。
 
 
 `样例一`:
 `样例一`:
 
 
-不使用`GPU`测试图片 `/path/to/garbage.bmp`  
+不使用`GPU`测试图片 `/path/to/xiaoduxiong.jpeg`  
 
 
 ```shell
 ```shell
-./build/detector --model_dir=/path/to/inference_model --image=/path/to/garbage.bmp --save_dir=output
+./build/detector --model_dir=/path/to/inference_model --image=/path/to/xiaoduxiong.jpeg --save_dir=output
 ```
 ```
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
 
 
@@ -104,13 +125,12 @@ make
 
 
 使用`GPU`预测多个图片`/path/to/image_list.txt`,image_list.txt内容的格式如下:
 使用`GPU`预测多个图片`/path/to/image_list.txt`,image_list.txt内容的格式如下:
 ```
 ```
-/path/to/images/garbage1.jpeg
-/path/to/images/garbage2.jpeg
+/path/to/images/xiaoduxiong1.jpeg
+/path/to/images/xiaoduxiong2.jpeg
 ...
 ...
-/path/to/images/garbagen.jpeg
+/path/to/images/xiaoduxiongn.jpeg
 ```
 ```
 ```shell
 ```shell
 ./build/detector --model_dir=/path/to/models/inference_model --image_list=/root/projects/images_list.txt --use_gpu=1 --save_dir=output
 ./build/detector --model_dir=/path/to/models/inference_model --image_list=/root/projects/images_list.txt --use_gpu=1 --save_dir=output
 ```
 ```
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
-

+ 18 - 8
docs/deploy/deploy_cpp_win_vs2019.md

@@ -27,7 +27,18 @@ git clone https://github.com/PaddlePaddle/PaddleX.git
 
 
 ### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference
 ### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference
 
 
-PaddlePaddle C++ 预测库针对不同的`CPU`和`CUDA`版本提供了不同的预编译版本,请根据实际情况下载:  [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/windows_cpp_inference.html)
+PaddlePaddle C++ 预测库针对不同的`CPU`,`CUDA`,以及是否支持TensorRT,提供了不同的预编译版本,目前PaddleX依赖于Paddle1.7版本,以下提供了多个不同版本的Paddle预测库:
+
+|  版本说明   | 预测库(1.7.2版本)  | 编译器 | 构建工具| cuDNN | CUDA
+|  ----  |  ----  |  ----  |  ----  | ---- | ---- |
+| cpu_avx_mkl  | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/mkl/cpu/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 |
+| cpu_avx_openblas  | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/open/cpu/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 |
+| cuda9.0_cudnn7_avx_mkl  | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/mkl/post97/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 | 7.4.1 | 9.0 |
+| cuda9.0_cudnn7_avx_openblas  | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/open/post97/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 | 7.4.1 | 9.0 |
+| cuda10.0_cudnn7_avx_mkl  | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/mkl/post107/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 | 7.5.0 | 9.0 |
+
+
+更多和更新的版本,请根据实际情况下载:  [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html#id1)
 
 
 解压后`D:\projects\fluid_inference*\`目录下主要包含的内容为:
 解压后`D:\projects\fluid_inference*\`目录下主要包含的内容为:
 ```
 ```
@@ -109,14 +120,14 @@ cd D:\projects\PaddleX\deploy\cpp\out\build\x64-Release
 
 
 ## 样例
 ## 样例
 
 
-可使用[垃圾检测模型](deploy.md#导出inference模型)中生成的`inference_model`模型和测试图片进行预测。
+可使用[小度熊识别模型](deploy.md#导出inference模型)中导出的`inference_model`和测试图片进行预测。
 
 
 `样例一`:
 `样例一`:
 
 
-不使用`GPU`测试图片  `\\path\\to\\garbage.bmp`  
+不使用`GPU`测试图片  `\\path\\to\\xiaoduxiong.jpeg`  
 
 
 ```shell
 ```shell
-.\detector --model_dir=\\path\\to\\inference_model --image=D:\\images\\garbage.bmp --save_dir=output
+.\detector --model_dir=\\path\\to\\inference_model --image=D:\\images\\xiaoduxiong.jpeg --save_dir=output
 
 
 ```
 ```
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
@@ -126,13 +137,12 @@ cd D:\projects\PaddleX\deploy\cpp\out\build\x64-Release
 
 
 使用`GPU`预测多个图片`\\path\\to\\image_list.txt`,image_list.txt内容的格式如下:
 使用`GPU`预测多个图片`\\path\\to\\image_list.txt`,image_list.txt内容的格式如下:
 ```
 ```
-\\path\\to\\images\\garbage1.jpeg
-\\path\\to\\images\\garbage2.jpeg
+\\path\\to\\images\\xiaoduxiong1.jpeg
+\\path\\to\\images\\xiaoduxiong2.jpeg
 ...
 ...
-\\path\\to\\images\\garbagen.jpeg
+\\path\\to\\images\\xiaoduxiongn.jpeg
 ```
 ```
 ```shell
 ```shell
 .\detector --model_dir=\\path\\to\\inference_model --image_list=\\path\\to\\images_list.txt --use_gpu=1 --save_dir=output
 .\detector --model_dir=\\path\\to\\inference_model --image_list=\\path\\to\\images_list.txt --use_gpu=1 --save_dir=output
 ```
 ```
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
-

+ 1 - 1
docs/index.rst

@@ -22,7 +22,7 @@ PaddleX是基于飞桨技术生态的深度学习全流程开发工具。具备
    client_use.md
    client_use.md
    FAQ.md
    FAQ.md
 
 
-* PaddleX版本: v0.1.6
+* PaddleX版本: v0.1.7
 * 项目官网: http://www.paddlepaddle.org.cn/paddle/paddlex  
 * 项目官网: http://www.paddlepaddle.org.cn/paddle/paddlex  
 * 项目GitHub: https://github.com/PaddlePaddle/PaddleX/tree/develop  
 * 项目GitHub: https://github.com/PaddlePaddle/PaddleX/tree/develop  
 * 官方QQ用户群: 1045148026  
 * 官方QQ用户群: 1045148026  

+ 19 - 7
paddlex/__init__.py

@@ -13,27 +13,39 @@
 # limitations under the License.
 # limitations under the License.
 
 
 from __future__ import absolute_import
 from __future__ import absolute_import
+import os
+if 'FLAGS_eager_delete_tensor_gb' not in os.environ:
+    os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
+if 'FLAGS_allocator_strategy' not in os.environ:
+    os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
+if "CUDA_VISIBLE_DEVICES" in os.environ:
+    if os.environ["CUDA_VISIBLE_DEVICES"].count("-1") > 0:
+        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 from .utils.utils import get_environ_info
 from .utils.utils import get_environ_info
 from . import cv
 from . import cv
 from . import det
 from . import det
 from . import seg
 from . import seg
 from . import cls
 from . import cls
 from . import slim
 from . import slim
+from . import tools
 
 
 try:
 try:
     import pycocotools
     import pycocotools
 except:
 except:
-    print("[WARNING] pycocotools is not installed, detection model is not available now.")
-    print("[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md")
-
-import paddlehub as hub
-if hub.version.hub_version < '1.6.2':
-    raise Exception("[ERROR] paddlehub >= 1.6.2 is required")
+    print(
+        "[WARNING] pycocotools is not installed, detection model is not available now."
+    )
+    print(
+        "[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md"
+    )
 
 
+#import paddlehub as hub
+#if hub.version.hub_version < '1.6.2':
+#    raise Exception("[ERROR] paddlehub >= 1.6.2 is required")
 
 
 env_info = get_environ_info()
 env_info = get_environ_info()
 load_model = cv.models.load_model
 load_model = cv.models.load_model
 datasets = cv.datasets
 datasets = cv.datasets
 
 
 log_level = 2
 log_level = 2
-__version__ = '0.1.6.github'
+__version__ = '0.1.7.github'

+ 20 - 2
paddlex/command.py

@@ -29,7 +29,11 @@ def arg_parser():
         action="store_true",
         action="store_true",
         default=False,
         default=False,
         help="export inference model for C++/Python deployment")
         help="export inference model for C++/Python deployment")
-
+    parser.add_argument(
+        "--fixed_input_shape",
+        "-fs",
+        default=None,
+        help="export inference model with fixed input shape:[w,h]")
     return parser
     return parser
 
 
 
 
@@ -53,9 +57,23 @@ def main():
     if args.export_inference:
     if args.export_inference:
         assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
         assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
         assert args.save_dir is not None, "--save_dir should be defined to save inference model"
         assert args.save_dir is not None, "--save_dir should be defined to save inference model"
-        model = pdx.load_model(args.model_dir)
+        fixed_input_shape = eval(args.fixed_input_shape)
+        assert len(
+            fixed_input_shape) == 2, "len of fixed input shape must == 2"
+
+        model = pdx.load_model(args.model_dir, fixed_input_shape)
         model.export_inference_model(args.save_dir)
         model.export_inference_model(args.save_dir)
 
 
+    if args.export_onnx:
+        assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
+        assert args.save_dir is not None, "--save_dir should be defined to save onnx model"
+        fixed_input_shape = eval(args.fixed_input_shape)
+        assert len(
+            fixed_input_shape) == 2, "len of fixed input shape must == 2"
+
+        model = pdx.load_model(args.model_dir, fixed_input_shape)
+        model.export_onnx_model(args.save_dir)
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     main()
     main()

+ 7 - 3
paddlex/cv/datasets/coco.py

@@ -100,7 +100,7 @@ class CocoDetection(VOCDetection):
             gt_score = np.ones((num_bbox, 1), dtype=np.float32)
             gt_score = np.ones((num_bbox, 1), dtype=np.float32)
             is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
             is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
             difficult = np.zeros((num_bbox, 1), dtype=np.int32)
             difficult = np.zeros((num_bbox, 1), dtype=np.int32)
-            gt_poly = [None] * num_bbox
+            gt_poly = None
 
 
             for i, box in enumerate(bboxes):
             for i, box in enumerate(bboxes):
                 catid = box['category_id']
                 catid = box['category_id']
@@ -108,20 +108,24 @@ class CocoDetection(VOCDetection):
                 gt_bbox[i, :] = box['clean_bbox']
                 gt_bbox[i, :] = box['clean_bbox']
                 is_crowd[i][0] = box['iscrowd']
                 is_crowd[i][0] = box['iscrowd']
                 if 'segmentation' in box:
                 if 'segmentation' in box:
+                    if gt_poly is None:
+                        gt_poly = [None] * num_bbox
                     gt_poly[i] = box['segmentation']
                     gt_poly[i] = box['segmentation']
 
 
             im_info = {
             im_info = {
                 'im_id': np.array([img_id]).astype('int32'),
                 'im_id': np.array([img_id]).astype('int32'),
-                'origin_shape': np.array([im_h, im_w]).astype('int32'),
+                'image_shape': np.array([im_h, im_w]).astype('int32'),
             }
             }
             label_info = {
             label_info = {
                 'is_crowd': is_crowd,
                 'is_crowd': is_crowd,
                 'gt_class': gt_class,
                 'gt_class': gt_class,
                 'gt_bbox': gt_bbox,
                 'gt_bbox': gt_bbox,
                 'gt_score': gt_score,
                 'gt_score': gt_score,
-                'gt_poly': gt_poly,
                 'difficult': difficult
                 'difficult': difficult
             }
             }
+            if gt_poly is not None:
+                label_info['gt_poly'] = gt_poly
+
             coco_rec = (im_info, label_info)
             coco_rec = (im_info, label_info)
             self.file_list.append([im_fname, coco_rec])
             self.file_list.append([im_fname, coco_rec])
 
 

+ 8 - 0
paddlex/cv/datasets/dataset.py

@@ -254,3 +254,11 @@ class Dataset:
             buffer_size=self.buffer_size,
             buffer_size=self.buffer_size,
             batch_size=batch_size,
             batch_size=batch_size,
             drop_last=drop_last)
             drop_last=drop_last)
+
+    def set_num_samples(self, num_samples):
+        if num_samples > len(self.file_list):
+            logging.warning(
+                "You want set num_samples to {}, but your dataset only has {} samples, so we will keep your dataset num_samples as {}"
+                .format(num_samples, len(self.file_list), len(self.file_list)))
+            num_samples = len(self.file_list)
+        self.num_samples = num_samples

+ 1 - 1
paddlex/cv/datasets/easydata_det.py

@@ -143,7 +143,7 @@ class EasyDataDet(VOCDetection):
                     ann_ct += 1
                     ann_ct += 1
                 im_info = {
                 im_info = {
                     'im_id': im_id,
                     'im_id': im_id,
-                    'origin_shape': np.array([im_h, im_w]).astype('int32'),
+                    'image_shape': np.array([im_h, im_w]).astype('int32'),
                 }
                 }
                 label_info = {
                 label_info = {
                     'is_crowd': is_crowd,
                     'is_crowd': is_crowd,

+ 1 - 2
paddlex/cv/datasets/voc.py

@@ -146,14 +146,13 @@ class VOCDetection(Dataset):
 
 
                 im_info = {
                 im_info = {
                     'im_id': im_id,
                     'im_id': im_id,
-                    'origin_shape': np.array([im_h, im_w]).astype('int32'),
+                    'image_shape': np.array([im_h, im_w]).astype('int32'),
                 }
                 }
                 label_info = {
                 label_info = {
                     'is_crowd': is_crowd,
                     'is_crowd': is_crowd,
                     'gt_class': gt_class,
                     'gt_class': gt_class,
                     'gt_bbox': gt_bbox,
                     'gt_bbox': gt_bbox,
                     'gt_score': gt_score,
                     'gt_score': gt_score,
-                    'gt_poly': [],
                     'difficult': difficult
                     'difficult': difficult
                 }
                 }
                 voc_rec = (im_info, label_info)
                 voc_rec = (im_info, label_info)

+ 59 - 23
paddlex/cv/models/base.py

@@ -70,6 +70,8 @@ class BaseAPI:
         self.sync_bn = False
         self.sync_bn = False
         # 当前模型状态
         # 当前模型状态
         self.status = 'Normal'
         self.status = 'Normal'
+        # 已完成迭代轮数,为恢复训练时的起始轮数
+        self.completed_epochs = 0
 
 
     def _get_single_card_bs(self, batch_size):
     def _get_single_card_bs(self, batch_size):
         if batch_size % len(self.places) == 0:
         if batch_size % len(self.places) == 0:
@@ -182,35 +184,62 @@ class BaseAPI:
                        fuse_bn=False,
                        fuse_bn=False,
                        save_dir='.',
                        save_dir='.',
                        sensitivities_file=None,
                        sensitivities_file=None,
-                       eval_metric_loss=0.05):
-        pretrain_dir = osp.join(save_dir, 'pretrain')
-        if not os.path.isdir(pretrain_dir):
-            if os.path.exists(pretrain_dir):
-                os.remove(pretrain_dir)
-            os.makedirs(pretrain_dir)
-        if hasattr(self, 'backbone'):
-            backbone = self.backbone
-        else:
-            backbone = self.__class__.__name__
-        pretrain_weights = get_pretrain_weights(
-            pretrain_weights, self.model_type, backbone, pretrain_dir)
+                       eval_metric_loss=0.05,
+                       resume_checkpoint=None):
+        if not resume_checkpoint:
+            pretrain_dir = osp.join(save_dir, 'pretrain')
+            if not os.path.isdir(pretrain_dir):
+                if os.path.exists(pretrain_dir):
+                    os.remove(pretrain_dir)
+                os.makedirs(pretrain_dir)
+            if hasattr(self, 'backbone'):
+                backbone = self.backbone
+            else:
+                backbone = self.__class__.__name__
+            pretrain_weights = get_pretrain_weights(
+                pretrain_weights, self.model_type, backbone, pretrain_dir)
         if startup_prog is None:
         if startup_prog is None:
             startup_prog = fluid.default_startup_program()
             startup_prog = fluid.default_startup_program()
         self.exe.run(startup_prog)
         self.exe.run(startup_prog)
-        if pretrain_weights is not None:
+        if resume_checkpoint:
             logging.info(
             logging.info(
-                "Load pretrain weights from {}.".format(pretrain_weights))
+                "Resume checkpoint from {}.".format(resume_checkpoint),
+                use_color=True)
+            paddlex.utils.utils.load_pretrain_weights(
+                self.exe, self.train_prog, resume_checkpoint, resume=True)
+            if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
+                raise Exception(
+                    "There's not model.yml in {}".format(resume_checkpoint))
+            with open(osp.join(resume_checkpoint, "model.yml")) as f:
+                info = yaml.load(f.read(), Loader=yaml.Loader)
+                self.completed_epochs = info['completed_epochs']
+        elif pretrain_weights is not None:
+            logging.info(
+                "Load pretrain weights from {}.".format(pretrain_weights),
+                use_color=True)
             paddlex.utils.utils.load_pretrain_weights(
             paddlex.utils.utils.load_pretrain_weights(
                 self.exe, self.train_prog, pretrain_weights, fuse_bn)
                 self.exe, self.train_prog, pretrain_weights, fuse_bn)
         # 进行裁剪
         # 进行裁剪
         if sensitivities_file is not None:
         if sensitivities_file is not None:
+            import paddleslim
             from .slim.prune_config import get_sensitivities
             from .slim.prune_config import get_sensitivities
             sensitivities_file = get_sensitivities(sensitivities_file, self,
             sensitivities_file = get_sensitivities(sensitivities_file, self,
                                                    save_dir)
                                                    save_dir)
             from .slim.prune import get_params_ratios, prune_program
             from .slim.prune import get_params_ratios, prune_program
+            logging.info(
+                "Start to prune program with eval_metric_loss = {}".format(
+                    eval_metric_loss),
+                use_color=True)
+            origin_flops = paddleslim.analysis.flops(self.test_prog)
             prune_params_ratios = get_params_ratios(
             prune_params_ratios = get_params_ratios(
                 sensitivities_file, eval_metric_loss=eval_metric_loss)
                 sensitivities_file, eval_metric_loss=eval_metric_loss)
             prune_program(self, prune_params_ratios)
             prune_program(self, prune_params_ratios)
+            current_flops = paddleslim.analysis.flops(self.test_prog)
+            remaining_ratio = current_flops / origin_flops
+            logging.info(
+                "Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}"
+                .format(origin_flops, current_flops, remaining_ratio),
+                use_color=True)
             self.status = 'Prune'
             self.status = 'Prune'
 
 
     def get_model_info(self):
     def get_model_info(self):
@@ -248,6 +277,7 @@ class BaseAPI:
                     name = op.__class__.__name__
                     name = op.__class__.__name__
                     attr = op.__dict__
                     attr = op.__dict__
                     info['Transforms'].append({name: attr})
                     info['Transforms'].append({name: attr})
+        info['completed_epochs'] = self.completed_epochs
         return info
         return info
 
 
     def save_model(self, save_dir):
     def save_model(self, save_dir):
@@ -255,7 +285,10 @@ class BaseAPI:
             if osp.exists(save_dir):
             if osp.exists(save_dir):
                 os.remove(save_dir)
                 os.remove(save_dir)
             os.makedirs(save_dir)
             os.makedirs(save_dir)
-        fluid.save(self.train_prog, osp.join(save_dir, 'model'))
+        if self.train_prog is not None:
+            fluid.save(self.train_prog, osp.join(save_dir, 'model'))
+        else:
+            fluid.save(self.test_prog, osp.join(save_dir, 'model'))
         model_info = self.get_model_info()
         model_info = self.get_model_info()
         model_info['status'] = self.status
         model_info['status'] = self.status
         with open(
         with open(
@@ -317,11 +350,11 @@ class BaseAPI:
         model_info['_ModelInputsOutputs']['test_outputs'] = [
         model_info['_ModelInputsOutputs']['test_outputs'] = [
             [k, v.name] for k, v in self.test_outputs.items()
             [k, v.name] for k, v in self.test_outputs.items()
         ]
         ]
-
         with open(
         with open(
                 osp.join(save_dir, 'model.yml'), encoding='utf-8',
                 osp.join(save_dir, 'model.yml'), encoding='utf-8',
                 mode='w') as f:
                 mode='w') as f:
             yaml.dump(model_info, f)
             yaml.dump(model_info, f)
+
         # 模型保存成功的标志
         # 模型保存成功的标志
         open(osp.join(save_dir, '.success'), 'w').close()
         open(osp.join(save_dir, '.success'), 'w').close()
         logging.info(
         logging.info(
@@ -404,8 +437,9 @@ class BaseAPI:
             earlystop = EarlyStop(early_stop_patience, thresh)
             earlystop = EarlyStop(early_stop_patience, thresh)
         best_accuracy_key = ""
         best_accuracy_key = ""
         best_accuracy = -1.0
         best_accuracy = -1.0
-        best_model_epoch = 1
-        for i in range(num_epochs):
+        best_model_epoch = -1
+        start_epoch = self.completed_epochs
+        for i in range(start_epoch, num_epochs):
             records = list()
             records = list()
             step_start_time = time.time()
             step_start_time = time.time()
             epoch_start_time = time.time()
             epoch_start_time = time.time()
@@ -477,7 +511,7 @@ class BaseAPI:
                 current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
                 current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
                 if not osp.isdir(current_save_dir):
                 if not osp.isdir(current_save_dir):
                     os.makedirs(current_save_dir)
                     os.makedirs(current_save_dir)
-                if eval_dataset is not None:
+                if eval_dataset is not None and eval_dataset.num_samples > 0:
                     self.eval_metrics, self.eval_details = self.evaluate(
                     self.eval_metrics, self.eval_details = self.evaluate(
                         eval_dataset=eval_dataset,
                         eval_dataset=eval_dataset,
                         batch_size=eval_batch_size,
                         batch_size=eval_batch_size,
@@ -485,6 +519,7 @@ class BaseAPI:
                         return_details=True)
                         return_details=True)
                     logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
                     logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
                         i + 1, dict2str(self.eval_metrics)))
                         i + 1, dict2str(self.eval_metrics)))
+                    self.completed_epochs += 1
                     # 保存最优模型
                     # 保存最优模型
                     best_accuracy_key = list(self.eval_metrics.keys())[0]
                     best_accuracy_key = list(self.eval_metrics.keys())[0]
                     current_accuracy = self.eval_metrics[best_accuracy_key]
                     current_accuracy = self.eval_metrics[best_accuracy_key]
@@ -509,10 +544,11 @@ class BaseAPI:
                 self.save_model(save_dir=current_save_dir)
                 self.save_model(save_dir=current_save_dir)
                 time_eval_one_epoch = time.time() - eval_epoch_start_time
                 time_eval_one_epoch = time.time() - eval_epoch_start_time
                 eval_epoch_start_time = time.time()
                 eval_epoch_start_time = time.time()
-                logging.info(
-                    'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
-                    .format(best_model_epoch, best_accuracy_key,
-                            best_accuracy))
+                if best_model_epoch > 0:
+                    logging.info(
+                        'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
+                        .format(best_model_epoch, best_accuracy_key,
+                                best_accuracy))
                 if eval_dataset is not None and early_stop:
                 if eval_dataset is not None and early_stop:
                     if earlystop(current_accuracy):
                     if earlystop(current_accuracy):
                         break
                         break

+ 15 - 5
paddlex/cv/models/classifier.py

@@ -46,10 +46,18 @@ class BaseClassifier(BaseAPI):
         self.model_name = model_name
         self.model_name = model_name
         self.labels = None
         self.labels = None
         self.num_classes = num_classes
         self.num_classes = num_classes
+        self.fixed_input_shape = None
 
 
     def build_net(self, mode='train'):
     def build_net(self, mode='train'):
-        image = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            image = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            image = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if mode != 'test':
         if mode != 'test':
             label = fluid.data(dtype='int64', shape=[None, 1], name='label')
             label = fluid.data(dtype='int64', shape=[None, 1], name='label')
         model = getattr(paddlex.cv.nets, str.lower(self.model_name))
         model = getattr(paddlex.cv.nets, str.lower(self.model_name))
@@ -104,7 +112,8 @@ class BaseClassifier(BaseAPI):
               sensitivities_file=None,
               sensitivities_file=None,
               eval_metric_loss=0.05,
               eval_metric_loss=0.05,
               early_stop=False,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
         """训练。
 
 
         Args:
         Args:
@@ -129,6 +138,7 @@ class BaseClassifier(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
         Raises:
         Raises:
             ValueError: 模型从inference model进行加载。
             ValueError: 模型从inference model进行加载。
@@ -152,8 +162,8 @@ class BaseClassifier(BaseAPI):
             pretrain_weights=pretrain_weights,
             pretrain_weights=pretrain_weights,
             save_dir=save_dir,
             save_dir=save_dir,
             sensitivities_file=sensitivities_file,
             sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss)
-
+            eval_metric_loss=eval_metric_loss,
+            resume_checkpoint=resume_checkpoint)
         # 训练
         # 训练
         self.train_loop(
         self.train_loop(
             num_epochs=num_epochs,
             num_epochs=num_epochs,

+ 16 - 11
paddlex/cv/models/deeplabv3p.py

@@ -48,7 +48,6 @@ class DeepLabv3p(BaseAPI):
             自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1,
             自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1,
             即平时使用的交叉熵损失函数。
             即平时使用的交叉熵损失函数。
         ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
         ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
-
     Raises:
     Raises:
         ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
         ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
         ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25',
         ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25',
@@ -118,6 +117,7 @@ class DeepLabv3p(BaseAPI):
         self.enable_decoder = enable_decoder
         self.enable_decoder = enable_decoder
         self.labels = None
         self.labels = None
         self.sync_bn = True
         self.sync_bn = True
+        self.fixed_input_shape = None
 
 
     def _get_backbone(self, backbone):
     def _get_backbone(self, backbone):
         def mobilenetv2(backbone):
         def mobilenetv2(backbone):
@@ -182,7 +182,8 @@ class DeepLabv3p(BaseAPI):
             use_bce_loss=self.use_bce_loss,
             use_bce_loss=self.use_bce_loss,
             use_dice_loss=self.use_dice_loss,
             use_dice_loss=self.use_dice_loss,
             class_weight=self.class_weight,
             class_weight=self.class_weight,
-            ignore_index=self.ignore_index)
+            ignore_index=self.ignore_index,
+            fixed_input_shape=self.fixed_input_shape)
         inputs = model.generate_inputs()
         inputs = model.generate_inputs()
         model_out = model.build_net(inputs)
         model_out = model.build_net(inputs)
         outputs = OrderedDict()
         outputs = OrderedDict()
@@ -233,7 +234,8 @@ class DeepLabv3p(BaseAPI):
               sensitivities_file=None,
               sensitivities_file=None,
               eval_metric_loss=0.05,
               eval_metric_loss=0.05,
               early_stop=False,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
         """训练。
 
 
         Args:
         Args:
@@ -257,6 +259,7 @@ class DeepLabv3p(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
         Raises:
         Raises:
             ValueError: 模型从inference model进行加载。
             ValueError: 模型从inference model进行加载。
@@ -283,7 +286,8 @@ class DeepLabv3p(BaseAPI):
             pretrain_weights=pretrain_weights,
             pretrain_weights=pretrain_weights,
             save_dir=save_dir,
             save_dir=save_dir,
             sensitivities_file=sensitivities_file,
             sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss)
+            eval_metric_loss=eval_metric_loss,
+            resume_checkpoint=resume_checkpoint)
         # 训练
         # 训练
         self.train_loop(
         self.train_loop(
             num_epochs=num_epochs,
             num_epochs=num_epochs,
@@ -396,13 +400,14 @@ class DeepLabv3p(BaseAPI):
             fetch_list=list(self.test_outputs.values()))
             fetch_list=list(self.test_outputs.values()))
         pred = result[0]
         pred = result[0]
         pred = np.squeeze(pred).astype('uint8')
         pred = np.squeeze(pred).astype('uint8')
-        keys = list(im_info.keys())
-        for k in keys[::-1]:
-            if k == 'shape_before_resize':
-                h, w = im_info[k][0], im_info[k][1]
+        for info in im_info[::-1]:
+            if info[0] == 'resize':
+                w, h = info[1][1], info[1][0]
                 pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
                 pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
-            elif k == 'shape_before_padding':
-                h, w = im_info[k][0], im_info[k][1]
+            elif info[0] == 'padding':
+                w, h = info[1][1], info[1][0]
                 pred = pred[0:h, 0:w]
                 pred = pred[0:h, 0:w]
-
+            else:
+                raise Exception("Unexpected info '{}' in im_info".format(
+                    info[0]))
         return {'label_map': pred, 'score_map': result[1]}
         return {'label_map': pred, 'score_map': result[1]}

+ 13 - 7
paddlex/cv/models/faster_rcnn.py

@@ -32,7 +32,7 @@ class FasterRCNN(BaseAPI):
     Args:
     Args:
         num_classes (int): 包含了背景类的类别数。默认为81。
         num_classes (int): 包含了背景类的类别数。默认为81。
         backbone (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50',
         backbone (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50',
-            'ResNet50vd', 'ResNet101', 'ResNet101vd']。默认为'ResNet50'。
+            'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。
         with_fpn (bool): 是否使用FPN结构。默认为True。
         with_fpn (bool): 是否使用FPN结构。默认为True。
         aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
         aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
         anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
         anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
@@ -47,7 +47,7 @@ class FasterRCNN(BaseAPI):
         self.init_params = locals()
         self.init_params = locals()
         super(FasterRCNN, self).__init__('detector')
         super(FasterRCNN, self).__init__('detector')
         backbones = [
         backbones = [
-            'ResNet18', 'ResNet50', 'ResNet50vd', 'ResNet101', 'ResNet101vd'
+            'ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd'
         ]
         ]
         assert backbone in backbones, "backbone should be one of {}".format(
         assert backbone in backbones, "backbone should be one of {}".format(
             backbones)
             backbones)
@@ -57,6 +57,7 @@ class FasterRCNN(BaseAPI):
         self.aspect_ratios = aspect_ratios
         self.aspect_ratios = aspect_ratios
         self.anchor_sizes = anchor_sizes
         self.anchor_sizes = anchor_sizes
         self.labels = None
         self.labels = None
+        self.fixed_input_shape = None
 
 
     def _get_backbone(self, backbone_name):
     def _get_backbone(self, backbone_name):
         norm_type = None
         norm_type = None
@@ -66,7 +67,7 @@ class FasterRCNN(BaseAPI):
         elif backbone_name == 'ResNet50':
         elif backbone_name == 'ResNet50':
             layers = 50
             layers = 50
             variant = 'b'
             variant = 'b'
-        elif backbone_name == 'ResNet50vd':
+        elif backbone_name == 'ResNet50_vd':
             layers = 50
             layers = 50
             variant = 'd'
             variant = 'd'
             norm_type = 'affine_channel'
             norm_type = 'affine_channel'
@@ -74,7 +75,7 @@ class FasterRCNN(BaseAPI):
             layers = 101
             layers = 101
             variant = 'b'
             variant = 'b'
             norm_type = 'affine_channel'
             norm_type = 'affine_channel'
-        elif backbone_name == 'ResNet101vd':
+        elif backbone_name == 'ResNet101_vd':
             layers = 101
             layers = 101
             variant = 'd'
             variant = 'd'
             norm_type = 'affine_channel'
             norm_type = 'affine_channel'
@@ -109,7 +110,8 @@ class FasterRCNN(BaseAPI):
             aspect_ratios=self.aspect_ratios,
             aspect_ratios=self.aspect_ratios,
             anchor_sizes=self.anchor_sizes,
             anchor_sizes=self.anchor_sizes,
             train_pre_nms_top_n=train_pre_nms_top_n,
             train_pre_nms_top_n=train_pre_nms_top_n,
-            test_pre_nms_top_n=test_pre_nms_top_n)
+            test_pre_nms_top_n=test_pre_nms_top_n,
+            fixed_input_shape=self.fixed_input_shape)
         inputs = model.generate_inputs()
         inputs = model.generate_inputs()
         if mode == 'train':
         if mode == 'train':
             model_out = model.build_net(inputs)
             model_out = model.build_net(inputs)
@@ -165,7 +167,8 @@ class FasterRCNN(BaseAPI):
               metric=None,
               metric=None,
               use_vdl=False,
               use_vdl=False,
               early_stop=False,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
         """训练。
 
 
         Args:
         Args:
@@ -191,6 +194,7 @@ class FasterRCNN(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
         Raises:
         Raises:
             ValueError: 评估类型不在指定列表中。
             ValueError: 评估类型不在指定列表中。
@@ -229,7 +233,9 @@ class FasterRCNN(BaseAPI):
             startup_prog=fluid.default_startup_program(),
             startup_prog=fluid.default_startup_program(),
             pretrain_weights=pretrain_weights,
             pretrain_weights=pretrain_weights,
             fuse_bn=fuse_bn,
             fuse_bn=fuse_bn,
-            save_dir=save_dir)
+            save_dir=save_dir,
+            resume_checkpoint=resume_checkpoint)
+
         # 训练
         # 训练
         self.train_loop(
         self.train_loop(
             num_epochs=num_epochs,
             num_epochs=num_epochs,

+ 31 - 1
paddlex/cv/models/load_model.py

@@ -23,7 +23,7 @@ import paddlex
 import paddlex.utils.logging as logging
 import paddlex.utils.logging as logging
 
 
 
 
-def load_model(model_dir):
+def load_model(model_dir, fixed_input_shape=None):
     if not osp.exists(osp.join(model_dir, "model.yml")):
     if not osp.exists(osp.join(model_dir, "model.yml")):
         raise Exception("There's not model.yml in {}".format(model_dir))
         raise Exception("There's not model.yml in {}".format(model_dir))
     with open(osp.join(model_dir, "model.yml")) as f:
     with open(osp.join(model_dir, "model.yml")) as f:
@@ -44,6 +44,7 @@ def load_model(model_dir):
     else:
     else:
         model = getattr(paddlex.cv.models,
         model = getattr(paddlex.cv.models,
                         info['Model'])(**info['_init_params'])
                         info['Model'])(**info['_init_params'])
+    model.fixed_input_shape = fixed_input_shape
     if status == "Normal" or \
     if status == "Normal" or \
             status == "Prune" or status == "fluid.save":
             status == "Prune" or status == "fluid.save":
         startup_prog = fluid.Program()
         startup_prog = fluid.Program()
@@ -78,6 +79,8 @@ def load_model(model_dir):
             model.test_outputs[var_desc[0]] = out
             model.test_outputs[var_desc[0]] = out
     if 'Transforms' in info:
     if 'Transforms' in info:
         transforms_mode = info.get('TransformsMode', 'RGB')
         transforms_mode = info.get('TransformsMode', 'RGB')
+        # 固定模型的输入shape
+        fix_input_shape(info, fixed_input_shape=fixed_input_shape)
         if transforms_mode == 'RGB':
         if transforms_mode == 'RGB':
             to_rgb = True
             to_rgb = True
         else:
         else:
@@ -102,6 +105,33 @@ def load_model(model_dir):
     return model
     return model
 
 
 
 
+def fix_input_shape(info, fixed_input_shape=None):
+    if fixed_input_shape is not None:
+        resize = {'ResizeByShort': {}}
+        padding = {'Padding': {}}
+        if info['_Attributes']['model_type'] == 'classifier':
+            crop_size = 0
+            for transform in info['Transforms']:
+                if 'CenterCrop' in transform:
+                    crop_size = transform['CenterCrop']['crop_size']
+                    break
+            assert crop_size == fixed_input_shape[
+                0], "fixed_input_shape must == CenterCrop:crop_size:{}".format(
+                    crop_size)
+            assert crop_size == fixed_input_shape[
+                1], "fixed_input_shape must == CenterCrop:crop_size:{}".format(
+                    crop_size)
+            if crop_size == 0:
+                logging.warning(
+                    "fixed_input_shape must == input shape when trainning")
+        else:
+            resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
+            resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
+            padding['Padding']['target_size'] = list(fixed_input_shape)
+            info['Transforms'].append(resize)
+            info['Transforms'].append(padding)
+
+
 def build_transforms(model_type, transforms_info, to_rgb=True):
 def build_transforms(model_type, transforms_info, to_rgb=True):
     if model_type == "classifier":
     if model_type == "classifier":
         import paddlex.cv.transforms.cls_transforms as T
         import paddlex.cv.transforms.cls_transforms as T

+ 12 - 6
paddlex/cv/models/mask_rcnn.py

@@ -32,7 +32,7 @@ class MaskRCNN(FasterRCNN):
     Args:
     Args:
         num_classes (int): 包含了背景类的类别数。默认为81。
         num_classes (int): 包含了背景类的类别数。默认为81。
         backbone (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50',
         backbone (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50',
-            'ResNet50vd', 'ResNet101', 'ResNet101vd']。默认为'ResNet50'。
+            'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。
         with_fpn (bool): 是否使用FPN结构。默认为True。
         with_fpn (bool): 是否使用FPN结构。默认为True。
         aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
         aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
         anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
         anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
@@ -46,7 +46,7 @@ class MaskRCNN(FasterRCNN):
                  anchor_sizes=[32, 64, 128, 256, 512]):
                  anchor_sizes=[32, 64, 128, 256, 512]):
         self.init_params = locals()
         self.init_params = locals()
         backbones = [
         backbones = [
-            'ResNet18', 'ResNet50', 'ResNet50vd', 'ResNet101', 'ResNet101vd'
+            'ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd'
         ]
         ]
         assert backbone in backbones, "backbone should be one of {}".format(
         assert backbone in backbones, "backbone should be one of {}".format(
             backbones)
             backbones)
@@ -60,6 +60,7 @@ class MaskRCNN(FasterRCNN):
             self.mask_head_resolution = 28
             self.mask_head_resolution = 28
         else:
         else:
             self.mask_head_resolution = 14
             self.mask_head_resolution = 14
+        self.fixed_input_shape = None
 
 
     def build_net(self, mode='train'):
     def build_net(self, mode='train'):
         train_pre_nms_top_n = 2000 if self.with_fpn else 12000
         train_pre_nms_top_n = 2000 if self.with_fpn else 12000
@@ -73,7 +74,8 @@ class MaskRCNN(FasterRCNN):
             train_pre_nms_top_n=train_pre_nms_top_n,
             train_pre_nms_top_n=train_pre_nms_top_n,
             test_pre_nms_top_n=test_pre_nms_top_n,
             test_pre_nms_top_n=test_pre_nms_top_n,
             num_convs=num_convs,
             num_convs=num_convs,
-            mask_head_resolution=self.mask_head_resolution)
+            mask_head_resolution=self.mask_head_resolution,
+            fixed_input_shape=self.fixed_input_shape)
         inputs = model.generate_inputs()
         inputs = model.generate_inputs()
         if mode == 'train':
         if mode == 'train':
             model_out = model.build_net(inputs)
             model_out = model.build_net(inputs)
@@ -130,7 +132,8 @@ class MaskRCNN(FasterRCNN):
               metric=None,
               metric=None,
               use_vdl=False,
               use_vdl=False,
               early_stop=False,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
         """训练。
 
 
         Args:
         Args:
@@ -156,6 +159,7 @@ class MaskRCNN(FasterRCNN):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
         Raises:
         Raises:
             ValueError: 评估类型不在指定列表中。
             ValueError: 评估类型不在指定列表中。
@@ -167,7 +171,8 @@ class MaskRCNN(FasterRCNN):
                 metric = 'COCO'
                 metric = 'COCO'
             else:
             else:
                 raise Exception(
                 raise Exception(
-                    "train_dataset should be datasets.COCODetection or datasets.EasyDataDet.")
+                    "train_dataset should be datasets.COCODetection or datasets.EasyDataDet."
+                )
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         self.metric = metric
         self.metric = metric
         if not self.trainable:
         if not self.trainable:
@@ -195,7 +200,8 @@ class MaskRCNN(FasterRCNN):
             startup_prog=fluid.default_startup_program(),
             startup_prog=fluid.default_startup_program(),
             pretrain_weights=pretrain_weights,
             pretrain_weights=pretrain_weights,
             fuse_bn=fuse_bn,
             fuse_bn=fuse_bn,
-            save_dir=save_dir)
+            save_dir=save_dir,
+            resume_checkpoint=resume_checkpoint)
         # 训练
         # 训练
         self.train_loop(
         self.train_loop(
             num_epochs=num_epochs,
             num_epochs=num_epochs,

+ 20 - 16
paddlex/cv/models/slim/prune_config.py

@@ -15,7 +15,7 @@
 import numpy as np
 import numpy as np
 import os.path as osp
 import os.path as osp
 import paddle.fluid as fluid
 import paddle.fluid as fluid
-import paddlehub as hub
+#import paddlehub as hub
 import paddlex
 import paddlex
 
 
 sensitivities_data = {
 sensitivities_data = {
@@ -105,22 +105,26 @@ def get_sensitivities(flag, model, save_dir):
             model_type)
             model_type)
         url = sensitivities_data[model_type]
         url = sensitivities_data[model_type]
         fname = osp.split(url)[-1]
         fname = osp.split(url)[-1]
-        try:
-            hub.download(fname, save_path=save_dir)
-        except Exception as e:
-            if isinstance(e, hub.ResourceNotFoundError):
-                raise Exception(
-                    "Resource for model {}(key='{}') not found".format(
-                        model_type, fname))
-            elif isinstance(e, hub.ServerConnectionError):
-                raise Exception(
-                    "Cannot get reource for model {}(key='{}'), please check your internet connecgtion"
-                    .format(model_type, fname))
-            else:
-                raise Exception(
-                    "Unexpected error, please make sure paddlehub >= 1.6.2 {}".
-                    format(str(e)))
+        paddlex.utils.download(url, path=save_dir)
         return osp.join(save_dir, fname)
         return osp.join(save_dir, fname)
+
+
+#        try:
+#            hub.download(fname, save_path=save_dir)
+#        except Exception as e:
+#            if isinstance(e, hub.ResourceNotFoundError):
+#                raise Exception(
+#                    "Resource for model {}(key='{}') not found".format(
+#                        model_type, fname))
+#            elif isinstance(e, hub.ServerConnectionError):
+#                raise Exception(
+#                    "Cannot get reource for model {}(key='{}'), please check your internet connecgtion"
+#                    .format(model_type, fname))
+#            else:
+#                raise Exception(
+#                    "Unexpected error, please make sure paddlehub >= 1.6.2 {}".
+#                    format(str(e)))
+#        return osp.join(save_dir, fname)
     else:
     else:
         raise Exception(
         raise Exception(
             "sensitivities need to be defined as directory path or `DEFAULT`(download sensitivities automatically)."
             "sensitivities need to be defined as directory path or `DEFAULT`(download sensitivities automatically)."

+ 12 - 9
paddlex/cv/models/unet.py

@@ -77,6 +77,7 @@ class UNet(DeepLabv3p):
         self.class_weight = class_weight
         self.class_weight = class_weight
         self.ignore_index = ignore_index
         self.ignore_index = ignore_index
         self.labels = None
         self.labels = None
+        self.fixed_input_shape = None
 
 
     def build_net(self, mode='train'):
     def build_net(self, mode='train'):
         model = paddlex.cv.nets.segmentation.UNet(
         model = paddlex.cv.nets.segmentation.UNet(
@@ -86,7 +87,8 @@ class UNet(DeepLabv3p):
             use_bce_loss=self.use_bce_loss,
             use_bce_loss=self.use_bce_loss,
             use_dice_loss=self.use_dice_loss,
             use_dice_loss=self.use_dice_loss,
             class_weight=self.class_weight,
             class_weight=self.class_weight,
-            ignore_index=self.ignore_index)
+            ignore_index=self.ignore_index,
+            fixed_input_shape=self.fixed_input_shape)
         inputs = model.generate_inputs()
         inputs = model.generate_inputs()
         model_out = model.build_net(inputs)
         model_out = model.build_net(inputs)
         outputs = OrderedDict()
         outputs = OrderedDict()
@@ -119,7 +121,8 @@ class UNet(DeepLabv3p):
               sensitivities_file=None,
               sensitivities_file=None,
               eval_metric_loss=0.05,
               eval_metric_loss=0.05,
               early_stop=False,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
         """训练。
 
 
         Args:
         Args:
@@ -143,14 +146,14 @@ class UNet(DeepLabv3p):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
         Raises:
         Raises:
             ValueError: 模型从inference model进行加载。
             ValueError: 模型从inference model进行加载。
         """
         """
-        return super(
-            UNet,
-            self).train(num_epochs, train_dataset, train_batch_size,
-                        eval_dataset, save_interval_epochs, log_interval_steps,
-                        save_dir, pretrain_weights, optimizer, learning_rate,
-                        lr_decay_power, use_vdl, sensitivities_file,
-                        eval_metric_loss, early_stop, early_stop_patience)
+        return super(UNet, self).train(
+            num_epochs, train_dataset, train_batch_size, eval_dataset,
+            save_interval_epochs, log_interval_steps, save_dir,
+            pretrain_weights, optimizer, learning_rate, lr_decay_power,
+            use_vdl, sensitivities_file, eval_metric_loss, early_stop,
+            early_stop_patience, resume_checkpoint)

+ 44 - 31
paddlex/cv/models/utils/pretrain_weights.py

@@ -1,5 +1,5 @@
 import paddlex
 import paddlex
-import paddlehub as hub
+#import paddlehub as hub
 import os
 import os
 import os.path as osp
 import os.path as osp
 
 
@@ -85,40 +85,53 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
                 backbone = 'DetResNet50'
                 backbone = 'DetResNet50'
         assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
         assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
             backbone)
             backbone)
-        try:
-            hub.download(backbone, save_path=new_save_dir)
-        except Exception as e:
-            if isinstance(e, hub.ResourceNotFoundError):
-                raise Exception(
-                    "Resource for backbone {} not found".format(backbone))
-            elif isinstance(e, hub.ServerConnectionError):
-                raise Exception(
-                    "Cannot get reource for backbone {}, please check your internet connecgtion"
-                    .format(backbone))
-            else:
-                raise Exception(
-                    "Unexpected error, please make sure paddlehub >= 1.6.2")
-        return osp.join(new_save_dir, backbone)
+        url = image_pretrain[backbone]
+        fname = osp.split(url)[-1].split('.')[0]
+        paddlex.utils.download_and_decompress(url, path=new_save_dir)
+        return osp.join(new_save_dir, fname)
+#        try:
+#            hub.download(backbone, save_path=new_save_dir)
+#        except Exception as e:
+#            if isinstance(e, hub.ResourceNotFoundError):
+#                raise Exception(
+#                    "Resource for backbone {} not found".format(backbone))
+#            elif isinstance(e, hub.ServerConnectionError):
+#                raise Exception(
+#                    "Cannot get reource for backbone {}, please check your internet connecgtion"
+#                    .format(backbone))
+#            else:
+#                raise Exception(
+#                    "Unexpected error, please make sure paddlehub >= 1.6.2")
+#        return osp.join(new_save_dir, backbone)
     elif flag == 'COCO':
     elif flag == 'COCO':
         new_save_dir = save_dir
         new_save_dir = save_dir
         if hasattr(paddlex, 'pretrain_dir'):
         if hasattr(paddlex, 'pretrain_dir'):
             new_save_dir = paddlex.pretrain_dir
             new_save_dir = paddlex.pretrain_dir
-        assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format(
-            backbone)
-        try:
-            hub.download(backbone, save_path=new_save_dir)
-        except Exception as e:
-            if isinstance(hub.ResourceNotFoundError):
-                raise Exception(
-                    "Resource for backbone {} not found".format(backbone))
-            elif isinstance(hub.ServerConnectionError):
-                raise Exception(
-                    "Cannot get reource for backbone {}, please check your internet connecgtion"
-                    .format(backbone))
-            else:
-                raise Exception(
-                    "Unexpected error, please make sure paddlehub >= 1.6.2")
-        return osp.join(new_save_dir, backbone)
+        url = coco_pretrain[backbone]
+        fname = osp.split(url)[-1].split('.')[0]
+        paddlex.utils.download_and_decompress(url, path=new_save_dir)
+        return osp.join(new_save_dir, fname)
+
+
+#        new_save_dir = save_dir
+#        if hasattr(paddlex, 'pretrain_dir'):
+#            new_save_dir = paddlex.pretrain_dir
+#        assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format(
+#            backbone)
+#        try:
+#            hub.download(backbone, save_path=new_save_dir)
+#        except Exception as e:
+#            if isinstance(hub.ResourceNotFoundError):
+#                raise Exception(
+#                    "Resource for backbone {} not found".format(backbone))
+#            elif isinstance(hub.ServerConnectionError):
+#                raise Exception(
+#                    "Cannot get reource for backbone {}, please check your internet connecgtion"
+#                    .format(backbone))
+#            else:
+#                raise Exception(
+#                    "Unexpected error, please make sure paddlehub >= 1.6.2")
+#        return osp.join(new_save_dir, backbone)
     else:
     else:
         raise Exception(
         raise Exception(
             "pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."
             "pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."

+ 14 - 4
paddlex/cv/models/utils/visualize.py

@@ -16,6 +16,7 @@ import os
 import cv2
 import cv2
 import colorsys
 import colorsys
 import numpy as np
 import numpy as np
+import time
 import paddlex.utils.logging as logging
 import paddlex.utils.logging as logging
 from .detection_eval import fixed_linspace, backup_linspace, loadRes
 from .detection_eval import fixed_linspace, backup_linspace, loadRes
 
 
@@ -25,8 +26,12 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'):
         Visualize bbox and mask results
         Visualize bbox and mask results
     """
     """
 
 
-    image_name = os.path.split(image)[-1]
-    image = cv2.imread(image)
+    if isinstance(image, np.ndarray):
+        image_name = str(int(time.time())) + '.jpg'
+    else:
+        image_name = os.path.split(image)[-1]
+        image = cv2.imread(image)
+
     image = draw_bbox_mask(image, result, threshold=threshold)
     image = draw_bbox_mask(image, result, threshold=threshold)
     if save_dir is not None:
     if save_dir is not None:
         if not os.path.exists(save_dir):
         if not os.path.exists(save_dir):
@@ -56,13 +61,18 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
     c3 = cv2.LUT(label_map, color_map[:, 2])
     c3 = cv2.LUT(label_map, color_map[:, 2])
     pseudo_img = np.dstack((c1, c2, c3))
     pseudo_img = np.dstack((c1, c2, c3))
 
 
-    im = cv2.imread(image)
+    if isinstance(image, np.ndarray):
+        im = image
+        image_name = str(int(time.time())) + '.jpg'
+    else:
+        image_name = os.path.split(image)[-1]
+        im = cv2.imread(image)
+
     vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
     vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
 
 
     if save_dir is not None:
     if save_dir is not None:
         if not os.path.exists(save_dir):
         if not os.path.exists(save_dir):
             os.makedirs(save_dir)
             os.makedirs(save_dir)
-        image_name = os.path.split(image)[-1]
         out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
         out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
         cv2.imwrite(out_path, vis_result)
         cv2.imwrite(out_path, vis_result)
         logging.info('The visualized result is saved as {}'.format(out_path))
         logging.info('The visualized result is saved as {}'.format(out_path))

+ 8 - 3
paddlex/cv/models/yolo_v3.py

@@ -80,6 +80,7 @@ class YOLOv3(BaseAPI):
         self.label_smooth = label_smooth
         self.label_smooth = label_smooth
         self.sync_bn = True
         self.sync_bn = True
         self.train_random_shapes = train_random_shapes
         self.train_random_shapes = train_random_shapes
+        self.fixed_input_shape = None
 
 
     def _get_backbone(self, backbone_name):
     def _get_backbone(self, backbone_name):
         if backbone_name == 'DarkNet53':
         if backbone_name == 'DarkNet53':
@@ -113,7 +114,8 @@ class YOLOv3(BaseAPI):
             nms_topk=self.nms_topk,
             nms_topk=self.nms_topk,
             nms_keep_topk=self.nms_keep_topk,
             nms_keep_topk=self.nms_keep_topk,
             nms_iou_threshold=self.nms_iou_threshold,
             nms_iou_threshold=self.nms_iou_threshold,
-            train_random_shapes=self.train_random_shapes)
+            train_random_shapes=self.train_random_shapes,
+            fixed_input_shape=self.fixed_input_shape)
         inputs = model.generate_inputs()
         inputs = model.generate_inputs()
         model_out = model.build_net(inputs)
         model_out = model.build_net(inputs)
         outputs = OrderedDict([('bbox', model_out)])
         outputs = OrderedDict([('bbox', model_out)])
@@ -164,7 +166,8 @@ class YOLOv3(BaseAPI):
               sensitivities_file=None,
               sensitivities_file=None,
               eval_metric_loss=0.05,
               eval_metric_loss=0.05,
               early_stop=False,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
         """训练。
 
 
         Args:
         Args:
@@ -193,6 +196,7 @@ class YOLOv3(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 
         Raises:
         Raises:
             ValueError: 评估类型不在指定列表中。
             ValueError: 评估类型不在指定列表中。
@@ -234,7 +238,8 @@ class YOLOv3(BaseAPI):
             pretrain_weights=pretrain_weights,
             pretrain_weights=pretrain_weights,
             save_dir=save_dir,
             save_dir=save_dir,
             sensitivities_file=sensitivities_file,
             sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss)
+            eval_metric_loss=eval_metric_loss,
+            resume_checkpoint=resume_checkpoint)
         # 训练
         # 训练
         self.train_loop(
         self.train_loop(
             num_epochs=num_epochs,
             num_epochs=num_epochs,

+ 13 - 3
paddlex/cv/nets/detection/faster_rcnn.py

@@ -76,7 +76,8 @@ class FasterRCNN(object):
             fg_thresh=.5,
             fg_thresh=.5,
             bg_thresh_hi=.5,
             bg_thresh_hi=.5,
             bg_thresh_lo=0.,
             bg_thresh_lo=0.,
-            bbox_reg_weights=[0.1, 0.1, 0.2, 0.2]):
+            bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
+            fixed_input_shape=None):
         super(FasterRCNN, self).__init__()
         super(FasterRCNN, self).__init__()
         self.backbone = backbone
         self.backbone = backbone
         self.mode = mode
         self.mode = mode
@@ -148,6 +149,7 @@ class FasterRCNN(object):
         self.bg_thresh_lo = bg_thresh_lo
         self.bg_thresh_lo = bg_thresh_lo
         self.bbox_reg_weights = bbox_reg_weights
         self.bbox_reg_weights = bbox_reg_weights
         self.rpn_only = rpn_only
         self.rpn_only = rpn_only
+        self.fixed_input_shape = fixed_input_shape
 
 
     def build_net(self, inputs):
     def build_net(self, inputs):
         im = inputs['image']
         im = inputs['image']
@@ -219,8 +221,16 @@ class FasterRCNN(object):
 
 
     def generate_inputs(self):
     def generate_inputs(self):
         inputs = OrderedDict()
         inputs = OrderedDict()
-        inputs['image'] = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if self.mode == 'train':
         if self.mode == 'train':
             inputs['im_info'] = fluid.data(
             inputs['im_info'] = fluid.data(
                 dtype='float32', shape=[None, 3], name='im_info')
                 dtype='float32', shape=[None, 3], name='im_info')

+ 13 - 3
paddlex/cv/nets/detection/mask_rcnn.py

@@ -86,7 +86,8 @@ class MaskRCNN(object):
             fg_thresh=.5,
             fg_thresh=.5,
             bg_thresh_hi=.5,
             bg_thresh_hi=.5,
             bg_thresh_lo=0.,
             bg_thresh_lo=0.,
-            bbox_reg_weights=[0.1, 0.1, 0.2, 0.2]):
+            bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
+            fixed_input_shape=None):
         super(MaskRCNN, self).__init__()
         super(MaskRCNN, self).__init__()
         self.backbone = backbone
         self.backbone = backbone
         self.mode = mode
         self.mode = mode
@@ -167,6 +168,7 @@ class MaskRCNN(object):
         self.bg_thresh_lo = bg_thresh_lo
         self.bg_thresh_lo = bg_thresh_lo
         self.bbox_reg_weights = bbox_reg_weights
         self.bbox_reg_weights = bbox_reg_weights
         self.rpn_only = rpn_only
         self.rpn_only = rpn_only
+        self.fixed_input_shape = fixed_input_shape
 
 
     def build_net(self, inputs):
     def build_net(self, inputs):
         im = inputs['image']
         im = inputs['image']
@@ -306,8 +308,16 @@ class MaskRCNN(object):
 
 
     def generate_inputs(self):
     def generate_inputs(self):
         inputs = OrderedDict()
         inputs = OrderedDict()
-        inputs['image'] = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if self.mode == 'train':
         if self.mode == 'train':
             inputs['im_info'] = fluid.data(
             inputs['im_info'] = fluid.data(
                 dtype='float32', shape=[None, 3], name='im_info')
                 dtype='float32', shape=[None, 3], name='im_info')

+ 12 - 3
paddlex/cv/nets/detection/yolo_v3.py

@@ -33,7 +33,8 @@ class YOLOv3:
                  nms_iou_threshold=0.45,
                  nms_iou_threshold=0.45,
                  train_random_shapes=[
                  train_random_shapes=[
                      320, 352, 384, 416, 448, 480, 512, 544, 576, 608
                      320, 352, 384, 416, 448, 480, 512, 544, 576, 608
-                 ]):
+                 ],
+                 fixed_input_shape=None):
         if anchors is None:
         if anchors is None:
             anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
             anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
                        [59, 119], [116, 90], [156, 198], [373, 326]]
                        [59, 119], [116, 90], [156, 198], [373, 326]]
@@ -54,6 +55,7 @@ class YOLOv3:
         self.norm_decay = 0.0
         self.norm_decay = 0.0
         self.prefix_name = ''
         self.prefix_name = ''
         self.train_random_shapes = train_random_shapes
         self.train_random_shapes = train_random_shapes
+        self.fixed_input_shape = fixed_input_shape
 
 
     def _head(self, feats):
     def _head(self, feats):
         outputs = []
         outputs = []
@@ -247,8 +249,15 @@ class YOLOv3:
 
 
     def generate_inputs(self):
     def generate_inputs(self):
         inputs = OrderedDict()
         inputs = OrderedDict()
-        inputs['image'] = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if self.mode == 'train':
         if self.mode == 'train':
             inputs['gt_box'] = fluid.data(
             inputs['gt_box'] = fluid.data(
                 dtype='float32', shape=[None, None, 4], name='gt_box')
                 dtype='float32', shape=[None, None, 4], name='gt_box')

+ 14 - 3
paddlex/cv/nets/segmentation/deeplabv3p.py

@@ -61,6 +61,7 @@ class DeepLabv3p(object):
             自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
             自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
             即平时使用的交叉熵损失函数。
             即平时使用的交叉熵损失函数。
         ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
         ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
+        fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
 
 
     Raises:
     Raises:
         ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
         ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
@@ -81,7 +82,8 @@ class DeepLabv3p(object):
                  use_bce_loss=False,
                  use_bce_loss=False,
                  use_dice_loss=False,
                  use_dice_loss=False,
                  class_weight=None,
                  class_weight=None,
-                 ignore_index=255):
+                 ignore_index=255,
+                 fixed_input_shape=None):
         # dice_loss或bce_loss只适用两类分割中
         # dice_loss或bce_loss只适用两类分割中
         if num_classes > 2 and (use_bce_loss or use_dice_loss):
         if num_classes > 2 and (use_bce_loss or use_dice_loss):
             raise ValueError(
             raise ValueError(
@@ -115,6 +117,7 @@ class DeepLabv3p(object):
         self.decoder_use_sep_conv = decoder_use_sep_conv
         self.decoder_use_sep_conv = decoder_use_sep_conv
         self.encoder_with_aspp = encoder_with_aspp
         self.encoder_with_aspp = encoder_with_aspp
         self.enable_decoder = enable_decoder
         self.enable_decoder = enable_decoder
+        self.fixed_input_shape = fixed_input_shape
 
 
     def _encoder(self, input):
     def _encoder(self, input):
         # 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
         # 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
@@ -310,8 +313,16 @@ class DeepLabv3p(object):
 
 
     def generate_inputs(self):
     def generate_inputs(self):
         inputs = OrderedDict()
         inputs = OrderedDict()
-        inputs['image'] = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if self.mode == 'train':
         if self.mode == 'train':
             inputs['label'] = fluid.data(
             inputs['label'] = fluid.data(
                 dtype='int32', shape=[None, 1, None, None], name='label')
                 dtype='int32', shape=[None, 1, None, None], name='label')

+ 14 - 3
paddlex/cv/nets/segmentation/unet.py

@@ -54,6 +54,7 @@ class UNet(object):
                 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
                 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
                 即平时使用的交叉熵损失函数。
                 即平时使用的交叉熵损失函数。
             ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
             ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
+            fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
 
 
         Raises:
         Raises:
             ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
             ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
@@ -69,7 +70,8 @@ class UNet(object):
                  use_bce_loss=False,
                  use_bce_loss=False,
                  use_dice_loss=False,
                  use_dice_loss=False,
                  class_weight=None,
                  class_weight=None,
-                 ignore_index=255):
+                 ignore_index=255,
+                 fixed_input_shape=None):
         # dice_loss或bce_loss只适用两类分割中
         # dice_loss或bce_loss只适用两类分割中
         if num_classes > 2 and (use_bce_loss or use_dice_loss):
         if num_classes > 2 and (use_bce_loss or use_dice_loss):
             raise Exception(
             raise Exception(
@@ -97,6 +99,7 @@ class UNet(object):
         self.use_dice_loss = use_dice_loss
         self.use_dice_loss = use_dice_loss
         self.class_weight = class_weight
         self.class_weight = class_weight
         self.ignore_index = ignore_index
         self.ignore_index = ignore_index
+        self.fixed_input_shape = fixed_input_shape
 
 
     def _double_conv(self, data, out_ch):
     def _double_conv(self, data, out_ch):
         param_attr = fluid.ParamAttr(
         param_attr = fluid.ParamAttr(
@@ -226,8 +229,16 @@ class UNet(object):
 
 
     def generate_inputs(self):
     def generate_inputs(self):
         inputs = OrderedDict()
         inputs = OrderedDict()
-        inputs['image'] = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if self.mode == 'train':
         if self.mode == 'train':
             inputs['label'] = fluid.data(
             inputs['label'] = fluid.data(
                 dtype='int32', shape=[None, 1, None, None], name='label')
                 dtype='int32', shape=[None, 1, None, None], name='label')

+ 50 - 18
paddlex/cv/transforms/cls_transforms.py

@@ -13,13 +13,22 @@
 # limitations under the License.
 # limitations under the License.
 
 
 from .ops import *
 from .ops import *
+from .imgaug_support import execute_imgaug
 import random
 import random
 import os.path as osp
 import os.path as osp
 import numpy as np
 import numpy as np
 from PIL import Image, ImageEnhance
 from PIL import Image, ImageEnhance
 
 
 
 
-class Compose:
+class ClsTransform:
+    """分类Transform的基类
+    """
+
+    def __init__(self):
+        pass
+
+
+class Compose(ClsTransform):
     """根据数据预处理/增强算子对输入数据进行操作。
     """根据数据预处理/增强算子对输入数据进行操作。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
 
 
@@ -39,6 +48,15 @@ class Compose:
                             'must be equal or larger than 1!')
                             'must be equal or larger than 1!')
         self.transforms = transforms
         self.transforms = transforms
 
 
+        # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
+        for op in self.transforms:
+            if not isinstance(op, ClsTransform):
+                import imgaug.augmenters as iaa
+                if not isinstance(op, iaa.Augmenter):
+                    raise Exception(
+                        "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
+                    )
+
     def __call__(self, im, label=None):
     def __call__(self, im, label=None):
         """
         """
         Args:
         Args:
@@ -48,20 +66,34 @@ class Compose:
             tuple: 根据网络所需字段所组成的tuple;
             tuple: 根据网络所需字段所组成的tuple;
                 字段由transforms中的最后一个数据预处理操作决定。
                 字段由transforms中的最后一个数据预处理操作决定。
         """
         """
-        try:
-            im = cv2.imread(im).astype('float32')
-        except:
-            raise TypeError('Can\'t read The image file {}!'.format(im))
+        if isinstance(im, np.ndarray):
+            if len(im.shape) != 3:
+                raise Exception(
+                    "im should be 3-dimension, but now is {}-dimensions".
+                    format(len(im.shape)))
+        else:
+            try:
+                im = cv2.imread(im).astype('float32')
+            except:
+                raise TypeError('Can\'t read The image file {}!'.format(im))
         im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         for op in self.transforms:
         for op in self.transforms:
-            outputs = op(im, label)
-            im = outputs[0]
-            if len(outputs) == 2:
-                label = outputs[1]
+            if isinstance(op, ClsTransform):
+                outputs = op(im, label)
+                im = outputs[0]
+                if len(outputs) == 2:
+                    label = outputs[1]
+            else:
+                import imgaug.augmenters as iaa
+                if isinstance(op, iaa.Augmenter):
+                    im, = execute_imgaug(op, im)
+                output = (im, )
+                if label is not None:
+                    output = (im, label)
         return outputs
         return outputs
 
 
 
 
-class RandomCrop:
+class RandomCrop(ClsTransform):
     """对图像进行随机剪裁,模型训练时的数据增强操作。
     """对图像进行随机剪裁,模型训练时的数据增强操作。
 
 
     1. 根据lower_scale、lower_ratio、upper_ratio计算随机剪裁的高、宽。
     1. 根据lower_scale、lower_ratio、upper_ratio计算随机剪裁的高、宽。
@@ -104,7 +136,7 @@ class RandomCrop:
             return (im, label)
             return (im, label)
 
 
 
 
-class RandomHorizontalFlip:
+class RandomHorizontalFlip(ClsTransform):
     """以一定的概率对图像进行随机水平翻转,模型训练时的数据增强操作。
     """以一定的概率对图像进行随机水平翻转,模型训练时的数据增强操作。
 
 
     Args:
     Args:
@@ -132,7 +164,7 @@ class RandomHorizontalFlip:
             return (im, label)
             return (im, label)
 
 
 
 
-class RandomVerticalFlip:
+class RandomVerticalFlip(ClsTransform):
     """以一定的概率对图像进行随机垂直翻转,模型训练时的数据增强操作。
     """以一定的概率对图像进行随机垂直翻转,模型训练时的数据增强操作。
 
 
     Args:
     Args:
@@ -160,7 +192,7 @@ class RandomVerticalFlip:
             return (im, label)
             return (im, label)
 
 
 
 
-class Normalize:
+class Normalize(ClsTransform):
     """对图像进行标准化。
     """对图像进行标准化。
 
 
     1. 对图像进行归一化到区间[0.0, 1.0]。
     1. 对图像进行归一化到区间[0.0, 1.0]。
@@ -195,7 +227,7 @@ class Normalize:
             return (im, label)
             return (im, label)
 
 
 
 
-class ResizeByShort:
+class ResizeByShort(ClsTransform):
     """根据图像短边对图像重新调整大小(resize)。
     """根据图像短边对图像重新调整大小(resize)。
 
 
     1. 获取图像的长边和短边长度。
     1. 获取图像的长边和短边长度。
@@ -242,7 +274,7 @@ class ResizeByShort:
             return (im, label)
             return (im, label)
 
 
 
 
-class CenterCrop:
+class CenterCrop(ClsTransform):
     """以图像中心点扩散裁剪长宽为`crop_size`的正方形
     """以图像中心点扩散裁剪长宽为`crop_size`的正方形
 
 
     1. 计算剪裁的起始点。
     1. 计算剪裁的起始点。
@@ -272,7 +304,7 @@ class CenterCrop:
             return (im, label)
             return (im, label)
 
 
 
 
-class RandomRotate:
+class RandomRotate(ClsTransform):
     def __init__(self, rotate_range=30, prob=0.5):
     def __init__(self, rotate_range=30, prob=0.5):
         """以一定的概率对图像在[-rotate_range, rotaterange]角度范围内进行旋转,模型训练时的数据增强操作。
         """以一定的概率对图像在[-rotate_range, rotaterange]角度范围内进行旋转,模型训练时的数据增强操作。
 
 
@@ -306,7 +338,7 @@ class RandomRotate:
             return (im, label)
             return (im, label)
 
 
 
 
-class RandomDistort:
+class RandomDistort(ClsTransform):
     """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
     """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
 
 
     1. 对变换的操作顺序进行随机化操作。
     1. 对变换的操作顺序进行随机化操作。
@@ -397,7 +429,7 @@ class RandomDistort:
             return (im, label)
             return (im, label)
 
 
 
 
-class ArrangeClassifier:
+class ArrangeClassifier(ClsTransform):
     """获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用
     """获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用
 
 
     Args:
     Args:

+ 139 - 72
paddlex/cv/transforms/det_transforms.py

@@ -24,11 +24,20 @@ import numpy as np
 import cv2
 import cv2
 from PIL import Image, ImageEnhance
 from PIL import Image, ImageEnhance
 
 
+from .imgaug_support import execute_imgaug
 from .ops import *
 from .ops import *
 from .box_utils import *
 from .box_utils import *
 
 
 
 
-class Compose:
+class DetTransform:
+    """检测数据处理基类
+    """
+
+    def __init__(self):
+        pass
+
+
+class Compose(DetTransform):
     """根据数据预处理/增强列表对输入数据进行操作。
     """根据数据预处理/增强列表对输入数据进行操作。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
 
 
@@ -49,8 +58,16 @@ class Compose:
         self.transforms = transforms
         self.transforms = transforms
         self.use_mixup = False
         self.use_mixup = False
         for t in self.transforms:
         for t in self.transforms:
-            if t.__class__.__name__ == 'MixupImage':
+            if type(t).__name__ == 'MixupImage':
                 self.use_mixup = True
                 self.use_mixup = True
+        # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
+        for op in self.transforms:
+            if not isinstance(op, DetTransform):
+                import imgaug.augmenters as iaa
+                if not isinstance(op, iaa.Augmenter):
+                    raise Exception(
+                        "Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
+                    )
 
 
     def __call__(self, im, im_info=None, label_info=None):
     def __call__(self, im, im_info=None, label_info=None):
         """
         """
@@ -58,8 +75,8 @@ class Compose:
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
             im_info (dict): 存储与图像相关的信息,dict中的字段如下:
             im_info (dict): 存储与图像相关的信息,dict中的字段如下:
                 - im_id (np.ndarray): 图像序列号,形状为(1,)。
                 - im_id (np.ndarray): 图像序列号,形状为(1,)。
-                - origin_shape (np.ndarray): 图像原始大小,形状为(2,),
-                                        origin_shape[0]为高,origin_shape[1]为宽。
+                - image_shape (np.ndarray): 图像原始大小,形状为(2,),
+                                        image_shape[0]为高,image_shape[1]为宽。
                 - mixup (list): list为[im, im_info, label_info],分别对应
                 - mixup (list): list为[im, im_info, label_info],分别对应
                                 与当前图像进行mixup的图像np.ndarray数据、图像相关信息、标注框相关信息;
                                 与当前图像进行mixup的图像np.ndarray数据、图像相关信息、标注框相关信息;
                                 注意,当前epoch若无需进行mixup,则无该字段。
                                 注意,当前epoch若无需进行mixup,则无该字段。
@@ -84,18 +101,24 @@ class Compose:
         def decode_image(im_file, im_info, label_info):
         def decode_image(im_file, im_info, label_info):
             if im_info is None:
             if im_info is None:
                 im_info = dict()
                 im_info = dict()
-            try:
-                im = cv2.imread(im_file).astype('float32')
-            except:
-                raise TypeError(
-                    'Can\'t read The image file {}!'.format(im_file))
+            if isinstance(im_file, np.ndarray):
+                if len(im_file.shape) != 3:
+                    raise Exception(
+                        "im should be 3-dimensions, but now is {}-dimensions".
+                        format(len(im_file.shape)))
+                im = im_file
+            else:
+                try:
+                    im = cv2.imread(im_file).astype('float32')
+                except:
+                    raise TypeError(
+                        'Can\'t read The image file {}!'.format(im_file))
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
             # make default im_info with [h, w, 1]
             # make default im_info with [h, w, 1]
             im_info['im_resize_info'] = np.array(
             im_info['im_resize_info'] = np.array(
                 [im.shape[0], im.shape[1], 1.], dtype=np.float32)
                 [im.shape[0], im.shape[1], 1.], dtype=np.float32)
-            # copy augment_shape from origin_shape
-            im_info['augment_shape'] = np.array([im.shape[0],
-                                                 im.shape[1]]).astype('int32')
+            im_info['image_shape'] = np.array([im.shape[0],
+                                               im.shape[1]]).astype('int32')
             if not self.use_mixup:
             if not self.use_mixup:
                 if 'mixup' in im_info:
                 if 'mixup' in im_info:
                     del im_info['mixup']
                     del im_info['mixup']
@@ -118,12 +141,28 @@ class Compose:
         for op in self.transforms:
         for op in self.transforms:
             if im is None:
             if im is None:
                 return None
                 return None
-            outputs = op(im, im_info, label_info)
-            im = outputs[0]
+            if isinstance(op, DetTransform):
+                outputs = op(im, im_info, label_info)
+                im = outputs[0]
+            else:
+                if label_info is not None:
+                    gt_poly = label_info.get('gt_poly', None)
+                    gt_bbox = label_info['gt_bbox']
+                    if gt_poly is None:
+                        im, aug_bbox = execute_imgaug(op, im, bboxes=gt_bbox)
+                    else:
+                        im, aug_bbox, aug_poly = execute_imgaug(
+                            op, im, bboxes=gt_bbox, polygons=gt_poly)
+                        label_info['gt_poly'] = aug_poly
+                    label_info['gt_bbox'] = aug_bbox
+                    outputs = (im, im_info, label_info)
+                else:
+                    im, = execute_imgaug(op, im)
+                    outputs = (im, im_info)
         return outputs
         return outputs
 
 
 
 
-class ResizeByShort:
+class ResizeByShort(DetTransform):
     """根据图像的短边调整图像大小(resize)。
     """根据图像的短边调整图像大小(resize)。
 
 
     1. 获取图像的长边和短边长度。
     1. 获取图像的长边和短边长度。
@@ -195,12 +234,17 @@ class ResizeByShort:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class Padding:
-    """将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640],
+class Padding(DetTransform):
+    """1.将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640],
        `coarest_stride`为32,则由于300不为32的倍数,因此在图像最右和最下使用0值
        `coarest_stride`为32,则由于300不为32的倍数,因此在图像最右和最下使用0值
        进行padding,最终输出图像为[320, 640]。
        进行padding,最终输出图像为[320, 640]。
+       2.或者,将图像的长和宽padding到target_size指定的shape,如输入的图像为[300,640],
+         a. `target_size` = 960,在图像最右和最下使用0值进行padding,最终输出
+            图像为[960, 960]。
+         b. `target_size` = [640, 960],在图像最右和最下使用0值进行padding,最终
+            输出图像为[640, 960]。
 
 
-    1. 如果coarsest_stride为1则直接返回。
+    1. 如果coarsest_stride为1,target_size为None则直接返回。
     2. 获取图像的高H、宽W。
     2. 获取图像的高H、宽W。
     3. 计算填充后图像的高H_new、宽W_new。
     3. 计算填充后图像的高H_new、宽W_new。
     4. 构建大小为(H_new, W_new, 3)像素值为0的np.ndarray,
     4. 构建大小为(H_new, W_new, 3)像素值为0的np.ndarray,
@@ -208,10 +252,26 @@ class Padding:
 
 
     Args:
     Args:
         coarsest_stride (int): 填充后的图像长、宽为该参数的倍数,默认为1。
         coarsest_stride (int): 填充后的图像长、宽为该参数的倍数,默认为1。
+        target_size (int|list|tuple): 填充后的图像长、宽,默认为None,coarset_stride优先级更高。
+
+    Raises:
+        TypeError: 形参`target_size`数据类型不满足需求。
+        ValueError: 形参`target_size`为(list|tuple)时,长度不满足需求。
     """
     """
 
 
-    def __init__(self, coarsest_stride=1):
+    def __init__(self, coarsest_stride=1, target_size=None):
         self.coarsest_stride = coarsest_stride
         self.coarsest_stride = coarsest_stride
+        if target_size is not None:
+            if not isinstance(target_size, int):
+                if not isinstance(target_size, tuple) and not isinstance(
+                        target_size, list):
+                    raise TypeError(
+                        "Padding: Type of target_size must in (int|list|tuple)."
+                    )
+                elif len(target_size) != 2:
+                    raise ValueError(
+                        "Padding: Length of target_size must equal 2.")
+        self.target_size = target_size
 
 
     def __call__(self, im, im_info=None, label_info=None):
     def __call__(self, im, im_info=None, label_info=None):
         """
         """
@@ -228,13 +288,9 @@ class Padding:
         Raises:
         Raises:
             TypeError: 形参数据类型不满足需求。
             TypeError: 形参数据类型不满足需求。
             ValueError: 数据长度不匹配。
             ValueError: 数据长度不匹配。
+            ValueError: coarsest_stride,target_size需有且只有一个被指定。
+            ValueError: target_size小于原图的大小。
         """
         """
-
-        if self.coarsest_stride == 1:
-            if label_info is None:
-                return (im, im_info)
-            else:
-                return (im, im_info, label_info)
         if im_info is None:
         if im_info is None:
             im_info = dict()
             im_info = dict()
         if not isinstance(im, np.ndarray):
         if not isinstance(im, np.ndarray):
@@ -242,11 +298,29 @@ class Padding:
         if len(im.shape) != 3:
         if len(im.shape) != 3:
             raise ValueError('Padding: image is not 3-dimensional.')
             raise ValueError('Padding: image is not 3-dimensional.')
         im_h, im_w, im_c = im.shape[:]
         im_h, im_w, im_c = im.shape[:]
-        if self.coarsest_stride > 1:
+
+        if isinstance(self.target_size, int):
+            padding_im_h = self.target_size
+            padding_im_w = self.target_size
+        elif isinstance(self.target_size, list) or isinstance(
+                self.target_size, tuple):
+            padding_im_w = self.target_size[0]
+            padding_im_h = self.target_size[1]
+        elif self.coarsest_stride > 0:
             padding_im_h = int(
             padding_im_h = int(
                 np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride)
                 np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride)
             padding_im_w = int(
             padding_im_w = int(
                 np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride)
                 np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride)
+        else:
+            raise ValueError(
+                "coarsest_stridei(>1) or target_size(list|int) need setting in Padding transform"
+            )
+        pad_height = padding_im_h - im_h
+        pad_width = padding_im_w - im_w
+        if pad_height < 0 or pad_width < 0:
+            raise ValueError(
+                'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
+                .format(im_w, im_h, padding_im_w, padding_im_h))
         padding_im = np.zeros((padding_im_h, padding_im_w, im_c),
         padding_im = np.zeros((padding_im_h, padding_im_w, im_c),
                               dtype=np.float32)
                               dtype=np.float32)
         padding_im[:im_h, :im_w, :] = im
         padding_im[:im_h, :im_w, :] = im
@@ -256,7 +330,7 @@ class Padding:
             return (padding_im, im_info, label_info)
             return (padding_im, im_info, label_info)
 
 
 
 
-class Resize:
+class Resize(DetTransform):
     """调整图像大小(resize)。
     """调整图像大小(resize)。
 
 
     - 当目标大小(target_size)类型为int时,根据插值方式,
     - 当目标大小(target_size)类型为int时,根据插值方式,
@@ -335,7 +409,7 @@ class Resize:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class RandomHorizontalFlip:
+class RandomHorizontalFlip(DetTransform):
     """随机翻转图像、标注框、分割信息,模型训练时的数据增强操作。
     """随机翻转图像、标注框、分割信息,模型训练时的数据增强操作。
 
 
     1. 随机采样一个0-1之间的小数,当小数小于水平翻转概率时,
     1. 随机采样一个0-1之间的小数,当小数小于水平翻转概率时,
@@ -387,16 +461,13 @@ class RandomHorizontalFlip:
             raise TypeError(
             raise TypeError(
                 'Cannot do RandomHorizontalFlip! ' +
                 'Cannot do RandomHorizontalFlip! ' +
                 'Becasuse the im_info and label_info can not be None!')
                 'Becasuse the im_info and label_info can not be None!')
-        if 'augment_shape' not in im_info:
-            raise TypeError('Cannot do RandomHorizontalFlip! ' + \
-                            'Becasuse augment_shape is not in im_info!')
         if 'gt_bbox' not in label_info:
         if 'gt_bbox' not in label_info:
             raise TypeError('Cannot do RandomHorizontalFlip! ' + \
             raise TypeError('Cannot do RandomHorizontalFlip! ' + \
                             'Becasuse gt_bbox is not in label_info!')
                             'Becasuse gt_bbox is not in label_info!')
-        augment_shape = im_info['augment_shape']
+        image_shape = im_info['image_shape']
         gt_bbox = label_info['gt_bbox']
         gt_bbox = label_info['gt_bbox']
-        height = augment_shape[0]
-        width = augment_shape[1]
+        height = image_shape[0]
+        width = image_shape[1]
 
 
         if np.random.uniform(0, 1) < self.prob:
         if np.random.uniform(0, 1) < self.prob:
             im = horizontal_flip(im)
             im = horizontal_flip(im)
@@ -416,7 +487,7 @@ class RandomHorizontalFlip:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class Normalize:
+class Normalize(DetTransform):
     """对图像进行标准化。
     """对图像进行标准化。
 
 
     1. 归一化图像到到区间[0.0, 1.0]。
     1. 归一化图像到到区间[0.0, 1.0]。
@@ -460,7 +531,7 @@ class Normalize:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class RandomDistort:
+class RandomDistort(DetTransform):
     """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
     """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
 
 
     1. 对变换的操作顺序进行随机化操作。
     1. 对变换的操作顺序进行随机化操作。
@@ -545,7 +616,7 @@ class RandomDistort:
             params = params_dict[ops[id].__name__]
             params = params_dict[ops[id].__name__]
             prob = prob_dict[ops[id].__name__]
             prob = prob_dict[ops[id].__name__]
             params['im'] = im
             params['im'] = im
-            
+
             if np.random.uniform(0, 1) < prob:
             if np.random.uniform(0, 1) < prob:
                 im = ops[id](**params)
                 im = ops[id](**params)
         if label_info is None:
         if label_info is None:
@@ -554,7 +625,7 @@ class RandomDistort:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class MixupImage:
+class MixupImage(DetTransform):
     """对图像进行mixup操作,模型训练时的数据增强操作,目前仅YOLOv3模型支持该transform。
     """对图像进行mixup操作,模型训练时的数据增强操作,目前仅YOLOv3模型支持该transform。
 
 
     当label_info中不存在mixup字段时,直接返回,否则进行下述操作:
     当label_info中不存在mixup字段时,直接返回,否则进行下述操作:
@@ -567,7 +638,7 @@ class MixupImage:
             (2)拼接原图像标注框和mixup图像标注框。
             (2)拼接原图像标注框和mixup图像标注框。
             (3)拼接原图像标注框类别和mixup图像标注框类别。
             (3)拼接原图像标注框类别和mixup图像标注框类别。
             (4)原图像标注框混合得分乘以factor,mixup图像标注框混合得分乘以(1-factor),叠加2个结果。
             (4)原图像标注框混合得分乘以factor,mixup图像标注框混合得分乘以(1-factor),叠加2个结果。
-    3. 更新im_info中的augment_shape信息。
+    3. 更新im_info中的image_shape信息。
 
 
     Args:
     Args:
         alpha (float): 随机beta分布的下限。默认为1.5。
         alpha (float): 随机beta分布的下限。默认为1.5。
@@ -610,7 +681,7 @@ class MixupImage:
                    当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
                    当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
                    存储与标注框相关信息的字典。
                    存储与标注框相关信息的字典。
                    其中,im_info更新字段为:
                    其中,im_info更新字段为:
-                       - augment_shape (np.ndarray): mixup后的图像高、宽二者组成的np.ndarray,形状为(2,)。
+                       - image_shape (np.ndarray): mixup后的图像高、宽二者组成的np.ndarray,形状为(2,)。
                    im_info删除的字段:
                    im_info删除的字段:
                        - mixup (list): 与当前字段进行mixup的图像相关信息。
                        - mixup (list): 与当前字段进行mixup的图像相关信息。
                    label_info更新字段为:
                    label_info更新字段为:
@@ -674,8 +745,8 @@ class MixupImage:
         label_info['gt_score'] = gt_score
         label_info['gt_score'] = gt_score
         label_info['gt_class'] = gt_class
         label_info['gt_class'] = gt_class
         label_info['is_crowd'] = is_crowd
         label_info['is_crowd'] = is_crowd
-        im_info['augment_shape'] = np.array([im.shape[0],
-                                             im.shape[1]]).astype('int32')
+        im_info['image_shape'] = np.array([im.shape[0],
+                                           im.shape[1]]).astype('int32')
         im_info.pop('mixup')
         im_info.pop('mixup')
         if label_info is None:
         if label_info is None:
             return (im, im_info)
             return (im, im_info)
@@ -683,7 +754,7 @@ class MixupImage:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class RandomExpand:
+class RandomExpand(DetTransform):
     """随机扩张图像,模型训练时的数据增强操作。
     """随机扩张图像,模型训练时的数据增强操作。
     1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。
     1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。
     2. 计算扩张后图像大小。
     2. 计算扩张后图像大小。
@@ -721,7 +792,7 @@ class RandomExpand:
                    当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
                    当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
                    存储与标注框相关信息的字典。
                    存储与标注框相关信息的字典。
                    其中,im_info更新字段为:
                    其中,im_info更新字段为:
-                       - augment_shape (np.ndarray): 扩张后的图像高、宽二者组成的np.ndarray,形状为(2,)。
+                       - image_shape (np.ndarray): 扩张后的图像高、宽二者组成的np.ndarray,形状为(2,)。
                    label_info更新字段为:
                    label_info更新字段为:
                        - gt_bbox (np.ndarray): 随机扩张后真实标注框坐标,形状为(n, 4),
                        - gt_bbox (np.ndarray): 随机扩张后真实标注框坐标,形状为(n, 4),
                                           其中n代表真实标注框的个数。
                                           其中n代表真实标注框的个数。
@@ -734,9 +805,6 @@ class RandomExpand:
             raise TypeError(
             raise TypeError(
                 'Cannot do RandomExpand! ' +
                 'Cannot do RandomExpand! ' +
                 'Becasuse the im_info and label_info can not be None!')
                 'Becasuse the im_info and label_info can not be None!')
-        if 'augment_shape' not in im_info:
-            raise TypeError('Cannot do RandomExpand! ' + \
-                            'Becasuse augment_shape is not in im_info!')
         if 'gt_bbox' not in label_info or \
         if 'gt_bbox' not in label_info or \
                 'gt_class' not in label_info:
                 'gt_class' not in label_info:
             raise TypeError('Cannot do RandomExpand! ' + \
             raise TypeError('Cannot do RandomExpand! ' + \
@@ -744,9 +812,9 @@ class RandomExpand:
         if np.random.uniform(0., 1.) < self.prob:
         if np.random.uniform(0., 1.) < self.prob:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
-        augment_shape = im_info['augment_shape']
-        height = int(augment_shape[0])
-        width = int(augment_shape[1])
+        image_shape = im_info['image_shape']
+        height = int(image_shape[0])
+        width = int(image_shape[1])
 
 
         expand_ratio = np.random.uniform(1., self.ratio)
         expand_ratio = np.random.uniform(1., self.ratio)
         h = int(height * expand_ratio)
         h = int(height * expand_ratio)
@@ -759,7 +827,7 @@ class RandomExpand:
         canvas *= np.array(self.fill_value, dtype=np.float32)
         canvas *= np.array(self.fill_value, dtype=np.float32)
         canvas[y:y + height, x:x + width, :] = im
         canvas[y:y + height, x:x + width, :] = im
 
 
-        im_info['augment_shape'] = np.array([h, w]).astype('int32')
+        im_info['image_shape'] = np.array([h, w]).astype('int32')
         if 'gt_bbox' in label_info and len(label_info['gt_bbox']) > 0:
         if 'gt_bbox' in label_info and len(label_info['gt_bbox']) > 0:
             label_info['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32)
             label_info['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32)
         if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0:
         if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0:
@@ -768,7 +836,7 @@ class RandomExpand:
         return (canvas, im_info, label_info)
         return (canvas, im_info, label_info)
 
 
 
 
-class RandomCrop:
+class RandomCrop(DetTransform):
     """随机裁剪图像。
     """随机裁剪图像。
     1. 若allow_no_crop为True,则在thresholds加入’no_crop’。
     1. 若allow_no_crop为True,则在thresholds加入’no_crop’。
     2. 随机打乱thresholds。
     2. 随机打乱thresholds。
@@ -815,12 +883,14 @@ class RandomCrop:
             tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
             tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
                    当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
                    当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
                    存储与标注框相关信息的字典。
                    存储与标注框相关信息的字典。
-                   其中,label_info更新字段为:
-                       - gt_bbox (np.ndarray): 随机裁剪后真实标注框坐标,形状为(n, 4),
+                   其中,im_info更新字段为:
+                           - image_shape (np.ndarray): 扩裁剪的图像高、宽二者组成的np.ndarray,形状为(2,)。
+                       label_info更新字段为:
+                           - gt_bbox (np.ndarray): 随机裁剪后真实标注框坐标,形状为(n, 4),
                                           其中n代表真实标注框的个数。
                                           其中n代表真实标注框的个数。
-                       - gt_class (np.ndarray): 随机裁剪后每个真实标注框对应的类别序号,形状为(n, 1),
+                           - gt_class (np.ndarray): 随机裁剪后每个真实标注框对应的类别序号,形状为(n, 1),
                                            其中n代表真实标注框的个数。
                                            其中n代表真实标注框的个数。
-                       - gt_score (np.ndarray): 随机裁剪后每个真实标注框对应的混合得分,形状为(n, 1),
+                           - gt_score (np.ndarray): 随机裁剪后每个真实标注框对应的混合得分,形状为(n, 1),
                                            其中n代表真实标注框的个数。
                                            其中n代表真实标注框的个数。
 
 
         Raises:
         Raises:
@@ -830,9 +900,6 @@ class RandomCrop:
             raise TypeError(
             raise TypeError(
                 'Cannot do RandomCrop! ' +
                 'Cannot do RandomCrop! ' +
                 'Becasuse the im_info and label_info can not be None!')
                 'Becasuse the im_info and label_info can not be None!')
-        if 'augment_shape' not in im_info:
-            raise TypeError('Cannot do RandomCrop! ' + \
-                            'Becasuse augment_shape is not in im_info!')
         if 'gt_bbox' not in label_info or \
         if 'gt_bbox' not in label_info or \
                 'gt_class' not in label_info:
                 'gt_class' not in label_info:
             raise TypeError('Cannot do RandomCrop! ' + \
             raise TypeError('Cannot do RandomCrop! ' + \
@@ -841,9 +908,9 @@ class RandomCrop:
         if len(label_info['gt_bbox']) == 0:
         if len(label_info['gt_bbox']) == 0:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
-        augment_shape = im_info['augment_shape']
-        w = augment_shape[1]
-        h = augment_shape[0]
+        image_shape = im_info['image_shape']
+        w = image_shape[1]
+        h = image_shape[0]
         gt_bbox = label_info['gt_bbox']
         gt_bbox = label_info['gt_bbox']
         thresholds = list(self.thresholds)
         thresholds = list(self.thresholds)
         if self.allow_no_crop:
         if self.allow_no_crop:
@@ -902,7 +969,7 @@ class RandomCrop:
                 label_info['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
                 label_info['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
                 label_info['gt_class'] = np.take(
                 label_info['gt_class'] = np.take(
                     label_info['gt_class'], valid_ids, axis=0)
                     label_info['gt_class'], valid_ids, axis=0)
-                im_info['augment_shape'] = np.array(
+                im_info['image_shape'] = np.array(
                     [crop_box[3] - crop_box[1],
                     [crop_box[3] - crop_box[1],
                      crop_box[2] - crop_box[0]]).astype('int32')
                      crop_box[2] - crop_box[0]]).astype('int32')
                 if 'gt_score' in label_info:
                 if 'gt_score' in label_info:
@@ -917,7 +984,7 @@ class RandomCrop:
         return (im, im_info, label_info)
         return (im, im_info, label_info)
 
 
 
 
-class ArrangeFasterRCNN:
+class ArrangeFasterRCNN(DetTransform):
     """获取FasterRCNN模型训练/验证/预测所需信息。
     """获取FasterRCNN模型训练/验证/预测所需信息。
 
 
     Args:
     Args:
@@ -973,7 +1040,7 @@ class ArrangeFasterRCNN:
             im_resize_info = im_info['im_resize_info']
             im_resize_info = im_info['im_resize_info']
             im_id = im_info['im_id']
             im_id = im_info['im_id']
             im_shape = np.array(
             im_shape = np.array(
-                (im_info['augment_shape'][0], im_info['augment_shape'][1], 1),
+                (im_info['image_shape'][0], im_info['image_shape'][1], 1),
                 dtype=np.float32)
                 dtype=np.float32)
             gt_bbox = label_info['gt_bbox']
             gt_bbox = label_info['gt_bbox']
             gt_class = label_info['gt_class']
             gt_class = label_info['gt_class']
@@ -986,13 +1053,13 @@ class ArrangeFasterRCNN:
                                 'Becasuse the im_info can not be None!')
                                 'Becasuse the im_info can not be None!')
             im_resize_info = im_info['im_resize_info']
             im_resize_info = im_info['im_resize_info']
             im_shape = np.array(
             im_shape = np.array(
-                (im_info['augment_shape'][0], im_info['augment_shape'][1], 1),
+                (im_info['image_shape'][0], im_info['image_shape'][1], 1),
                 dtype=np.float32)
                 dtype=np.float32)
             outputs = (im, im_resize_info, im_shape)
             outputs = (im, im_resize_info, im_shape)
         return outputs
         return outputs
 
 
 
 
-class ArrangeMaskRCNN:
+class ArrangeMaskRCNN(DetTransform):
     """获取MaskRCNN模型训练/验证/预测所需信息。
     """获取MaskRCNN模型训练/验证/预测所需信息。
 
 
     Args:
     Args:
@@ -1066,7 +1133,7 @@ class ArrangeMaskRCNN:
                                 'Becasuse the im_info can not be None!')
                                 'Becasuse the im_info can not be None!')
             im_resize_info = im_info['im_resize_info']
             im_resize_info = im_info['im_resize_info']
             im_shape = np.array(
             im_shape = np.array(
-                (im_info['augment_shape'][0], im_info['augment_shape'][1], 1),
+                (im_info['image_shape'][0], im_info['image_shape'][1], 1),
                 dtype=np.float32)
                 dtype=np.float32)
             if self.mode == 'eval':
             if self.mode == 'eval':
                 im_id = im_info['im_id']
                 im_id = im_info['im_id']
@@ -1076,7 +1143,7 @@ class ArrangeMaskRCNN:
         return outputs
         return outputs
 
 
 
 
-class ArrangeYOLOv3:
+class ArrangeYOLOv3(DetTransform):
     """获取YOLOv3模型训练/验证/预测所需信息。
     """获取YOLOv3模型训练/验证/预测所需信息。
 
 
     Args:
     Args:
@@ -1117,7 +1184,7 @@ class ArrangeYOLOv3:
                 raise TypeError(
                 raise TypeError(
                     'Cannot do ArrangeYolov3! ' +
                     'Cannot do ArrangeYolov3! ' +
                     'Becasuse the im_info and label_info can not be None!')
                     'Becasuse the im_info and label_info can not be None!')
-            im_shape = im_info['augment_shape']
+            im_shape = im_info['image_shape']
             if len(label_info['gt_bbox']) != len(label_info['gt_class']):
             if len(label_info['gt_bbox']) != len(label_info['gt_class']):
                 raise ValueError("gt num mismatch: bbox and class.")
                 raise ValueError("gt num mismatch: bbox and class.")
             if len(label_info['gt_bbox']) != len(label_info['gt_score']):
             if len(label_info['gt_bbox']) != len(label_info['gt_score']):
@@ -1141,7 +1208,7 @@ class ArrangeYOLOv3:
                 raise TypeError(
                 raise TypeError(
                     'Cannot do ArrangeYolov3! ' +
                     'Cannot do ArrangeYolov3! ' +
                     'Becasuse the im_info and label_info can not be None!')
                     'Becasuse the im_info and label_info can not be None!')
-            im_shape = im_info['augment_shape']
+            im_shape = im_info['image_shape']
             if len(label_info['gt_bbox']) != len(label_info['gt_class']):
             if len(label_info['gt_bbox']) != len(label_info['gt_class']):
                 raise ValueError("gt num mismatch: bbox and class.")
                 raise ValueError("gt num mismatch: bbox and class.")
             im_id = im_info['im_id']
             im_id = im_info['im_id']
@@ -1160,6 +1227,6 @@ class ArrangeYOLOv3:
             if im_info is None:
             if im_info is None:
                 raise TypeError('Cannot do ArrangeYolov3! ' +
                 raise TypeError('Cannot do ArrangeYolov3! ' +
                                 'Becasuse the im_info can not be None!')
                                 'Becasuse the im_info can not be None!')
-            im_shape = im_info['augment_shape']
+            im_shape = im_info['image_shape']
             outputs = (im, im_shape)
             outputs = (im, im_shape)
         return outputs
         return outputs

+ 131 - 0
paddlex/cv/transforms/imgaug_support.py

@@ -0,0 +1,131 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+
+def execute_imgaug(augmenter, im, bboxes=None, polygons=None,
+                   segment_map=None):
+    # 预处理,将bboxes, polygons转换成imgaug格式
+    import imgaug.augmentables.polys as polys
+    import imgaug.augmentables.bbs as bbs
+
+    aug_im = im.astype('uint8')
+
+    aug_bboxes = None
+    if bboxes is not None:
+        aug_bboxes = list()
+        for i in range(len(bboxes)):
+            x1 = bboxes[i, 0] - 1
+            y1 = bboxes[i, 1]
+            x2 = bboxes[i, 2]
+            y2 = bboxes[i, 3]
+            aug_bboxes.append(bbs.BoundingBox(x1, y1, x2, y2))
+
+    aug_polygons = None
+    lod_info = list()
+    if polygons is not None:
+        aug_polygons = list()
+        for i in range(len(polygons)):
+            num = len(polygons[i])
+            lod_info.append(num)
+            for j in range(num):
+                points = np.reshape(polygons[i][j], (-1, 2))
+                aug_polygons.append(polys.Polygon(points))
+
+    aug_segment_map = None
+    if segment_map is not None:
+        if len(segment_map.shape) == 2:
+            h, w = segment_map.shape
+            aug_segment_map = np.reshape(segment_map, (1, h, w, 1))
+        elif len(segment_map.shape) == 3:
+            h, w, c = segment_map.shape
+            aug_segment_map = np.reshape(segment_map, (1, h, w, c))
+        else:
+            raise Exception(
+                "Only support 2-dimensions for 3-dimensions for segment_map")
+
+    aug_im, aug_bboxes, aug_polygons, aug_seg_map = augmenter.augment(
+        image=aug_im,
+        bounding_boxes=aug_bboxes,
+        polygons=aug_polygons,
+        segmentation_maps=aug_segment_map)
+
+    aug_im = aug_im.astype('float32')
+
+    if aug_polygons is not None:
+        assert len(aug_bboxes) == len(
+            lod_info
+        ), "Number of aug_bboxes should be equal to number of aug_polygons"
+
+    if aug_bboxes is not None:
+        # 裁剪掉在图像之外的bbox和polygon
+        for i in range(len(aug_bboxes)):
+            aug_bboxes[i] = aug_bboxes[i].clip_out_of_image(aug_im)
+        if aug_polygons is not None:
+            for i in range(len(aug_polygons)):
+                aug_polygons[i] = aug_polygons[i].clip_out_of_image(aug_im)
+
+        # 过滤掉无效的bbox和polygon,并转换为训练数据格式
+        converted_bboxes = list()
+        converted_polygons = list()
+        poly_index = 0
+        for i in range(len(aug_bboxes)):
+            # 过滤width或height不足1像素的框
+            if aug_bboxes[i].width < 1 or aug_bboxes[i].height < 1:
+                continue
+            if aug_polygons is None:
+                converted_bboxes.append([
+                    aug_bboxes[i].x1, aug_bboxes[i].y1, aug_bboxes[i].x2,
+                    aug_bboxes[i].y2
+                ])
+                continue
+
+            # 如若有polygons,将会继续执行下面代码
+            polygons_this_box = list()
+            for ps in aug_polygons[poly_index:poly_index + lod_info[i]]:
+                if len(ps) == 0:
+                    continue
+                for p in ps:
+                    # 没有3个point的polygon被过滤
+                    if len(p.exterior) < 3:
+                        continue
+                    polygons_this_box.append(p.exterior.flatten().tolist())
+            poly_index += lod_info[i]
+
+            if len(polygons_this_box) == 0:
+                continue
+            converted_bboxes.append([
+                aug_bboxes[i].x1, aug_bboxes[i].y1, aug_bboxes[i].x2,
+                aug_bboxes[i].y2
+            ])
+            converted_polygons.append(polygons_this_box)
+        if len(converted_bboxes) == 0:
+            aug_im = im
+            converted_bboxes = bboxes
+            converted_polygons = polygons
+
+    result = [aug_im]
+    if bboxes is not None:
+        result.append(np.array(converted_bboxes))
+    if polygons is not None:
+        result.append(converted_polygons)
+    if segment_map is not None:
+        n, h, w, c = aug_seg_map.shape
+        if len(segment_map.shape) == 2:
+            aug_seg_map = np.reshape(aug_seg_map, (h, w))
+        elif len(segment_map.shape) == 3:
+            aug_seg_map = np.reshape(aug_seg_map, (h, w, c))
+        result.append(aug_seg_map)
+    return result

+ 198 - 47
paddlex/cv/transforms/seg_transforms.py

@@ -14,6 +14,7 @@
 # limitations under the License.
 # limitations under the License.
 
 
 from .ops import *
 from .ops import *
+from .imgaug_support import execute_imgaug
 import random
 import random
 import os.path as osp
 import os.path as osp
 import numpy as np
 import numpy as np
@@ -22,7 +23,15 @@ import cv2
 from collections import OrderedDict
 from collections import OrderedDict
 
 
 
 
-class Compose:
+class SegTransform:
+    """ 分割transform基类
+    """
+
+    def __init__(self):
+        pass
+
+
+class Compose(SegTransform):
     """根据数据预处理/增强算子对输入数据进行操作。
     """根据数据预处理/增强算子对输入数据进行操作。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
 
 
@@ -43,14 +52,23 @@ class Compose:
                             'must be equal or larger than 1!')
                             'must be equal or larger than 1!')
         self.transforms = transforms
         self.transforms = transforms
         self.to_rgb = False
         self.to_rgb = False
+        # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
+        for op in self.transforms:
+            if not isinstance(op, SegTransform):
+                import imgaug.augmenters as iaa
+                if not isinstance(op, iaa.Augmenter):
+                    raise Exception(
+                        "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
+                    )
 
 
     def __call__(self, im, im_info=None, label=None):
     def __call__(self, im, im_info=None, label=None):
         """
         """
         Args:
         Args:
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息,dict中的字段如下:
-                - shape_before_resize (tuple): 图像resize之前的大小(h, w)。
-                - shape_before_padding (tuple): 图像padding之前的大小(h, w)。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。
             label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -58,27 +76,41 @@ class Compose:
         """
         """
 
 
         if im_info is None:
         if im_info is None:
-            im_info = dict()
-        try:
-            im = cv2.imread(im).astype('float32')
-        except:
-            raise ValueError('Can\'t read The image file {}!'.format(im))
+            im_info = list()
+        if isinstance(im, np.ndarray):
+            if len(im.shape) != 3:
+                raise Exception(
+                    "im should be 3-dimensions, but now is {}-dimensions".
+                    format(len(im.shape)))
+        else:
+            try:
+                im = cv2.imread(im).astype('float32')
+            except:
+                raise ValueError('Can\'t read The image file {}!'.format(im))
         if self.to_rgb:
         if self.to_rgb:
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         if label is not None:
         if label is not None:
             if not isinstance(label, np.ndarray):
             if not isinstance(label, np.ndarray):
                 label = np.asarray(Image.open(label))
                 label = np.asarray(Image.open(label))
         for op in self.transforms:
         for op in self.transforms:
-            outputs = op(im, im_info, label)
-            im = outputs[0]
-            if len(outputs) >= 2:
-                im_info = outputs[1]
-            if len(outputs) == 3:
-                label = outputs[2]
+            if isinstance(op, SegTransform):
+                outputs = op(im, im_info, label)
+                im = outputs[0]
+                if len(outputs) >= 2:
+                    im_info = outputs[1]
+                if len(outputs) == 3:
+                    label = outputs[2]
+            else:
+                if label is not None:
+                    im, label = execute_imgaug(op, im, segment_map=label)
+                    outputs = (im, im_info, label)
+                else:
+                    im, = execute_imgaug(op, im)
+                    outputs = (im, im_info)
         return outputs
         return outputs
 
 
 
 
-class RandomHorizontalFlip:
+class RandomHorizontalFlip(SegTransform):
     """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。
     """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。
 
 
     Args:
     Args:
@@ -93,7 +125,10 @@ class RandomHorizontalFlip:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -111,7 +146,7 @@ class RandomHorizontalFlip:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomVerticalFlip:
+class RandomVerticalFlip(SegTransform):
     """以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。
     """以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。
 
 
     Args:
     Args:
@@ -125,7 +160,10 @@ class RandomVerticalFlip:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -143,7 +181,7 @@ class RandomVerticalFlip:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class Resize:
+class Resize(SegTransform):
     """调整图像大小(resize),当存在标注图像时,则同步进行处理。
     """调整图像大小(resize),当存在标注图像时,则同步进行处理。
 
 
     - 当目标大小(target_size)类型为int时,根据插值方式,
     - 当目标大小(target_size)类型为int时,根据插值方式,
@@ -191,7 +229,10 @@ class Resize:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -208,7 +249,7 @@ class Resize:
         """
         """
         if im_info is None:
         if im_info is None:
             im_info = OrderedDict()
             im_info = OrderedDict()
-        im_info['shape_before_resize'] = im.shape[:2]
+        im_info.append(('resize', im.shape[:2]))
 
 
         if not isinstance(im, np.ndarray):
         if not isinstance(im, np.ndarray):
             raise TypeError("ResizeImage: image type is not np.ndarray.")
             raise TypeError("ResizeImage: image type is not np.ndarray.")
@@ -250,7 +291,7 @@ class Resize:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class ResizeByLong:
+class ResizeByLong(SegTransform):
     """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
     """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
 
 
     Args:
     Args:
@@ -264,7 +305,10 @@ class ResizeByLong:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -272,12 +316,12 @@ class ResizeByLong:
                 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
                 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
                 存储与图像相关信息的字典和标注图像np.ndarray数据。
                 存储与图像相关信息的字典和标注图像np.ndarray数据。
                 其中,im_info新增字段为:
                 其中,im_info新增字段为:
-                    -shape_before_resize (tuple): 保存resize之前图像的形状(h, w
+                    -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)
         """
         """
         if im_info is None:
         if im_info is None:
             im_info = OrderedDict()
             im_info = OrderedDict()
 
 
-        im_info['shape_before_resize'] = im.shape[:2]
+        im_info.append(('resize', im.shape[:2]))
         im = resize_long(im, self.long_size)
         im = resize_long(im, self.long_size)
         if label is not None:
         if label is not None:
             label = resize_long(label, self.long_size, cv2.INTER_NEAREST)
             label = resize_long(label, self.long_size, cv2.INTER_NEAREST)
@@ -288,7 +332,84 @@ class ResizeByLong:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class ResizeRangeScaling:
+class ResizeByShort(SegTransform):
+    """根据图像的短边调整图像大小(resize)。
+
+    1. 获取图像的长边和短边长度。
+    2. 根据短边与short_size的比例,计算长边的目标长度,
+       此时高、宽的resize比例为short_size/原图短边长度。
+    3. 如果max_size>0,调整resize比例:
+       如果长边的目标长度>max_size,则高、宽的resize比例为max_size/原图长边长度。
+    4. 根据调整大小的比例对图像进行resize。
+
+    Args:
+        target_size (int): 短边目标长度。默认为800。
+        max_size (int): 长边目标长度的最大限制。默认为1333。
+
+     Raises:
+        TypeError: 形参数据类型不满足需求。
+    """
+
+    def __init__(self, short_size=800, max_size=1333):
+        self.max_size = int(max_size)
+        if not isinstance(short_size, int):
+            raise TypeError(
+                "Type of short_size is invalid. Must be Integer, now is {}".
+                format(type(short_size)))
+        self.short_size = short_size
+        if not (isinstance(self.max_size, int)):
+            raise TypeError("max_size: input type is invalid.")
+
+    def __call__(self, im, im_info=None, label=None):
+        """
+        Args:
+            im (numnp.ndarraypy): 图像np.ndarray数据。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
+            label (np.ndarray): 标注图像np.ndarray数据。
+
+        Returns:
+            tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
+                   当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
+                   存储与图像相关信息的字典和标注图像np.ndarray数据。
+                   其中,im_info更新字段为:
+                       -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
+
+        Raises:
+            TypeError: 形参数据类型不满足需求。
+            ValueError: 数据长度不匹配。
+        """
+        if im_info is None:
+            im_info = OrderedDict()
+        if not isinstance(im, np.ndarray):
+            raise TypeError("ResizeByShort: image type is not numpy.")
+        if len(im.shape) != 3:
+            raise ValueError('ResizeByShort: image is not 3-dimensional.')
+        im_info.append(('resize', im.shape[:2]))
+        im_short_size = min(im.shape[0], im.shape[1])
+        im_long_size = max(im.shape[0], im.shape[1])
+        scale = float(self.short_size) / im_short_size
+        if self.max_size > 0 and np.round(
+                scale * im_long_size) > self.max_size:
+            scale = float(self.max_size) / float(im_long_size)
+        resized_width = int(round(im.shape[1] * scale))
+        resized_height = int(round(im.shape[0] * scale))
+        im = cv2.resize(
+            im, (resized_width, resized_height),
+            interpolation=cv2.INTER_NEAREST)
+        if label is not None:
+            im = cv2.resize(
+                label, (resized_width, resized_height),
+                interpolation=cv2.INTER_NEAREST)
+        if label is None:
+            return (im, im_info)
+        else:
+            return (im, im_info, label)
+
+
+class ResizeRangeScaling(SegTransform):
     """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
     """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
 
 
     Args:
     Args:
@@ -311,7 +432,10 @@ class ResizeRangeScaling:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -334,7 +458,7 @@ class ResizeRangeScaling:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class ResizeStepScaling:
+class ResizeStepScaling(SegTransform):
     """对图像按照某一个比例resize,这个比例以scale_step_size为步长
     """对图像按照某一个比例resize,这个比例以scale_step_size为步长
     在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
     在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
 
 
@@ -364,7 +488,10 @@ class ResizeStepScaling:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -406,7 +533,7 @@ class ResizeStepScaling:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class Normalize:
+class Normalize(SegTransform):
     """对图像进行标准化。
     """对图像进行标准化。
     1.尺度缩放到 [0,1]。
     1.尺度缩放到 [0,1]。
     2.对图像进行减均值除以标准差操作。
     2.对图像进行减均值除以标准差操作。
@@ -432,7 +559,10 @@ class Normalize:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
          Returns:
          Returns:
@@ -451,7 +581,7 @@ class Normalize:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class Padding:
+class Padding(SegTransform):
     """对图像或标注图像进行padding,padding方向为右和下。
     """对图像或标注图像进行padding,padding方向为右和下。
     根据提供的值对图像或标注图像进行padding操作。
     根据提供的值对图像或标注图像进行padding操作。
 
 
@@ -486,7 +616,10 @@ class Padding:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -501,7 +634,7 @@ class Padding:
         """
         """
         if im_info is None:
         if im_info is None:
             im_info = OrderedDict()
             im_info = OrderedDict()
-        im_info['shape_before_padding'] = im.shape[:2]
+        im_info.append(('padding', im.shape[:2]))
 
 
         im_height, im_width = im.shape[0], im.shape[1]
         im_height, im_width = im.shape[0], im.shape[1]
         if isinstance(self.target_size, int):
         if isinstance(self.target_size, int):
@@ -540,7 +673,7 @@ class Padding:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomPaddingCrop:
+class RandomPaddingCrop(SegTransform):
     """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。
     """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。
 
 
     Args:
     Args:
@@ -574,7 +707,10 @@ class RandomPaddingCrop:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
          Returns:
          Returns:
@@ -636,7 +772,7 @@ class RandomPaddingCrop:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomBlur:
+class RandomBlur(SegTransform):
     """以一定的概率对图像进行高斯模糊。
     """以一定的概率对图像进行高斯模糊。
 
 
     Args:
     Args:
@@ -650,7 +786,10 @@ class RandomBlur:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -679,7 +818,7 @@ class RandomBlur:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomRotate:
+class RandomRotate(SegTransform):
     """对图像进行随机旋转, 模型训练时的数据增强操作。
     """对图像进行随机旋转, 模型训练时的数据增强操作。
     在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
     在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
     并对旋转后的图像和标注图像进行相应的padding。
     并对旋转后的图像和标注图像进行相应的padding。
@@ -703,7 +842,10 @@ class RandomRotate:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -748,7 +890,7 @@ class RandomRotate:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomScaleAspect:
+class RandomScaleAspect(SegTransform):
     """裁剪并resize回原始尺寸的图像和标注图像。
     """裁剪并resize回原始尺寸的图像和标注图像。
     按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
     按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
 
 
@@ -765,7 +907,10 @@ class RandomScaleAspect:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -808,7 +953,7 @@ class RandomScaleAspect:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomDistort:
+class RandomDistort(SegTransform):
     """对图像进行随机失真。
     """对图像进行随机失真。
 
 
     1. 对变换的操作顺序进行随机化操作。
     1. 对变换的操作顺序进行随机化操作。
@@ -847,7 +992,10 @@ class RandomDistort:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:
@@ -901,7 +1049,7 @@ class RandomDistort:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class ArrangeSegmenter:
+class ArrangeSegmenter(SegTransform):
     """获取训练/验证/预测所需的信息。
     """获取训练/验证/预测所需的信息。
 
 
     Args:
     Args:
@@ -922,7 +1070,10 @@ class ArrangeSegmenter:
         """
         """
         Args:
         Args:
             im (np.ndarray): 图像np.ndarray数据。
             im (np.ndarray): 图像np.ndarray数据。
-            im_info (dict): 存储与图像相关的信息。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
             label (np.ndarray): 标注图像np.ndarray数据。
             label (np.ndarray): 标注图像np.ndarray数据。
 
 
         Returns:
         Returns:

+ 24 - 0
paddlex/tools/__init__.py

@@ -0,0 +1,24 @@
+#!/usr/bin/env python
+# coding: utf-8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .x2imagenet import EasyData2ImageNet
+from .x2coco import LabelMe2COCO
+from .x2coco import EasyData2COCO
+from .x2voc import LabelMe2VOC
+from .x2voc import EasyData2VOC
+from .x2seg import JingLing2Seg
+from .x2seg import LabelMe2Seg
+from .x2seg import EasyData2Seg

+ 43 - 0
paddlex/tools/base.py

@@ -0,0 +1,43 @@
+#!/usr/bin/env python
+# coding: utf-8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import chardet
+import numpy as np
+
+class MyEncoder(json.JSONEncoder):
+    def default(self, obj):
+        if isinstance(obj, np.integer):
+            return int(obj)
+        elif isinstance(obj, np.floating):
+            return float(obj)
+        elif isinstance(obj, np.ndarray):
+            return obj.tolist()
+        else:
+            return super(MyEncoder, self).default(obj)
+        
+def is_pic(img_name):
+    valid_suffix = ["JPEG", "jpeg", "JPG", "jpg", "BMP", "bmp", "PNG", "png"]
+    suffix = img_name.split(".")[-1]
+    if suffix not in valid_suffix:
+        return False
+    return True
+
+def get_encoding(path):
+    f = open(path, 'rb')
+    data = f.read()
+    file_encoding = chardet.detect(data).get('encoding')
+    return file_encoding

+ 257 - 0
paddlex/tools/x2coco.py

@@ -0,0 +1,257 @@
+#!/usr/bin/env python
+# coding: utf-8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import json
+import os
+import os.path as osp
+import shutil
+import numpy as np
+import PIL.ImageDraw
+from .base import MyEncoder, is_pic, get_encoding
+        
+        
+class X2COCO(object):
+    def __init__(self):
+        self.images_list = []
+        self.categories_list = []
+        self.annotations_list = []
+    
+    def generate_categories_field(self, label, labels_list):
+        category = {}
+        category["supercategory"] = "component"
+        category["id"] = len(labels_list) + 1
+        category["name"] = label
+        return category
+    
+    def generate_rectangle_anns_field(self, points, label, image_id, object_id, label_to_num):
+        annotation = {}
+        seg_points = np.asarray(points).copy()
+        seg_points[1, :] = np.asarray(points)[2, :]
+        seg_points[2, :] = np.asarray(points)[1, :]
+        annotation["segmentation"] = [list(seg_points.flatten())]
+        annotation["iscrowd"] = 0
+        annotation["image_id"] = image_id + 1
+        annotation["bbox"] = list(
+            map(float, [
+                points[0][0], points[0][1], points[1][0] - points[0][0], points[1][
+                    1] - points[0][1]
+            ]))
+        annotation["area"] = annotation["bbox"][2] * annotation["bbox"][3]
+        annotation["category_id"] = label_to_num[label]
+        annotation["id"] = object_id + 1
+        return annotation
+    
+    def convert(self, image_dir, json_dir, dataset_save_dir):
+        """转换。
+        Args:
+            image_dir (str): 图像文件存放的路径。
+            json_dir (str): 与每张图像对应的json文件的存放路径。
+            dataset_save_dir (str): 转换后数据集存放路径。
+        """
+        assert osp.exists(image_dir), "he image folder does not exist!"
+        assert osp.exists(json_dir), "The json folder does not exist!"
+        assert osp.exists(dataset_save_dir), "The save folder does not exist!"
+        # Convert the image files.
+        new_image_dir = osp.join(dataset_save_dir, "JPEGImages")
+        if osp.exists(new_image_dir):
+            shutil.rmtree(new_image_dir)
+        os.makedirs(new_image_dir)
+        for img_name in os.listdir(image_dir):
+            if is_pic(img_name):
+                shutil.copyfile(
+                            osp.join(image_dir, img_name),
+                            osp.join(new_image_dir, img_name))
+        # Convert the json files.
+        self.parse_json(new_image_dir, json_dir)
+        coco_data = {}
+        coco_data["images"] = self.images_list
+        coco_data["categories"] = self.categories_list
+        coco_data["annotations"] = self.annotations_list
+        json_path = osp.join(dataset_save_dir, "annotations.json")
+        json.dump(
+            coco_data,
+            open(json_path, "w"),
+            indent=4,
+            cls=MyEncoder)
+    
+    
+class LabelMe2COCO(X2COCO):
+    """将使用LabelMe标注的数据集转换为COCO数据集。
+    """
+    def __init__(self):
+        super(LabelMe2COCO, self).__init__()
+        
+    def generate_images_field(self, json_info, image_id):
+        image = {}
+        image["height"] = json_info["imageHeight"]
+        image["width"] = json_info["imageWidth"]
+        image["id"] = image_id + 1
+        image["file_name"] = json_info["imagePath"].split("/")[-1]
+        return image
+    
+    def generate_polygon_anns_field(self, height, width, 
+                                    points, label, image_id, 
+                                    object_id, label_to_num):
+        annotation = {}
+        annotation["segmentation"] = [list(np.asarray(points).flatten())]
+        annotation["iscrowd"] = 0
+        annotation["image_id"] = image_id + 1
+        annotation["bbox"] = list(map(float, get_bbox(height, width, points)))
+        annotation["area"] = annotation["bbox"][2] * annotation["bbox"][3]
+        annotation["category_id"] = label_to_num[label]
+        annotation["id"] = object_id + 1
+        return annotation
+    
+    def get_bbox(self, height, width, points):
+        polygons = points
+        mask = np.zeros([height, width], dtype=np.uint8)
+        mask = PIL.Image.fromarray(mask)
+        xy = list(map(tuple, polygons))
+        PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)
+        mask = np.array(mask, dtype=bool)
+        index = np.argwhere(mask == 1)
+        rows = index[:, 0]
+        clos = index[:, 1]
+        left_top_r = np.min(rows)
+        left_top_c = np.min(clos)
+        right_bottom_r = np.max(rows)
+        right_bottom_c = np.max(clos)
+        return [
+            left_top_c, left_top_r, right_bottom_c - left_top_c,
+            right_bottom_r - left_top_r
+        ]
+    
+    def parse_json(self, img_dir, json_dir):
+        image_id = -1
+        object_id = -1
+        labels_list = []
+        label_to_num = {}
+        for img_file in os.listdir(img_dir):
+            img_name_part = osp.splitext(img_file)[0]
+            json_file = osp.join(json_dir, img_name_part + ".json")
+            if not osp.exists(json_file):
+                os.remove(os.remove(osp.join(image_dir, img_file)))
+                continue
+            image_id = image_id + 1
+            with open(json_file, mode='r', \
+                              encoding=get_encoding(json_file)) as j:
+                json_info = json.load(j)
+                img_info = self.generate_images_field(json_info, image_id)
+                self.images_list.append(img_info)
+                for shapes in json_info["shapes"]:
+                    object_id = object_id + 1
+                    label = shapes["label"]
+                    if label not in labels_list:
+                        self.categories_list.append(\
+                            self.generate_categories_field(label, labels_list))
+                        labels_list.append(label)
+                        label_to_num[label] = len(labels_list)
+                    points = shapes["points"]
+                    p_type = shapes["shape_type"]
+                    if p_type == "polygon":
+                        self.annotations_list.append(
+                            self.generate_polygon_anns_field(json_info["imageHeight"], json_info[
+                                "imageWidth"], points, label, image_id,
+                                                object_id, label_to_num))
+                    if p_type == "rectangle":
+                        points.append([points[0][0], points[1][1]])
+                        points.append([points[1][0], points[0][1]])
+                        self.annotations_list.append(
+                            self.generate_rectangle_anns_field(points, label, image_id,
+                                                  object_id, label_to_num))
+                        
+    
+class EasyData2COCO(X2COCO):
+    """将使用EasyData标注的检测或分割数据集转换为COCO数据集。
+    """
+    def __init__(self):
+        super(EasyData2COCO, self).__init__()        
+    
+    def generate_images_field(self, img_path, image_id):
+        image = {}
+        img = cv2.imread(img_path)
+        image["height"] = img.shape[0]
+        image["width"] = img.shape[1]
+        image["id"] = image_id + 1
+        image["file_name"] = osp.split(img_path)[-1]
+        return image
+    
+    def generate_polygon_anns_field(self, points, segmentation, 
+                                    label, image_id, object_id,
+                                    label_to_num):
+        annotation = {}
+        annotation["segmentation"] = segmentation
+        annotation["iscrowd"] = 1 if len(segmentation) > 1 else 0
+        annotation["image_id"] = image_id + 1
+        annotation["bbox"] = list(map(float, [
+                points[0][0], points[0][1], points[1][0] - points[0][0], points[1][
+                    1] - points[0][1]
+            ]))
+        annotation["area"] = annotation["bbox"][2] * annotation["bbox"][3]
+        annotation["category_id"] = label_to_num[label]
+        annotation["id"] = object_id + 1
+        return annotation
+        
+    def parse_json(self, img_dir, json_dir):
+        from pycocotools.mask import decode
+        image_id = -1
+        object_id = -1
+        labels_list = []
+        label_to_num = {}
+        for img_file in os.listdir(img_dir):
+            img_name_part = osp.splitext(img_file)[0]
+            json_file = osp.join(json_dir, img_name_part + ".json")
+            if not osp.exists(json_file):
+                os.remove(os.remove(osp.join(image_dir, img_file)))
+                continue
+            image_id = image_id + 1
+            with open(json_file, mode='r', \
+                              encoding=get_encoding(json_file)) as j:
+                json_info = json.load(j)
+                img_info = self.generate_images_field(osp.join(img_dir, img_file), image_id)
+                self.images_list.append(img_info)
+                for shapes in json_info["labels"]:
+                    object_id = object_id + 1
+                    label = shapes["name"]
+                    if label not in labels_list:
+                        self.categories_list.append(\
+                            self.generate_categories_field(label, labels_list))
+                        labels_list.append(label)
+                        label_to_num[label] = len(labels_list)
+                    points = [[shapes["x1"], shapes["y1"]],
+                              [shapes["x2"], shapes["y2"]]]
+                    if "mask" not in shapes:
+                        points.append([points[0][0], points[1][1]])
+                        points.append([points[1][0], points[0][1]])
+                        self.annotations_list.append(
+                            self.generate_rectangle_anns_field(points, label, image_id,
+                                                  object_id, label_to_num))
+                    else:
+                        mask_dict = {}
+                        mask_dict['size'] = [img_info["height"], img_info["width"]]
+                        mask_dict['counts'] = shapes['mask'].encode()
+                        mask = decode(mask_dict)
+                        contours, hierarchy = cv2.findContours(
+                                (mask).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+                        segmentation = []
+                        for contour in contours:
+                            contour_list = contour.flatten().tolist()
+                            if len(contour_list) > 4:
+                                segmentation.append(contour_list)
+                        self.annotations_list.append(
+                            self.generate_polygon_anns_field(points, segmentation, label, image_id, object_id,
+                                                label_to_num))

+ 58 - 0
paddlex/tools/x2imagenet.py

@@ -0,0 +1,58 @@
+#!/usr/bin/env python
+# coding: utf-8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import json
+import os
+import os.path as osp
+import shutil
+import numpy as np
+from .base import MyEncoder, is_pic, get_encoding
+
+class EasyData2ImageNet(object):
+    """将使用EasyData标注的分类数据集转换为COCO数据集。
+    """
+    def __init__(self):
+        pass
+    
+    def convert(self, image_dir, json_dir, dataset_save_dir):
+        """转换。
+        Args:
+            image_dir (str): 图像文件存放的路径。
+            json_dir (str): 与每张图像对应的json文件的存放路径。
+            dataset_save_dir (str): 转换后数据集存放路径。
+        """
+        assert osp.exists(image_dir), "The image folder does not exist!"
+        assert osp.exists(json_dir), "The json folder does not exist!"
+        assert osp.exists(dataset_save_dir), "The save folder does not exist!"
+        assert len(os.listdir(dataset_save_dir)) == 0, "The save folder must be empty!"
+        for img_name in os.listdir(image_dir):
+            img_name_part = osp.splitext(img_name)[0]
+            json_file = osp.join(json_dir, img_name_part + ".json")
+            if not osp.exists(json_file):
+                continue
+            with open(json_file, mode="r", \
+                              encoding=get_encoding(json_file)) as j:
+                json_info = json.load(j)
+                for output in json_info['labels']:
+                    cls_name = output['name']
+                    new_image_dir = osp.join(dataset_save_dir, cls_name)
+                    if not osp.exists(new_image_dir):
+                        os.makedirs(new_image_dir)
+                    if is_pic(img_name):
+                        shutil.copyfile(
+                                    osp.join(image_dir, img_name),
+                                    osp.join(new_image_dir, img_name))

+ 332 - 0
paddlex/tools/x2seg.py

@@ -0,0 +1,332 @@
+#!/usr/bin/env python
+# coding: utf-8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import uuid
+import json
+import os
+import os.path as osp
+import shutil
+import numpy as np
+import PIL.Image
+from .base import MyEncoder, is_pic, get_encoding
+
+class X2Seg(object):
+    def __init__(self):
+        self.labels2ids = {'_background_': 0}
+        
+    def shapes_to_label(self, img_shape, shapes, label_name_to_value):
+        # 该函数基于https://github.com/wkentaro/labelme/blob/master/labelme/utils/shape.py实现。
+        def shape_to_mask(img_shape, points, shape_type=None,
+                  line_width=10, point_size=5):
+            mask = np.zeros(img_shape[:2], dtype=np.uint8)
+            mask = PIL.Image.fromarray(mask)
+            draw = PIL.ImageDraw.Draw(mask)
+            xy = [tuple(point) for point in points]
+            if shape_type == 'circle':
+                assert len(xy) == 2, 'Shape of shape_type=circle must have 2 points'
+                (cx, cy), (px, py) = xy
+                d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2)
+                draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1)
+            elif shape_type == 'rectangle':
+                assert len(xy) == 2, 'Shape of shape_type=rectangle must have 2 points'
+                draw.rectangle(xy, outline=1, fill=1)
+            elif shape_type == 'line':
+                assert len(xy) == 2, 'Shape of shape_type=line must have 2 points'
+                draw.line(xy=xy, fill=1, width=line_width)
+            elif shape_type == 'linestrip':
+                draw.line(xy=xy, fill=1, width=line_width)
+            elif shape_type == 'point':
+                assert len(xy) == 1, 'Shape of shape_type=point must have 1 points'
+                cx, cy = xy[0]
+                r = point_size
+                draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=1, fill=1)
+            else:
+                assert len(xy) > 2, 'Polygon must have points more than 2'
+                draw.polygon(xy=xy, outline=1, fill=1)
+            mask = np.array(mask, dtype=bool)
+            return mask
+        cls = np.zeros(img_shape[:2], dtype=np.int32)
+        ins = np.zeros_like(cls)
+        instances = []
+        for shape in shapes:
+            points = shape['points']
+            label = shape['label']
+            group_id = shape.get('group_id')
+            if group_id is None:
+                group_id = uuid.uuid1()
+            shape_type = shape.get('shape_type', None)
+
+            cls_name = label
+            instance = (cls_name, group_id)
+
+            if instance not in instances:
+                instances.append(instance)
+            ins_id = instances.index(instance) + 1
+            cls_id = label_name_to_value[cls_name]
+            mask = shape_to_mask(img_shape[:2], points, shape_type)
+            cls[mask] = cls_id
+            ins[mask] = ins_id
+        return cls, ins
+    
+    def get_color_map_list(self, num_classes):
+        color_map = num_classes * [0, 0, 0]
+        for i in range(0, num_classes):
+            j = 0
+            lab = i
+            while lab:
+                color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
+                color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
+                color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
+                j += 1
+                lab >>= 3
+        return color_map
+    
+    def convert(self, image_dir, json_dir, dataset_save_dir):
+        """转换。
+        Args:
+            image_dir (str): 图像文件存放的路径。
+            json_dir (str): 与每张图像对应的json文件的存放路径。
+            dataset_save_dir (str): 转换后数据集存放路径。
+        """
+        assert osp.exists(image_dir), "The image folder does not exist!"
+        assert osp.exists(json_dir), "The json folder does not exist!"
+        assert osp.exists(dataset_save_dir), "The save folder does not exist!"
+        # Convert the image files.
+        new_image_dir = osp.join(dataset_save_dir, "JPEGImages")
+        if osp.exists(new_image_dir):
+            shutil.rmtree(new_image_dir)
+        os.makedirs(new_image_dir)
+        for img_name in os.listdir(image_dir):
+            if is_pic(img_name):
+                shutil.copyfile(
+                            osp.join(image_dir, img_name),
+                            osp.join(new_image_dir, img_name))
+        # Convert the json files.
+        png_dir = osp.join(dataset_save_dir, "Annotations")
+        if osp.exists(png_dir):
+            shutil.rmtree(png_dir)
+        os.makedirs(png_dir)
+        self.get_labels2ids(new_image_dir, json_dir)
+        self.json2png(new_image_dir, json_dir, png_dir)
+        # Generate the labels.txt
+        ids2labels = {v : k for k, v in self.labels2ids.items()}
+        with open(osp.join(dataset_save_dir, 'labels.txt'), 'w') as fw:
+            for i in range(len(ids2labels)):
+                fw.write(ids2labels[i] + '\n')
+        
+
+class JingLing2Seg(X2Seg):
+    """将使用标注精灵标注的数据集转换为Seg数据集。
+    """
+    def __init__(self):
+        super(JingLing2Seg, self).__init__() 
+        
+    def get_labels2ids(self, image_dir, json_dir):
+        for img_name in os.listdir(image_dir):
+            img_name_part = osp.splitext(img_name)[0]
+            json_file = osp.join(json_dir, img_name_part + ".json")
+            if not osp.exists(json_file):
+                os.remove(os.remove(osp.join(image_dir, img_name)))
+                continue
+            with open(json_file, mode="r", \
+                              encoding=get_encoding(json_file)) as j:
+                json_info = json.load(j)
+                if 'outputs' in json_info:
+                    for output in json_info['outputs']['object']:
+                        cls_name = output['name']
+                        if cls_name not in self.labels2ids:
+                            self.labels2ids[cls_name] =  len(self.labels2ids)
+    
+    def json2png(self, image_dir, json_dir, png_dir):
+        color_map = self.get_color_map_list(256)
+        for img_name in os.listdir(image_dir):
+            img_name_part = osp.splitext(img_name)[0]
+            json_file = osp.join(json_dir, img_name_part + ".json")
+            if not osp.exists(json_file):
+                os.remove(os.remove(osp.join(image_dir, img_name)))
+                continue
+            with open(json_file, mode="r", \
+                              encoding=get_encoding(json_file)) as j:
+                json_info = json.load(j)
+                data_shapes = []
+                if 'outputs' in json_info:
+                    for output in json_info['outputs']['object']:
+                        if 'polygon' in output.keys():
+                            polygon = output['polygon']
+                            name = output['name']
+                            points = []
+                            for i in range(1, int(len(polygon) / 2) + 1):
+                                points.append(
+                                    [polygon['x' + str(i)], polygon['y' + str(i)]])
+                            shape = {
+                                'label': name,
+                                'points': points,
+                                'shape_type': 'polygon'
+                            }
+                            data_shapes.append(shape)
+                if 'size' not in json_info:
+                    continue
+            img_shape = (json_info['size']['height'], 
+                         json_info['size']['width'],
+                         json_info['size']['depth'])
+            lbl, _ = self.shapes_to_label(
+                img_shape=img_shape,
+                shapes=data_shapes,
+                label_name_to_value=self.labels2ids,
+            )
+            out_png_file = osp.join(png_dir, img_name_part + '.png')
+            if lbl.min() >= 0 and lbl.max() <= 255:
+                lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
+                lbl_pil.putpalette(color_map)
+                lbl_pil.save(out_png_file)
+            else:
+                raise ValueError(
+                    '[%s] Cannot save the pixel-wise class label as PNG. '
+                    'Please consider using the .npy format.' % out_png_file)
+                
+                
+class LabelMe2Seg(X2Seg):
+    """将使用LabelMe标注的数据集转换为Seg数据集。
+    """
+    def __init__(self):
+        super(LabelMe2Seg, self).__init__()
+    
+    def get_labels2ids(self, image_dir, json_dir):
+        for img_name in os.listdir(image_dir):
+            img_name_part = osp.splitext(img_name)[0]
+            json_file = osp.join(json_dir, img_name_part + ".json")
+            if not osp.exists(json_file):
+                os.remove(os.remove(osp.join(image_dir, img_name)))
+                continue
+            with open(json_file, mode="r", \
+                              encoding=get_encoding(json_file)) as j:
+                json_info = json.load(j)
+                for shape in json_info['shapes']:
+                    cls_name = shape['label']
+                    if cls_name not in self.labels2ids:
+                        self.labels2ids[cls_name] =  len(self.labels2ids)
+                     
+    def json2png(self, image_dir, json_dir, png_dir):
+        color_map = self.get_color_map_list(256)
+        for img_name in os.listdir(image_dir):
+            img_name_part = osp.splitext(img_name)[0]
+            json_file = osp.join(json_dir, img_name_part + ".json")
+            if not osp.exists(json_file):
+                os.remove(os.remove(osp.join(image_dir, img_name)))
+                continue
+            img_file = osp.join(image_dir, img_name)
+            img = np.asarray(PIL.Image.open(img_file))
+            with open(json_file, mode="r", \
+                              encoding=get_encoding(json_file)) as j:
+                json_info = json.load(j)
+            lbl, _ = self.shapes_to_label(
+                img_shape=img.shape,
+                shapes=json_info['shapes'],
+                label_name_to_value=self.labels2ids,
+            )
+            out_png_file = osp.join(png_dir, img_name_part + '.png')
+            if lbl.min() >= 0 and lbl.max() <= 255:
+                lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
+                lbl_pil.putpalette(color_map)
+                lbl_pil.save(out_png_file)
+            else:
+                raise ValueError(
+                    '[%s] Cannot save the pixel-wise class label as PNG. '
+                    'Please consider using the .npy format.' % out_png_file)
+                
+                            
+class EasyData2Seg(X2Seg):
+    """将使用EasyData标注的分割数据集转换为Seg数据集。
+    """
+    def __init__(self):
+        super(EasyData2Seg, self).__init__()
+    
+    def get_labels2ids(self, image_dir, json_dir):
+        for img_name in os.listdir(image_dir):
+            img_name_part = osp.splitext(img_name)[0]
+            json_file = osp.join(json_dir, img_name_part + ".json")
+            if not osp.exists(json_file):
+                os.remove(os.remove(osp.join(image_dir, img_name)))
+                continue
+            with open(json_file, mode="r", \
+                              encoding=get_encoding(json_file)) as j:
+                json_info = json.load(j)
+                for shape in json_info["labels"]:
+                    cls_name = shape['name']
+                    if cls_name not in self.labels2ids:
+                        self.labels2ids[cls_name] =  len(self.labels2ids)
+                        
+    def mask2polygon(self, mask, label):
+        contours, hierarchy = cv2.findContours(
+            (mask).astype(np.uint8), cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
+        segmentation = []
+        for contour in contours:
+            contour_list = contour.flatten().tolist()
+            if len(contour_list) > 4:
+                points = []
+                for i in range(0, len(contour_list), 2):
+                    points.append(
+                                [contour_list[i], contour_list[i + 1]])
+                shape = {
+                    'label': label,
+                    'points': points,
+                    'shape_type': 'polygon'
+                }
+                segmentation.append(shape)
+        return segmentation
+    
+    def json2png(self, image_dir, json_dir, png_dir):
+        from pycocotools.mask import decode
+        color_map = self.get_color_map_list(256)
+        for img_name in os.listdir(image_dir):
+            img_name_part = osp.splitext(img_name)[0]
+            json_file = osp.join(json_dir, img_name_part + ".json")
+            if not osp.exists(json_file):
+                os.remove(os.remove(osp.join(image_dir, img_name)))
+                continue
+            img_file = osp.join(image_dir, img_name)
+            img = np.asarray(PIL.Image.open(img_file))
+            img_h = img.shape[0]
+            img_w = img.shape[1]
+            with open(json_file, mode="r", \
+                              encoding=get_encoding(json_file)) as j:
+                json_info = json.load(j)
+                data_shapes = []
+                for shape in json_info['labels']:
+                    mask_dict = {}
+                    mask_dict['size'] = [img_h, img_w]
+                    mask_dict['counts'] = shape['mask'].encode()
+                    mask = decode(mask_dict)
+                    polygon = self.mask2polygon(mask, shape["name"])
+                    data_shapes.extend(polygon)
+            lbl, _ = self.shapes_to_label(
+                img_shape=img.shape,
+                shapes=data_shapes,
+                label_name_to_value=self.labels2ids,
+            )
+            out_png_file = osp.join(png_dir, img_name_part + '.png')
+            if lbl.min() >= 0 and lbl.max() <= 255:
+                lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
+                lbl_pil.putpalette(color_map)
+                lbl_pil.save(out_png_file)
+            else:
+                raise ValueError(
+                    '[%s] Cannot save the pixel-wise class label as PNG. '
+                    'Please consider using the .npy format.' % out_png_file)
+            
+
+

+ 199 - 0
paddlex/tools/x2voc.py

@@ -0,0 +1,199 @@
+#!/usr/bin/env python
+# coding: utf-8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import json
+import os
+import os.path as osp
+import shutil
+import numpy as np
+from .base import MyEncoder, is_pic, get_encoding
+
+class X2VOC(object):
+    def __init__(self):
+        pass
+    
+    def convert(self, image_dir, json_dir, dataset_save_dir):
+        """转换。
+        Args:
+            image_dir (str): 图像文件存放的路径。
+            json_dir (str): 与每张图像对应的json文件的存放路径。
+            dataset_save_dir (str): 转换后数据集存放路径。
+        """
+        assert osp.exists(image_dir), "The image folder does not exist!"
+        assert osp.exists(json_dir), "The json folder does not exist!"
+        assert osp.exists(dataset_save_dir), "The save folder does not exist!"
+        # Convert the image files.
+        new_image_dir = osp.join(dataset_save_dir, "JPEGImages")
+        if osp.exists(new_image_dir):
+            shutil.rmtree(new_image_dir)
+        os.makedirs(new_image_dir)
+        for img_name in os.listdir(image_dir):
+            if is_pic(img_name):
+                shutil.copyfile(
+                            osp.join(image_dir, img_name),
+                            osp.join(new_image_dir, img_name))
+        # Convert the json files.
+        xml_dir = osp.join(dataset_save_dir, "Annotations")
+        if osp.exists(xml_dir):
+            shutil.rmtree(xml_dir)
+        os.makedirs(xml_dir)
+        self.json2xml(new_image_dir, json_dir, xml_dir)
+        
+        
+class LabelMe2VOC(X2VOC):
+    """将使用LabelMe标注的数据集转换为VOC数据集。
+    """
+    def __init__(self):
+        pass
+    
+    def json2xml(self, image_dir, json_dir, xml_dir):
+        import xml.dom.minidom as minidom
+        for img_name in os.listdir(image_dir):
+            img_name_part = osp.splitext(img_name)[0]
+            json_file = osp.join(json_dir, img_name_part + ".json")
+            if not osp.exists(json_file):
+                os.remove(os.remove(osp.join(image_dir, img_name)))
+                continue
+            xml_doc = minidom.Document() 
+            root = xml_doc.createElement("annotation") 
+            xml_doc.appendChild(root)
+            node_folder = xml_doc.createElement("folder")
+            node_folder.appendChild(xml_doc.createTextNode("JPEGImages"))
+            root.appendChild(node_folder)
+            node_filename = xml_doc.createElement("filename")
+            node_filename.appendChild(xml_doc.createTextNode(img_name))
+            root.appendChild(node_filename)
+            with open(json_file, mode="r", \
+                              encoding=get_encoding(json_file)) as j:
+                json_info = json.load(j)
+                h = json_info["imageHeight"]
+                w = json_info["imageWidth"]
+                node_size = xml_doc.createElement("size")
+                node_width = xml_doc.createElement("width")
+                node_width.appendChild(xml_doc.createTextNode(str(w)))
+                node_size.appendChild(node_width)
+                node_height = xml_doc.createElement("height")
+                node_height.appendChild(xml_doc.createTextNode(str(h)))
+                node_size.appendChild(node_height)
+                node_depth = xml_doc.createElement("depth")
+                node_depth.appendChild(xml_doc.createTextNode(str(3)))
+                node_size.appendChild(node_depth)
+                root.appendChild(node_size)
+                for shape in json_info["shapes"]:
+                    if shape["shape_type"] != "rectangle":
+                        continue
+                    label = shape["label"]
+                    (xmin, ymin), (xmax, ymax) = shape["points"]
+                    xmin, xmax = sorted([xmin, xmax])
+                    ymin, ymax = sorted([ymin, ymax])
+                    node_obj = xml_doc.createElement("object")
+                    node_name = xml_doc.createElement("name")
+                    node_name.appendChild(xml_doc.createTextNode(label))
+                    node_obj.appendChild(node_name)
+                    node_diff = xml_doc.createElement("difficult")
+                    node_diff.appendChild(xml_doc.createTextNode(str(0)))
+                    node_obj.appendChild(node_diff)
+                    node_box = xml_doc.createElement("bndbox")
+                    node_xmin = xml_doc.createElement("xmin")
+                    node_xmin.appendChild(xml_doc.createTextNode(str(xmin)))
+                    node_box.appendChild(node_xmin)
+                    node_ymin = xml_doc.createElement("ymin")
+                    node_ymin.appendChild(xml_doc.createTextNode(str(ymin)))
+                    node_box.appendChild(node_ymin)
+                    node_xmax = xml_doc.createElement("xmax")
+                    node_xmax.appendChild(xml_doc.createTextNode(str(xmax)))
+                    node_box.appendChild(node_xmax)
+                    node_ymax = xml_doc.createElement("ymax")
+                    node_ymax.appendChild(xml_doc.createTextNode(str(ymax)))
+                    node_box.appendChild(node_ymax)
+                    node_obj.appendChild(node_box)
+                    root.appendChild(node_obj)
+            with open(osp.join(xml_dir, img_name_part + ".xml"), 'w') as fxml:
+                xml_doc.writexml(fxml, indent='\t', addindent='\t', newl='\n', encoding="utf-8")
+                    
+                    
+class EasyData2VOC(X2VOC):
+    """将使用EasyData标注的分割数据集转换为VOC数据集。
+    """
+    def __init__(self):
+        pass
+    
+    def json2xml(self, image_dir, json_dir, xml_dir):
+        import xml.dom.minidom as minidom
+        for img_name in os.listdir(image_dir):
+            img_name_part = osp.splitext(img_name)[0]
+            json_file = osp.join(json_dir, img_name_part + ".json")
+            if not osp.exists(json_file):
+                os.remove(os.remove(osp.join(image_dir, img_name)))
+                continue
+            xml_doc = minidom.Document() 
+            root = xml_doc.createElement("annotation") 
+            xml_doc.appendChild(root)
+            node_folder = xml_doc.createElement("folder")
+            node_folder.appendChild(xml_doc.createTextNode("JPEGImages"))
+            root.appendChild(node_folder)
+            node_filename = xml_doc.createElement("filename")
+            node_filename.appendChild(xml_doc.createTextNode(img_name))
+            root.appendChild(node_filename)
+            img = cv2.imread(osp.join(image_dir, img_name))
+            h = img.shape[0]
+            w = img.shape[1]
+            node_size = xml_doc.createElement("size")
+            node_width = xml_doc.createElement("width")
+            node_width.appendChild(xml_doc.createTextNode(str(w)))
+            node_size.appendChild(node_width)
+            node_height = xml_doc.createElement("height")
+            node_height.appendChild(xml_doc.createTextNode(str(h)))
+            node_size.appendChild(node_height)
+            node_depth = xml_doc.createElement("depth")
+            node_depth.appendChild(xml_doc.createTextNode(str(3)))
+            node_size.appendChild(node_depth)
+            root.appendChild(node_size)
+            with open(json_file, mode="r", \
+                              encoding=get_encoding(json_file)) as j:
+                json_info = json.load(j)
+                for shape in json_info["labels"]:
+                    label = shape["name"]
+                    xmin = shape["x1"]
+                    ymin = shape["y1"]
+                    xmax = shape["x2"]
+                    ymax = shape["y2"]
+                    node_obj = xml_doc.createElement("object")
+                    node_name = xml_doc.createElement("name")
+                    node_name.appendChild(xml_doc.createTextNode(label))
+                    node_obj.appendChild(node_name)
+                    node_diff = xml_doc.createElement("difficult")
+                    node_diff.appendChild(xml_doc.createTextNode(str(0)))
+                    node_obj.appendChild(node_diff)
+                    node_box = xml_doc.createElement("bndbox")
+                    node_xmin = xml_doc.createElement("xmin")
+                    node_xmin.appendChild(xml_doc.createTextNode(str(xmin)))
+                    node_box.appendChild(node_xmin)
+                    node_ymin = xml_doc.createElement("ymin")
+                    node_ymin.appendChild(xml_doc.createTextNode(str(ymin)))
+                    node_box.appendChild(node_ymin)
+                    node_xmax = xml_doc.createElement("xmax")
+                    node_xmax.appendChild(xml_doc.createTextNode(str(xmax)))
+                    node_box.appendChild(node_xmax)
+                    node_ymax = xml_doc.createElement("ymax")
+                    node_ymax.appendChild(xml_doc.createTextNode(str(ymax)))
+                    node_box.appendChild(node_ymax)
+                    node_obj.appendChild(node_box)
+                    root.appendChild(node_obj)
+            with open(osp.join(xml_dir, img_name_part + ".xml"), 'w') as fxml:
+                xml_doc.writexml(fxml, indent='\t', addindent='\t', newl='\n', encoding="utf-8")                    
+                    

+ 115 - 13
paddlex/utils/utils.py

@@ -31,18 +31,7 @@ def seconds_to_hms(seconds):
     return hms_str
     return hms_str
 
 
 
 
-def setting_environ_flags():
-    if 'FLAGS_eager_delete_tensor_gb' not in os.environ:
-        os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
-    if 'FLAGS_allocator_strategy' not in os.environ:
-        os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
-    if "CUDA_VISIBLE_DEVICES" in os.environ:
-        if os.environ["CUDA_VISIBLE_DEVICES"].count("-1") > 0:
-            os.environ["CUDA_VISIBLE_DEVICES"] = ""
-
-
 def get_environ_info():
 def get_environ_info():
-    setting_environ_flags()
     import paddle.fluid as fluid
     import paddle.fluid as fluid
     info = dict()
     info = dict()
     info['place'] = 'cpu'
     info['place'] = 'cpu'
@@ -181,11 +170,85 @@ def load_pdparams(exe, main_prog, model_dir):
             len(vars_to_load), model_dir))
             len(vars_to_load), model_dir))
 
 
 
 
-def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
+def is_persistable(var):
+    import paddle.fluid as fluid
+    from paddle.fluid.proto.framework_pb2 import VarType
+
+    if var.desc.type() == fluid.core.VarDesc.VarType.FEED_MINIBATCH or \
+        var.desc.type() == fluid.core.VarDesc.VarType.FETCH_LIST or \
+        var.desc.type() == fluid.core.VarDesc.VarType.READER:
+        return False
+    return var.persistable
+
+
+def is_belong_to_optimizer(var):
+    import paddle.fluid as fluid
+    from paddle.fluid.proto.framework_pb2 import VarType
+
+    if not (isinstance(var, fluid.framework.Parameter)
+            or var.desc.need_check_feed()):
+        return is_persistable(var)
+    return False
+
+
+def load_pdopt(exe, main_prog, model_dir):
+    import paddle.fluid as fluid
+
+    optimizer_var_list = list()
+    vars_to_load = list()
+    import pickle
+    with open(osp.join(model_dir, 'model.pdopt'), 'rb') as f:
+        opt_dict = pickle.load(f) if six.PY2 else pickle.load(
+            f, encoding='latin1')
+    optimizer_var_list = list(
+        filter(is_belong_to_optimizer, main_prog.list_vars()))
+    exception_message = "the training process can not be resumed due to optimizer set now and last time is different. Recommend to use `pretrain_weights` instead of `resume_checkpoint`"
+    if len(optimizer_var_list) > 0:
+        for var in optimizer_var_list:
+            if var.name not in opt_dict:
+                raise Exception(
+                    "{} is not in saved paddlex optimizer, {}".format(
+                        var.name, exception_message))
+            if var.shape != opt_dict[var.name].shape:
+                raise Exception(
+                    "Shape of optimizer variable {} doesn't match.(Last: {}, Now: {}), {}"
+                    .format(var.name, opt_dict[var.name].shape,
+                            var.shape), exception_message)
+        optimizer_varname_list = [var.name for var in optimizer_var_list]
+        for k, v in opt_dict.items():
+            if k not in optimizer_varname_list:
+                raise Exception(
+                    "{} in saved paddlex optimizer is not in the model, {}".
+                    format(k, exception_message))
+        fluid.io.set_program_state(main_prog, opt_dict)
+
+    if len(optimizer_var_list) == 0:
+        raise Exception(
+            "There is no optimizer parameters in the model, please set the optimizer!"
+        )
+    else:
+        logging.info(
+            "There are {} optimizer parameters in {} are loaded.".format(
+                len(optimizer_var_list), model_dir))
+
+
+def load_pretrain_weights(exe,
+                          main_prog,
+                          weights_dir,
+                          fuse_bn=False,
+                          resume=False):
     if not osp.exists(weights_dir):
     if not osp.exists(weights_dir):
         raise Exception("Path {} not exists.".format(weights_dir))
         raise Exception("Path {} not exists.".format(weights_dir))
     if osp.exists(osp.join(weights_dir, "model.pdparams")):
     if osp.exists(osp.join(weights_dir, "model.pdparams")):
-        return load_pdparams(exe, main_prog, weights_dir)
+        load_pdparams(exe, main_prog, weights_dir)
+        if resume:
+            if osp.exists(osp.join(weights_dir, "model.pdopt")):
+                load_pdopt(exe, main_prog, weights_dir)
+            else:
+                raise Exception(
+                    "Optimizer file {} does not exist. Stop resumming training. Recommend to use `pretrain_weights` instead of `resume_checkpoint`"
+                    .format(osp.join(weights_dir, "model.pdopt")))
+        return
     import paddle.fluid as fluid
     import paddle.fluid as fluid
     vars_to_load = list()
     vars_to_load = list()
     for var in main_prog.list_vars():
     for var in main_prog.list_vars():
@@ -220,6 +283,45 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
             len(vars_to_load), weights_dir))
             len(vars_to_load), weights_dir))
     if fuse_bn:
     if fuse_bn:
         fuse_bn_weights(exe, main_prog, weights_dir)
         fuse_bn_weights(exe, main_prog, weights_dir)
+    if resume:
+        exception_message = "the training process can not be resumed due to optimizer set now and last time is different. Recommend to use `pretrain_weights` instead of `resume_checkpoint`"
+        optimizer_var_list = list(
+            filter(is_belong_to_optimizer, main_prog.list_vars()))
+        if len(optimizer_var_list) > 0:
+            for var in optimizer_var_list:
+                if not osp.exists(osp.join(weights_dir, var.name)):
+                    raise Exception(
+                        "Optimizer parameter {} doesn't exist, {}".format(
+                            osp.join(weights_dir, var.name),
+                            exception_message))
+                pretrained_shape = parse_param_file(
+                    osp.join(weights_dir, var.name))
+                actual_shape = tuple(var.shape)
+                if pretrained_shape != actual_shape:
+                    raise Exception(
+                        "Shape of optimizer variable {} doesn't match.(Last: {}, Now: {}), {}"
+                        .format(var.name, pretrained_shape,
+                                actual_shape), exception_message)
+            optimizer_varname_list = [var.name for var in optimizer_var_list]
+            if os.exists(osp.join(weights_dir, 'learning_rate')
+                         ) and 'learning_rate' not in optimizer_varname_list:
+                raise Exception(
+                    "Optimizer parameter {}/learning_rate is not in the model, {}"
+                    .format(weights_dir, exception_message))
+            fluid.io.load_vars(
+                executor=exe,
+                dirname=weights_dir,
+                main_program=main_prog,
+                vars=optimizer_var_list)
+
+        if len(optimizer_var_list) == 0:
+            raise Exception(
+                "There is no optimizer parameters in the model, please set the optimizer!"
+            )
+        else:
+            logging.info(
+                "There are {} optimizer parameters in {} are loaded.".format(
+                    len(optimizer_var_list), weights_dir))
 
 
 
 
 class EarlyStop:
 class EarlyStop:

+ 3 - 4
setup.py

@@ -19,7 +19,7 @@ long_description = "PaddleX. A end-to-end deeplearning model development toolkit
 
 
 setuptools.setup(
 setuptools.setup(
     name="paddlex",
     name="paddlex",
-    version='0.1.6',
+    version='0.1.7',
     author="paddlex",
     author="paddlex",
     author_email="paddlex@baidu.com",
     author_email="paddlex@baidu.com",
     description=long_description,
     description=long_description,
@@ -29,9 +29,8 @@ setuptools.setup(
     packages=setuptools.find_packages(),
     packages=setuptools.find_packages(),
     setup_requires=['cython', 'numpy', 'sklearn'],
     setup_requires=['cython', 'numpy', 'sklearn'],
     install_requires=[
     install_requires=[
-        "pycocotools;platform_system!='Windows'", 
-        'pyyaml', 'colorama', 'tqdm', 'visualdl==1.3.0',
-        'paddleslim==1.0.1', 'paddlehub>=1.6.2'
+        "pycocotools;platform_system!='Windows'", 'pyyaml', 'colorama', 'tqdm',
+        'visualdl==1.3.0', 'paddleslim==1.0.1'
     ],
     ],
     classifiers=[
     classifiers=[
         "Programming Language :: Python :: 3",
         "Programming Language :: Python :: 3",