FlyingQianMM 4 лет назад
Родитель
Сommit
c08c781909

+ 10 - 6
dygraph/paddlex/cv/models/classifier.py

@@ -529,16 +529,18 @@ class AlexNet(BaseClassifier):
         super(AlexNet, self).__init__(
             model_name='AlexNet', num_classes=num_classes)
 
-    def get_test_inputs(self, image_shape):
+    def _get_test_inputs(self, image_shape):
         if image_shape is not None:
             if len(image_shape) == 2:
                 image_shape = [None, 3] + image_shape
         else:
-            image_shape = [224, 224]
+            image_shape = [None, 3, 224, 224]
             logging.info('When exporting inference model for {},'.format(
                 self.__class__.__name__
             ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
                          )
+        self._fix_transforms_shape(image_shape[-2:])
+
         input_spec = [
             InputSpec(
                 shape=image_shape, name='image', dtype='float32')
@@ -732,16 +734,17 @@ class ShuffleNetV2(BaseClassifier):
         super(ShuffleNetV2, self).__init__(
             model_name=model_name, num_classes=num_classes)
 
-    def get_test_inputs(self, image_shape):
+    def _get_test_inputs(self, image_shape):
         if image_shape is not None:
             if len(image_shape) == 2:
                 image_shape = [None, 3] + image_shape
         else:
-            image_shape = [224, 224]
+            image_shape = [None, 3, 224, 224]
             logging.info('When exporting inference model for {},'.format(
                 self.__class__.__name__
             ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
                          )
+        self._fix_transforms_shape(image_shape[-2:])
         input_spec = [
             InputSpec(
                 shape=image_shape, name='image', dtype='float32')
@@ -754,16 +757,17 @@ class ShuffleNetV2_swish(BaseClassifier):
         super(ShuffleNetV2_swish, self).__init__(
             model_name='ShuffleNetV2_x1_5', num_classes=num_classes)
 
-    def get_test_inputs(self, image_shape):
+    def _get_test_inputs(self, image_shape):
         if image_shape is not None:
             if len(image_shape) == 2:
                 image_shape = [None, 3] + image_shape
         else:
-            image_shape = [224, 224]
+            image_shape = [None, 3, 224, 224]
             logging.info('When exporting inference model for {},'.format(
                 self.__class__.__name__
             ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
                          )
+        self._fix_transforms_shape(image_shape[-2:])
         input_spec = [
             InputSpec(
                 shape=image_shape, name='image', dtype='float32')

+ 6 - 2
dygraph/paddlex/cv/models/detector.py

@@ -1640,9 +1640,13 @@ class MaskRCNN(BaseDetector):
                     self.test_transforms.transforms.insert(
                         normalize_op_idx,
                         Resize(
-                            target_size=image_shape, keep_ratio=True))
+                            target_size=image_shape,
+                            keep_ratio=True,
+                            interp='CUBIC'))
                 else:
                     self.test_transforms.transforms[resize_op_idx] = Resize(
-                        target_size=image_shape, keep_ratio=True)
+                        target_size=image_shape,
+                        keep_ratio=True,
+                        interp='CUBIC')
                 self.test_transforms.transforms.append(
                     Padding(im_padding_value=[0., 0., 0.]))

+ 25 - 36
dygraph/paddlex/cv/models/load_model.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import os.path as osp
-
+import numpy as np
 import yaml
 import paddle
 import paddleslim
@@ -22,6 +22,29 @@ import paddlex.utils.logging as logging
 from paddlex.cv.transforms import build_transforms
 
 
+def load_rcnn_inference_model(model_dir):
+    paddle.enable_static()
+    exe = paddle.static.Executor(paddle.CPUPlace())
+    path_prefix = osp.join(model_dir, "model")
+    prog, _, _ = paddle.static.load_inference_model(path_prefix, exe)
+    paddle.disable_static()
+    extra_var_info = paddle.load(osp.join(model_dir, "model.pdiparams.info"))
+
+    net_state_dict = dict()
+    static_state_dict = dict()
+
+    for name, var in prog.state_dict().items():
+        static_state_dict[name] = np.array(var)
+    for var_name in static_state_dict:
+        if var_name not in extra_var_info:
+            continue
+        structured_name = extra_var_info[var_name].get('structured_name', None)
+        if structured_name is None:
+            continue
+        net_state_dict[structured_name] = static_state_dict[var_name]
+    return net_state_dict
+
+
 def load_model(model_dir):
     """
     Load saved model from a given directory.
@@ -85,41 +108,7 @@ def load_model(model_dir):
 
         if status == 'Infer':
             if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
-                #net_state_dict = paddle.load(
-                #    model_dir,
-                #    params_filename='model.pdiparams',
-                #    model_filename='model.pdmodel')
-
-                net = paddle.jit.load(osp.join(model_dir, 'model'))
-                #load_param_dict = paddle.load(osp.join(model_dir, 'model.pdiparams'))
-                #print(load_param_dict)
-
-                import pickle
-                var_info_path = osp.join(model_dir, 'model.pdiparams.info')
-                with open(var_info_path, 'rb') as f:
-                    extra_var_info = pickle.load(f)
-                net_state_dict = dict()
-                static_state_dict = dict()
-                for name, var in net.state_dict().items():
-                    print(name, var.name)
-                    static_state_dict[var.name] = var.numpy()
-                exit()
-                for var_name in static_state_dict:
-                    if var_name not in extra_var_info:
-                        print(var_name)
-                        continue
-                    structured_name = extra_var_info[var_name].get(
-                        'structured_name', None)
-                    if structured_name is None:
-                        continue
-                    net_state_dict[structured_name] = static_state_dict[
-                        var_name]
-
-                #model.net = paddle.jit.load(
-                #    model_dir,
-                #    params_filename='model.pdiparams',
-                #    model_filename='model.pdmodel')
-                #net_state_dict = paddle.load(osp.join(model_dir, 'model'))
+                net_state_dict = load_rcnn_inference_model(model_dir)
             else:
                 net_state_dict = paddle.load(osp.join(model_dir, 'model'))
         else:

+ 12 - 16
dygraph/paddlex/cv/transforms/operators.py

@@ -238,24 +238,20 @@ class Resize(Transform):
         self.interp = interp
         self.keep_ratio = keep_ratio
 
-    def apply_im(self, image, interp):
-        image = cv2.resize(
-            image, (self.target_size[1], self.target_size[0]),
-            interpolation=interp)
+    def apply_im(self, image, interp, target_size):
+        image = cv2.resize(image, target_size, interpolation=interp)
         return image
 
     def apply_mask(self, mask):
-        mask = cv2.resize(
-            mask, (self.target_size[1], self.target_size[0]),
-            interpolation=cv2.INTER_NEAREST)
+        mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
         return mask
 
-    def apply_bbox(self, bbox, scale):
+    def apply_bbox(self, bbox, scale, target_size):
         im_scale_x, im_scale_y = scale
         bbox[:, 0::2] *= im_scale_x
         bbox[:, 1::2] *= im_scale_y
-        bbox[:, 0::2] = np.clip(bbox[:, 0::2], 0, self.target_size[1])
-        bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, self.target_size[0])
+        bbox[:, 0::2] = np.clip(bbox[:, 0::2], 0, target_size[0])
+        bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, target_size[1])
         return bbox
 
     def apply_segm(self, segms, im_size, scale):
@@ -284,22 +280,22 @@ class Resize(Transform):
 
         im_scale_y = self.target_size[0] / im_h
         im_scale_x = self.target_size[1] / im_w
-        target_size = list(self.target_size)
+        target_size = (self.target_size[1], self.target_size[0])
         if self.keep_ratio:
             scale = min(im_scale_y, im_scale_x)
             target_w = int(round(im_w * scale))
             target_h = int(round(im_h * scale))
-            target_size = [target_w, target_h]
+            target_size = (target_w, target_h)
             im_scale_y = target_h / im_h
             im_scale_x = target_w / im_w
 
-        sample['image'] = self.apply_im(sample['image'], interp)
+        sample['image'] = self.apply_im(sample['image'], interp, target_size)
 
         if 'mask' in sample:
-            sample['mask'] = self.apply_mask(sample['mask'])
+            sample['mask'] = self.apply_mask(sample['mask'], target_size)
         if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
-            sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'],
-                                                [im_scale_x, im_scale_y])
+            sample['gt_bbox'] = self.apply_bbox(
+                sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             sample['gt_poly'] = self.apply_segm(
                 sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y])