| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- import paddle.fluid as fluid
- import paddlex
- from collections import OrderedDict
- from .deeplabv3p import DeepLabv3p
- class FastSCNN(DeepLabv3p):
- """实现Fast SCNN网络的构建并进行训练、评估、预测和模型导出。
- Args:
- num_classes (int): 类别数。
- use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
- use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
- 当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
- class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
- num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
- 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
- 即平时使用的交叉熵损失函数。
- ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
- multi_loss_weight (list): 多分支上的loss权重。默认计算一个分支上的loss,即默认值为[1.0]。
- 也支持计算两个分支或三个分支上的loss,权重按[fusion_branch_weight, higher_branch_weight, lower_branch_weight]排列,
- fusion_branch_weight为空间细节分支和全局上下文分支融合后的分支上的loss权重,higher_branch_weight为空间细节分支上的loss权重,
- lower_branch_weight为全局上下文分支上的loss权重,若higher_branch_weight和lower_branch_weight未设置则不会计算这两个分支上的loss。
- Raises:
- ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
- ValueError: class_weight为list, 但长度不等于num_class。
- class_weight为str, 但class_weight.low()不等于dynamic。
- TypeError: class_weight不为None时,其类型不是list或str。
- TypeError: multi_loss_weight不为list。
- ValueError: multi_loss_weight为list但长度小于0或者大于3。
- """
- def __init__(self,
- num_classes=2,
- use_bce_loss=False,
- use_dice_loss=False,
- class_weight=None,
- ignore_index=255,
- multi_loss_weight=[1.0]):
- self.init_params = locals()
- super(DeepLabv3p, self).__init__('segmenter')
- # dice_loss或bce_loss只适用两类分割中
- if num_classes > 2 and (use_bce_loss or use_dice_loss):
- raise ValueError(
- "dice loss and bce loss is only applicable to binary classfication"
- )
- if class_weight is not None:
- if isinstance(class_weight, list):
- if len(class_weight) != num_classes:
- raise ValueError(
- "Length of class_weight should be equal to number of classes"
- )
- elif isinstance(class_weight, str):
- if class_weight.lower() != 'dynamic':
- raise ValueError(
- "if class_weight is string, must be dynamic!")
- else:
- raise TypeError(
- 'Expect class_weight is a list or string but receive {}'.
- format(type(class_weight)))
- if not isinstance(multi_loss_weight, list):
- raise TypeError(
- 'Expect multi_loss_weight is a list but receive {}'.format(
- type(multi_loss_weight)))
- if len(multi_loss_weight) > 3 or len(multi_loss_weight) < 0:
- raise ValueError(
- "Length of multi_loss_weight should be lower than or equal to 3 but greater than 0."
- )
- self.num_classes = num_classes
- self.use_bce_loss = use_bce_loss
- self.use_dice_loss = use_dice_loss
- self.class_weight = class_weight
- self.multi_loss_weight = multi_loss_weight
- self.ignore_index = ignore_index
- self.labels = None
- self.fixed_input_shape = None
- def build_net(self, mode='train'):
- model = paddlex.cv.nets.segmentation.FastSCNN(
- self.num_classes,
- mode=mode,
- use_bce_loss=self.use_bce_loss,
- use_dice_loss=self.use_dice_loss,
- class_weight=self.class_weight,
- ignore_index=self.ignore_index,
- multi_loss_weight=self.multi_loss_weight,
- fixed_input_shape=self.fixed_input_shape)
- inputs = model.generate_inputs()
- model_out = model.build_net(inputs)
- outputs = OrderedDict()
- if mode == 'train':
- self.optimizer.minimize(model_out)
- outputs['loss'] = model_out
- else:
- outputs['pred'] = model_out[0]
- outputs['logit'] = model_out[1]
- return inputs, outputs
- def train(self,
- num_epochs,
- train_dataset,
- train_batch_size=2,
- eval_dataset=None,
- save_interval_epochs=1,
- log_interval_steps=2,
- save_dir='output',
- pretrain_weights='CITYSCAPES',
- optimizer=None,
- learning_rate=0.01,
- lr_decay_power=0.9,
- use_vdl=False,
- sensitivities_file=None,
- eval_metric_loss=0.05,
- early_stop=False,
- early_stop_patience=5,
- resume_checkpoint=None):
- """训练。
- Args:
- num_epochs (int): 训练迭代轮数。
- train_dataset (paddlex.datasets): 训练数据读取器。
- train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。
- eval_dataset (paddlex.datasets): 评估数据读取器。
- save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
- log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
- save_dir (str): 模型保存路径。默认'output'。
- pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'CITYSCAPES'
- 则自动下载在CITYSCAPES图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'CITYSCAPES'。
- optimizer (paddle.fluid.optimizer): 优化器。当改参数为None时,使用默认的优化器:使用
- fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
- learning_rate (float): 默认优化器的初始学习率。默认0.01。
- lr_decay_power (float): 默认优化器学习率多项式衰减系数。默认0.9。
- use_vdl (bool): 是否使用VisualDL进行可视化。默认False。
- sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
- 则自动下载在Cityscapes图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
- eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
- early_stop (bool): 是否使用提前终止训练策略。默认值为False。
- early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
- 连续下降或持平,则终止训练。默认值为5。
- resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
- Raises:
- ValueError: 模型从inference model进行加载。
- """
- return super(FastSCNN, self).train(
- num_epochs, train_dataset, train_batch_size, eval_dataset,
- save_interval_epochs, log_interval_steps, save_dir,
- pretrain_weights, optimizer, learning_rate, lr_decay_power,
- use_vdl, sensitivities_file, eval_metric_loss, early_stop,
- early_stop_patience, resume_checkpoint)
|