coco_metrics.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import copy
  18. import sys
  19. from collections import OrderedDict
  20. from .coco_utils import get_infer_results, cocoapi_eval
  21. class COCOMetric(object):
  22. def __init__(self, coco_gt, **kwargs):
  23. self.clsid2catid = {
  24. i: cat["id"] for i, cat in enumerate(coco_gt.loadCats(coco_gt.getCatIds()))
  25. }
  26. self.coco_gt = coco_gt
  27. self.classwise = kwargs.get("classwise", False)
  28. self.bias = 0
  29. self.reset()
  30. def reset(self):
  31. # only bbox and mask evaluation support currently
  32. self.details = {
  33. "gt": copy.deepcopy(self.coco_gt.dataset),
  34. "bbox": [],
  35. "mask": [],
  36. }
  37. self.eval_stats = {}
  38. def update(self, im_id, outputs):
  39. outs = {}
  40. # outputs Tensor -> numpy.ndarray
  41. for k, v in outputs.items():
  42. outs[k] = v
  43. outs["im_id"] = im_id
  44. infer_results = get_infer_results(outs, self.clsid2catid, bias=self.bias)
  45. self.details["bbox"] += infer_results["bbox"] if "bbox" in infer_results else []
  46. self.details["mask"] += infer_results["mask"] if "mask" in infer_results else []
  47. def accumulate(self):
  48. if len(self.details["bbox"]) > 0:
  49. bbox_stats = cocoapi_eval(
  50. copy.deepcopy(self.details["bbox"]),
  51. "bbox",
  52. coco_gt=self.coco_gt,
  53. classwise=self.classwise,
  54. )
  55. self.eval_stats["bbox"] = bbox_stats
  56. sys.stdout.flush()
  57. if len(self.details["mask"]) > 0:
  58. seg_stats = cocoapi_eval(
  59. copy.deepcopy(self.details["mask"]),
  60. "segm",
  61. coco_gt=self.coco_gt,
  62. classwise=self.classwise,
  63. )
  64. self.eval_stats["mask"] = seg_stats
  65. sys.stdout.flush()
  66. def log(self):
  67. pass
  68. def get(self):
  69. if "bbox" not in self.eval_stats:
  70. return {"bbox_mmap": 0.0}
  71. if "mask" in self.eval_stats:
  72. return OrderedDict(
  73. zip(
  74. ["bbox_mmap", "segm_mmap"],
  75. [self.eval_stats["bbox"][0], self.eval_stats["mask"][0]],
  76. )
  77. )
  78. else:
  79. return {"bbox_mmap": self.eval_stats["bbox"][0]}