Browse Source

merge develop

will-jl944 4 years ago
parent
commit
96e1ecab37

+ 1 - 1
README.md

@@ -16,7 +16,7 @@
  ![QQGroup](https://img.shields.io/badge/QQ_Group-1045148026-52B6EF?style=social&logo=tencent-qq&logoColor=000&logoWidth=20)
 
 
-## PaddleX dygraph mode is ready! Static mode is set by default and dynamic graph code base is in [dygraph](https://github.com/PaddlePaddle/PaddleX/tree/develop/dygraph). If you want to use static mode, the version 1.3.10 can be installed by pip. The version 2.0.0rc0 corresponds to the dygraph mode.
+## PaddleX dynamic graph mode is ready! Static graph mode is set as default and dynamic graph code base is in [dygraph](https://github.com/PaddlePaddle/PaddleX/tree/develop/dygraph). If you want to use static graph mode, the version 1.3.11 can be installed by pip. The version 2.0.0rc0 corresponds to the dynamic graph mode.
 
 
 :hugs:  PaddleX integrated the abilities of **Image classification**, **Object detection**, **Semantic segmentation**, and **Instance segmentation** in the Paddle CV toolkits, and get through the whole-process development from **Data preparation** and **Model training and optimization** to **Multi-end deployment**. At the same time, PaddleX provides **Succinct APIs** and a **Graphical User Interface**. Developers can quickly complete the end-to-end process development of the Paddle in a form of **low-code**  without installing different libraries.

+ 20 - 9
dygraph/deploy/cpp/model_deploy/ppseg/src/seg_postprocess.cpp

@@ -59,29 +59,40 @@ void SegPostprocess::RestoreSegMap(const ShapeInfo& shape_info,
     score_mat->begin<float>(), score_mat->end<float>());
 }
 
+// ppseg version >= 2.1  shape = [b, w, h]
 bool SegPostprocess::RunV2(const DataBlob& output,
                            const std::vector<ShapeInfo>& shape_infos,
                            std::vector<Result>* results, int thread_num) {
   int batch_size = shape_infos.size();
-  std::vector<int> score_map_shape = output.shape;
-  int score_map_size = std::accumulate(output.shape.begin() + 1,
-                                       output.shape.end(), 1,
-                                       std::multiplies<int>());
-  const uint8_t* score_map_data =
-          reinterpret_cast<const uint8_t*>(output.data.data());
-  int num_map_pixels = output.shape[1] * output.shape[2];
+  int label_map_size = output.shape[1] * output.shape[2];
+  const uint8_t* label_data;
+  std::vector<uint8_t> label_vector;
+  if (output.dtype == INT64) {  // int64
+    const int64_t* output_data =
+          reinterpret_cast<const int64_t*>(output.data.data());
+    std::transform(output_data, output_data + label_map_size * batch_size,
+                   std::back_inserter(label_vector),
+                   [](int64_t x) { return (uint8_t)x;});
+    label_data = reinterpret_cast<const uint8_t*>(label_vector.data());
+  } else if (output.dtype == INT8) {  // uint8
+    label_data = reinterpret_cast<const uint8_t*>(output.data.data());
+  } else {
+    std::cerr << "Output dtype is not support on seg posrtprocess "
+              << output.dtype << std::endl;
+    return false;
+  }
 
   for (int i = 0; i < batch_size; ++i) {
     (*results)[i].model_type = "seg";
     (*results)[i].seg_result = new SegResult();
-    const uint8_t* current_start_ptr = score_map_data + i * score_map_size;
+    const uint8_t* current_start_ptr = label_data + i * label_map_size;
     cv::Mat score_mat(output.shape[1], output.shape[2],
                       CV_32FC1, cv::Scalar(1.0));
     cv::Mat label_mat(output.shape[1], output.shape[2],
                       CV_8UC1, const_cast<uint8_t*>(current_start_ptr));
 
     RestoreSegMap(shape_infos[i], &label_mat,
-                &score_mat, (*results)[i].seg_result);
+                 &score_mat, (*results)[i].seg_result);
   }
   return true;
 }

