Browse Source

Merge pull request #23 from Channingss/cpp_trt

support deploy with TensorRT
Jason 5 years ago
parent
commit
76b49aa7eb

+ 3 - 1
deploy/README.md

@@ -1,3 +1,5 @@
 # 模型部署
 
-本目录为PaddleX模型部署代码。
+本目录为PaddleX模型部署代码, 编译和使用的教程参考:
+
+- [C++部署文档](../docs/deploy/deploy.md#C部署)

+ 8 - 8
deploy/cpp/CMakeLists.txt

@@ -3,9 +3,10 @@ project(PaddleX CXX C)
 
 option(WITH_MKL        "Compile demo with MKL/OpenBlas support,defaultuseMKL."          ON)
 option(WITH_GPU        "Compile demo with GPU/CPU, default use CPU."                    ON)
-option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static."   ON)
+option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static."   OFF)
 option(WITH_TENSORRT "Compile demo with TensorRT."   OFF)
 
+SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
 SET(PADDLE_DIR "" CACHE PATH "Location of libraries")
 SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
 SET(CUDA_LIB "" CACHE PATH "Location of libraries")
@@ -111,8 +112,8 @@ endif()
 
 if (NOT WIN32)
   if (WITH_TENSORRT AND WITH_GPU)
-      include_directories("${PADDLE_DIR}/third_party/install/tensorrt/include")
-      link_directories("${PADDLE_DIR}/third_party/install/tensorrt/lib")
+      include_directories("${TENSORRT_DIR}/include")
+      link_directories("${TENSORRT_DIR}/lib")
   endif()
 endif(NOT WIN32)
 
@@ -169,7 +170,7 @@ endif()
 
 if (NOT WIN32)
     set(DEPS ${DEPS}
-        ${MATH_LIB} ${MKLDNN_LIB} 
+        ${MATH_LIB} ${MKLDNN_LIB}
         glog gflags protobuf z xxhash yaml-cpp
         )
     if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
@@ -194,8 +195,8 @@ endif(NOT WIN32)
 if(WITH_GPU)
   if(NOT WIN32)
     if (WITH_TENSORRT)
-      set(DEPS ${DEPS} ${PADDLE_DIR}/third_party/install/tensorrt/lib/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX})
-      set(DEPS ${DEPS} ${PADDLE_DIR}/third_party/install/tensorrt/lib/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX})
+      set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX})
+      set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX})
     endif()
     set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
     set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX})
@@ -211,7 +212,7 @@ if (NOT WIN32)
     set(DEPS ${DEPS} ${EXTERNAL_LIB})
 endif()
 
-set(DEPS ${DEPS} ${OpenCV_LIBS}) 
+set(DEPS ${DEPS} ${OpenCV_LIBS})
 add_executable(classifier src/classifier.cpp src/transforms.cpp src/paddlex.cpp)
 ADD_DEPENDENCIES(classifier ext-yaml-cpp)
 target_link_libraries(classifier ${DEPS})
@@ -251,4 +252,3 @@ if (WIN32 AND WITH_MKL)
     )
 
 endif()
-

+ 3 - 1
deploy/cpp/include/paddlex/paddlex.h

