dataset_checker.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from abc import ABC, abstractmethod
  16. from ....utils.config import AttrDict
  17. from ....utils.logging import info
  18. from ....utils.misc import AutoRegisterABCMetaClass
  19. from .utils import build_res_dict
  20. def build_dataset_checker(config: AttrDict) -> "BaseDatasetChecker":
  21. """build dataset checker
  22. Args:
  23. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  24. Returns:
  25. BaseDatasetChecker: the dataset checker, which is subclass of BaseDatasetChecker.
  26. """
  27. model_name = config.Global.model
  28. try:
  29. pass
  30. except ModuleNotFoundError:
  31. pass
  32. return BaseDatasetChecker.get(model_name)(config)
  33. class BaseDatasetChecker(ABC, metaclass=AutoRegisterABCMetaClass):
  34. """Base Dataset Checker"""
  35. __is_base = True
  36. def __init__(self, config):
  37. """Initialize the instance.
  38. Args:
  39. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  40. """
  41. super().__init__()
  42. self.global_config = config.Global
  43. self.check_dataset_config = config.CheckDataset
  44. self.output = os.path.join(self.global_config.output, "check_dataset")
  45. def check(self) -> dict:
  46. """execute dataset checking
  47. Returns:
  48. dict: the dataset checking result.
  49. """
  50. dataset_dir = self.get_dataset_root(self.global_config.dataset_dir)
  51. if not os.path.exists(self.output):
  52. os.makedirs(self.output)
  53. if self.check_dataset_config.get("convert", None):
  54. if self.check_dataset_config.convert.get("enable", False):
  55. self.convert_dataset(dataset_dir)
  56. info("Convert dataset successfully !")
  57. if self.check_dataset_config.get("split", None):
  58. if self.check_dataset_config.split.get("enable", False):
  59. self.split_dataset(dataset_dir)
  60. info("Split dataset successfully !")
  61. attrs = self.check_dataset(dataset_dir)
  62. analysis = self.analyse(dataset_dir)
  63. check_result = build_res_dict(True)
  64. check_result["attributes"] = attrs
  65. check_result["analysis"] = analysis
  66. check_result["dataset_path"] = os.path.basename(dataset_dir)
  67. check_result["show_type"] = self.get_show_type()
  68. check_result["dataset_type"] = self.get_dataset_type()
  69. info("Check dataset passed !")
  70. return check_result
  71. def get_dataset_root(self, dataset_dir: str) -> str:
  72. """find the dataset root dir
  73. Args:
  74. dataset_dir (str): the directory that contain dataset.
  75. Returns:
  76. str: the root directory of dataset.
  77. """
  78. # XXX: forward compatible
  79. # dataset_dir = [d for d in Path(dataset_dir).iterdir() if d.is_dir()]
  80. # assert len(dataset_dir) == 1
  81. # return dataset_dir[0].as_posix()
  82. return dataset_dir
  83. @abstractmethod
  84. def check_dataset(self, dataset_dir: str):
  85. """check if the dataset meets the specifications and get dataset summary
  86. Args:
  87. dataset_dir (str): the root directory of dataset.
  88. Raises:
  89. NotImplementedError
  90. """
  91. raise NotImplementedError
  92. def convert_dataset(self, src_dataset_dir: str) -> str:
  93. """convert the dataset from other type to specified type
  94. Args:
  95. src_dataset_dir (str): the root directory of dataset.
  96. Returns:
  97. str: the root directory of converted dataset.
  98. """
  99. dst_dataset_dir = src_dataset_dir
  100. return dst_dataset_dir
  101. def split_dataset(self, src_dataset_dir: str) -> str:
  102. """repartition the train and validation dataset
  103. Args:
  104. src_dataset_dir (str): the root directory of dataset.
  105. Returns:
  106. str: the root directory of splited dataset.
  107. """
  108. dst_dataset_dir = src_dataset_dir
  109. return dst_dataset_dir
  110. def analyse(self, dataset_dir: str) -> dict:
  111. """deep analyse dataset
  112. Args:
  113. dataset_dir (str): the root directory of dataset.
  114. Returns:
  115. dict: the deep analysis results.
  116. """
  117. return {}
  118. @abstractmethod
  119. def get_show_type(self):
  120. """return the dataset show type
  121. Raises:
  122. NotImplementedError
  123. """
  124. raise NotImplementedError
  125. @abstractmethod
  126. def get_dataset_type(self):
  127. """return the dataset type
  128. Raises:
  129. NotImplementedError
  130. """
  131. raise NotImplementedError