pp3d_config.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import codecs
  15. import yaml
  16. from ...utils.misc import abspath
  17. from ..base import BaseConfig
  18. class PP3DConfig(BaseConfig):
  19. # Refer to https://github.com/PaddlePaddle/Paddle3D/blob/release/1.0/paddle3d/apis/config.py
  20. def update(self, dict_like_obj):
  21. def _merge_config_dicts(dict_from, dict_to):
  22. # According to
  23. # https://github.com/PaddlePaddle/Paddle3D/blob/3cf884ecbc94330be0e2db780434bb60b9b4fe8c/paddle3d/apis/config.py#L90
  24. for key, val in dict_from.items():
  25. if isinstance(val, dict) and key in dict_to:
  26. dict_to[key] = _merge_config_dicts(val, dict_to[key])
  27. else:
  28. dict_to[key] = val
  29. return dict_to
  30. dict_ = _merge_config_dicts(dict_like_obj, self.dict)
  31. self.reset_from_dict(dict_)
  32. def load(self, config_path):
  33. with codecs.open(config_path, "r", "utf-8") as file:
  34. dic = yaml.load(file, Loader=yaml.FullLoader)
  35. dict_ = dic
  36. self.reset_from_dict(dict_)
  37. def dump(self, config_path):
  38. with open(config_path, "w", encoding="utf-8") as f:
  39. yaml.dump(self.dict, f)
  40. def update_learning_rate(self, learning_rate):
  41. if "lr_scheduler" not in self:
  42. raise RuntimeError(
  43. "Not able to update learning rate, because no LR scheduler config was found."
  44. )
  45. # Some lr_scheduler in Paddle3D has not learning_rate parameter.
  46. if self.lr_scheduler["type"] == "OneCycle":
  47. self.lr_scheduler["lr_max"] = learning_rate
  48. elif self.lr_scheduler["type"] == "OneCycleWarmupDecayLr":
  49. self.lr_scheduler["base_learning_rate"] = learning_rate
  50. else:
  51. self.lr_scheduler["learning_rate"] = learning_rate
  52. def update_batch_size(self, batch_size, mode="train"):
  53. if mode == "train":
  54. self.set_val("batch_size", batch_size)
  55. else:
  56. raise ValueError(
  57. f"Setting `batch_size` in {repr(mode)} mode is not supported."
  58. )
  59. def update_epochs(self, epochs, mode="train"):
  60. if mode == "train":
  61. self.set_val("epochs", epochs)
  62. else:
  63. raise ValueError(f"Setting `epochs` in {repr(mode)} mode is not supported.")
  64. def update_pretrained_weights(self, weight_path, is_backbone=False):
  65. raise NotImplementedError
  66. def get_epochs_iters(self):
  67. if "iters" in self:
  68. return self.iters
  69. else:
  70. assert "epochs" in self
  71. return self.epochs
  72. def get_learning_rate(self):
  73. if "lr_scheduler" not in self or "learning_rate" not in self.lr_scheduler:
  74. # Default lr
  75. return 0.0001
  76. else:
  77. lr = self.lr_scheduler["learning_rate"]
  78. while isinstance(lr, dict):
  79. lr = lr["learning_rate"]
  80. return lr
  81. def get_batch_size(self, mode="train"):
  82. if "batch_size" in self:
  83. return self.batch_size
  84. else:
  85. # Default batch size
  86. return 1
  87. def get_qat_epochs_iters(self):
  88. assert (
  89. "finetune_config" in self
  90. ), "QAT training yaml should contain finetune_config key"
  91. if "iters" in self.finetune_config:
  92. return self.finetune_config["iters"]
  93. else:
  94. assert "epochs" in self.finetune_config
  95. return self.finetune_config["epochs"]
  96. def get_qat_learning_rate(self):
  97. assert (
  98. "finetune_config" in self
  99. ), "QAT training yaml should contain finetune_config key"
  100. cfg = self.finetune_config
  101. if "lr_scheduler" in cfg or "learning_rate" not in cfg.lr_scheduler:
  102. # Default lr
  103. return 1.25e-4
  104. else:
  105. lr = cfg.lr_scheduler["learning_rate"]
  106. while isinstance(lr, dict):
  107. lr = lr["learning_rate"]
  108. return lr
  109. def update_warmup_steps(self, steps):
  110. self.lr_scheduler["warmup_steps"] = steps
  111. def update_end_lr(self, learning_rate):
  112. self.lr_scheduler["end_lr"] = learning_rate
  113. def update_iters(self, iters):
  114. self.set_val("iters", iters)
  115. if "epochs" in self:
  116. self.set_val("epochs", None)
  117. def update_finetune_iters(self, iters):
  118. self.finetune_config["iters"] = iters
  119. if "epochs" in self.finetune_config:
  120. self.finetune_config["epochs"] = None
  121. def update_save_dir(self, save_dir: str):
  122. self["save_dir"] = abspath(save_dir)