| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- from abc import ABC, abstractmethod
- from pathlib import Path
- from ...utils.config import AttrDict
- from ...utils.device import (
- check_supported_device,
- set_env_for_device,
- update_device_num,
- )
- from ...utils.logging import *
- from ...utils.misc import AutoRegisterABCMetaClass
- from .build_model import build_model
- def build_evaluator(config: AttrDict) -> "BaseEvaluator":
- """build model evaluator
- Args:
- config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
- Returns:
- BaseEvaluator: the evaluator, which is subclass of BaseEvaluator.
- """
- model_name = config.Global.model
- try:
- pass
- except ModuleNotFoundError:
- pass
- return BaseEvaluator.get(model_name)(config)
- class BaseEvaluator(ABC, metaclass=AutoRegisterABCMetaClass):
- """Base Model Evaluator"""
- __is_base = True
- def __init__(self, config):
- """Initialize the instance.
- Args:
- config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
- """
- super().__init__()
- self.global_config = config.Global
- self.eval_config = config.Evaluate
- config_path = self.get_config_path(self.eval_config.weight_path)
- if self.eval_config.get("basic_config_path", None):
- config_path = self.eval_config.get("basic_config_path", None)
- self.pdx_config, self.pdx_model = build_model(
- self.global_config.model, config_path=config_path
- )
- def get_config_path(self, weight_path):
- """
- get config path
- Args:
- weight_path (str): The path to the weight
- Returns:
- config_path (str): The path to the config
- """
- config_path = Path(weight_path).parent / "config.yaml"
- if not config_path.exists():
- config_path = config_path.parent.parent / "config.yaml"
- if not config_path.exists():
- warning(
- f"The config file (`{config_path}`) related to weight file (`{weight_path}`) does not exist. Using default instead."
- )
- config_path = None
- return config_path
- def check_return(self, metrics: dict) -> bool:
- """check evaluation metrics
- Args:
- metrics (dict): evaluation output metrics
- Returns:
- bool: whether the format of evaluation metrics is legal
- """
- if not isinstance(metrics, dict):
- return False
- for metric in metrics:
- val = metrics[metric]
- if not isinstance(val, (float, int)):
- return False
- return True
- def evaluate(self) -> dict:
- """execute model evaluating
- Returns:
- dict: the evaluation metrics
- """
- self.update_config()
- # self.dump_config()
- evaluate_result = self.pdx_model.evaluate(**self.get_eval_kwargs())
- assert (
- evaluate_result.returncode == 0
- ), f"Encountered an unexpected error({evaluate_result.returncode}) in \
- evaling!"
- metrics = evaluate_result.metrics
- assert self.check_return(
- metrics
- ), f"The return value({metrics}) of Evaluator.eval() is illegal!"
- return {"metrics": metrics}
- def dump_config(self, config_file_path=None):
- """dump the config
- Args:
- config_file_path (str, optional): the path to save dumped config.
- Defaults to None, means that save in `Global.output` as `config.yaml`.
- """
- if config_file_path is None:
- config_file_path = os.path.join(self.global_config.output, "config.yaml")
- self.pdx_config.dump(config_file_path)
- def get_device(self, using_device_number: int = None) -> str:
- """get device setting from config
- Args:
- using_device_number (int, optional): specify device number to use.
- Defaults to None, means that base on config setting.
- Returns:
- str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
- """
- check_supported_device(self.global_config.device, self.global_config.model)
- set_env_for_device(self.global_config.device)
- device_setting = (
- update_device_num(self.global_config.device, using_device_number)
- if using_device_number
- else self.global_config.device
- )
- # replace "dcu" with "gpu"
- device_setting = device_setting.replace("dcu", "gpu")
- return device_setting
- @abstractmethod
- def update_config(self):
- """update evaluation config"""
- raise NotImplementedError
- @abstractmethod
- def get_eval_kwargs(self):
- """get key-value arguments of model evaluation function"""
- raise NotImplementedError
|