coco_eval.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import sys
  13. import argparse
  14. import numpy as np
  15. from pycocotools.coco import COCO
  16. from pycocotools.cocoeval import COCOeval
  17. def parse_args():
  18. """ Parse input arguments """
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument(
  21. '--prediction_json_path', type=str, default='./bbox.json')
  22. parser.add_argument(
  23. '--gt_json_path', type=str, default='./instance_val.json')
  24. args = parser.parse_args()
  25. return args
  26. def json_eval_results(args):
  27. """
  28. cocoapi eval with already exists bbox.json
  29. """
  30. prediction_json_path = args.prediction_json_path
  31. gt_json_path = args.gt_json_path
  32. assert os.path.exists(
  33. prediction_json_path), "The json directory:{} does not exist".format(
  34. prediction_json_path)
  35. cocoapi_eval(prediction_json_path, "bbox", anno_file=gt_json_path)
  36. def cocoapi_eval(jsonfile,
  37. style,
  38. coco_gt=None,
  39. anno_file=None,
  40. max_dets=(100, 300, 1000),
  41. sigmas=None,
  42. use_area=True):
  43. """
  44. Args:
  45. jsonfile (str): Evaluation json file, eg: bbox.json
  46. style (str): COCOeval style, can be `bbox`
  47. coco_gt (str): Whether to load COCOAPI through anno_file,
  48. eg: coco_gt = COCO(anno_file)
  49. anno_file (str): COCO annotations file.
  50. max_dets (tuple): COCO evaluation maxDets.
  51. sigmas (nparray): keypoint labelling sigmas.
  52. use_area (bool): If gt annotations (eg. CrowdPose, AIC)
  53. do not have 'area', please set use_area=False.
  54. """
  55. assert coco_gt is not None or anno_file is not None
  56. if coco_gt is None:
  57. coco_gt = COCO(anno_file)
  58. coco_dt = coco_gt.loadRes(jsonfile)
  59. if style == 'proposal':
  60. coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
  61. coco_eval.params.useCats = 0
  62. coco_eval.params.maxDets = list(max_dets)
  63. elif style == 'keypoints_crowd':
  64. coco_eval = COCOeval(coco_gt, coco_dt, style, sigmas, use_area)
  65. else:
  66. coco_eval = COCOeval(coco_gt, coco_dt, style)
  67. coco_eval.evaluate()
  68. coco_eval.accumulate()
  69. coco_eval.summarize()
  70. # flush coco evaluation result
  71. sys.stdout.flush()
  72. return coco_eval.stats
  73. if __name__ == "__main__":
  74. args = parse_args()
  75. json_eval_results(args)