|
|
@@ -0,0 +1,548 @@
|
|
|
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+
|
|
|
+import os
|
|
|
+import yaml
|
|
|
+from typing import Union
|
|
|
+
|
|
|
+from ...base import BaseConfig
|
|
|
+from ....utils.misc import abspath
|
|
|
+from ..config_utils import merge_config
|
|
|
+
|
|
|
+
|
|
|
+class VideoDetConfig(BaseConfig):
|
|
|
+ """Image Det Task Config"""
|
|
|
+
|
|
|
+ def update(self, dict_like_obj: list):
|
|
|
+ """update self
|
|
|
+
|
|
|
+ Args:
|
|
|
+ dict_like_obj (list): list of pairs(key0.key1.idx.key2=value)
|
|
|
+ """
|
|
|
+ dict_ = merge_config(self.dict, dict_like_obj)
|
|
|
+ self.reset_from_dict(dict_)
|
|
|
+
|
|
|
+ def load(self, config_file_path: str):
|
|
|
+ """load config from yaml file
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_file_path (str): the path of yaml file.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ TypeError: the content of yaml file `config_file_path` error.
|
|
|
+ """
|
|
|
+ dict_ = yaml.load(open(config_file_path, "rb"), Loader=yaml.Loader)
|
|
|
+ if not isinstance(dict_, dict):
|
|
|
+ raise TypeError
|
|
|
+ self.reset_from_dict(dict_)
|
|
|
+
|
|
|
+ def dump(self, config_file_path: str):
|
|
|
+ """dump self to yaml file
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_file_path (str): the path to save self as yaml file.
|
|
|
+ """
|
|
|
+ with open(config_file_path, "w", encoding="utf-8") as f:
|
|
|
+ yaml.dump(self.dict, f, default_flow_style=False, sort_keys=False)
|
|
|
+
|
|
|
+ def update_dataset(
|
|
|
+ self,
|
|
|
+ dataset_path: str,
|
|
|
+ dataset_type: str = None,
|
|
|
+ *,
|
|
|
+ train_list_path: str = None,
|
|
|
+ ):
|
|
|
+ """update dataset settings
|
|
|
+
|
|
|
+ Args:
|
|
|
+ dataset_path (str): the root path of dataset.
|
|
|
+ dataset_type (str, optional): dataset type. Defaults to None.
|
|
|
+ train_list_path (str, optional): the path of train dataset annotation file . Defaults to None.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ValueError: the dataset_type error.
|
|
|
+ """
|
|
|
+ dataset_path = abspath(dataset_path)
|
|
|
+ if dataset_type is None:
|
|
|
+ dataset_type = "VideoDetDataset"
|
|
|
+ if train_list_path:
|
|
|
+ train_list_path = f"{train_list_path}"
|
|
|
+ else:
|
|
|
+ train_list_path = f"{dataset_path}/train.txt"
|
|
|
+
|
|
|
+ if dataset_type in ["VideoDetDataset"]:
|
|
|
+ _cfg = {
|
|
|
+ "DATASET.train.image_dir": dataset_path,
|
|
|
+ "DATASET.train.file_path": os.path.join(dataset_path, "train.txt"),
|
|
|
+ "DATASET.valid.image_dir": dataset_path,
|
|
|
+ "DATASET.valid.file_path": os.path.join(dataset_path, "val.txt"),
|
|
|
+ "DATASET.test.image_dir": dataset_path,
|
|
|
+ "DATASET.test.file_path": os.path.join(dataset_path, "val.txt"),
|
|
|
+ "METRIC.gt_folder": os.path.join(dataset_path, "val.txt"),
|
|
|
+ "label_dict_path": os.path.join(dataset_path, "label_map.txt"),
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ raise ValueError(f"{repr(dataset_type)} is not supported.")
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_batch_size(self, batch_size: int, mode: str = "train"):
|
|
|
+ """update batch size setting
|
|
|
+
|
|
|
+ Args:
|
|
|
+ batch_size (int): the batch size number to set.
|
|
|
+ mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
|
|
|
+ Defaults to 'train'.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ValueError: `mode` error.
|
|
|
+ """
|
|
|
+
|
|
|
+ if mode == "train":
|
|
|
+ _cfg = {"DATASET.batch_size": batch_size}
|
|
|
+ elif mode == "eval":
|
|
|
+ _cfg = {"DATASET.test_batch_size": batch_size}
|
|
|
+ elif mode == "test":
|
|
|
+ _cfg = {"DATASET.test_batch_size": batch_size}
|
|
|
+ else:
|
|
|
+ raise ValueError("The input `mode` should be train, eval or test.")
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_learning_rate(self, learning_rate: float):
|
|
|
+ """update learning rate
|
|
|
+
|
|
|
+ Args:
|
|
|
+ learning_rate (float): the learning rate value to set.
|
|
|
+ """
|
|
|
+ if (
|
|
|
+ self._dict["OPTIMIZER"]["learning_rate"].get("learning_rate", None)
|
|
|
+ is not None
|
|
|
+ ):
|
|
|
+ _cfg = {"OPTIMIZER.learning_rate.learning_rate": learning_rate}
|
|
|
+ else:
|
|
|
+ raise ValueError("unsupported lr format")
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_num_classes(self, num_classes: int):
|
|
|
+ """update classes number
|
|
|
+
|
|
|
+ Args:
|
|
|
+ num_classes (int): the classes number value to set.
|
|
|
+ """
|
|
|
+
|
|
|
+ update_str_list = {"MODEL.backbone.num_class": num_classes}
|
|
|
+ self.update(update_str_list)
|
|
|
+ update_str_list = {"MODEL.loss.num_classes": num_classes}
|
|
|
+ self.update(update_str_list)
|
|
|
+
|
|
|
+ def update_label_list(self, label_path: str):
|
|
|
+ """update label list
|
|
|
+
|
|
|
+ Args:
|
|
|
+ label_list (str): the path of label list file to set.
|
|
|
+ """
|
|
|
+ with open(label_path, "r") as f:
|
|
|
+ lines = [line.strip().split(" ") for line in f.readlines()]
|
|
|
+ sorted_lines = sorted(lines, key=lambda x: int(x[1]))
|
|
|
+ label_list = [line[0] for line in sorted_lines]
|
|
|
+ f.close()
|
|
|
+
|
|
|
+ self.update({"label_list": label_list})
|
|
|
+
|
|
|
+ def update_pretrained_weights(self, pretrained_model: str):
|
|
|
+ """update pretrained weight path
|
|
|
+
|
|
|
+ Args:
|
|
|
+ pretrained_model (str): the local path or url of pretrained weight file to set.
|
|
|
+ """
|
|
|
+ assert isinstance(
|
|
|
+ pretrained_model, (str, type(None))
|
|
|
+ ), "The 'pretrained_model' should be a string, indicating the path to the '*.pdparams' file, or 'None', \
|
|
|
+indicating that no pretrained model to be used."
|
|
|
+
|
|
|
+ if pretrained_model is None:
|
|
|
+ self.update({"Global.pretrained_model", None})
|
|
|
+ else:
|
|
|
+ if pretrained_model.lower() == "default":
|
|
|
+ self.update({"Global.pretrained_model", None})
|
|
|
+ else:
|
|
|
+ if not pretrained_model.startswith(("http://", "https://")):
|
|
|
+ pretrained_model = abspath(pretrained_model)
|
|
|
+ self.update({"Global.pretrained_model": pretrained_model})
|
|
|
+
|
|
|
+ def _update_slim_config(self, slim_config_path: str):
|
|
|
+ """update slim settings
|
|
|
+
|
|
|
+ Args:
|
|
|
+ slim_config_path (str): the path to slim config yaml file.
|
|
|
+ """
|
|
|
+ slim_config = yaml.load(open(slim_config_path, "rb"), Loader=yaml.Loader)[
|
|
|
+ "Slim"
|
|
|
+ ]
|
|
|
+ self.update({"Slim": slim_config})
|
|
|
+
|
|
|
+ def _update_amp(self, amp: Union[None, str]):
|
|
|
+ """update AMP settings
|
|
|
+
|
|
|
+ Args:
|
|
|
+ amp (None | str): the AMP settings.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ValueError: AMP setting `amp` error, missing field `AMP`.
|
|
|
+ """
|
|
|
+ if amp is None or amp == "OFF":
|
|
|
+ if "AMP" in self.dict:
|
|
|
+ self._dict.pop("AMP")
|
|
|
+ else:
|
|
|
+ if "AMP" not in self.dict:
|
|
|
+ raise ValueError("Config must have AMP information.")
|
|
|
+ _cfg = {"AMP.use_amp": True, "AMP.level": amp}
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_num_workers(self, num_workers: int):
|
|
|
+ """update workers number of train and eval dataloader
|
|
|
+
|
|
|
+ Args:
|
|
|
+ num_workers (int): the value of train and eval dataloader workers number to set.
|
|
|
+ """
|
|
|
+ _cfg = {
|
|
|
+ "DATASET.num_workers": num_workers,
|
|
|
+ }
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_shared_memory(self, shared_memeory: bool):
|
|
|
+ """update shared memory setting of train and eval dataloader
|
|
|
+
|
|
|
+ Args:
|
|
|
+ shared_memeory (bool): whether or not to use shared memory
|
|
|
+ """
|
|
|
+ assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
|
|
|
+ _cfg = {
|
|
|
+ "DataLoader.Train.loader.use_shared_memory": shared_memeory,
|
|
|
+ "DataLoader.Eval.loader.use_shared_memory": shared_memeory,
|
|
|
+ }
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_shuffle(self, shuffle: bool):
|
|
|
+ """update shuffle setting of train and eval dataloader
|
|
|
+
|
|
|
+ Args:
|
|
|
+ shuffle (bool): whether or not to shuffle the data
|
|
|
+ """
|
|
|
+ assert isinstance(shuffle, bool), "shuffle should be a bool"
|
|
|
+ _cfg = [
|
|
|
+ f"DataLoader.Train.loader.shuffle={shuffle}",
|
|
|
+ f"DataLoader.Eval.loader.shuffle={shuffle}",
|
|
|
+ ]
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_dali(self, dali: bool):
|
|
|
+ """enable DALI setting of train and eval dataloader
|
|
|
+
|
|
|
+ Args:
|
|
|
+ dali (bool): whether or not to use DALI
|
|
|
+ """
|
|
|
+ assert isinstance(dali, bool), "dali should be a bool"
|
|
|
+ _cfg = [
|
|
|
+ f"Global.use_dali={dali}",
|
|
|
+ f"Global.use_dali={dali}",
|
|
|
+ ]
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_seed(self, seed: int):
|
|
|
+ """update seed
|
|
|
+
|
|
|
+ Args:
|
|
|
+ seed (int): the random seed value to set
|
|
|
+ """
|
|
|
+ _cfg = {"Global.seed": seed}
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_device(self, device: str):
|
|
|
+ """update device setting
|
|
|
+
|
|
|
+ Args:
|
|
|
+ device (str): the running device to set
|
|
|
+ """
|
|
|
+ device = device.split(":")[0]
|
|
|
+ _cfg = {"Global.device": device}
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_label_dict_path(self, dict_path: str):
|
|
|
+ """update label dict file path
|
|
|
+
|
|
|
+ Args:
|
|
|
+ dict_path (str): the path of label dict file to set
|
|
|
+ """
|
|
|
+ _cfg = {
|
|
|
+ "PostProcess.Topk.class_id_map_file": {abspath(dict_path)},
|
|
|
+ }
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def _update_to_static(self, dy2st: bool):
|
|
|
+ """update config to set dynamic to static mode
|
|
|
+
|
|
|
+ Args:
|
|
|
+ dy2st (bool): whether or not to use the dynamic to static mode.
|
|
|
+ """
|
|
|
+ self.update({"Global.to_static": dy2st})
|
|
|
+
|
|
|
+ def _update_use_vdl(self, use_vdl: bool):
|
|
|
+ """update config to set VisualDL
|
|
|
+
|
|
|
+ Args:
|
|
|
+ use_vdl (bool): whether or not to use VisualDL.
|
|
|
+ """
|
|
|
+ self.update({"Global.use_visuald": use_vdl})
|
|
|
+
|
|
|
+ def _update_epochs(self, epochs: int):
|
|
|
+ """update epochs setting
|
|
|
+
|
|
|
+ Args:
|
|
|
+ epochs (int): the epochs number value to set
|
|
|
+ """
|
|
|
+ self.update({"epochs": epochs})
|
|
|
+
|
|
|
+ def _update_checkpoints(self, resume_path: Union[None, str]):
|
|
|
+ """update checkpoint setting
|
|
|
+
|
|
|
+ Args:
|
|
|
+ resume_path (None | str): the resume training setting. if is `None`, train from scratch, otherwise,
|
|
|
+ train from checkpoint file that path is `.pdparams` file.
|
|
|
+ """
|
|
|
+ if resume_path is not None:
|
|
|
+ resume_path = resume_path.replace(".pdparams", "")
|
|
|
+ self.update({"Global.checkpoints": resume_path})
|
|
|
+
|
|
|
+ def _update_output_dir(self, save_dir: str):
|
|
|
+ """update output directory
|
|
|
+
|
|
|
+ Args:
|
|
|
+ save_dir (str): the path to save outputs.
|
|
|
+ """
|
|
|
+ self.update({"output_dir": abspath(save_dir)})
|
|
|
+ self.update({"METRIC.result_path": abspath(save_dir)})
|
|
|
+
|
|
|
+ def update_log_interval(self, log_interval: int):
|
|
|
+ """update log interval(steps)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ log_interval (int): the log interval value to set.
|
|
|
+ """
|
|
|
+ self.update({"log_interval": log_interval})
|
|
|
+
|
|
|
+ def update_eval_interval(self, eval_interval: int):
|
|
|
+ """update eval interval(epochs)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ eval_interval (int): the eval interval value to set.
|
|
|
+ """
|
|
|
+ self.update({"val_interval": eval_interval})
|
|
|
+
|
|
|
+ def update_save_interval(self, save_interval: int):
|
|
|
+ """update eval interval(epochs)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ save_interval (int): the save interval value to set.
|
|
|
+ """
|
|
|
+ self.update({"save_interval": save_interval})
|
|
|
+
|
|
|
+ def update_log_ranks(self, device):
|
|
|
+ """update log ranks
|
|
|
+
|
|
|
+ Args:
|
|
|
+ device (str): the running device to set
|
|
|
+ """
|
|
|
+ log_ranks = device.split(":")[1]
|
|
|
+ self.update({"Global.log_ranks": log_ranks})
|
|
|
+
|
|
|
+ def update_print_mem_info(self, print_mem_info: bool):
|
|
|
+ """setting print memory info"""
|
|
|
+ assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
|
|
|
+ self.update({"Global.print_mem_info": print_mem_info})
|
|
|
+
|
|
|
+ def _update_predict_video(self, infer_video: str, infer_list: str = None):
|
|
|
+ """update video to be predicted
|
|
|
+
|
|
|
+ Args:
|
|
|
+ infer_video (str): the path to image that to be predicted.
|
|
|
+ infer_list (str, optional): the path to file that videos. Defaults to None.
|
|
|
+ """
|
|
|
+ if infer_list:
|
|
|
+ self.update({"Infer.infer_list": infer_list})
|
|
|
+ self.update({"Infer.infer_videos": infer_video})
|
|
|
+
|
|
|
+ def _update_save_inference_dir(self, save_inference_dir: str):
|
|
|
+ """update directory path to save inference model files
|
|
|
+
|
|
|
+ Args:
|
|
|
+ save_inference_dir (str): the directory path to set.
|
|
|
+ """
|
|
|
+ self.update({"Global.save_inference_dir": abspath(save_inference_dir)})
|
|
|
+
|
|
|
+ def _update_inference_model_dir(self, model_dir: str):
|
|
|
+ """update inference model directory
|
|
|
+
|
|
|
+ Args:
|
|
|
+ model_dir (str): the directory path of inference model fils that used to predict.
|
|
|
+ """
|
|
|
+ self.update({"Global.inference_model_dir": abspath(model_dir)})
|
|
|
+
|
|
|
+ def _update_infer_video(self, infer_video: str):
|
|
|
+ """update path of image that would be predict
|
|
|
+
|
|
|
+ Args:
|
|
|
+ infer_video (str): the image path.
|
|
|
+ """
|
|
|
+ self.update({"Global.infer_videos": infer_video})
|
|
|
+
|
|
|
+ def _update_infer_device(self, device: str):
|
|
|
+ """update the device used in predicting
|
|
|
+
|
|
|
+ Args:
|
|
|
+ device (str): the running device setting
|
|
|
+ """
|
|
|
+ self.update({"Global.use_gpu": device.split(":")[0] == "gpu"})
|
|
|
+
|
|
|
+ def _update_enable_mkldnn(self, enable_mkldnn: bool):
|
|
|
+ """update whether to enable MKLDNN
|
|
|
+
|
|
|
+ Args:
|
|
|
+ enable_mkldnn (bool): `True` is enable, otherwise is disable.
|
|
|
+ """
|
|
|
+ self.update({"Global.enable_mkldnn": enable_mkldnn})
|
|
|
+
|
|
|
+ def _update_infer_video_shape(self, img_shape: str):
|
|
|
+ """update image cropping shape in the preprocessing
|
|
|
+
|
|
|
+ Args:
|
|
|
+ img_shape (str): the shape of cropping in the preprocessing,
|
|
|
+ i.e. `PreProcess.transform_ops.1.CropImage.size`.
|
|
|
+ """
|
|
|
+ self.update({"INFERENCE.target_size": img_shape})
|
|
|
+
|
|
|
+ def _update_save_predict_result(self, save_dir: str):
|
|
|
+ """update directory that save predicting output
|
|
|
+
|
|
|
+ Args:
|
|
|
+ save_dir (str): the dicrectory path that save predicting output.
|
|
|
+ """
|
|
|
+ self.update({"Infer.save_dir": save_dir})
|
|
|
+
|
|
|
+ def get_epochs_iters(self) -> int:
|
|
|
+ """get epochs
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ int: the epochs value, i.e., `Global.epochs` in config.
|
|
|
+ """
|
|
|
+ return self.dict["Global"]["epochs"]
|
|
|
+
|
|
|
+ def get_log_interval(self) -> int:
|
|
|
+ """get log interval(steps)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ int: the log interval value, i.e., `Global.print_batch_step` in config.
|
|
|
+ """
|
|
|
+ return self.dict["Global"]["print_batch_step"]
|
|
|
+
|
|
|
+ def get_eval_interval(self) -> int:
|
|
|
+ """get eval interval(epochs)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ int: the eval interval value, i.e., `Global.eval_interval` in config.
|
|
|
+ """
|
|
|
+ return self.dict["Global"]["eval_interval"]
|
|
|
+
|
|
|
+ def get_save_interval(self) -> int:
|
|
|
+ """get save interval(epochs)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ int: the save interval value, i.e., `Global.save_interval` in config.
|
|
|
+ """
|
|
|
+ return self.dict["Global"]["save_interval"]
|
|
|
+
|
|
|
+ def get_learning_rate(self) -> float:
|
|
|
+ """get learning rate
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
|
|
|
+ """
|
|
|
+ return self.dict["Optimizer"]["lr"]["learning_rate"]
|
|
|
+
|
|
|
+ def get_warmup_epochs(self) -> int:
|
|
|
+ """get warmup epochs
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ int: the warmup epochs value, i.e., `Optimizer.lr.warmup_epochs` in config.
|
|
|
+ """
|
|
|
+ return self.dict["Optimizer"]["lr"]["warmup_epoch"]
|
|
|
+
|
|
|
+ def get_label_dict_path(self) -> str:
|
|
|
+ """get label dict file path
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: the label dict file path, i.e., `PostProcess.Topk.class_id_map_file` in config.
|
|
|
+ """
|
|
|
+ return self.dict["PostProcess"]["Topk"]["class_id_map_file"]
|
|
|
+
|
|
|
+ def get_batch_size(self, mode="train") -> int:
|
|
|
+ """get batch size
|
|
|
+
|
|
|
+ Args:
|
|
|
+ mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
|
|
|
+ Defaults to 'train'.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
|
|
|
+ """
|
|
|
+ return self.dict["DataLoader"]["Train"]["sampler"]["batch_size"]
|
|
|
+
|
|
|
+ def get_qat_epochs_iters(self) -> int:
|
|
|
+ """get qat epochs
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ int: the epochs value.
|
|
|
+ """
|
|
|
+ return self.get_epochs_iters()
|
|
|
+
|
|
|
+ def get_qat_learning_rate(self) -> float:
|
|
|
+ """get qat learning rate
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ float: the learning rate value.
|
|
|
+ """
|
|
|
+ return self.get_learning_rate()
|
|
|
+
|
|
|
+ def _get_arch_name(self) -> str:
|
|
|
+ """get architecture name of model
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: the model arch name, i.e., `Arch.name` in config.
|
|
|
+ """
|
|
|
+ return self.dict["Arch"]["name"]
|
|
|
+
|
|
|
+ def _get_dataset_root(self) -> str:
|
|
|
+ """get root directory of dataset, i.e. `DataLoader.Train.dataset.video_root`
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: the root directory of dataset
|
|
|
+ """
|
|
|
+ return self.dict["DataLoader"]["Train"]["dataset"]["video_root"]
|
|
|
+
|
|
|
+ def get_train_save_dir(self) -> str:
|
|
|
+ """get the directory to save output
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: the directory to save output
|
|
|
+ """
|
|
|
+ return self["output_dir"]
|