|
|
@@ -77,6 +77,7 @@ class HRNet(DeepLabv3p):
|
|
|
self.class_weight = class_weight
|
|
|
self.ignore_index = ignore_index
|
|
|
self.labels = None
|
|
|
+ self.fixed_input_shape = None
|
|
|
|
|
|
def build_net(self, mode='train'):
|
|
|
model = paddlex.cv.nets.segmentation.HRNet(
|
|
|
@@ -86,7 +87,8 @@ class HRNet(DeepLabv3p):
|
|
|
use_bce_loss=self.use_bce_loss,
|
|
|
use_dice_loss=self.use_dice_loss,
|
|
|
class_weight=self.class_weight,
|
|
|
- ignore_index=self.ignore_index)
|
|
|
+ ignore_index=self.ignore_index,
|
|
|
+ fixed_input_shape=self.fixed_input_shape)
|
|
|
inputs = model.generate_inputs()
|
|
|
model_out = model.build_net(inputs)
|
|
|
outputs = OrderedDict()
|
|
|
@@ -170,6 +172,6 @@ class HRNet(DeepLabv3p):
|
|
|
return super(HRNet, self).train(
|
|
|
num_epochs, train_dataset, train_batch_size, eval_dataset,
|
|
|
save_interval_epochs, log_interval_steps, save_dir,
|
|
|
- pretrain_weights, optimizer, learning_rate, lr_decay_power,
|
|
|
- use_vdl, sensitivities_file, eval_metric_loss, early_stop,
|
|
|
+ pretrain_weights, optimizer, learning_rate, lr_decay_power, use_vdl,
|
|
|
+ sensitivities_file, eval_metric_loss, early_stop,
|
|
|
early_stop_patience, resume_checkpoint)
|