cityscapes.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import glob
  16. from . import fd_logging as logging
  17. # import fd_logging as logging
  18. class Cityscapes(object):
  19. """
  20. Cityscapes dataset `https://www.cityscapes-dataset.com/`.
  21. The folder structure is as follow:
  22. cityscapes
  23. |
  24. |--leftImg8bit
  25. | |--train
  26. | |--val
  27. | |--test
  28. |
  29. |--gtFine
  30. | |--train
  31. | |--val
  32. | |--test
  33. Args:
  34. dataset_root (str): Cityscapes dataset directory.
  35. """
  36. NUM_CLASSES = 19
  37. def __init__(self, dataset_root, mode):
  38. self.dataset_root = dataset_root
  39. self.file_list = list()
  40. mode = mode.lower()
  41. self.mode = mode
  42. self.num_classes = self.NUM_CLASSES
  43. self.ignore_index = 255
  44. img_dir = os.path.join(self.dataset_root, "leftImg8bit")
  45. label_dir = os.path.join(self.dataset_root, "gtFine")
  46. if (
  47. self.dataset_root is None
  48. or not os.path.isdir(self.dataset_root)
  49. or not os.path.isdir(img_dir)
  50. or not os.path.isdir(label_dir)
  51. ):
  52. raise ValueError(
  53. "The dataset is not Found or the folder structure is nonconfoumance."
  54. )
  55. label_files = sorted(
  56. glob.glob(os.path.join(label_dir, mode, "*", "*_gtFine_labelTrainIds.png"))
  57. )
  58. img_files = sorted(
  59. glob.glob(os.path.join(img_dir, mode, "*", "*_leftImg8bit.png"))
  60. )
  61. self.file_list = [
  62. [img_path, label_path]
  63. for img_path, label_path in zip(img_files, label_files)
  64. ]
  65. self.num_samples = len(self.file_list)
  66. logging.info("{} samples in file {}".format(self.num_samples, img_dir))