# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import json import argparse from paddle import nn import paddle from ....utils import logging def parse_args(): """Parse all arguments """ parser = argparse.ArgumentParser() parser.add_argument( '--prediction_json_path', type=str, default='./pre_res.json') parser.add_argument('--gt_val_path', type=str, default='./val.txt') parser.add_argument('--image_dir', type=str) parser.add_argument('--num_classes', type=int) args = parser.parse_args() return args class AvgMetrics(nn.Layer): """ Average metrics """ def __init__(self): super().__init__() self.avg_meters = {} @property def avg(self): """ Return average value of each metric """ if self.avg_meters: for metric_key in self.avg_meters: return self.avg_meters[metric_key].avg @property def avg_info(self): """ Return a formatted string of average values and names """ return ", ".join( [self.avg_meters[key].avg_info for key in self.avg_meters]) class TopkAcc(AvgMetrics): """ Top-k accuracy metric """ def __init__(self, topk=(1, 5)): super().__init__() assert isinstance(topk, (int, list, tuple)) if isinstance(topk, int): topk = [topk] self.topk = topk self.warned = False def forward(self, x, label): """ forward function """ if isinstance(x, dict): x = x["logits"] output_dims = x.shape[-1] metric_dict = dict() for idx, k in enumerate(self.topk): if output_dims < k: if not self.warned: msg = f"The output dims({output_dims}) is less than k({k}), so the Top-{k} metric is meaningless." logging.info(msg) self.warned = True metric_dict[f"top{k}"] = 1 else: metric_dict[f"top{k}"] = paddle.metric.accuracy( x, label, k=k).item() return metric_dict def prase_pt_info(pt_info, num_classes): """ Parse prediction information to probability vector """ pre_list = [0.0] * num_classes for idx, val in zip(pt_info["class_ids"], pt_info["scores"]): pre_list[idx] = val return pre_list def main(args): """ main function """ with open(args.prediction_json_path, 'r') as fp: predication_result = json.load(fp) gt_info = {} pred = [] label = [] for line in open(args.gt_val_path): img_file, gt_label = line.strip().split(" ") img_file = img_file.split('/')[-1] gt_info[img_file] = int(gt_label) for pt_info in predication_result: img_file = os.path.relpath(pt_info['file_name'], args.image_dir) pred.append(prase_pt_info(pt_info, args.num_classes)) label.append([gt_info[img_file]]) metric_dict = TopkAcc()(paddle.to_tensor(pred), paddle.to_tensor(label)) logging.info(metric_dict) if __name__ == "__main__": args = parse_args() main(args)