+ 0 - 1
dygraph/paddlex/cv/datasets/voc.py

@@ -14,7 +14,6 @@
 
 from __future__ import absolute_import
 import copy
-import os
 import os.path as osp
 import random
 import re

+ 21 - 12
dygraph/paddlex/cv/models/classifier.py

@@ -546,10 +546,13 @@ class AlexNet(BaseClassifier):
                 image_shape = [None, 3] + image_shape
         else:
             image_shape = [None, 3, 224, 224]
-            logging.info('When exporting inference model for {},'.format(
-                self.__class__.__name__
-            ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
-                         )
+            logging.warning(
+                '[Important!!!] When exporting inference model for {},'.format(
+                    self.__class__.__name__) +
+                ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
+                +
+                'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
+                + 'should be specified manually.')
         self._fix_transforms_shape(image_shape[-2:])
 
         input_spec = [
@@ -751,10 +754,13 @@ class ShuffleNetV2(BaseClassifier):
                 image_shape = [None, 3] + image_shape
         else:
             image_shape = [None, 3, 224, 224]
-            logging.info('When exporting inference model for {},'.format(
-                self.__class__.__name__
-            ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
-                         )
+            logging.warning(
+                '[Important!!!] When exporting inference model for {},'.format(
+                    self.__class__.__name__) +
+                ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
+                +
+                'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
+                + 'should be specified manually.')
         self._fix_transforms_shape(image_shape[-2:])
         input_spec = [
             InputSpec(
@@ -774,10 +780,13 @@ class ShuffleNetV2_swish(BaseClassifier):
                 image_shape = [None, 3] + image_shape
         else:
             image_shape = [None, 3, 224, 224]
-            logging.info('When exporting inference model for {},'.format(
-                self.__class__.__name__
-            ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
-                         )
+            logging.warning(
+                '[Important!!!] When exporting inference model for {},'.format(
+                    self.__class__.__name__) +
+                ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
+                +
+                'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
+                + 'should be specified manually.')
         self._fix_transforms_shape(image_shape[-2:])
         input_spec = [
             InputSpec(

+ 34 - 1
dygraph/paddlex/cv/models/detector.py

@@ -848,7 +848,8 @@ class FasterRCNN(BaseDetector):
                 if test_pre_nms_top_n is None else test_pre_nms_top_n,
                 'post_nms_top_n': test_post_nms_top_n
             }
-            head = ppdet.modeling.TwoFCHead(out_channel=1024)
+            head = ppdet.modeling.TwoFCHead(
+                in_channel=neck.out_shape[0].channels, out_channel=1024)
             roi_extractor_cfg = {
                 'resolution': 7,
                 'spatial_scale': [1. / i.stride for i in neck.out_shape],
@@ -1378,6 +1379,38 @@ class PPYOLOv2(YOLOv3):
         self.downsample_ratios = downsample_ratios
         self.model_name = 'PPYOLOv2'
 
+    def _get_test_inputs(self, image_shape):
+        if image_shape is not None:
+            if len(image_shape) == 2:
+                image_shape = [None, 3] + image_shape
+            if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
+                raise Exception(
+                    "Height and width in fixed_input_shape must be a multiple of 32, but recieved is {}.".
+                    format(image_shape[-2:]))
+            self._fix_transforms_shape(image_shape[-2:])
+        else:
+            logging.warning(
+                '[Important!!!] When exporting inference model for {},'.format(
+                    self.__class__.__name__) +
+                ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 608, 608]. '
+                +
+                'Please check image shape after transforms is [3, 608, 608], if not, fixed_input_shape '
+                + 'should be specified manually.')
+            image_shape = [None, 3, 608, 608]
+
+        input_spec = [{
+            "image": InputSpec(
+                shape=image_shape, name='image', dtype='float32'),
+            "im_shape": InputSpec(
+                shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
+            "scale_factor": InputSpec(
+                shape=[image_shape[0], 2],
+                name='scale_factor',
+                dtype='float32')
+        }]
+
+        return input_spec
+
 
 class MaskRCNN(BaseDetector):
     def __init__(self,

+ 7 - 1
dygraph/paddlex/cv/models/segmenter.py

@@ -463,7 +463,13 @@ class BaseSegmenter(BaseModel):
         label_map = label_map.numpy().astype('uint8')
         score_map = outputs['score_map']
         score_map = score_map.numpy().astype('float32')
-        return {'label_map': label_map, 'score_map': score_map}
+        prediction = [{
+            'label_map': l,
+            'score_map': s
+        } for l, s in zip(label_map, score_map)]
+        if isinstance(img_file, (str, np.ndarray)):
+            prediction = prediction[0]
+        return prediction
 
     def _preprocess(self, images, transforms, model_type):
         arrange_transforms(

+ 10 - 4
dygraph/paddlex/cv/models/utils/det_metrics/metrics.py

@@ -61,6 +61,11 @@ class VOCMetric(Metric):
                  classwise=False):
         self.cid2cname = {i: name for i, name in enumerate(labels)}
         self.coco_gt = coco_gt
+        self.clsid2catid = {
+            i: cat['id']
+            for i, cat in enumerate(
+                self.coco_gt.loadCats(self.coco_gt.getCatIds()))
+        }
         self.overlap_thresh = overlap_thresh
         self.map_type = map_type
         self.evaluate_difficult = evaluate_difficult
@@ -80,9 +85,10 @@ class VOCMetric(Metric):
         self.detection_map.reset()
 
     def update(self, inputs, outputs):
-        bboxes = outputs['bbox'][:, 2:].numpy()
-        scores = outputs['bbox'][:, 1].numpy()
-        labels = outputs['bbox'][:, 0].numpy()
+        bbox_np = outputs['bbox'].numpy()
+        bboxes = bbox_np[:, 2:]
+        scores = bbox_np[:, 1]
+        labels = bbox_np[:, 0]
         bbox_lengths = outputs['bbox_num'].numpy()
 
         if bboxes.shape == (1, 1) or bboxes is None:
@@ -121,7 +127,7 @@ class VOCMetric(Metric):
                 bbox = [xmin, ymin, w, h]
                 coco_res = {
                     'image_id': int(inputs['im_id']),
-                    'category_id': int(l + 1),
+                    'category_id': self.clsid2catid[int(l)],
                     'bbox': bbox,
                     'score': float(s)
                 }

+ 8 - 5
dygraph/paddlex/cv/transforms/operators.py

@@ -858,13 +858,16 @@ class RandomScaleAspect(Transform):
 
 class RandomExpand(Transform):
     """
-    Randomly expand the input by padding to the lower right side of the image(s) in input.
+    Randomly expand the input by padding according to random offsets.
 
     Args:
         upper_ratio(float, optional): The maximum ratio to which the original image is expanded. Defaults to 4..
         prob(float, optional): The probability of apply expanding. Defaults to .5.
         im_padding_value(List[float] or Tuple[float], optional): RGB filling value for the image. Defaults to (127.5, 127.5, 127.5).
         label_padding_value(int, optional): Filling value for the mask. Defaults to 255.
+
+    See Also:
+        paddlex.transforms.Padding
     """
 
     def __init__(self,
@@ -992,10 +995,10 @@ class Padding(Transform):
             ), 'target size ({}, {}) cannot be less than image size ({}, {})'\
                 .format(h, w, im_h, im_w)
         else:
-            h = (np.ceil(im_h // self.coarsest_stride) *
-                 self.coarsest_stride).astype(int)
-            w = (np.ceil(im_w / self.coarsest_stride) *
-                 self.coarsest_stride).astype(int)
+            h = (np.ceil(im_h / self.size_divisor) *
+                 self.size_divisor).astype(int)
+            w = (np.ceil(im_w / self.size_divisor) *
+                 self.size_divisor).astype(int)
 
         if h == im_h and w == im_w:
             return sample