| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- import logging
- import numpy as np
- from typing import Dict, List, Optional, Tuple
- import torch
- from torch import nn
- from detectron2.config import configurable
- from detectron2.structures import ImageList, Instances
- from detectron2.utils.events import get_event_storage
- from detectron2.modeling.backbone import Backbone, build_backbone
- from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
- from detectron2.modeling.meta_arch import GeneralizedRCNN
- from detectron2.modeling.postprocessing import detector_postprocess
- from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image
- from contextlib import contextmanager
- from itertools import count
- @META_ARCH_REGISTRY.register()
- class VLGeneralizedRCNN(GeneralizedRCNN):
- """
- Generalized R-CNN. Any models that contains the following three components:
- 1. Per-image feature extraction (aka backbone)
- 2. Region proposal generation
- 3. Per-region feature extraction and prediction
- """
- def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
- """
- Args:
- batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
- Each item in the list contains the inputs for one image.
- For now, each item in the list is a dict that contains:
- * image: Tensor, image in (C, H, W) format.
- * instances (optional): groundtruth :class:`Instances`
- * proposals (optional): :class:`Instances`, precomputed proposals.
- Other information that's included in the original dicts, such as:
- * "height", "width" (int): the output resolution of the model, used in inference.
- See :meth:`postprocess` for details.
- Returns:
- list[dict]:
- Each dict is the output for one input image.
- The dict contains one key "instances" whose value is a :class:`Instances`.
- The :class:`Instances` object has the following keys:
- "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
- """
- if not self.training:
- return self.inference(batched_inputs)
- images = self.preprocess_image(batched_inputs)
- if "instances" in batched_inputs[0]:
- gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
- else:
- gt_instances = None
- # features = self.backbone(images.tensor)
- input = self.get_batch(batched_inputs, images)
- features = self.backbone(input)
- if self.proposal_generator is not None:
- proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
- else:
- assert "proposals" in batched_inputs[0]
- proposals = [x["proposals"].to(self.device) for x in batched_inputs]
- proposal_losses = {}
- _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
- if self.vis_period > 0:
- storage = get_event_storage()
- if storage.iter % self.vis_period == 0:
- self.visualize_training(batched_inputs, proposals)
- losses = {}
- losses.update(detector_losses)
- losses.update(proposal_losses)
- return losses
- def inference(
- self,
- batched_inputs: List[Dict[str, torch.Tensor]],
- detected_instances: Optional[List[Instances]] = None,
- do_postprocess: bool = True,
- ):
- """
- Run inference on the given inputs.
- Args:
- batched_inputs (list[dict]): same as in :meth:`forward`
- detected_instances (None or list[Instances]): if not None, it
- contains an `Instances` object per image. The `Instances`
- object contains "pred_boxes" and "pred_classes" which are
- known boxes in the image.
- The inference will then skip the detection of bounding boxes,
- and only predict other per-ROI outputs.
- do_postprocess (bool): whether to apply post-processing on the outputs.
- Returns:
- When do_postprocess=True, same as in :meth:`forward`.
- Otherwise, a list[Instances] containing raw network outputs.
- """
- assert not self.training
- images = self.preprocess_image(batched_inputs)
- # features = self.backbone(images.tensor)
- input = self.get_batch(batched_inputs, images)
- features = self.backbone(input)
- if detected_instances is None:
- if self.proposal_generator is not None:
- proposals, _ = self.proposal_generator(images, features, None)
- else:
- assert "proposals" in batched_inputs[0]
- proposals = [x["proposals"].to(self.device) for x in batched_inputs]
- results, _ = self.roi_heads(images, features, proposals, None)
- else:
- detected_instances = [x.to(self.device) for x in detected_instances]
- results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
- if do_postprocess:
- assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess."
- return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes)
- else:
- return results
- def get_batch(self, examples, images):
- if len(examples) >= 1 and "bbox" not in examples[0]: # image_only
- return {"images": images.tensor}
- return input
- def _batch_inference(self, batched_inputs, detected_instances=None):
- """
- Execute inference on a list of inputs,
- using batch size = self.batch_size (e.g., 2), instead of the length of the list.
- Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference`
- """
- if detected_instances is None:
- detected_instances = [None] * len(batched_inputs)
- outputs = []
- inputs, instances = [], []
- for idx, input, instance in zip(count(), batched_inputs, detected_instances):
- inputs.append(input)
- instances.append(instance)
- if len(inputs) == 2 or idx == len(batched_inputs) - 1:
- outputs.extend(
- self.inference(
- inputs,
- instances if instances[0] is not None else None,
- do_postprocess=True, # False
- )
- )
- inputs, instances = [], []
- return outputs
|