yolo_cluster.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import os
  16. import numpy as np
  17. from tqdm import tqdm
  18. from scipy.cluster.vq import kmeans
  19. from paddlex.utils import logging
  20. __all__ = ['YOLOAnchorCluster']
  21. class BaseAnchorCluster(object):
  22. def __init__(self, num_anchors, cache, cache_path):
  23. """
  24. Base Anchor Cluster
  25. Args:
  26. num_anchors (int): number of clusters
  27. cache (bool): whether using cache
  28. cache_path (str): cache directory path
  29. """
  30. super(BaseAnchorCluster, self).__init__()
  31. self.num_anchors = num_anchors
  32. self.cache_path = cache_path
  33. self.cache = cache
  34. def print_result(self, centers):
  35. raise NotImplementedError('%s.print_result is not available' %
  36. self.__class__.__name__)
  37. def get_whs(self):
  38. whs_cache_path = os.path.join(self.cache_path, 'whs.npy')
  39. shapes_cache_path = os.path.join(self.cache_path, 'shapes.npy')
  40. if self.cache and os.path.exists(whs_cache_path) and os.path.exists(
  41. shapes_cache_path):
  42. self.whs = np.load(whs_cache_path)
  43. self.shapes = np.load(shapes_cache_path)
  44. return self.whs, self.shapes
  45. whs = np.zeros((0, 2))
  46. shapes = np.zeros((0, 2))
  47. samples = copy.deepcopy(self.dataset.file_list)
  48. for sample in tqdm(samples):
  49. im_h, im_w = sample['image_shape']
  50. bbox = sample['gt_bbox']
  51. wh = bbox[:, 2:4] - bbox[:, 0:2]
  52. wh = wh / np.array([[im_w, im_h]])
  53. shape = np.ones_like(wh) * np.array([[im_w, im_h]])
  54. whs = np.vstack((whs, wh))
  55. shapes = np.vstack((shapes, shape))
  56. if self.cache:
  57. os.makedirs(self.cache_path, exist_ok=True)
  58. np.save(whs_cache_path, whs)
  59. np.save(shapes_cache_path, shapes)
  60. self.whs = whs
  61. self.shapes = shapes
  62. return self.whs, self.shapes
  63. def calc_anchors(self):
  64. raise NotImplementedError('%s.calc_anchors is not available' %
  65. self.__class__.__name__)
  66. def __call__(self):
  67. self.get_whs()
  68. centers = self.calc_anchors()
  69. return centers
  70. class YOLOAnchorCluster(BaseAnchorCluster):
  71. def __init__(self,
  72. num_anchors,
  73. dataset,
  74. image_size,
  75. cache=True,
  76. cache_path=None,
  77. iters=300,
  78. gen_iters=1000,
  79. thresh=0.25):
  80. """
  81. YOLOv5 Anchor Cluster
  82. Reference:
  83. https://github.com/ultralytics/yolov5/blob/master/utils/autoanchor.py
  84. Args:
  85. num_anchors (int): number of clusters
  86. dataset (DataSet): DataSet instance, VOC or COCO
  87. image_size (list or int): [h, w], being an int means image height and image width are the same.
  88. cache (bool): whether using cache。 Defaults to True.
  89. cache_path (str or None, optional): cache directory path. If None, use `data_dir` of dataset. Defaults to None.
  90. iters (int, optional): iters of kmeans algorithm. Defaults to 300.
  91. gen_iters (int, optional): iters of genetic algorithm. Defaults to 1000.
  92. thresh (float, optional): anchor scale threshold. Defaults to 0.25.
  93. """
  94. self.dataset = dataset
  95. if cache_path is None:
  96. cache_path = self.dataset.data_dir
  97. if isinstance(image_size, int):
  98. image_size = [image_size] * 2
  99. self.image_size = image_size
  100. self.iters = iters
  101. self.gen_iters = gen_iters
  102. self.thresh = thresh
  103. super(YOLOAnchorCluster, self).__init__(num_anchors, cache, cache_path)
  104. def print_result(self, centers):
  105. whs = self.whs
  106. x, best = self.metric(whs, centers)
  107. bpr, aat = (best > self.thresh).mean(), (
  108. x > self.thresh).mean() * self.num_anchors
  109. logging.info(
  110. 'thresh=%.2f: %.4f best possible recall, %.2f anchors past thr' %
  111. (self.thresh, bpr, aat))
  112. logging.info(
  113. 'n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thresh=%.3f-mean: '
  114. % (self.num_anchors, self.image_size, x.mean(), best.mean(),
  115. x[x > self.thresh].mean()))
  116. logging.info('%d anchor cluster result: [w, h]' % self.num_anchors)
  117. for w, h in centers:
  118. logging.info('[%d, %d]' % (w, h))
  119. def metric(self, whs, centers):
  120. r = whs[:, None] / centers[None]
  121. x = np.minimum(r, 1. / r).min(2)
  122. return x, x.max(1)
  123. def fitness(self, whs, centers):
  124. _, best = self.metric(whs, centers)
  125. return (best * (best > self.thresh)).mean()
  126. def calc_anchors(self):
  127. self.whs = self.whs * self.shapes / self.shapes.max(
  128. 1, keepdims=True) * np.array([self.image_size[::-1]])
  129. wh0 = self.whs
  130. i = (wh0 < 3.0).any(1).sum()
  131. if i:
  132. logging.warning('Extremely small objects found. %d of %d '
  133. 'labels are < 3 pixels in width or height' %
  134. (i, len(wh0)))
  135. wh = wh0[(wh0 >= 2.0).any(1)]
  136. logging.info('Running kmeans for %g anchors on %g points...' %
  137. (self.num_anchors, len(wh)))
  138. s = wh.std(0)
  139. centers, dist = kmeans(wh / s, self.num_anchors, iter=self.iters)
  140. centers *= s
  141. f, sh, mp, s = self.fitness(wh, centers), centers.shape, 0.9, 0.1
  142. pbar = tqdm(
  143. range(self.gen_iters),
  144. desc='Evolving anchors with Genetic Algorithm')
  145. for _ in pbar:
  146. v = np.ones(sh)
  147. while (v == 1).all():
  148. v = ((np.random.random(sh) < mp) * np.random.random() *
  149. np.random.randn(*sh) * s + 1).clip(0.3, 3.0)
  150. new_centers = (centers.copy() * v).clip(min=2.0)
  151. new_f = self.fitness(wh, new_centers)
  152. if new_f > f:
  153. f, centers = new_f, new_centers.copy()
  154. pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
  155. centers = np.round(centers[np.argsort(centers.prod(1))]).astype(
  156. int).tolist()
  157. return centers