rcnn_vl.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. import logging
  3. import numpy as np
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. from torch import nn
  7. from detectron2.config import configurable
  8. from detectron2.structures import ImageList, Instances
  9. from detectron2.utils.events import get_event_storage
  10. from detectron2.modeling.backbone import Backbone, build_backbone
  11. from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
  12. from detectron2.modeling.meta_arch import GeneralizedRCNN
  13. from detectron2.modeling.postprocessing import detector_postprocess
  14. from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image
  15. from contextlib import contextmanager
  16. from itertools import count
  17. @META_ARCH_REGISTRY.register()
  18. class VLGeneralizedRCNN(GeneralizedRCNN):
  19. """
  20. Generalized R-CNN. Any models that contains the following three components:
  21. 1. Per-image feature extraction (aka backbone)
  22. 2. Region proposal generation
  23. 3. Per-region feature extraction and prediction
  24. """
  25. def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
  26. """
  27. Args:
  28. batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
  29. Each item in the list contains the inputs for one image.
  30. For now, each item in the list is a dict that contains:
  31. * image: Tensor, image in (C, H, W) format.
  32. * instances (optional): groundtruth :class:`Instances`
  33. * proposals (optional): :class:`Instances`, precomputed proposals.
  34. Other information that's included in the original dicts, such as:
  35. * "height", "width" (int): the output resolution of the model, used in inference.
  36. See :meth:`postprocess` for details.
  37. Returns:
  38. list[dict]:
  39. Each dict is the output for one input image.
  40. The dict contains one key "instances" whose value is a :class:`Instances`.
  41. The :class:`Instances` object has the following keys:
  42. "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
  43. """
  44. if not self.training:
  45. return self.inference(batched_inputs)
  46. images = self.preprocess_image(batched_inputs)
  47. if "instances" in batched_inputs[0]:
  48. gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
  49. else:
  50. gt_instances = None
  51. # features = self.backbone(images.tensor)
  52. input = self.get_batch(batched_inputs, images)
  53. features = self.backbone(input)
  54. if self.proposal_generator is not None:
  55. proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
  56. else:
  57. assert "proposals" in batched_inputs[0]
  58. proposals = [x["proposals"].to(self.device) for x in batched_inputs]
  59. proposal_losses = {}
  60. _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
  61. if self.vis_period > 0:
  62. storage = get_event_storage()
  63. if storage.iter % self.vis_period == 0:
  64. self.visualize_training(batched_inputs, proposals)
  65. losses = {}
  66. losses.update(detector_losses)
  67. losses.update(proposal_losses)
  68. return losses
  69. def inference(
  70. self,
  71. batched_inputs: List[Dict[str, torch.Tensor]],
  72. detected_instances: Optional[List[Instances]] = None,
  73. do_postprocess: bool = True,
  74. ):
  75. """
  76. Run inference on the given inputs.
  77. Args:
  78. batched_inputs (list[dict]): same as in :meth:`forward`
  79. detected_instances (None or list[Instances]): if not None, it
  80. contains an `Instances` object per image. The `Instances`
  81. object contains "pred_boxes" and "pred_classes" which are
  82. known boxes in the image.
  83. The inference will then skip the detection of bounding boxes,
  84. and only predict other per-ROI outputs.
  85. do_postprocess (bool): whether to apply post-processing on the outputs.
  86. Returns:
  87. When do_postprocess=True, same as in :meth:`forward`.
  88. Otherwise, a list[Instances] containing raw network outputs.
  89. """
  90. assert not self.training
  91. images = self.preprocess_image(batched_inputs)
  92. # features = self.backbone(images.tensor)
  93. input = self.get_batch(batched_inputs, images)
  94. features = self.backbone(input)
  95. if detected_instances is None:
  96. if self.proposal_generator is not None:
  97. proposals, _ = self.proposal_generator(images, features, None)
  98. else:
  99. assert "proposals" in batched_inputs[0]
  100. proposals = [x["proposals"].to(self.device) for x in batched_inputs]
  101. results, _ = self.roi_heads(images, features, proposals, None)
  102. else:
  103. detected_instances = [x.to(self.device) for x in detected_instances]
  104. results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
  105. if do_postprocess:
  106. assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess."
  107. return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes)
  108. else:
  109. return results
  110. def get_batch(self, examples, images):
  111. if len(examples) >= 1 and "bbox" not in examples[0]: # image_only
  112. return {"images": images.tensor}
  113. return input
  114. def _batch_inference(self, batched_inputs, detected_instances=None):
  115. """
  116. Execute inference on a list of inputs,
  117. using batch size = self.batch_size (e.g., 2), instead of the length of the list.
  118. Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference`
  119. """
  120. if detected_instances is None:
  121. detected_instances = [None] * len(batched_inputs)
  122. outputs = []
  123. inputs, instances = [], []
  124. for idx, input, instance in zip(count(), batched_inputs, detected_instances):
  125. inputs.append(input)
  126. instances.append(instance)
  127. if len(inputs) == 2 or idx == len(batched_inputs) - 1:
  128. outputs.extend(
  129. self.inference(
  130. inputs,
  131. instances if instances[0] is not None else None,
  132. do_postprocess=True, # False
  133. )
  134. )
  135. inputs, instances = [], []
  136. return outputs