# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import yaml from typing import Union from ...base import BaseConfig from ....utils.misc import abspath class ClsConfig(BaseConfig): """Image Classification Task Config""" def update(self, list_like_obj: list): """update self Args: list_like_obj (list): list of pairs(key0.key1.idx.key2=value), such as: [ 'topk=2', 'VALID.transforms.1.ResizeImage.resize_short=300' ] """ from paddleclas.ppcls.utils.config import override_config dict_ = override_config(self.dict, list_like_obj) self.reset_from_dict(dict_) def load(self, config_file_path: str): """load config from yaml file Args: config_file_path (str): the path of yaml file. Raises: TypeError: the content of yaml file `config_file_path` error. """ dict_ = yaml.load(open(config_file_path, "rb"), Loader=yaml.Loader) if not isinstance(dict_, dict): raise TypeError self.reset_from_dict(dict_) def dump(self, config_file_path: str): """dump self to yaml file Args: config_file_path (str): the path to save self as yaml file. """ with open(config_file_path, "w", encoding="utf-8") as f: yaml.dump(self.dict, f, default_flow_style=False, sort_keys=False) def update_dataset( self, dataset_path: str, dataset_type: str = None, *, train_list_path: str = None, ): """update dataset settings Args: dataset_path (str): the root path of dataset. dataset_type (str, optional): dataset type. Defaults to None. train_list_path (str, optional): the path of train dataset annotation file . Defaults to None. Raises: ValueError: the dataset_type error. """ dataset_path = abspath(dataset_path) if dataset_type is None: dataset_type = "ClsDataset" if train_list_path: train_list_path = f"{train_list_path}" else: train_list_path = f"{dataset_path}/train.txt" if dataset_type in ["ClsDataset", "MLClsDataset"]: ds_cfg = [ f"DataLoader.Train.dataset.name={dataset_type}", f"DataLoader.Train.dataset.image_root={dataset_path}", f"DataLoader.Train.dataset.cls_label_path={train_list_path}", f"DataLoader.Eval.dataset.name={dataset_type}", f"DataLoader.Eval.dataset.image_root={dataset_path}", f"DataLoader.Eval.dataset.cls_label_path={dataset_path}/val.txt", f"Infer.PostProcess.class_id_map_file={dataset_path}/label.txt", ] else: raise ValueError(f"{repr(dataset_type)} is not supported.") self.update(ds_cfg) def update_batch_size(self, batch_size: int, mode: str = "train"): """update batch size setting Args: batch_size (int): the batch size number to set. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'. Defaults to 'train'. Raises: ValueError: `mode` error. """ if mode == "train": if self.DataLoader["Train"]["sampler"].get("batch_size", False): _cfg = [f"DataLoader.Train.sampler.batch_size={batch_size}"] else: _cfg = [f"DataLoader.Train.sampler.first_bs={batch_size}"] _cfg = [f"DataLoader.Train.dataset.name=MultiScaleDataset"] elif mode == "eval": _cfg = [f"DataLoader.Eval.sampler.batch_size={batch_size}"] elif mode == "test": _cfg = [f"DataLoader.Infer.batch_size={batch_size}"] else: raise ValueError("The input `mode` should be train, eval or test.") self.update(_cfg) def update_learning_rate(self, learning_rate: float): """update learning rate Args: learning_rate (float): the learning rate value to set. """ if self._dict["Optimizer"]["lr"].get("learning_rate", None) is not None: _cfg = [f"Optimizer.lr.learning_rate={learning_rate}"] elif self._dict["Optimizer"]["lr"].get("max_learning_rate", None) is not None: _cfg = [f"Optimizer.lr.max_learning_rate={learning_rate}"] else: raise ValueError("unsupported lr format") self.update(_cfg) def update_warmup_epochs(self, warmup_epochs: int): """update warmup epochs Args: warmup_epochs (int): the warmup epochs value to set. """ _cfg = [f"Optimizer.lr.warmup_epoch={warmup_epochs}"] self.update(_cfg) def update_pretrained_weights(self, pretrained_model: str): """update pretrained weight path Args: pretrained_model (str): the local path or url of pretrained weight file to set. """ assert isinstance( pretrained_model, (str, type(None)) ), "The 'pretrained_model' should be a string, indicating the path to the '*.pdparams' file, or 'None', \ indicating that no pretrained model to be used." if pretrained_model is None: self.update(["Global.pretrained_model=None"]) self.update(["Arch.pretrained=False"]) else: if pretrained_model.lower() == "default": self.update(["Global.pretrained_model=None"]) self.update(["Arch.pretrained=True"]) else: if not pretrained_model.startswith(("http://", "https://")): pretrained_model = abspath( pretrained_model.replace(".pdparams", "") ) self.update([f"Global.pretrained_model={pretrained_model}"]) def update_num_classes(self, num_classes: int): """update classes number Args: num_classes (int): the classes number value to set. """ update_str_list = [f"Arch.class_num={num_classes}"] if self._get_arch_name() == "DistillationModel": update_str_list.append(f"Arch.models.0.Teacher.class_num={num_classes}") update_str_list.append(f"Arch.models.1.Student.class_num={num_classes}") ml_decoder = self.dict.get("MLDecoder", None) if ml_decoder is not None: self.update_ml_query_num(num_classes) self.update_ml_class_num(num_classes) self.update(update_str_list) def update_ml_query_num(self, query_num: int): """update MLDecoder query number Args: query_num (int): the query number value to set,qury_num should be less than or equal to num_classes. """ base_query_num = self.dict.get("MLDecoder", {}).get("query_num", None) if base_query_num is not None: _cfg = [f"MLDecoder.query_num={query_num}"] self.update(_cfg) def update_ml_class_num(self, class_num: int): """update MLDecoder query number Args: num_classes (int): the classes number value to set. """ base_class_num = self.dict.get("MLDecoder", {}).get("class_num", None) if base_class_num is not None: _cfg = [f"MLDecoder.class_num={class_num}"] self.update(_cfg) def _update_slim_config(self, slim_config_path: str): """update slim settings Args: slim_config_path (str): the path to slim config yaml file. """ slim_config = yaml.load(open(slim_config_path, "rb"), Loader=yaml.Loader)[ "Slim" ] self.update([f"Slim={slim_config}"]) def _update_amp(self, amp: Union[None, str]): """update AMP settings Args: amp (None | str): the AMP settings. Raises: ValueError: AMP setting `amp` error, missing field `AMP`. """ if amp is None or amp == "OFF": if "AMP" in self.dict: self._dict.pop("AMP") else: if "AMP" not in self.dict: raise ValueError("Config must have AMP information.") _cfg = ["AMP.use_amp=True", f"AMP.level={amp}"] self.update(_cfg) def update_num_workers(self, num_workers: int): """update workers number of train and eval dataloader Args: num_workers (int): the value of train and eval dataloader workers number to set. """ _cfg = [ f"DataLoader.Train.loader.num_workers={num_workers}", f"DataLoader.Eval.loader.num_workers={num_workers}", ] self.update(_cfg) def update_shared_memory(self, shared_memeory: bool): """update shared memory setting of train and eval dataloader Args: shared_memeory (bool): whether or not to use shared memory """ assert isinstance(shared_memeory, bool), "shared_memeory should be a bool" _cfg = [ f"DataLoader.Train.loader.use_shared_memory={shared_memeory}", f"DataLoader.Eval.loader.use_shared_memory={shared_memeory}", ] self.update(_cfg) def update_shuffle(self, shuffle: bool): """update shuffle setting of train and eval dataloader Args: shuffle (bool): whether or not to shuffle the data """ assert isinstance(shuffle, bool), "shuffle should be a bool" _cfg = [ f"DataLoader.Train.loader.shuffle={shuffle}", f"DataLoader.Eval.loader.shuffle={shuffle}", ] self.update(_cfg) def update_dali(self, dali: bool): """enable DALI setting of train and eval dataloader Args: dali (bool): whether or not to use DALI """ assert isinstance(dali, bool), "dali should be a bool" _cfg = [ f"Global.use_dali={dali}", f"Global.use_dali={dali}", ] self.update(_cfg) def update_seed(self, seed: int): """update seed Args: seed (int): the random seed value to set """ _cfg = [f"Global.seed={seed}"] self.update(_cfg) def update_device(self, device: str): """update device setting Args: device (str): the running device to set """ device = device.split(":")[0] _cfg = [f"Global.device={device}"] self.update(_cfg) def update_label_dict_path(self, dict_path: str): """update label dict file path Args: dict_path (str): the path of label dict file to set """ _cfg = [ f"PostProcess.Topk.class_id_map_file={abspath(dict_path)}", ] self.update(_cfg) def _update_to_static(self, dy2st: bool): """update config to set dynamic to static mode Args: dy2st (bool): whether or not to use the dynamic to static mode. """ self.update([f"Global.to_static={dy2st}"]) def _update_use_vdl(self, use_vdl: bool): """update config to set VisualDL Args: use_vdl (bool): whether or not to use VisualDL. """ self.update([f"Global.use_visualdl={use_vdl}"]) def _update_epochs(self, epochs: int): """update epochs setting Args: epochs (int): the epochs number value to set """ self.update([f"Global.epochs={epochs}"]) def _update_checkpoints(self, resume_path: Union[None, str]): """update checkpoint setting Args: resume_path (None | str): the resume training setting. if is `None`, train from scratch, otherwise, train from checkpoint file that path is `.pdparams` file. """ if resume_path is not None: resume_path = resume_path.replace(".pdparams", "") self.update([f"Global.checkpoints={resume_path}"]) def _update_output_dir(self, save_dir: str): """update output directory Args: save_dir (str): the path to save outputs. """ self.update([f"Global.output_dir={abspath(save_dir)}"]) def update_log_interval(self, log_interval: int): """update log interval(steps) Args: log_interval (int): the log interval value to set. """ self.update([f"Global.print_batch_step={log_interval}"]) def update_eval_interval(self, eval_interval: int): """update eval interval(epochs) Args: eval_interval (int): the eval interval value to set. """ self.update([f"Global.eval_interval={eval_interval}"]) def update_save_interval(self, save_interval: int): """update eval interval(epochs) Args: save_interval (int): the save interval value to set. """ self.update([f"Global.save_interval={save_interval}"]) def update_log_ranks(self, device): """update log ranks Args: device (str): the running device to set """ log_ranks = device.split(":")[1] self.update([f'Global.log_ranks="{log_ranks}"']) def update_print_mem_info(self, print_mem_info: bool): """setting print memory info""" assert isinstance(print_mem_info, bool), "print_mem_info should be a bool" self.update([f"Global.print_mem_info={print_mem_info}"]) def _update_predict_img(self, infer_img: str, infer_list: str = None): """update image to be predicted Args: infer_img (str): the path to image that to be predicted. infer_list (str, optional): the path to file that images. Defaults to None. """ if infer_list: self.update([f"Infer.infer_list={infer_list}"]) self.update([f"Infer.infer_imgs={infer_img}"]) def _update_save_inference_dir(self, save_inference_dir: str): """update directory path to save inference model files Args: save_inference_dir (str): the directory path to set. """ self.update([f"Global.save_inference_dir={abspath(save_inference_dir)}"]) def _update_inference_model_dir(self, model_dir: str): """update inference model directory Args: model_dir (str): the directory path of inference model fils that used to predict. """ self.update([f"Global.inference_model_dir={abspath(model_dir)}"]) def _update_infer_img(self, infer_img: str): """update path of image that would be predict Args: infer_img (str): the image path. """ self.update([f"Global.infer_imgs={infer_img}"]) def _update_infer_device(self, device: str): """update the device used in predicting Args: device (str): the running device setting """ self.update([f'Global.use_gpu={device.split(":")[0]=="gpu"}']) def _update_enable_mkldnn(self, enable_mkldnn: bool): """update whether to enable MKLDNN Args: enable_mkldnn (bool): `True` is enable, otherwise is disable. """ self.update([f"Global.enable_mkldnn={enable_mkldnn}"]) def _update_infer_img_shape(self, img_shape: str): """update image cropping shape in the preprocessing Args: img_shape (str): the shape of cropping in the preprocessing, i.e. `PreProcess.transform_ops.1.CropImage.size`. """ self.update([f"PreProcess.transform_ops.1.CropImage.size={img_shape}"]) def _update_save_predict_result(self, save_dir: str): """update directory that save predicting output Args: save_dir (str): the dicrectory path that save predicting output. """ self.update([f"Infer.save_dir={save_dir}"]) def update_model(self, **kwargs): """update model settings""" for k in kwargs: v = kwargs[k] self.update([f"Arch.{k}={v}"]) def update_teacher_model(self, **kwargs): """update teacher model settings""" for k in kwargs: v = kwargs[k] self.update([f"Arch.models.0.Teacher.{k}={v}"]) def update_student_model(self, **kwargs): """update student model settings""" for k in kwargs: v = kwargs[k] self.update([f"Arch.models.1.Student.{k}={v}"]) def get_epochs_iters(self) -> int: """get epochs Returns: int: the epochs value, i.e., `Global.epochs` in config. """ return self.dict["Global"]["epochs"] def get_log_interval(self) -> int: """get log interval(steps) Returns: int: the log interval value, i.e., `Global.print_batch_step` in config. """ return self.dict["Global"]["print_batch_step"] def get_eval_interval(self) -> int: """get eval interval(epochs) Returns: int: the eval interval value, i.e., `Global.eval_interval` in config. """ return self.dict["Global"]["eval_interval"] def get_save_interval(self) -> int: """get save interval(epochs) Returns: int: the save interval value, i.e., `Global.save_interval` in config. """ return self.dict["Global"]["save_interval"] def get_learning_rate(self) -> float: """get learning rate Returns: float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config. """ return self.dict["Optimizer"]["lr"]["learning_rate"] def get_warmup_epochs(self) -> int: """get warmup epochs Returns: int: the warmup epochs value, i.e., `Optimizer.lr.warmup_epochs` in config. """ return self.dict["Optimizer"]["lr"]["warmup_epoch"] def get_label_dict_path(self) -> str: """get label dict file path Returns: str: the label dict file path, i.e., `PostProcess.Topk.class_id_map_file` in config. """ return self.dict["PostProcess"]["Topk"]["class_id_map_file"] def get_batch_size(self, mode="train") -> int: """get batch size Args: mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'. Defaults to 'train'. Returns: int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config. """ return self.dict["DataLoader"]["Train"]["sampler"]["batch_size"] def get_qat_epochs_iters(self) -> int: """get qat epochs Returns: int: the epochs value. """ return self.get_epochs_iters() def get_qat_learning_rate(self) -> float: """get qat learning rate Returns: float: the learning rate value. """ return self.get_learning_rate() def _get_arch_name(self) -> str: """get architecture name of model Returns: str: the model arch name, i.e., `Arch.name` in config. """ return self.dict["Arch"]["name"] def _get_dataset_root(self) -> str: """get root directory of dataset, i.e. `DataLoader.Train.dataset.image_root` Returns: str: the root directory of dataset """ return self.dict["DataLoader"]["Train"]["dataset"]["image_root"] def get_train_save_dir(self) -> str: """get the directory to save output Returns: str: the directory to save output """ return self["Global"]["output_dir"]