base_seg_config.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. from urllib.parse import urlparse
  12. import yaml
  13. from paddleseg.utils import NoAliasDumper
  14. from paddleseg.cvlibs.config import parse_from_yaml, merge_config_dicts
  15. from ..base import BaseConfig
  16. from ...utils.misc import abspath
  17. class BaseSegConfig(BaseConfig):
  18. """ BaseSegConfig """
  19. def update(self, dict_like_obj):
  20. """ update """
  21. dict_ = merge_config_dicts(dict_like_obj, self.dict)
  22. self.reset_from_dict(dict_)
  23. def load(self, config_path):
  24. """ load """
  25. dict_ = parse_from_yaml(config_path)
  26. if not isinstance(dict_, dict):
  27. raise TypeError
  28. self.reset_from_dict(dict_)
  29. def dump(self, config_path):
  30. """ dump """
  31. with open(config_path, 'w', encoding='utf-8') as f:
  32. yaml.dump(self.dict, f, Dumper=NoAliasDumper)
  33. def update_learning_rate(self, learning_rate):
  34. """ update_learning_rate """
  35. if 'lr_scheduler' not in self:
  36. raise RuntimeError(
  37. "Not able to update learning rate, because no LR scheduler config was found."
  38. )
  39. self.lr_scheduler['learning_rate'] = learning_rate
  40. def update_batch_size(self, batch_size, mode='train'):
  41. """ update_batch_size """
  42. if mode == 'train':
  43. self.set_val('batch_size', batch_size)
  44. else:
  45. raise ValueError(
  46. f"Setting `batch_size` in {repr(mode)} mode is not supported.")
  47. def update_pretrained_weights(self, weight_path, is_backbone=False):
  48. """ update_pretrained_weights """
  49. if 'model' not in self:
  50. raise RuntimeError(
  51. "Not able to update pretrained weight path, because no model config was found."
  52. )
  53. if isinstance(weight_path, str):
  54. if urlparse(weight_path).scheme == '':
  55. # If `weight_path` is a string but not URL (with scheme present),
  56. # it will be recognized as a local file path.
  57. weight_path = abspath(weight_path)
  58. else:
  59. if weight_path is not None:
  60. raise TypeError("`weight_path` must be string or None.")
  61. if is_backbone:
  62. if 'backbone' not in self.model:
  63. raise RuntimeError(
  64. "Not able to update pretrained weight path of backbone, because no backbone config was found."
  65. )
  66. self.model['backbone']['pretrained'] = weight_path
  67. else:
  68. self.model['pretrained'] = weight_path
  69. def update_dy2st(self, dy2st):
  70. """ update_dy2st """
  71. self.set_val('to_static_training', dy2st)
  72. def update_dataset(self, dataset_dir, dataset_type=None):
  73. """ update_dataset """
  74. raise NotImplementedError
  75. def get_epochs_iters(self):
  76. """ get_epochs_iters """
  77. raise NotImplementedError
  78. def get_learning_rate(self):
  79. """ get_learning_rate """
  80. raise NotImplementedError
  81. def get_batch_size(self, mode='train'):
  82. """ get_batch_size """
  83. raise NotImplementedError
  84. def get_qat_epochs_iters(self):
  85. """ get_qat_epochs_iters """
  86. return self.get_epochs_iters() // 2
  87. def get_qat_learning_rate(self):
  88. """ get_qat_learning_rate """
  89. return self.get_learning_rate() / 2