@@ -38,12 +38,14 @@ class Model {
  public:
   void Init(const std::string& model_dir,
             bool use_gpu = false,
+            bool use_trt = false,
             int gpu_id = 0) {
-    create_predictor(model_dir, use_gpu, gpu_id);
+    create_predictor(model_dir, use_gpu, use_trt, gpu_id);
   }
 
   void create_predictor(const std::string& model_dir,
                         bool use_gpu = false,
+                        bool use_trt = false,
                         int gpu_id = 0);
 
   bool load_config(const std::string& model_dir);

+ 4 - 6
deploy/cpp/include/paddlex/transforms.h

@@ -35,10 +35,8 @@ class ImageBlob {
   std::vector<int> ori_im_size_ = std::vector<int>(2);
   // Newest image height and width after process
   std::vector<int> new_im_size_ = std::vector<int>(2);
-  // Image height and width before padding
-  std::vector<int> im_size_before_padding_ = std::vector<int>(2);
   // Image height and width before resize
-  std::vector<int> im_size_before_resize_ = std::vector<int>(2);
+  std::vector<std::vector<int>> im_size_before_resize_;
   // Reshape order
   std::vector<std::string> reshape_order_;
   // Resize scale
@@ -49,7 +47,6 @@ class ImageBlob {
   void clear() {
     ori_im_size_.clear();
     new_im_size_.clear();
-    im_size_before_padding_.clear();
     im_size_before_resize_.clear();
     reshape_order_.clear();
     im_data_.clear();
@@ -155,12 +152,13 @@ class Padding : public Transform {
   virtual void Init(const YAML::Node& item) {
     if (item["coarsest_stride"].IsDefined()) {
       coarsest_stride_ = item["coarsest_stride"].as<int>();
-      if (coarsest_stride_ <= 1) {
+      if (coarsest_stride_ < 1) {
         std::cerr << "[Padding] coarest_stride should greater than 0"
                   << std::endl;
         exit(-1);
       }
-    } else {
+    }
+    if (item["target_size"].IsDefined()) {
       if (item["target_size"].IsScalar()) {
         width_ = item["target_size"].as<int>();
         height_ = item["target_size"].as<int>();

+ 11 - 1
deploy/cpp/scripts/build.sh

@@ -1,9 +1,16 @@
 # 是否使用GPU(即是否使用 CUDA)
-WITH_GPU=ON
+WITH_GPU=OFF
+# 使用MKL or openblas
+WITH_MKL=ON
 # 是否集成 TensorRT(仅WITH_GPU=ON 有效)
 WITH_TENSORRT=OFF
+# TensorRT 的lib路径
+TENSORRT_DIR=/path/to/TensorRT/
 # Paddle 预测库路径
 PADDLE_DIR=/path/to/fluid_inference/
+# Paddle 的预测库是否使用静态库来编译
+# 使用TensorRT时,Paddle的预测库通常为动态库
+WITH_STATIC_LIB=OFF
 # CUDA 的 lib 路径
 CUDA_LIB=/path/to/cuda/lib/
 # CUDNN 的 lib 路径
@@ -19,8 +26,11 @@ mkdir -p build
 cd build
 cmake .. \
     -DWITH_GPU=${WITH_GPU} \
+    -DWITH_MKL=${WITH_MKL} \
     -DWITH_TENSORRT=${WITH_TENSORRT} \
+    -DTENSORRT_DIR=${TENSORRT_DIR} \
     -DPADDLE_DIR=${PADDLE_DIR} \
+    -DWITH_STATIC_LIB=${WITH_STATIC_LIB} \
     -DCUDA_LIB=${CUDA_LIB} \
     -DCUDNN_LIB=${CUDNN_LIB} \
     -DOPENCV_DIR=${OPENCV_DIR}

+ 2 - 1
deploy/cpp/src/classifier.cpp

@@ -23,6 +23,7 @@
 
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
+DEFINE_bool(use_trt, false, "Infering with TensorRT");
 DEFINE_int32(gpu_id, 0, "GPU card id");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
@@ -42,7 +43,7 @@ int main(int argc, char** argv) {
 
   // 加载模型
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_gpu_id);
+  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id);
 
   // 进行预测
   if (FLAGS_image_list != "") {

+ 4 - 3
deploy/cpp/src/detector.cpp

@@ -24,6 +24,7 @@
 
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
+DEFINE_bool(use_trt, false, "Infering with TensorRT");
 DEFINE_int32(gpu_id, 0, "GPU card id");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
@@ -44,7 +45,7 @@ int main(int argc, char** argv) {
 
   // 加载模型
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_gpu_id);
+  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id);
 
   auto colormap = PaddleX::GenerateColorMap(model.labels.size());
   std::string save_dir = "output";
@@ -68,7 +69,7 @@ int main(int argc, char** argv) {
                   << result.boxes[i].coordinate[0] << ", "
                   << result.boxes[i].coordinate[1] << ", "
                   << result.boxes[i].coordinate[2] << ", "
-                  << result.boxes[i].coordinate[3] << std::endl;
+                  << result.boxes[i].coordinate[3] << ")" << std::endl;
       }
 
       // 可视化
@@ -91,7 +92,7 @@ int main(int argc, char** argv) {
                 << result.boxes[i].coordinate[0] << ", "
                 << result.boxes[i].coordinate[1] << ", "
                 << result.boxes[i].coordinate[2] << ", "
-                << result.boxes[i].coordinate[3] << std::endl;
+                << result.boxes[i].coordinate[3] << ")" << std::endl;
     }
 
     // 可视化

+ 21 - 6
deploy/cpp/src/paddlex.cpp

@@ -18,6 +18,7 @@ namespace PaddleX {
 
 void Model::create_predictor(const std::string& model_dir,
                              bool use_gpu,
+                             bool use_trt,
                              int gpu_id) {
   // 读取配置文件
   if (!load_config(model_dir)) {
@@ -37,6 +38,15 @@ void Model::create_predictor(const std::string& model_dir,
   config.SwitchSpecifyInputNames(true);
   // 开启内存优化
   config.EnableMemoryOptim();
+  if (use_trt) {
+    config.EnableTensorRtEngine(
+        1 << 20 /* workspace_size*/,
+        32 /* max_batch_size*/,
+        20 /* min_subgraph_size*/,
+        paddle::AnalysisConfig::Precision::kFloat32 /* precision*/,
+        true /* use_static*/,
+        false /* use_calib_mode*/);
+  }
   predictor_ = std::move(CreatePaddlePredictor(config));
 }
 
@@ -246,7 +256,6 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
   auto im_tensor = predictor_->GetInputTensor("image");
   im_tensor->Reshape({1, 3, h, w});
   im_tensor->copy_from_cpu(inputs_.im_data_.data());
-  std::cout << "input image: " << h << " " << w << std::endl;
 
   // 使用加载的模型进行预测
   predictor_->ZeroCopyRun();
@@ -286,19 +295,24 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
                      result->score_map.shape[3],
                      CV_32FC1,
                      result->score_map.data.data());
-
+  int idx = 1;
+  int len_postprocess = inputs_.im_size_before_resize_.size();
   for (std::vector<std::string>::reverse_iterator iter =
            inputs_.reshape_order_.rbegin();
        iter != inputs_.reshape_order_.rend();
        ++iter) {
     if (*iter == "padding") {
-      auto padding_w = inputs_.im_size_before_padding_[0];
-      auto padding_h = inputs_.im_size_before_padding_[1];
+      auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
+      inputs_.im_size_before_resize_.pop_back();
+      auto padding_w = before_shape[0];
+      auto padding_h = before_shape[1];
       mask_label = mask_label(cv::Rect(0, 0, padding_w, padding_h));
       mask_score = mask_score(cv::Rect(0, 0, padding_w, padding_h));
     } else if (*iter == "resize") {
-      auto resize_w = inputs_.im_size_before_resize_[0];
-      auto resize_h = inputs_.im_size_before_resize_[1];
+      auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
+      inputs_.im_size_before_resize_.pop_back();
+      auto resize_w = before_shape[0];
+      auto resize_h = before_shape[1];
       cv::resize(mask_label,
                  mask_label,
                  cv::Size(resize_h, resize_w),
@@ -312,6 +326,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
                  0,
                  cv::INTER_NEAREST);
     }
+    ++idx;
   }
   result->label_map.data.assign(mask_label.begin<uint8_t>(),
                                 mask_label.end<uint8_t>());

+ 3 - 1
deploy/cpp/src/segmenter.cpp

@@ -24,6 +24,7 @@
 
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
+DEFINE_bool(use_trt, false, "Infering with TensorRT");
 DEFINE_int32(gpu_id, 0, "GPU card id");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
@@ -44,7 +45,8 @@ int main(int argc, char** argv) {
 
   // 加载模型
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_gpu_id);
+  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id);
+
   auto colormap = PaddleX::GenerateColorMap(model.labels.size());
   // 进行预测
   if (FLAGS_image_list != "") {

+ 7 - 10
deploy/cpp/src/transforms.cpp

@@ -56,8 +56,7 @@ float ResizeByShort::GenerateScale(const cv::Mat& im) {
 }
 
 bool ResizeByShort::Run(cv::Mat* im, ImageBlob* data) {
-  data->im_size_before_resize_[0] = im->rows;
-  data->im_size_before_resize_[1] = im->cols;
+  data->im_size_before_resize_.push_back({im->rows, im->cols});
   data->reshape_order_.push_back("resize");
 
   float scale = GenerateScale(*im);
@@ -88,21 +87,21 @@ bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
 }
 
 bool Padding::Run(cv::Mat* im, ImageBlob* data) {
-  data->im_size_before_padding_[0] = im->rows;
-  data->im_size_before_padding_[1] = im->cols;
+  data->im_size_before_resize_.push_back({im->rows, im->cols});
   data->reshape_order_.push_back("padding");
 
   int padding_w = 0;
   int padding_h = 0;
-  if (width_ > 0 & height_ > 0) {
+  if (width_ > 1 & height_ > 1) {
     padding_w = width_ - im->cols;
     padding_h = height_ - im->rows;
-  } else if (coarsest_stride_ > 0) {
+  } else if (coarsest_stride_ > 1) {
     padding_h =
         ceil(im->rows * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
     padding_w =
         ceil(im->cols * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
   }
+
   if (padding_h < 0 || padding_w < 0) {
     std::cerr << "[Padding] Computed padding_h=" << padding_h
               << ", padding_w=" << padding_w
@@ -122,8 +121,7 @@ bool ResizeByLong::Run(cv::Mat* im, ImageBlob* data) {
               << std::endl;
     return false;
   }
-  data->im_size_before_resize_[0] = im->rows;
-  data->im_size_before_resize_[1] = im->cols;
+  data->im_size_before_resize_.push_back({im->rows, im->cols});
   data->reshape_order_.push_back("resize");
   int origin_w = im->cols;
   int origin_h = im->rows;
@@ -149,8 +147,7 @@ bool Resize::Run(cv::Mat* im, ImageBlob* data) {
               << std::endl;
     return false;
   }
-  data->im_size_before_resize_[0] = im->rows;
-  data->im_size_before_resize_[1] = im->cols;
+  data->im_size_before_resize_.push_back({im->rows, im->cols});
   data->reshape_order_.push_back("resize");
 
   cv::resize(

+ 14 - 5
docs/deploy/deploy.md

@@ -7,20 +7,29 @@
 ### 导出inference模型
 
 在服务端部署的模型需要首先将模型导出为inference格式模型,导出的模型将包括`__model__`、`__params__`和`model.yml`三个文名,分别为模型的网络结构,模型权重和模型的配置文件(包括数据预处理参数等等)。在安装完PaddleX后,在命令行终端使用如下命令导出模型到当前目录`inferece_model`下。
+> 可直接下载小度熊分拣模型测试本文档的流程[xiaoduxiong_epoch_12.tar.gz](https://bj.bcebos.com/paddlex/models/xiaoduxiong_epoch_12.tar.gz)
 
-> 可直接下载垃圾检测模型测试本文档的流程[garbage_epoch_12.tar.gz](https://bj.bcebos.com/paddlex/models/garbage_epoch_12.tar.gz)
+```
+paddlex --export_inference --model_dir=./xiaoduxiong_epoch_12 --save_dir=./inference_model
+```
+
+使用TensorRT预测时,需指定模型的图像输入shape:[w,h]。
+**注**:
+- 分类模型请保持于训练时输入的shape一致。
+- 指定[w,h]时,w和h中间逗号隔开,不允许存在空格等其他字符
 
 ```
-paddlex --export_inference --model_dir=./garbage_epoch_12 --save_dir=./inference_model
+paddlex --export_inference --model_dir=./xiaoduxiong_epoch_12 --save_dir=./inference_model --fixed_input_shape=[640,960]
 ```
 
 ### Python部署
 PaddleX已经集成了基于Python的高性能预测接口,在安装PaddleX后,可参照如下代码示例,进行预测。相关的接口文档可参考[paddlex.deploy](apis/deploy.md)
-> 点击下载测试图片 [garbage.bmp](https://bj.bcebos.com/paddlex/datasets/garbage.bmp)
+> 点击下载测试图片 [xiaoduxiong_test_image.tar.gz](https://bj.bcebos.com/paddlex/datasets/xiaoduxiong_test_image.tar.gz)
+
 ```
 import paddlex as pdx
-predictorpdx.deploy.create_predictor('./inference_model')
-result = predictor.predict(image='garbage.bmp')
+predictor = pdx.deploy.create_predictor('./inference_model')
+result = predictor.predict(image='xiaoduxiong_test_image/JPEGImages/WeChatIMG110.jpeg')
 ```
 
 ### C++部署

+ 34 - 14
docs/deploy/deploy_cpp_linux.md

@@ -19,8 +19,18 @@
 
 ### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference
 
-PaddlePaddle C++ 预测库针对不同的`CPU`,`CUDA`,以及是否支持TensorRT,提供了不同的预编译版本,请根据实际情况下载:  [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html#id1)
+PaddlePaddle C++ 预测库针对不同的`CPU`,`CUDA`,以及是否支持TensorRT,提供了不同的预编译版本,目前PaddleX依赖于Paddle1.7版本,以下提供了多个不同版本的Paddle预测库:
 
+|  版本说明   | 预测库(1.7.2版本)  |
+|  ----  | ----  |
+| ubuntu14.04_cpu_avx_mkl  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-cpu-avx-mkl/fluid_inference.tgz) |
+| ubuntu14.04_cpu_avx_openblas  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-cpu-avx-openblas/fluid_inference.tgz) |
+| ubuntu14.04_cpu_noavx_openblas  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-cpu-noavx-openblas/fluid_inference.tgz) |
+| ubuntu14.04_cuda9.0_cudnn7_avx_mkl  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-gpu-cuda9-cudnn7-avx-mkl/fluid_inference.tgz) |
+| ubuntu14.04_cuda10.0_cudnn7_avx_mkl  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-gpu-cuda10-cudnn7-avx-mkl/fluid_inference.tgz ) |
+| ubuntu14.04_cuda10.1_cudnn7.6_avx_mkl_trt6  | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-gpu-cuda10.1-cudnn7.6-avx-mkl-trt6%2Ffluid_inference.tgz) |
+
+更多和更新的版本,请根据实际情况下载:  [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/windows_cpp_inference.html#id1)
 
 下载并解压后`/root/projects/fluid_inference`目录包含内容为:
 ```
@@ -40,17 +50,24 @@ fluid_inference
 编译`cmake`的命令在`scripts/build.sh`中,请根据实际情况修改主要参数,其主要内容说明如下:
 ```
 # 是否使用GPU(即是否使用 CUDA)
-WITH_GPU=ON
+WITH_GPU=OFF
+# 使用MKL or openblas
+WITH_MKL=ON
 # 是否集成 TensorRT(仅WITH_GPU=ON 有效)
 WITH_TENSORRT=OFF
-# 上一步下载的 Paddle 预测库路径
-PADDLE_DIR=/root/projects/deps/fluid_inference/
+# TensorRT 的lib路径
+TENSORRT_DIR=/path/to/TensorRT/
+# Paddle 预测库路径
+PADDLE_DIR=/path/to/fluid_inference/
+# Paddle 的预测库是否使用静态库来编译
+# 使用TensorRT时,Paddle的预测库通常为动态库
+WITH_STATIC_LIB=ON
 # CUDA 的 lib 路径
-CUDA_LIB=/usr/local/cuda/lib64/
+CUDA_LIB=/path/to/cuda/lib/
 # CUDNN 的 lib 路径
-CUDNN_LIB=/usr/local/cudnn/lib64/
+CUDNN_LIB=/path/to/cudnn/lib/
 
-# OPENCV 路径, 如果使用自带预编译版本可不设置
+# OPENCV 路径, 如果使用自带预编译版本可不修改
 OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/
 sh $(pwd)/scripts/bootstrap.sh
 
@@ -60,8 +77,11 @@ mkdir -p build
 cd build
 cmake .. \
     -DWITH_GPU=${WITH_GPU} \
+    -DWITH_MKL=${WITH_MKL} \
     -DWITH_TENSORRT=${WITH_TENSORRT} \
+    -DTENSORRT_DIR=${TENSORRT_DIR} \
     -DPADDLE_DIR=${PADDLE_DIR} \
+    -DWITH_STATIC_LIB=${WITH_STATIC_LIB} \
     -DCUDA_LIB=${CUDA_LIB} \
     -DCUDNN_LIB=${CUDNN_LIB} \
     -DOPENCV_DIR=${OPENCV_DIR}
@@ -83,19 +103,20 @@ make
 | image  | 要预测的图片文件路径 |
 | image_list  | 按行存储图片路径的.txt文件 |
 | use_gpu  | 是否使用 GPU 预测, 支持值为0或1(默认值为0) |
+| use_trt  | 是否使用 TensorTr 预测, 支持值为0或1(默认值为0) |
 | gpu_id  | GPU 设备ID, 默认值为0 |
 | save_dir | 保存可视化结果的路径, 默认值为"output",classfier无该参数 |
 
 ## 样例
 
-可使用[垃圾检测模型](deploy.md#导出inference模型)中生成的`inference_model`模型和测试图片进行预测。
+可使用[小度熊识别模型](deploy.md#导出inference模型)中导出的`inference_model`和测试图片进行预测。
 
 `样例一`:
 
-不使用`GPU`测试图片 `/path/to/garbage.bmp`  
+不使用`GPU`测试图片 `/path/to/xiaoduxiong.jpeg`  
 
 ```shell
-./build/detector --model_dir=/path/to/inference_model --image=/path/to/garbage.bmp --save_dir=output
+./build/detector --model_dir=/path/to/inference_model --image=/path/to/xiaoduxiong.jpeg --save_dir=output
 ```
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
 
@@ -104,13 +125,12 @@ make
 
 使用`GPU`预测多个图片`/path/to/image_list.txt`,image_list.txt内容的格式如下:
 ```
-/path/to/images/garbage1.jpeg
-/path/to/images/garbage2.jpeg
+/path/to/images/xiaoduxiong1.jpeg
+/path/to/images/xiaoduxiong2.jpeg
 ...
-/path/to/images/garbagen.jpeg
+/path/to/images/xiaoduxiongn.jpeg
 ```
 ```shell
 ./build/detector --model_dir=/path/to/models/inference_model --image_list=/root/projects/images_list.txt --use_gpu=1 --save_dir=output
 ```
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
-

+ 18 - 8
docs/deploy/deploy_cpp_win_vs2019.md

@@ -27,7 +27,18 @@ git clone https://github.com/PaddlePaddle/PaddleX.git
 
 ### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference
 
-PaddlePaddle C++ 预测库针对不同的`CPU`和`CUDA`版本提供了不同的预编译版本,请根据实际情况下载:  [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/windows_cpp_inference.html)
+PaddlePaddle C++ 预测库针对不同的`CPU`,`CUDA`,以及是否支持TensorRT,提供了不同的预编译版本,目前PaddleX依赖于Paddle1.7版本,以下提供了多个不同版本的Paddle预测库:
+
+|  版本说明   | 预测库(1.7.2版本)  | 编译器 | 构建工具| cuDNN | CUDA
+|  ----  |  ----  |  ----  |  ----  | ---- | ---- |
+| cpu_avx_mkl  | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/mkl/cpu/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 |
+| cpu_avx_openblas  | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/open/cpu/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 |
+| cuda9.0_cudnn7_avx_mkl  | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/mkl/post97/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 | 7.4.1 | 9.0 |
+| cuda9.0_cudnn7_avx_openblas  | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/open/post97/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 | 7.4.1 | 9.0 |
+| cuda10.0_cudnn7_avx_mkl  | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/mkl/post107/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 | 7.5.0 | 9.0 |
+
+
+更多和更新的版本,请根据实际情况下载:  [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html#id1)
 
 解压后`D:\projects\fluid_inference*\`目录下主要包含的内容为:
 ```
@@ -109,14 +120,14 @@ cd D:\projects\PaddleX\deploy\cpp\out\build\x64-Release
 
 ## 样例
 
-可使用[垃圾检测模型](deploy.md#导出inference模型)中生成的`inference_model`模型和测试图片进行预测。
+可使用[小度熊识别模型](deploy.md#导出inference模型)中导出的`inference_model`和测试图片进行预测。
 
 `样例一`:
 
-不使用`GPU`测试图片  `\\path\\to\\garbage.bmp`  
+不使用`GPU`测试图片  `\\path\\to\\xiaoduxiong.jpeg`  
 
 ```shell
-.\detector --model_dir=\\path\\to\\inference_model --image=D:\\images\\garbage.bmp --save_dir=output
+.\detector --model_dir=\\path\\to\\inference_model --image=D:\\images\\xiaoduxiong.jpeg --save_dir=output
 
 ```
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
@@ -126,13 +137,12 @@ cd D:\projects\PaddleX\deploy\cpp\out\build\x64-Release
 
 使用`GPU`预测多个图片`\\path\\to\\image_list.txt`,image_list.txt内容的格式如下:
 ```
-\\path\\to\\images\\garbage1.jpeg
-\\path\\to\\images\\garbage2.jpeg
+\\path\\to\\images\\xiaoduxiong1.jpeg
+\\path\\to\\images\\xiaoduxiong2.jpeg
 ...
-\\path\\to\\images\\garbagen.jpeg
+\\path\\to\\images\\xiaoduxiongn.jpeg
 ```
 ```shell
 .\detector --model_dir=\\path\\to\\inference_model --image_list=\\path\\to\\images_list.txt --use_gpu=1 --save_dir=output
 ```
 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
-

+ 20 - 2
paddlex/command.py

@@ -29,7 +29,11 @@ def arg_parser():
         action="store_true",
         default=False,
         help="export inference model for C++/Python deployment")
-
+    parser.add_argument(
+        "--fixed_input_shape",
+        "-fs",
+        default=None,
+        help="export inference model with fixed input shape:[w,h]")
     return parser
 
 
@@ -53,9 +57,23 @@ def main():
     if args.export_inference:
         assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
         assert args.save_dir is not None, "--save_dir should be defined to save inference model"
-        model = pdx.load_model(args.model_dir)
+        fixed_input_shape = eval(args.fixed_input_shape)
+        assert len(
+            fixed_input_shape) == 2, "len of fixed input shape must == 2"
+
+        model = pdx.load_model(args.model_dir, fixed_input_shape)
         model.export_inference_model(args.save_dir)
 
+    if args.export_onnx:
+        assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
+        assert args.save_dir is not None, "--save_dir should be defined to save onnx model"
+        fixed_input_shape = eval(args.fixed_input_shape)
+        assert len(
+            fixed_input_shape) == 2, "len of fixed input shape must == 2"
+
+        model = pdx.load_model(args.model_dir, fixed_input_shape)
+        model.export_onnx_model(args.save_dir)
+
 
 if __name__ == "__main__":
     main()

+ 1 - 1
paddlex/cv/models/base.py

@@ -317,11 +317,11 @@ class BaseAPI:
         model_info['_ModelInputsOutputs']['test_outputs'] = [
             [k, v.name] for k, v in self.test_outputs.items()
         ]
-
         with open(
                 osp.join(save_dir, 'model.yml'), encoding='utf-8',
                 mode='w') as f:
             yaml.dump(model_info, f)
+
         # 模型保存成功的标志
         open(osp.join(save_dir, '.success'), 'w').close()
         logging.info(

+ 10 - 2
paddlex/cv/models/classifier.py

@@ -46,10 +46,18 @@ class BaseClassifier(BaseAPI):
         self.model_name = model_name
         self.labels = None
         self.num_classes = num_classes
+        self.fixed_input_shape = None
 
     def build_net(self, mode='train'):
-        image = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            image = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            image = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if mode != 'test':
             label = fluid.data(dtype='int64', shape=[None, 1], name='label')
         model = getattr(paddlex.cv.nets, str.lower(self.model_name))

+ 3 - 2
paddlex/cv/models/deeplabv3p.py

@@ -48,7 +48,6 @@ class DeepLabv3p(BaseAPI):
             自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1,
             即平时使用的交叉熵损失函数。
         ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
-
     Raises:
         ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
         ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25',
@@ -118,6 +117,7 @@ class DeepLabv3p(BaseAPI):
         self.enable_decoder = enable_decoder
         self.labels = None
         self.sync_bn = True
+        self.fixed_input_shape = None
 
     def _get_backbone(self, backbone):
         def mobilenetv2(backbone):
@@ -182,7 +182,8 @@ class DeepLabv3p(BaseAPI):
             use_bce_loss=self.use_bce_loss,
             use_dice_loss=self.use_dice_loss,
             class_weight=self.class_weight,
-            ignore_index=self.ignore_index)
+            ignore_index=self.ignore_index,
+            fixed_input_shape=self.fixed_input_shape)
         inputs = model.generate_inputs()
         model_out = model.build_net(inputs)
         outputs = OrderedDict()

+ 3 - 1
paddlex/cv/models/faster_rcnn.py

@@ -57,6 +57,7 @@ class FasterRCNN(BaseAPI):
         self.aspect_ratios = aspect_ratios
         self.anchor_sizes = anchor_sizes
         self.labels = None
+        self.fixed_input_shape = None
 
     def _get_backbone(self, backbone_name):
         norm_type = None
@@ -109,7 +110,8 @@ class FasterRCNN(BaseAPI):
             aspect_ratios=self.aspect_ratios,
             anchor_sizes=self.anchor_sizes,
             train_pre_nms_top_n=train_pre_nms_top_n,
-            test_pre_nms_top_n=test_pre_nms_top_n)
+            test_pre_nms_top_n=test_pre_nms_top_n,
+            fixed_input_shape=self.fixed_input_shape)
         inputs = model.generate_inputs()
         if mode == 'train':
             model_out = model.build_net(inputs)

+ 31 - 1
paddlex/cv/models/load_model.py

@@ -23,7 +23,7 @@ import paddlex
 import paddlex.utils.logging as logging
 
 
-def load_model(model_dir):
+def load_model(model_dir, fixed_input_shape=None):
     if not osp.exists(osp.join(model_dir, "model.yml")):
         raise Exception("There's not model.yml in {}".format(model_dir))
     with open(osp.join(model_dir, "model.yml")) as f:
@@ -44,6 +44,7 @@ def load_model(model_dir):
     else:
         model = getattr(paddlex.cv.models,
                         info['Model'])(**info['_init_params'])
+    model.fixed_input_shape = fixed_input_shape
     if status == "Normal" or \
             status == "Prune" or status == "fluid.save":
         startup_prog = fluid.Program()
@@ -78,6 +79,8 @@ def load_model(model_dir):
             model.test_outputs[var_desc[0]] = out
     if 'Transforms' in info:
         transforms_mode = info.get('TransformsMode', 'RGB')
+        # 固定模型的输入shape
+        fix_input_shape(info, fixed_input_shape=fixed_input_shape)
         if transforms_mode == 'RGB':
             to_rgb = True
         else:
@@ -102,6 +105,33 @@ def load_model(model_dir):
     return model
 
 
+def fix_input_shape(info, fixed_input_shape=None):
+    if fixed_input_shape is not None:
+        resize = {'ResizeByShort': {}}
+        padding = {'Padding': {}}
+        if info['_Attributes']['model_type'] == 'classifier':
+            crop_size = 0
+            for transform in info['Transforms']:
+                if 'CenterCrop' in transform:
+                    crop_size = transform['CenterCrop']['crop_size']
+                    break
+            assert crop_size == fixed_input_shape[
+                0], "fixed_input_shape must == CenterCrop:crop_size:{}".format(
+                    crop_size)
+            assert crop_size == fixed_input_shape[
+                1], "fixed_input_shape must == CenterCrop:crop_size:{}".format(
+                    crop_size)
+            if crop_size == 0:
+                logging.warning(
+                    "fixed_input_shape must == input shape when trainning")
+        else:
+            resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
+            resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
+            padding['Padding']['target_size'] = list(fixed_input_shape)
+            info['Transforms'].append(resize)
+            info['Transforms'].append(padding)
+
+
 def build_transforms(model_type, transforms_info, to_rgb=True):
     if model_type == "classifier":
         import paddlex.cv.transforms.cls_transforms as T

+ 3 - 1
paddlex/cv/models/mask_rcnn.py

@@ -60,6 +60,7 @@ class MaskRCNN(FasterRCNN):
             self.mask_head_resolution = 28
         else:
             self.mask_head_resolution = 14
+        self.fixed_input_shape = None
 
     def build_net(self, mode='train'):
         train_pre_nms_top_n = 2000 if self.with_fpn else 12000
@@ -73,7 +74,8 @@ class MaskRCNN(FasterRCNN):
             train_pre_nms_top_n=train_pre_nms_top_n,
             test_pre_nms_top_n=test_pre_nms_top_n,
             num_convs=num_convs,
-            mask_head_resolution=self.mask_head_resolution)
+            mask_head_resolution=self.mask_head_resolution,
+            fixed_input_shape=self.fixed_input_shape)
         inputs = model.generate_inputs()
         if mode == 'train':
             model_out = model.build_net(inputs)

+ 3 - 1
paddlex/cv/models/unet.py

@@ -77,6 +77,7 @@ class UNet(DeepLabv3p):
         self.class_weight = class_weight
         self.ignore_index = ignore_index
         self.labels = None
+        self.fixed_input_shape = None
 
     def build_net(self, mode='train'):
         model = paddlex.cv.nets.segmentation.UNet(
@@ -86,7 +87,8 @@ class UNet(DeepLabv3p):
             use_bce_loss=self.use_bce_loss,
             use_dice_loss=self.use_dice_loss,
             class_weight=self.class_weight,
-            ignore_index=self.ignore_index)
+            ignore_index=self.ignore_index,
+            fixed_input_shape=self.fixed_input_shape)
         inputs = model.generate_inputs()
         model_out = model.build_net(inputs)
         outputs = OrderedDict()

+ 3 - 1
paddlex/cv/models/yolo_v3.py

@@ -80,6 +80,7 @@ class YOLOv3(BaseAPI):
         self.label_smooth = label_smooth
         self.sync_bn = True
         self.train_random_shapes = train_random_shapes
+        self.fixed_input_shape = None
 
     def _get_backbone(self, backbone_name):
         if backbone_name == 'DarkNet53':
@@ -113,7 +114,8 @@ class YOLOv3(BaseAPI):
             nms_topk=self.nms_topk,
             nms_keep_topk=self.nms_keep_topk,
             nms_iou_threshold=self.nms_iou_threshold,
-            train_random_shapes=self.train_random_shapes)
+            train_random_shapes=self.train_random_shapes,
+            fixed_input_shape=self.fixed_input_shape)
         inputs = model.generate_inputs()
         model_out = model.build_net(inputs)
         outputs = OrderedDict([('bbox', model_out)])

+ 13 - 3
paddlex/cv/nets/detection/faster_rcnn.py

@@ -76,7 +76,8 @@ class FasterRCNN(object):
             fg_thresh=.5,
             bg_thresh_hi=.5,
             bg_thresh_lo=0.,
-            bbox_reg_weights=[0.1, 0.1, 0.2, 0.2]):
+            bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
+            fixed_input_shape=None):
         super(FasterRCNN, self).__init__()
         self.backbone = backbone
         self.mode = mode
@@ -148,6 +149,7 @@ class FasterRCNN(object):
         self.bg_thresh_lo = bg_thresh_lo
         self.bbox_reg_weights = bbox_reg_weights
         self.rpn_only = rpn_only
+        self.fixed_input_shape = fixed_input_shape
 
     def build_net(self, inputs):
         im = inputs['image']
@@ -219,8 +221,16 @@ class FasterRCNN(object):
 
     def generate_inputs(self):
         inputs = OrderedDict()
-        inputs['image'] = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if self.mode == 'train':
             inputs['im_info'] = fluid.data(
                 dtype='float32', shape=[None, 3], name='im_info')

+ 13 - 3
paddlex/cv/nets/detection/mask_rcnn.py

@@ -86,7 +86,8 @@ class MaskRCNN(object):
             fg_thresh=.5,
             bg_thresh_hi=.5,
             bg_thresh_lo=0.,
-            bbox_reg_weights=[0.1, 0.1, 0.2, 0.2]):
+            bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
+            fixed_input_shape=None):
         super(MaskRCNN, self).__init__()
         self.backbone = backbone
         self.mode = mode
@@ -167,6 +168,7 @@ class MaskRCNN(object):
         self.bg_thresh_lo = bg_thresh_lo
         self.bbox_reg_weights = bbox_reg_weights
         self.rpn_only = rpn_only
+        self.fixed_input_shape = fixed_input_shape
 
     def build_net(self, inputs):
         im = inputs['image']
@@ -306,8 +308,16 @@ class MaskRCNN(object):
 
     def generate_inputs(self):
         inputs = OrderedDict()
-        inputs['image'] = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if self.mode == 'train':
             inputs['im_info'] = fluid.data(
                 dtype='float32', shape=[None, 3], name='im_info')

+ 12 - 3
paddlex/cv/nets/detection/yolo_v3.py

@@ -33,7 +33,8 @@ class YOLOv3:
                  nms_iou_threshold=0.45,
                  train_random_shapes=[
                      320, 352, 384, 416, 448, 480, 512, 544, 576, 608
-                 ]):
+                 ],
+                 fixed_input_shape=None):
         if anchors is None:
             anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
                        [59, 119], [116, 90], [156, 198], [373, 326]]
@@ -54,6 +55,7 @@ class YOLOv3:
         self.norm_decay = 0.0
         self.prefix_name = ''
         self.train_random_shapes = train_random_shapes
+        self.fixed_input_shape = fixed_input_shape
 
     def _head(self, feats):
         outputs = []
@@ -247,8 +249,15 @@ class YOLOv3:
 
     def generate_inputs(self):
         inputs = OrderedDict()
-        inputs['image'] = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if self.mode == 'train':
             inputs['gt_box'] = fluid.data(
                 dtype='float32', shape=[None, None, 4], name='gt_box')

+ 14 - 3
paddlex/cv/nets/segmentation/deeplabv3p.py

@@ -61,6 +61,7 @@ class DeepLabv3p(object):
             自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
             即平时使用的交叉熵损失函数。
         ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
+        fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
 
     Raises:
         ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
@@ -81,7 +82,8 @@ class DeepLabv3p(object):
                  use_bce_loss=False,
                  use_dice_loss=False,
                  class_weight=None,
-                 ignore_index=255):
+                 ignore_index=255,
+                 fixed_input_shape=None):
         # dice_loss或bce_loss只适用两类分割中
         if num_classes > 2 and (use_bce_loss or use_dice_loss):
             raise ValueError(
@@ -115,6 +117,7 @@ class DeepLabv3p(object):
         self.decoder_use_sep_conv = decoder_use_sep_conv
         self.encoder_with_aspp = encoder_with_aspp
         self.enable_decoder = enable_decoder
+        self.fixed_input_shape = fixed_input_shape
 
     def _encoder(self, input):
         # 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
@@ -310,8 +313,16 @@ class DeepLabv3p(object):
 
     def generate_inputs(self):
         inputs = OrderedDict()
-        inputs['image'] = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if self.mode == 'train':
             inputs['label'] = fluid.data(
                 dtype='int32', shape=[None, 1, None, None], name='label')

+ 14 - 3
paddlex/cv/nets/segmentation/unet.py

@@ -54,6 +54,7 @@ class UNet(object):
                 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
                 即平时使用的交叉熵损失函数。
             ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
+            fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
 
         Raises:
             ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
@@ -69,7 +70,8 @@ class UNet(object):
                  use_bce_loss=False,
                  use_dice_loss=False,
                  class_weight=None,
-                 ignore_index=255):
+                 ignore_index=255,
+                 fixed_input_shape=None):
         # dice_loss或bce_loss只适用两类分割中
         if num_classes > 2 and (use_bce_loss or use_dice_loss):
             raise Exception(
@@ -97,6 +99,7 @@ class UNet(object):
         self.use_dice_loss = use_dice_loss
         self.class_weight = class_weight
         self.ignore_index = ignore_index
+        self.fixed_input_shape = fixed_input_shape
 
     def _double_conv(self, data, out_ch):
         param_attr = fluid.ParamAttr(
@@ -226,8 +229,16 @@ class UNet(object):
 
     def generate_inputs(self):
         inputs = OrderedDict()
-        inputs['image'] = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if self.mode == 'train':
             inputs['label'] = fluid.data(
                 dtype='int32', shape=[None, 1, None, None], name='label')

+ 48 - 11
paddlex/cv/transforms/det_transforms.py

@@ -93,6 +93,8 @@ class Compose:
             # make default im_info with [h, w, 1]
             im_info['im_resize_info'] = np.array(
                 [im.shape[0], im.shape[1], 1.], dtype=np.float32)
+            im_info['image_shape'] = np.array([im.shape[0],
+                                               im.shape[1]]).astype('int32')
             if not self.use_mixup:
                 if 'mixup' in im_info:
                     del im_info['mixup']
@@ -193,11 +195,16 @@ class ResizeByShort:
 
 
 class Padding:
-    """将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640],
+    """1.将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640],
        `coarest_stride`为32,则由于300不为32的倍数,因此在图像最右和最下使用0值
        进行padding,最终输出图像为[320, 640]。
+       2.或者,将图像的长和宽padding到target_size指定的shape,如输入的图像为[300,640],
+         a. `target_size` = 960,在图像最右和最下使用0值进行padding,最终输出
+            图像为[960, 960]。
+         b. `target_size` = [640, 960],在图像最右和最下使用0值进行padding,最终
+            输出图像为[640, 960]。
 
-    1. 如果coarsest_stride为1则直接返回。
+    1. 如果coarsest_stride为1,target_size为None则直接返回。
     2. 获取图像的高H、宽W。
     3. 计算填充后图像的高H_new、宽W_new。
     4. 构建大小为(H_new, W_new, 3)像素值为0的np.ndarray,
@@ -205,10 +212,26 @@ class Padding:
 
     Args:
         coarsest_stride (int): 填充后的图像长、宽为该参数的倍数,默认为1。
+        target_size (int|list|tuple): 填充后的图像长、宽,默认为None,coarset_stride优先级更高。
+
+    Raises:
+        TypeError: 形参`target_size`数据类型不满足需求。
+        ValueError: 形参`target_size`为(list|tuple)时,长度不满足需求。
     """
 
-    def __init__(self, coarsest_stride=1):
+    def __init__(self, coarsest_stride=1, target_size=None):
         self.coarsest_stride = coarsest_stride
+        if target_size is not None:
+            if not isinstance(target_size, int):
+                if not isinstance(target_size, tuple) and not isinstance(
+                        target_size, list):
+                    raise TypeError(
+                        "Padding: Type of target_size must in (int|list|tuple)."
+                    )
+                elif len(target_size) != 2:
+                    raise ValueError(
+                        "Padding: Length of target_size must equal 2.")
+        self.target_size = target_size
 
     def __call__(self, im, im_info=None, label_info=None):
         """
@@ -225,13 +248,9 @@ class Padding:
         Raises:
             TypeError: 形参数据类型不满足需求。
             ValueError: 数据长度不匹配。
+            ValueError: coarsest_stride,target_size需有且只有一个被指定。
+            ValueError: target_size小于原图的大小。
         """
-
-        if self.coarsest_stride == 1:
-            if label_info is None:
-                return (im, im_info)
-            else:
-                return (im, im_info, label_info)
         if im_info is None:
             im_info = dict()
         if not isinstance(im, np.ndarray):
@@ -239,11 +258,29 @@ class Padding:
         if len(im.shape) != 3:
             raise ValueError('Padding: image is not 3-dimensional.')
         im_h, im_w, im_c = im.shape[:]
-        if self.coarsest_stride > 1:
+
+        if isinstance(self.target_size, int):
+            padding_im_h = self.target_size
+            padding_im_w = self.target_size
+        elif isinstance(self.target_size, list) or isinstance(
+                self.target_size, tuple):
+            padding_im_w = self.target_size[0]
+            padding_im_h = self.target_size[1]
+        elif self.coarsest_stride > 0:
             padding_im_h = int(
                 np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride)
             padding_im_w = int(
                 np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride)
+        else:
+            raise ValueError(
+                "coarsest_stridei(>1) or target_size(list|int) need setting in Padding transform"
+            )
+        pad_height = padding_im_h - im_h
+        pad_width = padding_im_w - im_w
+        if pad_height < 0 or pad_width < 0:
+            raise ValueError(
+                'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
+                .format(im_w, im_h, padding_im_w, padding_im_h))
         padding_im = np.zeros((padding_im_h, padding_im_w, im_c),
                               dtype=np.float32)
         padding_im[:im_h, :im_w, :] = im
@@ -539,7 +576,7 @@ class RandomDistort:
             params = params_dict[ops[id].__name__]
             prob = prob_dict[ops[id].__name__]
             params['im'] = im
-            
+
             if np.random.uniform(0, 1) < prob:
                 im = ops[id](**params)
         if label_info is None:

+ 78 - 1
paddlex/cv/transforms/seg_transforms.py

@@ -285,7 +285,7 @@ class ResizeByLong:
                 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
                 存储与图像相关信息的字典和标注图像np.ndarray数据。
                 其中,im_info新增字段为:
-                    -shape_before_resize (tuple): 保存resize之前图像的形状(h, w
+                    -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)
         """
         if im_info is None:
             im_info = OrderedDict()
@@ -301,6 +301,83 @@ class ResizeByLong:
             return (im, im_info, label)
 
 
+class ResizeByShort:
+    """根据图像的短边调整图像大小(resize)。
+
+    1. 获取图像的长边和短边长度。
+    2. 根据短边与short_size的比例,计算长边的目标长度,
+       此时高、宽的resize比例为short_size/原图短边长度。
+    3. 如果max_size>0,调整resize比例:
+       如果长边的目标长度>max_size,则高、宽的resize比例为max_size/原图长边长度。
+    4. 根据调整大小的比例对图像进行resize。
+
+    Args:
+        target_size (int): 短边目标长度。默认为800。
+        max_size (int): 长边目标长度的最大限制。默认为1333。
+
+     Raises:
+        TypeError: 形参数据类型不满足需求。
+    """
+
+    def __init__(self, short_size=800, max_size=1333):
+        self.max_size = int(max_size)
+        if not isinstance(short_size, int):
+            raise TypeError(
+                "Type of short_size is invalid. Must be Integer, now is {}".
+                format(type(short_size)))
+        self.short_size = short_size
+        if not (isinstance(self.max_size, int)):
+            raise TypeError("max_size: input type is invalid.")
+
+    def __call__(self, im, im_info=None, label=None):
+        """
+        Args:
+            im (numnp.ndarraypy): 图像np.ndarray数据。
+            im_info (list): 存储图像reisze或padding前的shape信息,如
+                [('resize', [200, 300]), ('padding', [400, 600])]表示
+                图像在过resize前shape为(200, 300), 过padding前shape为
+                (400, 600)
+            label (np.ndarray): 标注图像np.ndarray数据。
+
+        Returns:
+            tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
+                   当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
+                   存储与图像相关信息的字典和标注图像np.ndarray数据。
+                   其中,im_info更新字段为:
+                       -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
+
+        Raises:
+            TypeError: 形参数据类型不满足需求。
+            ValueError: 数据长度不匹配。
+        """
+        if im_info is None:
+            im_info = OrderedDict()
+        if not isinstance(im, np.ndarray):
+            raise TypeError("ResizeByShort: image type is not numpy.")
+        if len(im.shape) != 3:
+            raise ValueError('ResizeByShort: image is not 3-dimensional.')
+        im_info.append(('resize', im.shape[:2]))
+        im_short_size = min(im.shape[0], im.shape[1])
+        im_long_size = max(im.shape[0], im.shape[1])
+        scale = float(self.short_size) / im_short_size
+        if self.max_size > 0 and np.round(
+                scale * im_long_size) > self.max_size:
+            scale = float(self.max_size) / float(im_long_size)
+        resized_width = int(round(im.shape[1] * scale))
+        resized_height = int(round(im.shape[0] * scale))
+        im = cv2.resize(
+            im, (resized_width, resized_height),
+            interpolation=cv2.INTER_NEAREST)
+        if label is not None:
+            im = cv2.resize(
+                label, (resized_width, resized_height),
+                interpolation=cv2.INTER_NEAREST)
+        if label is None:
+            return (im, im_info)
+        else:
+            return (im, im_info, label)
+
+
 class ResizeRangeScaling:
     """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。