instance_seg.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. from ....utils import logging
  17. from ..base import BaseComponent
  18. from .det import restructured_boxes
  19. def extract_masks_from_boxes(boxes, masks):
  20. """
  21. Extracts the portion of each mask that is within the corresponding box.
  22. """
  23. new_masks = []
  24. for i, box in enumerate(boxes):
  25. x_min, y_min, x_max, y_max = box["coordinate"]
  26. x_min, y_min, x_max, y_max = map(
  27. lambda x: int(round(x)), [x_min, y_min, x_max, y_max]
  28. )
  29. cropped_mask = masks[i][y_min:y_max, x_min:x_max]
  30. new_masks.append(cropped_mask)
  31. return new_masks
  32. class InstanceSegPostProcess(BaseComponent):
  33. """Save Result Transform"""
  34. INPUT_KEYS = ["boxes", "masks"]
  35. OUTPUT_KEYS = ["img_path", "boxes", "masks"]
  36. DEAULT_INPUTS = {"boxes": "boxes", "masks": "masks"}
  37. DEAULT_OUTPUTS = {
  38. "boxes": "boxes",
  39. "masks": "masks",
  40. }
  41. def __init__(self, threshold=0.5, labels=None):
  42. super().__init__()
  43. self.threshold = threshold
  44. self.labels = labels
  45. def apply(self, boxes, masks):
  46. """apply"""
  47. expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
  48. boxes = boxes[expect_boxes, :]
  49. boxes = restructured_boxes(boxes, self.labels)
  50. masks = masks[expect_boxes, :, :]
  51. masks = extract_masks_from_boxes(boxes, masks)
  52. result = {"boxes": boxes, "masks": masks}
  53. return result