evaluator.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. from pathlib import Path
  13. from abc import ABC, abstractmethod
  14. from .build_model import build_model
  15. from ...utils.device import get_device
  16. from ...utils.misc import AutoRegisterABCMetaClass
  17. from ...utils.config import AttrDict
  18. from ...utils.logging import *
  19. def build_evaluater(config: AttrDict) -> "BaseEvaluator":
  20. """build model evaluater
  21. Args:
  22. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  23. Returns:
  24. BaseEvaluator: the evaluater, which is subclass of BaseEvaluator.
  25. """
  26. model_name = config.Global.model
  27. return BaseEvaluator.get(model_name)(config)
  28. class BaseEvaluator(ABC, metaclass=AutoRegisterABCMetaClass):
  29. """ Base Model Evaluator """
  30. __is_base = True
  31. def __init__(self, config):
  32. """Initialize the instance.
  33. Args:
  34. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  35. """
  36. super().__init__()
  37. self.global_config = config.Global
  38. self.eval_config = config.Evaluate
  39. config_path = self.get_config_path(self.eval_config.weight_path)
  40. if not config_path.exists():
  41. warning(
  42. f"The config file(`{config_path}`) related to weight file(`{self.eval_config.weight_path}`) is not exist, use default instead."
  43. )
  44. config_path = None
  45. self.pdx_config, self.pdx_model = build_model(
  46. self.global_config.model, config_path=config_path)
  47. def get_config_path(self, weight_path):
  48. """
  49. get config path
  50. Args:
  51. weight_path (str): The path to the weight
  52. Returns:
  53. config_path (str): The path to the config
  54. """
  55. config_path = Path(weight_path).parent / "config.yaml"
  56. return config_path
  57. def check_return(self, metrics: dict) -> bool:
  58. """check evaluation metrics
  59. Args:
  60. metrics (dict): evaluation output metrics
  61. Returns:
  62. bool: whether the format of evaluation metrics is legal
  63. """
  64. if not isinstance(metrics, dict):
  65. return False
  66. for metric in metrics:
  67. val = metrics[metric]
  68. if not isinstance(val, float):
  69. return False
  70. return True
  71. def evaluate(self) -> dict:
  72. """execute model training
  73. Returns:
  74. dict: the evaluation metrics
  75. """
  76. self.update_config()
  77. # self.dump_config()
  78. evaluate_result = self.pdx_model.evaluate(**self.get_eval_kwargs())
  79. assert evaluate_result.returncode == 0, f"Encountered an unexpected error({evaluate_result.returncode}) in \
  80. evaling!"
  81. metrics = evaluate_result.metrics
  82. assert self.check_return(
  83. metrics
  84. ), f"The return value({metrics}) of Evaluator.eval() is illegal!"
  85. return {"metrics": metrics}
  86. def dump_config(self, config_file_path=None):
  87. """dump the config
  88. Args:
  89. config_file_path (str, optional): the path to save dumped config.
  90. Defaults to None, means that save in `Global.output` as `config.yaml`.
  91. """
  92. if config_file_path is None:
  93. config_file_path = os.path.join(self.global_config.output,
  94. "config.yaml")
  95. self.pdx_config.dump(config_file_path)
  96. def get_device(self, using_device_number: int=None) -> str:
  97. """get device setting from config
  98. Args:
  99. using_device_number (int, optional): specify device number to use.
  100. Defaults to None, means that base on config setting.
  101. Returns:
  102. str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
  103. """
  104. return get_device(
  105. self.global_config.device, using_device_number=using_device_number)
  106. @abstractmethod
  107. def update_config(self):
  108. """update evalution config
  109. """
  110. raise NotImplementedError
  111. @abstractmethod
  112. def get_eval_kwargs(self):
  113. """get key-value arguments of model evalution function
  114. """
  115. raise NotImplementedError