processors.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Optional
  15. from ...utils.benchmark import benchmark
  16. from ..object_detection.processors import restructured_boxes
  17. def extract_masks_from_boxes(boxes, masks):
  18. """
  19. Extracts the portion of each mask that is within the corresponding box.
  20. """
  21. new_masks = []
  22. for i, box in enumerate(boxes):
  23. x_min, y_min, x_max, y_max = box["coordinate"]
  24. x_min, y_min, x_max, y_max = map(
  25. lambda x: int(round(x)), [x_min, y_min, x_max, y_max]
  26. )
  27. cropped_mask = masks[i][y_min:y_max, x_min:x_max]
  28. new_masks.append(cropped_mask)
  29. return new_masks
  30. @benchmark.timeit
  31. class InstanceSegPostProcess(object):
  32. """Save Result Transform"""
  33. def __init__(self, threshold=0.5, labels=None):
  34. super().__init__()
  35. self.threshold = threshold
  36. self.labels = labels
  37. def apply(self, masks, img_size, boxes=None, class_id=None, threshold=None):
  38. """apply"""
  39. if boxes is not None:
  40. expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
  41. boxes = boxes[expect_boxes, :]
  42. boxes = restructured_boxes(boxes, self.labels, img_size)
  43. masks = masks[expect_boxes, :, :]
  44. masks = extract_masks_from_boxes(boxes, masks)
  45. result = {"boxes": boxes, "masks": masks}
  46. else:
  47. mask_info = []
  48. class_id = [list(item) for item in zip(class_id[0], class_id[1])]
  49. selected_masks = []
  50. for i, info in enumerate(class_id):
  51. label_id = int(info[0])
  52. if info[1] < threshold:
  53. continue
  54. mask_info.append(
  55. {
  56. "label": self.labels[label_id],
  57. "score": info[1],
  58. "class_id": label_id,
  59. }
  60. )
  61. selected_masks.append(masks[i])
  62. result = {"boxes": mask_info, "masks": selected_masks}
  63. return result
  64. def __call__(
  65. self,
  66. batch_outputs: List[dict],
  67. datas: List[dict],
  68. threshold: Optional[float] = None,
  69. ):
  70. """Apply the post-processing to a batch of outputs.
  71. Args:
  72. batch_outputs (List[dict]): The list of detection outputs.
  73. datas (List[dict]): The list of input data.
  74. threshold: Optional[float]: object score threshold for postprocess.
  75. Returns:
  76. List[Boxes]: The list of post-processed detection boxes.
  77. """
  78. outputs = []
  79. for data, output in zip(datas, batch_outputs):
  80. boxes_masks = self.apply(
  81. img_size=data["ori_img_size"],
  82. **output,
  83. threshold=threshold if threshold is not None else self.threshold
  84. )
  85. outputs.append(boxes_masks)
  86. return outputs