Jelajahi Sumber

fixed input shape for hrnet

jiangjiajun 5 tahun lalu
induk
melakukan
e3f56c10c3
2 mengubah file dengan 18 tambahan dan 6 penghapusan
  1. 5 3
      paddlex/cv/models/hrnet.py
  2. 13 3
      paddlex/cv/nets/segmentation/hrnet.py

+ 5 - 3
paddlex/cv/models/hrnet.py

@@ -77,6 +77,7 @@ class HRNet(DeepLabv3p):
         self.class_weight = class_weight
         self.ignore_index = ignore_index
         self.labels = None
+        self.fixed_input_shape = None
 
     def build_net(self, mode='train'):
         model = paddlex.cv.nets.segmentation.HRNet(
@@ -86,7 +87,8 @@ class HRNet(DeepLabv3p):
             use_bce_loss=self.use_bce_loss,
             use_dice_loss=self.use_dice_loss,
             class_weight=self.class_weight,
-            ignore_index=self.ignore_index)
+            ignore_index=self.ignore_index,
+            fixed_input_shape=self.fixed_input_shape)
         inputs = model.generate_inputs()
         model_out = model.build_net(inputs)
         outputs = OrderedDict()
@@ -170,6 +172,6 @@ class HRNet(DeepLabv3p):
         return super(HRNet, self).train(
             num_epochs, train_dataset, train_batch_size, eval_dataset,
             save_interval_epochs, log_interval_steps, save_dir,
-            pretrain_weights, optimizer, learning_rate, lr_decay_power,
-            use_vdl, sensitivities_file, eval_metric_loss, early_stop,
+            pretrain_weights, optimizer, learning_rate, lr_decay_power, use_vdl,
+            sensitivities_file, eval_metric_loss, early_stop,
             early_stop_patience, resume_checkpoint)

+ 13 - 3
paddlex/cv/nets/segmentation/hrnet.py

@@ -38,7 +38,8 @@ class HRNet(object):
                  use_bce_loss=False,
                  use_dice_loss=False,
                  class_weight=None,
-                 ignore_index=255):
+                 ignore_index=255,
+                 fixed_input_shape=None):
         # dice_loss或bce_loss只适用两类分割中
         if num_classes > 2 and (use_bce_loss or use_dice_loss):
             raise ValueError(
@@ -66,6 +67,7 @@ class HRNet(object):
         self.use_dice_loss = use_dice_loss
         self.class_weight = class_weight
         self.ignore_index = ignore_index
+        self.fixed_input_shape = fixed_input_shape
         self.backbone = paddlex.cv.nets.hrnet.HRNet(
             width=width, feature_maps="stage4")
 
@@ -131,8 +133,16 @@ class HRNet(object):
 
     def generate_inputs(self):
         inputs = OrderedDict()
-        inputs['image'] = fluid.data(
-            dtype='float32', shape=[None, 3, None, None], name='image')
+
+        if self.fixed_input_shape is not None:
+            input_shape = [
+                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+            ]
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
+        else:
+            inputs['image'] = fluid.data(
+                dtype='float32', shape=[None, 3, None, None], name='image')
         if self.mode == 'train':
             inputs['label'] = fluid.data(
                 dtype='int32', shape=[None, 1, None, None], name='label')