# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from urllib.parse import urlparse import yaml from paddleseg.utils import NoAliasDumper from paddleseg.cvlibs.config import parse_from_yaml, merge_config_dicts from ..base import BaseConfig from ...utils.misc import abspath class BaseSegConfig(BaseConfig): """ BaseSegConfig """ def update(self, dict_like_obj): """ update """ dict_ = merge_config_dicts(dict_like_obj, self.dict) self.reset_from_dict(dict_) def load(self, config_path): """ load """ dict_ = parse_from_yaml(config_path) if not isinstance(dict_, dict): raise TypeError self.reset_from_dict(dict_) def dump(self, config_path): """ dump """ with open(config_path, 'w', encoding='utf-8') as f: yaml.dump(self.dict, f, Dumper=NoAliasDumper) def update_learning_rate(self, learning_rate): """ update_learning_rate """ if 'lr_scheduler' not in self: raise RuntimeError( "Not able to update learning rate, because no LR scheduler config was found." ) self.lr_scheduler['learning_rate'] = learning_rate def update_batch_size(self, batch_size, mode='train'): """ update_batch_size """ if mode == 'train': self.set_val('batch_size', batch_size) else: raise ValueError( f"Setting `batch_size` in {repr(mode)} mode is not supported.") def update_log_ranks(self, device): """update log ranks Args: device (str): the running device to set """ log_ranks = device.split(':')[1] self.set_val('log_ranks', log_ranks) def update_print_mem_info(self, print_mem_info: bool): """setting print memory info""" assert isinstance(print_mem_info, bool), "print_mem_info should be a bool" self.set_val('print_mem_info', print_mem_info) def update_pretrained_weights(self, weight_path, is_backbone=False): """ update_pretrained_weights """ if 'model' not in self: raise RuntimeError( "Not able to update pretrained weight path, because no model config was found." ) if isinstance(weight_path, str): if urlparse(weight_path).scheme == '': # If `weight_path` is a string but not URL (with scheme present), # it will be recognized as a local file path. weight_path = abspath(weight_path) else: if weight_path is not None: raise TypeError("`weight_path` must be string or None.") if is_backbone: if 'backbone' not in self.model: raise RuntimeError( "Not able to update pretrained weight path of backbone, because no backbone config was found." ) self.model['backbone']['pretrained'] = weight_path else: self.model['pretrained'] = weight_path def update_dy2st(self, dy2st): """ update_dy2st """ self.set_val('to_static_training', dy2st) def update_dataset(self, dataset_dir, dataset_type=None): """ update_dataset """ raise NotImplementedError def get_epochs_iters(self): """ get_epochs_iters """ raise NotImplementedError def get_learning_rate(self): """ get_learning_rate """ raise NotImplementedError def get_batch_size(self, mode='train'): """ get_batch_size """ raise NotImplementedError def get_qat_epochs_iters(self): """ get_qat_epochs_iters """ return self.get_epochs_iters() // 2 def get_qat_learning_rate(self): """ get_qat_learning_rate """ return self.get_learning_rate() / 2