config.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ....utils.misc import abspath
  15. from ..pp3d_config import PP3DConfig
  16. class BEVFusionConfig(PP3DConfig):
  17. def update_dataset(
  18. self, dataset_dir, datart_prefix=True, dataset_type=None, *, version=None
  19. ):
  20. dataset_dir = abspath(dataset_dir)
  21. if dataset_type is None:
  22. dataset_type = "NuscenesMMDataset"
  23. if dataset_type == "NuscenesMMDataset":
  24. ds_cfg = self._make_nuscenes_mm_dataset_config(
  25. dataset_dir, datart_prefix, version=version
  26. )
  27. else:
  28. raise ValueError(f"{dataset_type} is not supported.")
  29. # Prune old config
  30. keys_to_keep = ("transforms", "mode", "class_names", "modality")
  31. if "train_dataset" in self:
  32. for key in list(k for k in self.train_dataset if k not in keys_to_keep):
  33. self.train_dataset.pop(key)
  34. if "val_dataset" in self:
  35. for key in list(k for k in self.val_dataset if k not in keys_to_keep):
  36. self.val_dataset.pop(key)
  37. self.update(ds_cfg)
  38. def _make_nuscenes_mm_dataset_config(
  39. self, dataset_root_path, datart_prefix, version
  40. ):
  41. if version is None:
  42. # Default version
  43. version = "trainval"
  44. if version == "trainval":
  45. train_mode = "train"
  46. val_mode = "val"
  47. elif version == "mini":
  48. train_mode = "mini_train"
  49. val_mode = "mini_val"
  50. else:
  51. raise ValueError("Unsupported version.")
  52. return {
  53. "train_dataset": {
  54. "type": "NuscenesMMDataset",
  55. "data_root": dataset_root_path,
  56. "ann_file": f"{dataset_root_path}/nuscenes_infos_train.pkl",
  57. "mode": train_mode,
  58. "datart_prefix": datart_prefix,
  59. },
  60. "val_dataset": {
  61. "type": "NuscenesMMDataset",
  62. "data_root": dataset_root_path,
  63. "ann_file": f"{dataset_root_path}/nuscenes_infos_val.pkl",
  64. "mode": val_mode,
  65. "datart_prefix": datart_prefix,
  66. },
  67. }
  68. def _update_amp(self, amp):
  69. # XXX: Currently, we hard-code the AMP settings according to
  70. # https://github.com/PaddlePaddle/Paddle3D/blob/3cf884ecbc94330be0e2db780434bb60b9b4fe8c/configs/smoke/smoke_dla34_no_dcn_kitti_amp.yml#L6
  71. amp_cfg = {
  72. "amp_cfg": {
  73. "use_amp": False,
  74. "enable": False,
  75. "level": amp,
  76. "scaler": {"init_loss_scaling": 512.0},
  77. "custom_black_list": ["matmul_v2", "elementwise_mul"],
  78. }
  79. }
  80. self.update(amp_cfg)
  81. def update_class_names(self, class_names):
  82. if "train_dataset" in self and "transforms" in getattr(self, "train_dataset"):
  83. self.train_dataset["class_names"] = class_names
  84. # TODO: Provide another method to customize `SampleNameFilter` classes names
  85. # TODO: Give an explicit warning for the implicit behavior
  86. tf_cfg_list = self.train_dataset["transforms"]
  87. for tf_cfg in tf_cfg_list:
  88. if tf_cfg["type"] == "SampleNameFilter":
  89. tf_cfg["classes"] = class_names
  90. # We assume that there is at most one `SampleNameFilter` in `tf_cfg_list`
  91. break
  92. if "val_dataset" in self:
  93. self.val_dataset["class_names"] = class_names
  94. def update_pretrained_model(self, load_cam_from: str, load_lidar_from: str):
  95. """update model pretrained weight
  96. Args:
  97. load_cam_from (str): the path to cam weight file of model.
  98. load_lidar_from (str): the path to lidar weight file of model.
  99. """
  100. self.model["load_cam_from"] = load_cam_from
  101. self.model["load_lidar_from"] = load_lidar_from
  102. def update_weights(self, weight_path: str):
  103. """update model weight
  104. Args:
  105. weight_path (str): the path to weight file of model.
  106. """
  107. self["weights"] = weight_path