Browse Source

fix input shape when exporting ppyolo serials

FlyingQianMM 4 năm trước cách đây
mục cha
commit
ca8f5b9cc2
1 tập tin đã thay đổi với 50 bổ sung0 xóa
  1. 50 0
      paddlex/cv/models/detector.py

+ 50 - 0
paddlex/cv/models/detector.py

@@ -1245,6 +1245,31 @@ class PPYOLO(YOLOv3):
         self.downsample_ratios = downsample_ratios
         self.model_name = 'PPYOLO'
 
+    def _get_test_inputs(self, image_shape):
+        if image_shape is not None:
+            image_shape = self._check_image_shape(image_shape)
+            self._fix_transforms_shape(image_shape[-2:])
+        else:
+            image_shape = [None, 3, 608, 608]
+            if hasattr(self, 'test_transforms'):
+                if self.test_transforms is not None:
+                    for idx, op in enumerate(self.test_transforms.transforms):
+                        name = op.__class__.__name__
+                        if name == 'Resize':
+                            image_shape = [None, 3] + list(
+                                self.test_transforms.transforms[
+                                    idx].target_size)
+            logging.warning(
+                '[Important!!!] When exporting inference model for {},'.format(
+                    self.__class__.__name__) +
+                ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
+                format(image_shape) +
+                'Please check image shape after transforms is {}, if not, fixed_input_shape '.
+                format(image_shape[1:]) + 'should be specified manually.')
+
+        self.fixed_input_shape = image_shape
+        return self._define_input_spec(image_shape)
+
 
 class PPYOLOTiny(YOLOv3):
     def __init__(self,
@@ -1353,6 +1378,31 @@ class PPYOLOTiny(YOLOv3):
         self.downsample_ratios = downsample_ratios
         self.model_name = 'PPYOLOTiny'
 
+    def _get_test_inputs(self, image_shape):
+        if image_shape is not None:
+            image_shape = self._check_image_shape(image_shape)
+            self._fix_transforms_shape(image_shape[-2:])
+        else:
+            image_shape = [None, 3, 320, 320]
+            if hasattr(self, 'test_transforms'):
+                if self.test_transforms is not None:
+                    for idx, op in enumerate(self.test_transforms.transforms):
+                        name = op.__class__.__name__
+                        if name == 'Resize':
+                            image_shape = [None, 3] + list(
+                                self.test_transforms.transforms[
+                                    idx].target_size)
+            logging.warning(
+                '[Important!!!] When exporting inference model for {},'.format(
+                    self.__class__.__name__) +
+                ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
+                format(image_shape) +
+                'Please check image shape after transforms is {}, if not, fixed_input_shape '.
+                format(image_shape[1:]) + 'should be specified manually.')
+
+        self.fixed_input_shape = image_shape
+        return self._define_input_spec(image_shape)
+
 
 class PPYOLOv2(YOLOv3):
     def __init__(self,