| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from ....utils.misc import abspath
- from ..pp3d_config import PP3DConfig
- class BEVFusionConfig(PP3DConfig):
- def update_dataset(
- self, dataset_dir, datart_prefix=True, dataset_type=None, *, version=None
- ):
- dataset_dir = abspath(dataset_dir)
- if dataset_type is None:
- dataset_type = "NuscenesMMDataset"
- if dataset_type == "NuscenesMMDataset":
- ds_cfg = self._make_nuscenes_mm_dataset_config(
- dataset_dir, datart_prefix, version=version
- )
- else:
- raise ValueError(f"{dataset_type} is not supported.")
- # Prune old config
- keys_to_keep = ("transforms", "mode", "class_names", "modality")
- if "train_dataset" in self:
- for key in list(k for k in self.train_dataset if k not in keys_to_keep):
- self.train_dataset.pop(key)
- if "val_dataset" in self:
- for key in list(k for k in self.val_dataset if k not in keys_to_keep):
- self.val_dataset.pop(key)
- self.update(ds_cfg)
- def _make_nuscenes_mm_dataset_config(
- self, dataset_root_path, datart_prefix, version
- ):
- if version is None:
- # Default version
- version = "trainval"
- if version == "trainval":
- train_mode = "train"
- val_mode = "val"
- elif version == "mini":
- train_mode = "mini_train"
- val_mode = "mini_val"
- else:
- raise ValueError("Unsupported version.")
- return {
- "train_dataset": {
- "type": "NuscenesMMDataset",
- "data_root": dataset_root_path,
- "ann_file": f"{dataset_root_path}/nuscenes_infos_train.pkl",
- "mode": train_mode,
- "datart_prefix": datart_prefix,
- },
- "val_dataset": {
- "type": "NuscenesMMDataset",
- "data_root": dataset_root_path,
- "ann_file": f"{dataset_root_path}/nuscenes_infos_val.pkl",
- "mode": val_mode,
- "datart_prefix": datart_prefix,
- },
- }
- def _update_amp(self, amp):
- # XXX: Currently, we hard-code the AMP settings according to
- # https://github.com/PaddlePaddle/Paddle3D/blob/3cf884ecbc94330be0e2db780434bb60b9b4fe8c/configs/smoke/smoke_dla34_no_dcn_kitti_amp.yml#L6
- amp_cfg = {
- "amp_cfg": {
- "use_amp": False,
- "enable": False,
- "level": amp,
- "scaler": {"init_loss_scaling": 512.0},
- "custom_black_list": ["matmul_v2", "elementwise_mul"],
- }
- }
- self.update(amp_cfg)
- def update_class_names(self, class_names):
- if "train_dataset" in self and "transforms" in getattr(self, "train_dataset"):
- self.train_dataset["class_names"] = class_names
- # TODO: Provide another method to customize `SampleNameFilter` classes names
- # TODO: Give an explicit warning for the implicit behavior
- tf_cfg_list = self.train_dataset["transforms"]
- for tf_cfg in tf_cfg_list:
- if tf_cfg["type"] == "SampleNameFilter":
- tf_cfg["classes"] = class_names
- # We assume that there is at most one `SampleNameFilter` in `tf_cfg_list`
- break
- if "val_dataset" in self:
- self.val_dataset["class_names"] = class_names
- def update_pretrained_model(self, load_cam_from: str, load_lidar_from: str):
- """update model pretrained weight
- Args:
- load_cam_from (str): the path to cam weight file of model.
- load_lidar_from (str): the path to lidar weight file of model.
- """
- self.model["load_cam_from"] = load_cam_from
- self.model["load_lidar_from"] = load_lidar_from
- def update_weights(self, weight_path: str):
- """update model weight
- Args:
- weight_path (str): the path to weight file of model.
- """
- self["weights"] = weight_path
|