dataset_checker.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. from abc import ABC, abstractmethod
  13. from .utils import build_res_dict
  14. from ....utils.misc import AutoRegisterABCMetaClass
  15. from ....utils.config import AttrDict
  16. from ....utils.logging import info
  17. def build_dataset_checker(config: AttrDict) -> "BaseDatasetChecker":
  18. """build dataset checker
  19. Args:
  20. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  21. Returns:
  22. BaseDatasetChecker: the dataset checker, which is subclass of BaseDatasetChecker.
  23. """
  24. model_name = config.Global.model
  25. return BaseDatasetChecker.get(model_name)(config)
  26. class BaseDatasetChecker(ABC, metaclass=AutoRegisterABCMetaClass):
  27. """ Base Dataset Checker """
  28. __is_base = True
  29. def __init__(self, config):
  30. """Initialize the instance.
  31. Args:
  32. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  33. """
  34. super().__init__()
  35. self.global_config = config.Global
  36. self.check_dataset_config = config.CheckDataset
  37. self.output = os.path.join(self.global_config.output, "check_dataset")
  38. def check(self) -> dict:
  39. """execute dataset checking
  40. Returns:
  41. dict: the dataset checking result.
  42. """
  43. dataset_dir = self.get_dataset_root(self.global_config.dataset_dir)
  44. if not os.path.exists(self.output):
  45. os.makedirs(self.output)
  46. if self.check_dataset_config.get("convert", None):
  47. if self.check_dataset_config.convert.get("enable", False):
  48. self.convert_dataset(dataset_dir)
  49. info("Convert dataset successfully !")
  50. if self.check_dataset_config.get("split", None):
  51. if self.check_dataset_config.split.get("enable", False):
  52. self.split_dataset(dataset_dir)
  53. info("Split dataset successfully !")
  54. attrs = self.check_dataset(dataset_dir)
  55. analysis = self.analyse(dataset_dir)
  56. check_result = build_res_dict(True)
  57. check_result["attributes"] = attrs
  58. check_result["analysis"] = analysis
  59. check_result["dataset_path"] = self.global_config.dataset_dir
  60. check_result["show_type"] = self.get_show_type()
  61. check_result["dataset_type"] = self.get_dataset_type()
  62. info("Check dataset passed !")
  63. return check_result
  64. def get_dataset_root(self, dataset_dir: str) -> str:
  65. """find the dataset root dir
  66. Args:
  67. dataset_dir (str): the directory that contain dataset.
  68. Returns:
  69. str: the root directory of dataset.
  70. """
  71. # XXX: forward compatible
  72. # dataset_dir = [d for d in Path(dataset_dir).iterdir() if d.is_dir()]
  73. # assert len(dataset_dir) == 1
  74. # return dataset_dir[0].as_posix()
  75. return dataset_dir
  76. @abstractmethod
  77. def check_dataset(self, dataset_dir: str):
  78. """check if the dataset meets the specifications and get dataset summary
  79. Args:
  80. dataset_dir (str): the root directory of dataset.
  81. Raises:
  82. NotImplementedError
  83. """
  84. raise NotImplementedError
  85. def convert_dataset(self, src_dataset_dir: str) -> str:
  86. """convert the dataset from other type to specified type
  87. Args:
  88. src_dataset_dir (str): the root directory of dataset.
  89. Returns:
  90. str: the root directory of converted dataset.
  91. """
  92. dst_dataset_dir = src_dataset_dir
  93. return dst_dataset_dir
  94. def split_dataset(self, src_dataset_dir: str) -> str:
  95. """repartition the train and validation dataset
  96. Args:
  97. src_dataset_dir (str): the root directory of dataset.
  98. Returns:
  99. str: the root directory of splited dataset.
  100. """
  101. dst_dataset_dir = src_dataset_dir
  102. return dst_dataset_dir
  103. def analyse(self, dataset_dir: str) -> dict:
  104. """deep analyse dataset
  105. Args:
  106. dataset_dir (str): the root directory of dataset.
  107. Returns:
  108. dict: the deep analysis results.
  109. """
  110. return {}
  111. @abstractmethod
  112. def get_show_type(self):
  113. """return the dataset show type
  114. Raises:
  115. NotImplementedError
  116. """
  117. raise NotImplementedError
  118. @abstractmethod
  119. def get_dataset_type(self):
  120. """ return the dataset type
  121. Raises:
  122. NotImplementedError
  123. """
  124. raise NotImplementedError