浏览代码

Merge pull request #350 from FlyingQianMM/develop_qh

support multi-channel inputs for openvino deployment
Jason 5 年之前
父节点
当前提交
f3b97ee3a2

+ 30 - 0
deploy/openvino/include/paddlex/transforms.h

@@ -71,6 +71,16 @@ class Normalize : public Transform {
   virtual void Init(const YAML::Node& item) {
     mean_ = item["mean"].as<std::vector<float>>();
     std_ = item["std"].as<std::vector<float>>();
+    if (item["min_val"].IsDefined()) {
+      min_val_ = item["min_val"].as<std::vector<float>>();
+    } else {
+      min_val_ = std::vector<float>(mean_.size(), 0.);
+    }
+    if (item["max_val"].IsDefined()) {
+      max_val_ = item["max_val"].as<std::vector<float>>();
+    } else {
+      max_val_ = std::vector<float>(mean_.size(), 255.);
+    }
   }
 
   virtual bool Run(cv::Mat* im, ImageBlob* data);
@@ -78,6 +88,8 @@ class Normalize : public Transform {
  private:
   std::vector<float> mean_;
   std::vector<float> std_;
+  std::vector<float> min_val_;
+  std::vector<float> max_val_;
 };
 
 class ResizeByShort : public Transform {
@@ -209,6 +221,24 @@ class Padding : public Transform {
   std::vector<float> im_value_;
 };
 
+/*
+ * @brief
+ * This class execute clip operation on image matrix
+ * */
+class Clip : public Transform {
+ public:
+  virtual void Init(const YAML::Node& item) {
+    min_val_ = item["min_val"].as<std::vector<float>>();
+    max_val_ = item["max_val"].as<std::vector<float>>();
+  }
+
+  virtual bool Run(cv::Mat* im, ImageBlob* data);
+
+ private:
+  std::vector<float> min_val_;
+  std::vector<float> max_val_;
+};
+
 class Transforms {
  public:
   void Init(const YAML::Node& node, std::string type, bool to_rgb = true);

+ 9 - 1
deploy/openvino/python/converter.py

@@ -13,9 +13,11 @@
 # limitations under the License.
 
 import os
+import os.path as osp
 from six import text_type as _text_type
 import argparse
 import sys
+import yaml
 import paddlex as pdx
 
 
@@ -72,7 +74,13 @@ def export_openvino_model(model, args):
     onnx_parser.set_defaults(input_model=onnx_input)
     onnx_parser.set_defaults(output_dir=args.save_dir)
     shape_list = args.fixed_input_shape[1:-1].split(',')
-    shape = '[1,3,' + shape_list[1] + ',' + shape_list[0] + ']'
+    with open(osp.join(args.model_dir, "model.yml")) as f:
+        info = yaml.load(f.read(), Loader=yaml.Loader)
+    input_channel = 3
+    if 'input_channel' in info['_init_params']:
+        input_channel = info['_init_params']['input_channel']
+    shape = '[1,{},' + shape_list[1] + ',' + shape_list[0] + ']'
+    shape = shape.format(input_channel)
     if model.__class__.__name__ == "YOLOV3":
         shape = shape + ",[1,2]"
         inputs = "image,im_size"

+ 0 - 2
deploy/openvino/python/deploy.py

@@ -173,7 +173,6 @@ class Predictor:
             'category': self.labels[l],
             'score': preds[output_name][0][l],
         } for l in pred_label]
-        print(result)
         return result
 
     def segmenter_postprocess(self, preds, preprocessed_inputs):
@@ -212,7 +211,6 @@ class Predictor:
                 result.append(out.tolist())
             else:
                 pass
-        print(result)
         return result
 
     def predict(self, image, topk=1, threshold=0.5):

+ 10 - 5
deploy/openvino/python/transforms/ops.py

@@ -18,11 +18,15 @@ import numpy as np
 from PIL import Image, ImageEnhance
 
 
-def normalize(im, mean, std):
-    im = im / 255.0
+def normalize(im, mean, std, min_value=[0, 0, 0], max_value=[255, 255, 255]):
+    # Rescaling (min-max normalization)
+    range_value = [max_value[i] - min_value[i] for i in range(len(max_value))]
+    im = (im - min_value) / range_value
+
+    # Standardization (Z-score Normalization)
     im -= mean
     im /= std
-    return im
+    return im.astype('float32')
 
 
 def permute(im, to_bgr=False):
@@ -69,8 +73,8 @@ def random_crop(im,
                 (float(im.shape[1]) / im.shape[0]) / (w**2))
     scale_max = min(scale[1], bound)
     scale_min = min(scale[0], bound)
-    target_area = im.shape[0] * im.shape[1] * np.random.uniform(
-        scale_min, scale_max)
+    target_area = im.shape[0] * im.shape[1] * np.random.uniform(scale_min,
+                                                                scale_max)
     target_size = math.sqrt(target_area)
     w = int(target_size * w)
     h = int(target_size * h)
@@ -146,6 +150,7 @@ def brightness(im, brightness_lower, brightness_upper):
     im += delta
     return im
 
+
 def rotate(im, rotate_lower, rotate_upper):
     rotate_delta = np.random.uniform(rotate_lower, rotate_upper)
     im = im.rotate(int(rotate_delta))

+ 197 - 72
deploy/openvino/python/transforms/seg_transforms.py

@@ -19,6 +19,7 @@ import os.path as osp
 import numpy as np
 from PIL import Image
 import cv2
+import imghdr
 from collections import OrderedDict
 
 
@@ -52,6 +53,62 @@ class Compose(SegTransform):
         self.transforms = transforms
         self.to_rgb = False
 
+    @staticmethod
+    def read_img(img_path):
+        img_format = imghdr.what(img_path)
+        name, ext = osp.splitext(img_path)
+        if img_format == 'tiff' or ext == '.img':
+            try:
+                import gdal
+            except:
+                six.reraise(*sys.exc_info())
+                raise Exception(
+                    "Please refer to https://github.com/PaddlePaddle/PaddleX/tree/develop/examples/multi-channel_remote_sensing/README.md to install gdal"
+                )
+
+            dataset = gdal.Open(img_path)
+            if dataset == None:
+                raise Exception('Can not open', img_path)
+            im_data = dataset.ReadAsArray()
+            return im_data.transpose((1, 2, 0))
+        elif img_format in ['jpeg', 'bmp', 'png']:
+            return cv2.imread(img_path)
+        elif ext == '.npy':
+            return np.load(img_path)
+        else:
+            raise Exception('Image format {} is not supported!'.format(ext))
+
+    @staticmethod
+    def decode_image(im, label):
+        if isinstance(im, np.ndarray):
+            if len(im.shape) != 3:
+                raise Exception(
+                    "im should be 3-dimensions, but now is {}-dimensions".
+                    format(len(im.shape)))
+        else:
+            try:
+                im = Compose.read_img(im)
+            except:
+                raise ValueError('Can\'t read The image file {}!'.format(im))
+        im = im.astype('float32')
+        if label is not None:
+            if isinstance(label, np.ndarray):
+                if len(label.shape) != 2:
+                    raise Exception(
+                        "label should be 2-dimensions, but now is {}-dimensions".
+                        format(len(label.shape)))
+
+            else:
+                try:
+                    label = np.asarray(Image.open(label))
+                except:
+                    ValueError('Can\'t read The label file {}!'.format(label))
+            im_height, im_width, _ = im.shape
+            label_height, label_width = label.shape
+            if im_height != label_height or im_width != label_width:
+                raise Exception(
+                    "The height or width of the image is not same as the label")
+        return (im, label)
 
     def __call__(self, im, im_info=None, label=None):
         """
@@ -67,23 +124,13 @@ class Compose(SegTransform):
             tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
         """
 
-        if im_info is None:
-            im_info = list()
-        if isinstance(im, np.ndarray):
-            if len(im.shape) != 3:
-                raise Exception(
-                    "im should be 3-dimensions, but now is {}-dimensions".
-                    format(len(im.shape)))
-        else:
-            try:
-                im = cv2.imread(im).astype('float32')
-            except:
-                raise ValueError('Can\'t read The image file {}!'.format(im))
+        im, label = self.decode_image(im, label)
         if self.to_rgb:
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+        if im_info is None:
+            im_info = [('origin_shape', im.shape[0:2])]
         if label is not None:
-            if not isinstance(label, np.ndarray):
-                label = np.asarray(Image.open(label))
+            origin_label = label.copy()
         for op in self.transforms:
             if isinstance(op, SegTransform):
                 outputs = op(im, im_info, label)
@@ -98,6 +145,10 @@ class Compose(SegTransform):
                     outputs = (im, im_info, label)
                 else:
                     outputs = (im, im_info)
+        if self.transforms[-1].__class__.__name__ == 'ArrangeSegmenter':
+            if self.transforms[-1].mode == 'eval':
+                if label is not None:
+                    outputs = (im, im_info, origin_label)
         return outputs
 
     def add_augmenters(self, augmenters):
@@ -107,7 +158,9 @@ class Compose(SegTransform):
         transform_names = [type(x).__name__ for x in self.transforms]
         for aug in augmenters:
             if type(aug).__name__ in transform_names:
-                print("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__))
+                print(
+                    "{} is already in ComposedTransforms, need to remove it from add_augmenters().".
+                    format(type(aug).__name__))
         self.transforms = augmenters + self.transforms
 
 
@@ -536,22 +589,35 @@ class ResizeStepScaling(SegTransform):
 
 class Normalize(SegTransform):
     """对图像进行标准化。
-    1.尺度缩放到 [0,1]。
-    2.对图像进行减均值除以标准差操作。
+    1.像素值减去min_val
+    2.像素值除以(max_val-min_val)
+    3.对图像进行减均值除以标准差操作。
 
     Args:
         mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
         std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
+        min_val (list): 图像数据集的最小值。默认值[0, 0, 0]。
+        max_val (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]。
 
     Raises:
         ValueError: mean或std不是list对象。std包含0。
     """
 
-    def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
+    def __init__(self,
+                 mean=[0.5, 0.5, 0.5],
+                 std=[0.5, 0.5, 0.5],
+                 min_val=[0, 0, 0],
+                 max_val=[255.0, 255.0, 255.0]):
+        self.min_val = min_val
+        self.max_val = max_val
         self.mean = mean
         self.std = std
         if not (isinstance(self.mean, list) and isinstance(self.std, list)):
             raise ValueError("{}: input type is invalid.".format(self))
+        if not (isinstance(self.min_val, list) and isinstance(self.max_val,
+                                                              list)):
+            raise ValueError("{}: input type is invalid.".format(self))
+
         from functools import reduce
         if reduce(lambda x, y: x * y, self.std) == 0:
             raise ValueError('{}: std is invalid!'.format(self))
@@ -574,7 +640,8 @@ class Normalize(SegTransform):
 
         mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
         std = np.array(self.std)[np.newaxis, np.newaxis, :]
-        im = normalize(im, mean, std)
+        im = normalize(im, mean, std, self.min_val, self.max_val)
+        im = im.astype('float32')
 
         if label is None:
             return (im, im_info)
@@ -646,28 +713,29 @@ class Padding(SegTransform):
             target_width = self.target_size[0]
         pad_height = target_height - im_height
         pad_width = target_width - im_width
-        if pad_height < 0 or pad_width < 0:
-            raise ValueError(
-                'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
-                .format(im_width, im_height, target_width, target_height))
-        else:
-            im = cv2.copyMakeBorder(
-                im,
-                0,
-                pad_height,
-                0,
-                pad_width,
-                cv2.BORDER_CONSTANT,
-                value=self.im_padding_value)
+        pad_height = max(pad_height, 0)
+        pad_width = max(pad_width, 0)
+        if (pad_height > 0 or pad_width > 0):
+            im_channel = im.shape[2]
+            import copy
+            orig_im = copy.deepcopy(im)
+            im = np.zeros((im_height + pad_height, im_width + pad_width,
+                           im_channel)).astype(orig_im.dtype)
+            for i in range(im_channel):
+                im[:, :, i] = np.pad(
+                    orig_im[:, :, i],
+                    pad_width=((0, pad_height), (0, pad_width)),
+                    mode='constant',
+                    constant_values=(self.im_padding_value[i],
+                                     self.im_padding_value[i]))
+
             if label is not None:
-                label = cv2.copyMakeBorder(
-                    label,
-                    0,
-                    pad_height,
-                    0,
-                    pad_width,
-                    cv2.BORDER_CONSTANT,
-                    value=self.label_padding_value)
+                label = np.pad(label,
+                               pad_width=((0, pad_height), (0, pad_width)),
+                               mode='constant',
+                               constant_values=(self.label_padding_value,
+                                                self.label_padding_value))
+
         if label is None:
             return (im, im_info)
         else:
@@ -738,23 +806,26 @@ class RandomPaddingCrop(SegTransform):
             pad_height = max(crop_height - img_height, 0)
             pad_width = max(crop_width - img_width, 0)
             if (pad_height > 0 or pad_width > 0):
-                im = cv2.copyMakeBorder(
-                    im,
-                    0,
-                    pad_height,
-                    0,
-                    pad_width,
-                    cv2.BORDER_CONSTANT,
-                    value=self.im_padding_value)
+                img_channel = im.shape[2]
+                import copy
+                orig_im = copy.deepcopy(im)
+                im = np.zeros((img_height + pad_height, img_width + pad_width,
+                               img_channel)).astype(orig_im.dtype)
+                for i in range(img_channel):
+                    im[:, :, i] = np.pad(
+                        orig_im[:, :, i],
+                        pad_width=((0, pad_height), (0, pad_width)),
+                        mode='constant',
+                        constant_values=(self.im_padding_value[i],
+                                         self.im_padding_value[i]))
+
                 if label is not None:
-                    label = cv2.copyMakeBorder(
-                        label,
-                        0,
-                        pad_height,
-                        0,
-                        pad_width,
-                        cv2.BORDER_CONSTANT,
-                        value=self.label_padding_value)
+                    label = np.pad(label,
+                                   pad_width=((0, pad_height), (0, pad_width)),
+                                   mode='constant',
+                                   constant_values=(self.label_padding_value,
+                                                    self.label_padding_value))
+
                 img_height = im.shape[0]
                 img_width = im.shape[1]
 
@@ -819,8 +890,6 @@ class RandomBlur(SegTransform):
             return (im, im_info, label)
 
 
-
-
 class RandomScaleAspect(SegTransform):
     """裁剪并resize回原始尺寸的图像和标注图像。
     按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
@@ -974,6 +1043,34 @@ class RandomDistort(SegTransform):
             params['im'] = im
             if np.random.uniform(0, 1) < prob:
                 im = ops[id](**params)
+        im = im.astype('float32')
+        if label is None:
+            return (im, im_info)
+        else:
+            return (im, im_info, label)
+
+
+class Clip(SegTransform):
+    """
+    对图像上超出一定范围的数据进行截断。
+
+    Args:
+        min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0.
+        max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0.
+    """
+
+    def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
+        self.min_val = min_val
+        self.max_val = max_val
+        if not (isinstance(self.min_val, list) and isinstance(self.max_val,
+                                                              list)):
+            raise ValueError("{}: input type is invalid.".format(self))
+
+    def __call__(self, im, im_info=None, label=None):
+        for k in range(im.shape[2]):
+            np.clip(
+                im[:, :, k], self.min_val[k], self.max_val[k], out=im[:, :, k])
+
         if label is None:
             return (im, im_info)
         else:
@@ -1013,9 +1110,12 @@ class ArrangeSegmenter(SegTransform):
                 'quant'时,返回的tuple为(im,),为图像np.ndarray数据。
         """
         im = permute(im, False)
-        if self.mode == 'train' or self.mode == 'eval':
+        if self.mode == 'train':
             label = label[np.newaxis, :, :]
             return (im, label)
+        if self.mode == 'eval':
+            label = label[np.newaxis, :, :]
+            return (im, im_info, label)
         elif self.mode == 'test':
             return (im, im_info)
         else:
@@ -1025,30 +1125,55 @@ class ArrangeSegmenter(SegTransform):
 class ComposedSegTransforms(Compose):
     """ 语义分割模型(UNet/DeepLabv3p)的图像处理流程,具体如下
         训练阶段:
-        1. 随机对图像以0.5的概率水平翻转
-        2. 按不同的比例随机Resize原图
+        1. 随机对图像以0.5的概率水平翻转,若random_horizontal_flip为False,则跳过此步骤
+        2. 按不同的比例随机Resize原图, 处理方式参考[paddlex.seg.transforms.ResizeRangeScaling](#resizerangescaling)。若min_max_size为None,则跳过此步骤
         3. 从原图中随机crop出大小为train_crop_size大小的子图,如若crop出来的图小于train_crop_size,则会将图padding到对应大小
         4. 图像归一化
-        预测阶段:
-        1. 图像归一化
+       预测阶段:
+        1. 将图像的最长边resize至(min_max_size[0] + min_max_size[1])//2, 短边按比例resize。若min_max_size为None,则跳过此步骤
+        2. 图像归一化
 
         Args:
-            mode(str): 图像处理所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
-            train_crop_size(list): 模型训练阶段,随机从原图crop的大小
-            mean(list): 图像均值
-            std(list): 图像方差
+            mode(str): Transforms所处的阶段,包括`train', 'eval'或'test'
+            min_max_size(list): 用于对图像进行resize,具体作用参见上述步骤。
+            train_crop_size(list): 训练过程中随机裁剪原图用于训练,具体作用参见上述步骤。此参数仅在mode为`train`时生效。
+            mean(list): 图像均值, 默认为[0.485, 0.456, 0.406]。
+            std(list): 图像方差,默认为[0.229, 0.224, 0.225]。
+            random_horizontal_flip(bool): 数据增强,是否随机水平翻转图像,此参数仅在mode为`train`时生效。
     """
 
     def __init__(self,
                  mode,
-                 train_crop_size=[769, 769],
+                 min_max_size=[400, 600],
+                 train_crop_size=[512, 512],
                  mean=[0.5, 0.5, 0.5],
-                 std=[0.5, 0.5, 0.5]):
+                 std=[0.5, 0.5, 0.5],
+                 random_horizontal_flip=True):
         if mode == 'train':
             # 训练时的transforms,包含数据增强
-            pass
+            if min_max_size is None:
+                transforms = [
+                    RandomPaddingCrop(crop_size=train_crop_size), Normalize(
+                        mean=mean, std=std)
+                ]
+            else:
+                transforms = [
+                    ResizeRangeScaling(
+                        min_value=min(min_max_size),
+                        max_value=max(min_max_size)),
+                    RandomPaddingCrop(crop_size=train_crop_size), Normalize(
+                        mean=mean, std=std)
+                ]
+            if random_horizontal_flip:
+                transforms.insert(0, RandomHorizontalFlip())
         else:
             # 验证/预测时的transforms
-            transforms = [Normalize(mean=mean, std=std)]
-
+            if min_max_size is None:
+                transforms = [Normalize(mean=mean, std=std)]
+            else:
+                long_size = (min(min_max_size) + max(min_max_size)) // 2
+                transforms = [
+                    ResizeByLong(long_size=long_size), Normalize(
+                        mean=mean, std=std)
+                ]
         super(ComposedSegTransforms, self).__init__(transforms)

+ 81 - 19
deploy/openvino/src/transforms.cpp

@@ -31,21 +31,25 @@ std::map<std::string, int> interpolations = {{"LINEAR", cv::INTER_LINEAR},
                                              {"LANCZOS4", cv::INTER_LANCZOS4}};
 
 bool Normalize::Run(cv::Mat* im, ImageBlob* data) {
-  for (int h = 0; h < im->rows; h++) {
-    for (int w = 0; w < im->cols; w++) {
-      im->at<cv::Vec3f>(h, w)[0] =
-          (im->at<cv::Vec3f>(h, w)[0] / 255.0 - mean_[0]) / std_[0];
-      im->at<cv::Vec3f>(h, w)[1] =
-          (im->at<cv::Vec3f>(h, w)[1] / 255.0 - mean_[1]) / std_[1];
-      im->at<cv::Vec3f>(h, w)[2] =
-          (im->at<cv::Vec3f>(h, w)[2] / 255.0 - mean_[2]) / std_[2];
-    }
+  std::vector<float> range_val;
+  for (int c = 0; c < im->channels(); c++) {
+    range_val.push_back(max_val_[c] - min_val_[c]);
+  }
+
+  std::vector<cv::Mat> split_im;
+  cv::split(*im, split_im);
+  for (int c = 0; c < im->channels(); c++) {
+    cv::subtract(split_im[c], cv::Scalar(min_val_[c]), split_im[c]);
+    cv::divide(split_im[c], cv::Scalar(range_val[c]), split_im[c]);
+    cv::subtract(split_im[c], cv::Scalar(mean_[c]), split_im[c]);
+    cv::divide(split_im[c], cv::Scalar(std_[c]), split_im[c]);
   }
+  cv::merge(split_im, *im);
+
   return true;
 }
 
 
-
 float ResizeByShort::GenerateScale(const cv::Mat& im) {
   int origin_w = im.cols;
   int origin_h = im.rows;
@@ -71,6 +75,7 @@ bool ResizeByShort::Run(cv::Mat* im, ImageBlob* data) {
   data->new_im_size_[0] = im->rows;
   data->new_im_size_[1] = im->cols;
   data->scale = scale;
+
   return true;
 }
 
@@ -115,11 +120,22 @@ bool Padding::Run(cv::Mat* im, ImageBlob* data) {
               << ", but they should be greater than 0." << std::endl;
     return false;
   }
-  cv::Scalar value = cv::Scalar(im_value_[0], im_value_[1], im_value_[2]);
-  cv::copyMakeBorder(
-      *im, *im, 0, padding_h, 0, padding_w, cv::BORDER_CONSTANT, value);
+  std::vector<cv::Mat> padded_im_per_channel;
+  for (size_t i = 0; i < im->channels(); i++) {
+    const cv::Mat per_channel = cv::Mat(im->rows + padding_h,
+                                        im->cols + padding_w,
+                                        CV_32FC1,
+                                        cv::Scalar(im_value_[i]));
+    padded_im_per_channel.push_back(per_channel);
+  }
+  cv::Mat padded_im;
+  cv::merge(padded_im_per_channel, padded_im);
+  cv::Rect im_roi = cv::Rect(0, 0, im->cols, im->rows);
+  im->copyTo(padded_im(im_roi));
+  *im = padded_im;
   data->new_im_size_[0] = im->rows;
   data->new_im_size_[1] = im->cols;
+
   return true;
 }
 
@@ -165,6 +181,22 @@ bool Resize::Run(cv::Mat* im, ImageBlob* data) {
   return true;
 }
 
+bool Clip::Run(cv::Mat* im, ImageBlob* data) {
+  std::vector<cv::Mat> split_im;
+  cv::split(*im, split_im);
+  for (int c = 0; c < im->channels(); c++) {
+    cv::threshold(split_im[c], split_im[c], max_val_[c], max_val_[c],
+                  cv::THRESH_TRUNC);
+    cv::subtract(cv::Scalar(0), split_im[c], split_im[c]);
+    cv::threshold(split_im[c], split_im[c], min_val_[c], min_val_[c],
+                  cv::THRESH_TRUNC);
+    cv::divide(split_im[c], cv::Scalar(-1), split_im[c]);
+  }
+  cv::merge(split_im, *im);
+
+  return true;
+}
+
 void Transforms::Init(
   const YAML::Node& transforms_node, std::string type, bool to_rgb) {
   transforms_.clear();
@@ -172,6 +204,21 @@ void Transforms::Init(
   type_ = type;
   for (const auto& item : transforms_node) {
     std::string name = item.begin()->first.as<std::string>();
+    if (name == "ArrangeClassifier") {
+      continue;
+    }
+    if (name == "ArrangeSegmenter") {
+      continue;
+    }
+    if (name == "ArrangeFasterRCNN") {
+      continue;
+    }
+    if (name == "ArrangeMaskRCNN") {
+      continue;
+    }
+    if (name == "ArrangeYOLOv3") {
+      continue;
+    }
     std::cout << "trans name: " << name << std::endl;
     std::shared_ptr<Transform> transform = CreateTransform(name);
     transform->Init(item.begin()->second);
@@ -193,6 +240,8 @@ std::shared_ptr<Transform> Transforms::CreateTransform(
     return std::make_shared<Padding>();
   } else if (transform_name == "ResizeByLong") {
     return std::make_shared<ResizeByLong>();
+  } else if (transform_name == "Clip") {
+    return std::make_shared<Clip>();
   } else {
     std::cerr << "There's unexpected transform(name='" << transform_name
               << "')." << std::endl;
@@ -205,7 +254,7 @@ bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
   if (to_rgb_) {
     cv::cvtColor(*im, *im, cv::COLOR_BGR2RGB);
   }
-  (*im).convertTo(*im, CV_32FC3);
+  (*im).convertTo(*im, CV_32FC(im->channels()));
   if (type_ == "detector") {
     InferenceEngine::LockedMemory<void> input2Mapped =
       InferenceEngine::as<InferenceEngine::MemoryBlob>(
@@ -234,14 +283,27 @@ bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
     InferenceEngine::as<InferenceEngine::MemoryBlob>(data->blob);
   auto mblobHolder = mblob->wmap();
   float *blob_data = mblobHolder.as<float *>();
-  for (size_t c = 0; c < channels; c++) {
+  if (channels == 3) {
+    for (size_t c = 0; c < channels; c++) {
       for (size_t  h = 0; h < height; h++) {
-          for (size_t w = 0; w < width; w++) {
-              blob_data[c * width * height + h * width + w] =
-                      im->at<cv::Vec3f>(h, w)[c];
-          }
+        for (size_t w = 0; w < width; w++) {
+          blob_data[c * width * height + h * width + w] =
+            im->at<cv::Vec3f>(h, w)[c];
+        }
       }
+    }
+  } else {
+    for (size_t  h = 0; h < height; h++) {
+      float *pixelPtr = im->ptr<float>(h);
+      for (size_t w = 0; w < width; w++) {
+        for (size_t c = 0; c < channels; c++) {
+          blob_data[c * width * height + h * width + w] =
+            pixelPtr[w*channels + c];
+        }
+      }
+    }
   }
+
   return true;
 }
 }  // namespace PaddleX

+ 4 - 1
docs/deploy/openvino/export_openvino_model.md

@@ -20,7 +20,10 @@ paddlex --export_inference --model_dir=/path/to/paddle_model --save_dir=./infere
 ## 导出OpenVINO模型
 
 ```
-cd /root/projects/python
+mkdir -p /root/projects
+cd /root/projects
+git clone https://github.com/PaddlePaddle/PaddleX.git
+cd PaddleX/deploy/openvino/python
 
 python converter.py --model_dir /path/to/inference_model --save_dir /path/to/openvino_model --fixed_input_shape [w,h]
 ```

+ 5 - 4
docs/deploy/openvino/introduction.md

@@ -1,5 +1,5 @@
 # OpenVINO部署简介
-PaddleX支持将训练好的Paddle模型通过OpenVINO实现模型的预测加速,OpenVINO详细资料与安装流程请参考[OpenVINO](https://docs.openvinotoolkit.org/latest/index.html)
+PaddleX支持将训练好的Paddle模型通过OpenVINO实现模型的预测加速,OpenVINO详细资料与安装流程请参考[OpenVINO](https://docs.openvinotoolkit.org/latest/index.html),本文档使用OpenVINO 2020.4测试通过。
 
 ## 部署支持情况
 下表提供了PaddleX在不同环境下对使用OpenVINO加速的支持情况  
@@ -13,20 +13,21 @@ PaddleX支持将训练好的Paddle模型通过OpenVINO实现模型的预测加
 **注意**:其中Raspbian OS为树莓派操作系统。检测模型仅支持YOLOV3,由于OpenVINO不支持ONNX的resize-11 OP的原因,目前还不支持Paddle的分割模型
 
 ## 部署流程
-**PaddleX到OpenVINO的部署流程可以分为如下两步**: 
+**PaddleX到OpenVINO的部署流程可以分为如下两步**:
 
   * **模型转换**:将Paddle的模型转换为OpenVINO的Inference Engine
   * **预测部署**:使用Inference Engine进行预测
 
-## 模型转换 
+## 模型转换
 **模型转换请参考文档[模型转换](./export_openvino_model.md)**  
 **说明**:由于不同软硬件平台下OpenVINO模型转换方法一致,故如何转换模型后续文档中不再赘述。
 
 ## 预测部署
 由于不同软硬下部署OpenVINO实现预测的方式不完全一致,具体请参考:  
+
 **[Linux](./linux.md)**:介绍了PaddleX在操作系统为Linux或者Raspbian OS,编程语言为C++,硬件平台为
 CPU或者VPU的情况下使用OpenVINO进行预测加速  
 
 **[Windows](./windows.md)**:介绍了PaddleX在操作系统为Window,编程语言为C++,硬件平台为CPU或者VPU的情况下使用OpenVINO进行预测加速  
 
-**[Python](./python.md)**:介绍了PaddleX在python下使用OpenVINO进行预测加速
+**[Python](./python.md)**:介绍了PaddleX在python下使用OpenVINO进行预测加速

+ 3 - 3
docs/deploy/openvino/linux.md

@@ -29,7 +29,8 @@ git clone https://github.com/PaddlePaddle/PaddleX.git
 **说明**:其中C++预测代码在PaddleX/deploy/openvino 目录,该目录不依赖任何PaddleX下其他目录。
 
 ### Step2 软件依赖
-提供了依赖软件预编包或者一键编译,用户不需要单独下载或编译第三方依赖软件。若需要自行编译第三方依赖软件请参考:
+
+Step3中的编译脚本会一键安装第三方依赖软件的预编译包,用户不需要单独下载或编译这些依赖软件。若需要自行编译第三方依赖软件请参考:
 
 - gflags:编译请参考 [编译文档](https://gflags.github.io/gflags/#download)  
 
@@ -37,7 +38,6 @@ git clone https://github.com/PaddlePaddle/PaddleX.git
 [编译文档](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html)
 
 
-
 ### Step3: 编译
 编译`cmake`的命令在`scripts/build.sh`中,若在树莓派(Raspbian OS)上编译请修改ARCH参数x86为armv7,若自行编译第三方依赖软件请根据Step1中编译软件的实际情况修改主要参数,其主要内容说明如下:
 ```
@@ -68,7 +68,7 @@ ARCH=x86
 | --image_list  | 按行存储图片路径的.txt文件 |
 | --device  | 运行的平台,可选项{"CPU","MYRIAD"},默认值为"CPU",如在VPU上请使用"MYRIAD"|
 | --cfg_file | PaddleX model 的.yml配置文件 |
-| --save_dir | 可视化结果图片保存地址,仅适用于检测任务,默认值为" "不保存可视化结果 |
+| --save_dir | 可视化结果图片保存地址,仅适用于检测任务,默认值为" ",即不保存可视化结果 |
 
 ### 样例
 `样例一`:

+ 1 - 1
docs/deploy/openvino/windows.md

@@ -83,7 +83,7 @@ cd D:\projects\PaddleX\deploy\openvino\out\build\x64-Release
 | --image_list  | 按行存储图片路径的.txt文件 |
 | --device  | 运行的平台,可选项{"CPU","MYRIAD"},默认值为"CPU",如在VPU上请使用"MYRIAD"|
 | --cfg_file | PaddleX model 的.yml配置文件 |
-| --save_dir | 可视化结果图片保存地址,仅适用于检测任务,默认值为" "不保存可视化结果 |
+| --save_dir | 可视化结果图片保存地址,仅适用于检测任务,默认值为" ",即不保存可视化结果 |
 
 ### 样例
 `样例一`:

+ 3 - 1
docs/deploy/raspberry/Raspberry.md

@@ -23,7 +23,8 @@ sudo apt-get upgrade
 ```
 
 ## Paddle-Lite部署
-基于Paddle-Lite的部署目前可以支持PaddleX的分类、分割与检测模型,其实检测模型仅支持YOLOV3  
+基于Paddle-Lite的部署目前可以支持PaddleX的分类、分割与检测模型,其中检测模型仅支持YOLOV3  
+
 部署的流程包括:PaddleX模型转换与转换后的模型部署  
 
 **说明**:PaddleX安装请参考[PaddleX](https://paddlex.readthedocs.io/zh_CN/develop/install.html),Paddle-Lite详细资料请参考[Paddle-Lite](https://paddle-lite.readthedocs.io/zh/latest/index.html)
@@ -81,6 +82,7 @@ OPENCV_DIR=$(pwd)/deps/opencv/
 ### Step3: 预测
 
 编译成功后,分类任务的预测可执行程序为`classifier`,分割任务的预测可执行程序为`segmenter`,检测任务的预测可执行程序为`detector`,其主要命令参数说明如下:  
+
 |  参数   | 说明  |
 |  ----  | ----  |
 | --model_dir  | 模型转换生成的.xml文件路径,请保证模型转换生成的三个文件在同一路径下|

+ 2 - 0
paddlex/cv/models/base.py

@@ -319,6 +319,8 @@ class BaseAPI:
                 info['Transforms'] = list()
                 for op in self.test_transforms.transforms:
                     name = op.__class__.__name__
+                    if name.startswith('Arrange'):
+                        continue
                     attr = op.__dict__
                     info['Transforms'].append({name: attr})
         info['completed_epochs'] = self.completed_epochs

+ 11 - 3
paddlex/cv/models/load_model.py

@@ -68,8 +68,11 @@ def load_model(model_dir, fixed_input_shape=None):
             model.exe.run(startup_prog)
             if status == "Prune":
                 from .slim.prune import update_program
-                model.test_prog = update_program(model.test_prog, model_dir,
-                                                 model.places[0], scope=model_scope)
+                model.test_prog = update_program(
+                    model.test_prog,
+                    model_dir,
+                    model.places[0],
+                    scope=model_scope)
             import pickle
             with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
                 load_dict = pickle.load(f)
@@ -92,7 +95,7 @@ def load_model(model_dir, fixed_input_shape=None):
     if 'Transforms' in info:
         transforms_mode = info.get('TransformsMode', 'RGB')
         # 固定模型的输入shape
-        fix_input_shape(info, fixed_input_shape=fixed_input_shape)
+        fix_input_shape(info, fixed_input_shape=model.fixed_input_shape)
         if transforms_mode == 'RGB':
             to_rgb = True
         else:
@@ -121,6 +124,9 @@ def load_model(model_dir, fixed_input_shape=None):
 
 def fix_input_shape(info, fixed_input_shape=None):
     if fixed_input_shape is not None:
+        input_channel = 3
+        if 'input_channel' in info['_init_params']:
+            input_channel = info['_init_params']['input_channel']
         resize = {'ResizeByShort': {}}
         padding = {'Padding': {}}
         if info['_Attributes']['model_type'] == 'classifier':
@@ -129,5 +135,7 @@ def fix_input_shape(info, fixed_input_shape=None):
             resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
             resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
             padding['Padding']['target_size'] = list(fixed_input_shape)
+            if info['_Attributes']['model_type'] == 'segmenter':
+                padding['Padding']['im_padding_value'] = [0.] * input_channel
             info['Transforms'].append(resize)
             info['Transforms'].append(padding)