wangsiyuan06 4 years ago
parent
commit
d102b3dfbc
1 changed files with 40 additions and 4 deletions
  1. 40 4
      paddlex/cv/models/load_model.py

+ 40 - 4
paddlex/cv/models/load_model.py

@@ -55,6 +55,17 @@ def load_model(model_dir, fixed_input_shape=None):
                              format(fixed_input_shape))
                 model.fixed_input_shape = fixed_input_shape
 
+    if info['Model'].count('RCNN') > 0:
+        if info['_init_params']['with_fpn']:
+            if model.fixed_input_shape[0] % 32 > 0:
+                raise Exception(
+                    "The first value in fixed_input_shape must be a multiple of 32, but recieved {}.".
+                    format(model.fixed_input_shape[0]))
+            if model.fixed_input_shape[1] % 32 > 0:
+                raise Exception(
+                    "The second value in fixed_input_shape must be a multiple of 32, but recieved {}.".
+                    format(model.fixed_input_shape[1]))
+
     with fluid.scope_guard(model_scope):
         if status == "Normal" or \
                 status == "Prune" or status == "fluid.save":
@@ -137,12 +148,37 @@ def fix_input_shape(info, fixed_input_shape=None):
                 if list(info['Transforms'][i].keys())[0] == 'Resize':
                     resize_op_index = i
             if resize_op_index is not None:
-                info['Transforms'][resize_op_index]['Resize']['target_size'] = fixed_input_shape[0]
-        else:
+                info['Transforms'][resize_op_index]['Resize'][
+                    'target_size'] = fixed_input_shape[0]
+        elif info['Model'].count('RCNN') > 0:
+            resize_op_index = None
+            for i in range(len(info['Transforms'])):
+                if list(info['Transforms'][i].keys())[0] == 'ResizeByShort':
+                    resize_op_index = i
+            if resize_op_index is not None:
+                info['Transforms'][resize_op_index]['ResizeByShort'][
+                    'short_size'] = min(fixed_input_shape)
+                info['Transforms'][resize_op_index]['ResizeByShort'][
+                    'max_size'] = max(fixed_input_shape)
+            else:
+                resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
+                resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
+                info['Transforms'].append(resize)
+
+            padding_op_index = None
+            for i in range(len(info['Transforms'])):
+                if list(info['Transforms'][i].keys())[0] == 'Padding':
+                    padding_op_index = i
+            if padding_op_index is not None:
+                info['Transforms'][padding_op_index]['Padding'][
+                    'target_size'] = list(fixed_input_shape)
+            else:
+                padding['Padding']['target_size'] = list(fixed_input_shape)
+                info['Transforms'].append(padding)
+        elif info['_Attributes']['model_type'] == 'segmenter':
             resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
             resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
             padding['Padding']['target_size'] = list(fixed_input_shape)
-            if info['_Attributes']['model_type'] == 'segmenter':
-                padding['Padding']['im_padding_value'] = [0.] * input_channel
+            padding['Padding']['im_padding_value'] = [0.] * input_channel
             info['Transforms'].append(resize)
             info['Transforms'].append(padding)