instance_seg.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. from ....utils import logging
  17. from ..base import BaseComponent
  18. from .det import restructured_boxes
  19. import cv2
  20. import numpy as np
  21. def extract_masks_from_boxes(boxes, masks):
  22. """
  23. Extracts the portion of each mask that is within the corresponding box.
  24. """
  25. new_masks = []
  26. for i, box in enumerate(boxes):
  27. x_min, y_min, x_max, y_max = box["coordinate"]
  28. x_min, y_min, x_max, y_max = map(
  29. lambda x: int(round(x)), [x_min, y_min, x_max, y_max]
  30. )
  31. cropped_mask = masks[i][y_min:y_max, x_min:x_max]
  32. new_masks.append(cropped_mask)
  33. return new_masks
  34. class InstanceSegPostProcess(BaseComponent):
  35. """Save Result Transform"""
  36. INPUT_KEYS = [["boxes", "masks", "img_size"], ["class_id", "masks", "img_size"]]
  37. OUTPUT_KEYS = ["img_path", "boxes", "masks"]
  38. DEAULT_INPUTS = {"boxes": "boxes", "masks": "masks", "img_size": "ori_img_size"}
  39. DEAULT_OUTPUTS = {
  40. "boxes": "boxes",
  41. "masks": "masks",
  42. }
  43. def __init__(self, threshold=0.5, labels=None):
  44. super().__init__()
  45. self.threshold = threshold
  46. self.labels = labels
  47. def apply(self, masks, img_size, boxes=None, class_id=None):
  48. """apply"""
  49. if boxes is not None:
  50. expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
  51. boxes = boxes[expect_boxes, :]
  52. boxes = restructured_boxes(boxes, self.labels, img_size)
  53. masks = masks[expect_boxes, :, :]
  54. masks = extract_masks_from_boxes(boxes, masks)
  55. result = {"boxes": boxes, "masks": masks}
  56. else:
  57. mask_info = []
  58. class_id = [list(item) for item in zip(class_id[0], class_id[1])]
  59. selected_masks = []
  60. for i, info in enumerate(class_id):
  61. label_id = int(info[0])
  62. if info[1] < self.threshold:
  63. continue
  64. mask_info.append(
  65. {
  66. "label": self.labels[label_id],
  67. "score": info[1],
  68. "class_id": label_id,
  69. }
  70. )
  71. selected_masks.append(masks[i])
  72. result = {"boxes": mask_info, "masks": selected_masks}
  73. return result