sort_boxes.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List
  15. import numpy as np
  16. from .base_operator import BaseOperator
  17. class SortQuadBoxes(BaseOperator):
  18. """SortQuadBoxes Operator."""
  19. entities = "SortQuadBoxes"
  20. def __init__(self):
  21. """Initializes the class."""
  22. super().__init__()
  23. def __call__(self, dt_polys: List[np.ndarray]) -> np.ndarray:
  24. """
  25. Sort quad boxes in order from top to bottom, left to right
  26. args:
  27. dt_polys(ndarray):detected quad boxes with shape [4, 2]
  28. return:
  29. sorted boxes(ndarray) with shape [4, 2]
  30. """
  31. dt_boxes = np.array(dt_polys)
  32. num_boxes = dt_boxes.shape[0]
  33. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  34. _boxes = list(sorted_boxes)
  35. for i in range(num_boxes - 1):
  36. for j in range(i, -1, -1):
  37. if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
  38. _boxes[j + 1][0][0] < _boxes[j][0][0]
  39. ):
  40. tmp = _boxes[j]
  41. _boxes[j] = _boxes[j + 1]
  42. _boxes[j + 1] = tmp
  43. else:
  44. break
  45. return _boxes
  46. class SortPolyBoxes(BaseOperator):
  47. """SortPolyBoxes Operator."""
  48. entities = "SortPolyBoxes"
  49. def __init__(self):
  50. """Initializes the class."""
  51. super().__init__()
  52. def __call__(self, dt_polys: List[np.ndarray]) -> np.ndarray:
  53. """
  54. Sort poly boxes in order from top to bottom, left to right
  55. args:
  56. dt_polys(ndarray):detected poly boxes with a [N, 2] np.ndarray list
  57. return:
  58. sorted boxes(ndarray) with [N, 2] np.ndarray list
  59. """
  60. num_boxes = len(dt_polys)
  61. if num_boxes == 0:
  62. return dt_polys
  63. else:
  64. y_min_list = []
  65. for bno in range(num_boxes):
  66. y_min_list.append(min(dt_polys[bno][:, 1]))
  67. rank = np.argsort(np.array(y_min_list))
  68. dt_polys_rank = []
  69. for no in range(num_boxes):
  70. dt_polys_rank.append(dt_polys[rank[no]])
  71. return dt_polys_rank