processors.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import List, Sequence, Tuple, Union, Optional
  16. import numpy as np
  17. from ....utils import logging
  18. from ..object_detection.processors import restructured_boxes
  19. import cv2
  20. def extract_masks_from_boxes(boxes, masks):
  21. """
  22. Extracts the portion of each mask that is within the corresponding box.
  23. """
  24. new_masks = []
  25. for i, box in enumerate(boxes):
  26. x_min, y_min, x_max, y_max = box["coordinate"]
  27. x_min, y_min, x_max, y_max = map(
  28. lambda x: int(round(x)), [x_min, y_min, x_max, y_max]
  29. )
  30. cropped_mask = masks[i][y_min:y_max, x_min:x_max]
  31. new_masks.append(cropped_mask)
  32. return new_masks
  33. class InstanceSegPostProcess(object):
  34. """Save Result Transform"""
  35. def __init__(self, threshold=0.5, labels=None):
  36. super().__init__()
  37. self.threshold = threshold
  38. self.labels = labels
  39. def apply(self, masks, img_size, boxes=None, class_id=None, threshold=None):
  40. """apply"""
  41. if boxes is not None:
  42. expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
  43. boxes = boxes[expect_boxes, :]
  44. boxes = restructured_boxes(boxes, self.labels, img_size)
  45. masks = masks[expect_boxes, :, :]
  46. masks = extract_masks_from_boxes(boxes, masks)
  47. result = {"boxes": boxes, "masks": masks}
  48. else:
  49. mask_info = []
  50. class_id = [list(item) for item in zip(class_id[0], class_id[1])]
  51. selected_masks = []
  52. for i, info in enumerate(class_id):
  53. label_id = int(info[0])
  54. if info[1] < threshold:
  55. continue
  56. mask_info.append(
  57. {
  58. "label": self.labels[label_id],
  59. "score": info[1],
  60. "class_id": label_id,
  61. }
  62. )
  63. selected_masks.append(masks[i])
  64. result = {"boxes": mask_info, "masks": selected_masks}
  65. return result
  66. def __call__(
  67. self,
  68. batch_outputs: List[dict],
  69. datas: List[dict],
  70. threshold: Optional[float] = None,
  71. ):
  72. """Apply the post-processing to a batch of outputs.
  73. Args:
  74. batch_outputs (List[dict]): The list of detection outputs.
  75. datas (List[dict]): The list of input data.
  76. threshold: Optional[float]: object score threshold for postprocess.
  77. Returns:
  78. List[Boxes]: The list of post-processed detection boxes.
  79. """
  80. outputs = []
  81. for data, output in zip(datas, batch_outputs):
  82. boxes_masks = self.apply(
  83. img_size=data["ori_img_size"],
  84. **output,
  85. threshold=threshold if threshold is not None else self.threshold
  86. )
  87. outputs.append(boxes_masks)
  88. return outputs