processors.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import List, Sequence, Tuple, Union, Optional
  16. import numpy as np
  17. from ....utils import logging
  18. from ..object_detection.processors import restructured_boxes
  19. from ...utils.benchmark import benchmark
  20. import cv2
  21. def extract_masks_from_boxes(boxes, masks):
  22. """
  23. Extracts the portion of each mask that is within the corresponding box.
  24. """
  25. new_masks = []
  26. for i, box in enumerate(boxes):
  27. x_min, y_min, x_max, y_max = box["coordinate"]
  28. x_min, y_min, x_max, y_max = map(
  29. lambda x: int(round(x)), [x_min, y_min, x_max, y_max]
  30. )
  31. cropped_mask = masks[i][y_min:y_max, x_min:x_max]
  32. new_masks.append(cropped_mask)
  33. return new_masks
  34. @benchmark.timeit
  35. class InstanceSegPostProcess(object):
  36. """Save Result Transform"""
  37. def __init__(self, threshold=0.5, labels=None):
  38. super().__init__()
  39. self.threshold = threshold
  40. self.labels = labels
  41. def apply(self, masks, img_size, boxes=None, class_id=None, threshold=None):
  42. """apply"""
  43. if boxes is not None:
  44. expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
  45. boxes = boxes[expect_boxes, :]
  46. boxes = restructured_boxes(boxes, self.labels, img_size)
  47. masks = masks[expect_boxes, :, :]
  48. masks = extract_masks_from_boxes(boxes, masks)
  49. result = {"boxes": boxes, "masks": masks}
  50. else:
  51. mask_info = []
  52. class_id = [list(item) for item in zip(class_id[0], class_id[1])]
  53. selected_masks = []
  54. for i, info in enumerate(class_id):
  55. label_id = int(info[0])
  56. if info[1] < threshold:
  57. continue
  58. mask_info.append(
  59. {
  60. "label": self.labels[label_id],
  61. "score": info[1],
  62. "class_id": label_id,
  63. }
  64. )
  65. selected_masks.append(masks[i])
  66. result = {"boxes": mask_info, "masks": selected_masks}
  67. return result
  68. def __call__(
  69. self,
  70. batch_outputs: List[dict],
  71. datas: List[dict],
  72. threshold: Optional[float] = None,
  73. ):
  74. """Apply the post-processing to a batch of outputs.
  75. Args:
  76. batch_outputs (List[dict]): The list of detection outputs.
  77. datas (List[dict]): The list of input data.
  78. threshold: Optional[float]: object score threshold for postprocess.
  79. Returns:
  80. List[Boxes]: The list of post-processed detection boxes.
  81. """
  82. outputs = []
  83. for data, output in zip(datas, batch_outputs):
  84. boxes_masks = self.apply(
  85. img_size=data["ori_img_size"],
  86. **output,
  87. threshold=threshold if threshold is not None else self.threshold
  88. )
  89. outputs.append(boxes_masks)
  90. return outputs