浏览代码

Merge pull request #866 from FlyingQianMM/develop_qh

append Pad() to test_transforms when exporting inference model for RCNN_FPN
FlyingQianMM 4 年之前
父节点
当前提交
5827b04d01
共有 1 个文件被更改,包括 43 次插入13 次删除
  1. 43 13
      dygraph/paddlex/cv/models/detector.py

+ 43 - 13
dygraph/paddlex/cv/models/detector.py

@@ -60,18 +60,7 @@ class BaseDetector(BaseModel):
     def _fix_transforms_shape(self, image_shape):
     def _fix_transforms_shape(self, image_shape):
         raise NotImplementedError("_fix_transforms_shape: not implemented!")
         raise NotImplementedError("_fix_transforms_shape: not implemented!")
 
 
-    def _get_test_inputs(self, image_shape):
-        if image_shape is not None:
-            if len(image_shape) == 2:
-                image_shape = [None, 3] + image_shape
-            if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
-                raise Exception(
-                    "Height and width in fixed_input_shape must be a multiple of 32, but recieved is {}.".
-                    format(image_shape[-2:]))
-            self._fix_transforms_shape(image_shape[-2:])
-        else:
-            image_shape = [None, 3, -1, -1]
-
+    def _define_input_spec(self, image_shape):
         input_spec = [{
         input_spec = [{
             "image": InputSpec(
             "image": InputSpec(
                 shape=image_shape, name='image', dtype='float32'),
                 shape=image_shape, name='image', dtype='float32'),
@@ -82,9 +71,26 @@ class BaseDetector(BaseModel):
                 name='scale_factor',
                 name='scale_factor',
                 dtype='float32')
                 dtype='float32')
         }]
         }]
-
         return input_spec
         return input_spec
 
 
+    def _check_image_shape(self, image_shape):
+        if len(image_shape) == 2:
+            image_shape = [None, 3] + image_shape
+            if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
+                raise Exception(
+                    "Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
+                    format(image_shape[-2:]))
+        return image_shape
+
+    def _get_test_inputs(self, image_shape):
+        if image_shape is not None:
+            image_shape = self._check_image_shape(image_shape)
+            self._fix_transforms_shape(image_shape[-2:])
+        else:
+            image_shape = [None, 3, -1, -1]
+
+        return self._define_input_spec(image_shape)
+
     def _get_backbone(self, backbone_name, **params):
     def _get_backbone(self, backbone_name, **params):
         backbone = getattr(ppdet.modeling, backbone_name)(**params)
         backbone = getattr(ppdet.modeling, backbone_name)(**params)
         return backbone
         return backbone
@@ -1004,6 +1010,18 @@ class FasterRCNN(BaseDetector):
                 self.test_transforms.transforms.append(
                 self.test_transforms.transforms.append(
                     Padding(im_padding_value=[0., 0., 0.]))
                     Padding(im_padding_value=[0., 0., 0.]))
 
 
+    def _get_test_inputs(self, image_shape):
+        if image_shape is not None:
+            image_shape = self._check_image_shape(image_shape)
+            self._fix_transforms_shape(image_shape[-2:])
+        else:
+            image_shape = [None, 3, -1, -1]
+            if self.with_fpn:
+                self.test_transforms.transforms.append(
+                    Padding(im_padding_value=[0., 0., 0.]))
+
+        return self._define_input_spec(image_shape)
+
 
 
 class PPYOLO(YOLOv3):
 class PPYOLO(YOLOv3):
     def __init__(self,
     def __init__(self,
@@ -1713,3 +1731,15 @@ class MaskRCNN(BaseDetector):
                         interp='CUBIC')
                         interp='CUBIC')
                 self.test_transforms.transforms.append(
                 self.test_transforms.transforms.append(
                     Padding(im_padding_value=[0., 0., 0.]))
                     Padding(im_padding_value=[0., 0., 0.]))
+
+    def _get_test_inputs(self, image_shape):
+        if image_shape is not None:
+            image_shape = self._check_image_shape(image_shape)
+            self._fix_transforms_shape(image_shape[-2:])
+        else:
+            image_shape = [None, 3, -1, -1]
+            if self.with_fpn:
+                self.test_transforms.transforms.append(
+                    Padding(im_padding_value=[0., 0., 0.]))
+
+        return self._define_input_spec(image_shape)