|
@@ -24,10 +24,11 @@ from paddle.nn.initializer import Constant
|
|
|
|
|
|
|
|
from paddlex.ppdet.core.workspace import register
|
|
from paddlex.ppdet.core.workspace import register
|
|
|
from ..initializer import normal_, constant_, bias_init_with_prob
|
|
from ..initializer import normal_, constant_, bias_init_with_prob
|
|
|
-from paddlex.ppdet.modeling.bbox_utils import bbox_center
|
|
|
|
|
|
|
+from paddlex.ppdet.modeling.bbox_utils import bbox_center, batch_distance2bbox
|
|
|
from ..losses import GIoULoss
|
|
from ..losses import GIoULoss
|
|
|
-from paddle.vision.ops import deform_conv2d
|
|
|
|
|
from paddlex.ppdet.modeling.layers import ConvNormLayer
|
|
from paddlex.ppdet.modeling.layers import ConvNormLayer
|
|
|
|
|
+from paddlex.ppdet.modeling.ops import get_static_shape
|
|
|
|
|
+from paddlex.ppdet.modeling.assigners.utils import generate_anchors_for_grid_cell
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScaleReg(nn.Layer):
|
|
class ScaleReg(nn.Layer):
|
|
@@ -84,25 +85,13 @@ class TaskDecomposition(nn.Layer):
|
|
|
normal_(self.la_conv1.weight, std=0.001)
|
|
normal_(self.la_conv1.weight, std=0.001)
|
|
|
normal_(self.la_conv2.weight, std=0.001)
|
|
normal_(self.la_conv2.weight, std=0.001)
|
|
|
|
|
|
|
|
- def forward(self, feat, avg_feat=None):
|
|
|
|
|
- b, _, h, w = feat.shape
|
|
|
|
|
- if avg_feat is None:
|
|
|
|
|
- avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
|
|
|
|
|
|
|
+ def forward(self, feat, avg_feat):
|
|
|
|
|
+ b, _, h, w = get_static_shape(feat)
|
|
|
weight = F.relu(self.la_conv1(avg_feat))
|
|
weight = F.relu(self.la_conv1(avg_feat))
|
|
|
- weight = F.sigmoid(self.la_conv2(weight))
|
|
|
|
|
-
|
|
|
|
|
- # here new_conv_weight = layer_attention_weight * conv_weight
|
|
|
|
|
- # in order to save memory and FLOPs.
|
|
|
|
|
- conv_weight = weight.reshape([b, 1, self.stacked_convs, 1]) * \
|
|
|
|
|
- self.reduction_conv.conv.weight.reshape(
|
|
|
|
|
- [1, self.feat_channels, self.stacked_convs, self.feat_channels])
|
|
|
|
|
- conv_weight = conv_weight.reshape(
|
|
|
|
|
- [b, self.feat_channels, self.in_channels])
|
|
|
|
|
- feat = feat.reshape([b, self.in_channels, h * w])
|
|
|
|
|
- feat = paddle.bmm(conv_weight, feat).reshape(
|
|
|
|
|
- [b, self.feat_channels, h, w])
|
|
|
|
|
- if self.norm_type is not None:
|
|
|
|
|
- feat = self.reduction_conv.norm(feat)
|
|
|
|
|
|
|
+ weight = F.sigmoid(self.la_conv2(weight)).unsqueeze(-1)
|
|
|
|
|
+ feat = paddle.reshape(
|
|
|
|
|
+ feat, [b, self.stacked_convs, self.feat_channels, h, w]) * weight
|
|
|
|
|
+ feat = self.reduction_conv(feat.flatten(1, 2))
|
|
|
feat = F.relu(feat)
|
|
feat = F.relu(feat)
|
|
|
return feat
|
|
return feat
|
|
|
|
|
|
|
@@ -211,81 +200,32 @@ class TOODHead(nn.Layer):
|
|
|
normal_(self.cls_prob_conv2.weight, std=0.01)
|
|
normal_(self.cls_prob_conv2.weight, std=0.01)
|
|
|
constant_(self.cls_prob_conv2.bias, bias_cls)
|
|
constant_(self.cls_prob_conv2.bias, bias_cls)
|
|
|
normal_(self.reg_offset_conv1.weight, std=0.001)
|
|
normal_(self.reg_offset_conv1.weight, std=0.001)
|
|
|
- normal_(self.reg_offset_conv2.weight, std=0.001)
|
|
|
|
|
|
|
+ constant_(self.reg_offset_conv2.weight)
|
|
|
constant_(self.reg_offset_conv2.bias)
|
|
constant_(self.reg_offset_conv2.bias)
|
|
|
|
|
|
|
|
- def _generate_anchors(self, feats):
|
|
|
|
|
- anchors, num_anchors_list = [], []
|
|
|
|
|
- stride_tensor_list = []
|
|
|
|
|
- for feat, stride in zip(feats, self.fpn_strides):
|
|
|
|
|
- _, _, h, w = feat.shape
|
|
|
|
|
- cell_half_size = self.grid_cell_scale * stride * 0.5
|
|
|
|
|
- shift_x = (paddle.arange(end=w) + self.grid_cell_offset) * stride
|
|
|
|
|
- shift_y = (paddle.arange(end=h) + self.grid_cell_offset) * stride
|
|
|
|
|
- shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
|
|
|
|
|
- anchor = paddle.stack(
|
|
|
|
|
- [
|
|
|
|
|
- shift_x - cell_half_size, shift_y - cell_half_size,
|
|
|
|
|
- shift_x + cell_half_size, shift_y + cell_half_size
|
|
|
|
|
- ],
|
|
|
|
|
- axis=-1)
|
|
|
|
|
- anchors.append(anchor.reshape([-1, 4]))
|
|
|
|
|
- num_anchors_list.append(len(anchors[-1]))
|
|
|
|
|
- stride_tensor_list.append(
|
|
|
|
|
- paddle.full([num_anchors_list[-1], 1], stride))
|
|
|
|
|
- return anchors, num_anchors_list, stride_tensor_list
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def _batch_distance2bbox(points, distance, max_shapes=None):
|
|
|
|
|
- """Decode distance prediction to bounding box.
|
|
|
|
|
- Args:
|
|
|
|
|
- points (Tensor): [B, l, 2]
|
|
|
|
|
- distance (Tensor): [B, l, 4]
|
|
|
|
|
- max_shapes (tuple): [B, 2], "h w" format, Shape of the image.
|
|
|
|
|
- Returns:
|
|
|
|
|
- Tensor: Decoded bboxes.
|
|
|
|
|
- """
|
|
|
|
|
- x1 = points[:, :, 0] - distance[:, :, 0]
|
|
|
|
|
- y1 = points[:, :, 1] - distance[:, :, 1]
|
|
|
|
|
- x2 = points[:, :, 0] + distance[:, :, 2]
|
|
|
|
|
- y2 = points[:, :, 1] + distance[:, :, 3]
|
|
|
|
|
- bboxes = paddle.stack([x1, y1, x2, y2], -1)
|
|
|
|
|
- if max_shapes is not None:
|
|
|
|
|
- out_bboxes = []
|
|
|
|
|
- for bbox, max_shape in zip(bboxes, max_shapes):
|
|
|
|
|
- bbox[:, 0] = bbox[:, 0].clip(min=0, max=max_shape[1])
|
|
|
|
|
- bbox[:, 1] = bbox[:, 1].clip(min=0, max=max_shape[0])
|
|
|
|
|
- bbox[:, 2] = bbox[:, 2].clip(min=0, max=max_shape[1])
|
|
|
|
|
- bbox[:, 3] = bbox[:, 3].clip(min=0, max=max_shape[0])
|
|
|
|
|
- out_bboxes.append(bbox)
|
|
|
|
|
- out_bboxes = paddle.stack(out_bboxes)
|
|
|
|
|
- return out_bboxes
|
|
|
|
|
- return bboxes
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def _deform_sampling(feat, offset):
|
|
|
|
|
- """ Sampling the feature according to offset.
|
|
|
|
|
- Args:
|
|
|
|
|
- feat (Tensor): Feature
|
|
|
|
|
- offset (Tensor): Spatial offset for for feature sampliing
|
|
|
|
|
- """
|
|
|
|
|
- # it is an equivalent implementation of bilinear interpolation
|
|
|
|
|
- # you can also use F.grid_sample instead
|
|
|
|
|
- c = feat.shape[1]
|
|
|
|
|
- weight = paddle.ones([c, 1, 1, 1])
|
|
|
|
|
- y = deform_conv2d(feat, offset, weight, deformable_groups=c, groups=c)
|
|
|
|
|
- return y
|
|
|
|
|
|
|
+ def _reg_grid_sample(self, feat, offset, anchor_points):
|
|
|
|
|
+ b, _, h, w = get_static_shape(feat)
|
|
|
|
|
+ feat = paddle.reshape(feat, [-1, 1, h, w])
|
|
|
|
|
+ offset = paddle.reshape(offset, [-1, 2, h, w]).transpose([0, 2, 3, 1])
|
|
|
|
|
+ grid_shape = paddle.concat([w, h]).astype('float32')
|
|
|
|
|
+ grid = (offset + anchor_points) / grid_shape
|
|
|
|
|
+ grid = 2 * grid.clip(0., 1.) - 1
|
|
|
|
|
+ feat = F.grid_sample(feat, grid)
|
|
|
|
|
+ feat = paddle.reshape(feat, [b, -1, h, w])
|
|
|
|
|
+ return feat
|
|
|
|
|
|
|
|
def forward(self, feats):
|
|
def forward(self, feats):
|
|
|
assert len(feats) == len(self.fpn_strides), \
|
|
assert len(feats) == len(self.fpn_strides), \
|
|
|
"The size of feats is not equal to size of fpn_strides"
|
|
"The size of feats is not equal to size of fpn_strides"
|
|
|
|
|
|
|
|
- anchors, num_anchors_list, stride_tensor_list = self._generate_anchors(
|
|
|
|
|
- feats)
|
|
|
|
|
|
|
+ anchors, num_anchors_list, stride_tensor_list = generate_anchors_for_grid_cell(
|
|
|
|
|
+ feats, self.fpn_strides, self.grid_cell_scale,
|
|
|
|
|
+ self.grid_cell_offset)
|
|
|
|
|
+
|
|
|
cls_score_list, bbox_pred_list = [], []
|
|
cls_score_list, bbox_pred_list = [], []
|
|
|
for feat, scale_reg, anchor, stride in zip(feats, self.scales_regs,
|
|
for feat, scale_reg, anchor, stride in zip(feats, self.scales_regs,
|
|
|
anchors, self.fpn_strides):
|
|
anchors, self.fpn_strides):
|
|
|
- b, _, h, w = feat.shape
|
|
|
|
|
|
|
+ b, _, h, w = get_static_shape(feat)
|
|
|
inter_feats = []
|
|
inter_feats = []
|
|
|
for inter_conv in self.inter_convs:
|
|
for inter_conv in self.inter_convs:
|
|
|
feat = F.relu(inter_conv(feat))
|
|
feat = F.relu(inter_conv(feat))
|
|
@@ -309,16 +249,16 @@ class TOODHead(nn.Layer):
|
|
|
|
|
|
|
|
# reg prediction and alignment
|
|
# reg prediction and alignment
|
|
|
reg_dist = scale_reg(self.tood_reg(reg_feat).exp())
|
|
reg_dist = scale_reg(self.tood_reg(reg_feat).exp())
|
|
|
- reg_dist = reg_dist.transpose([0, 2, 3, 1]).reshape([b, -1, 4])
|
|
|
|
|
|
|
+ reg_dist = reg_dist.flatten(2).transpose([0, 2, 1])
|
|
|
anchor_centers = bbox_center(anchor).unsqueeze(0) / stride
|
|
anchor_centers = bbox_center(anchor).unsqueeze(0) / stride
|
|
|
- reg_bbox = self._batch_distance2bbox(
|
|
|
|
|
- anchor_centers.tile([b, 1, 1]), reg_dist)
|
|
|
|
|
|
|
+ reg_bbox = batch_distance2bbox(anchor_centers, reg_dist)
|
|
|
if self.use_align_head:
|
|
if self.use_align_head:
|
|
|
- reg_bbox = reg_bbox.reshape([b, h, w, 4]).transpose(
|
|
|
|
|
- [0, 3, 1, 2])
|
|
|
|
|
reg_offset = F.relu(self.reg_offset_conv1(feat))
|
|
reg_offset = F.relu(self.reg_offset_conv1(feat))
|
|
|
reg_offset = self.reg_offset_conv2(reg_offset)
|
|
reg_offset = self.reg_offset_conv2(reg_offset)
|
|
|
- bbox_pred = self._deform_sampling(reg_bbox, reg_offset)
|
|
|
|
|
|
|
+ reg_bbox = reg_bbox.transpose([0, 2, 1]).reshape([b, 4, h, w])
|
|
|
|
|
+ anchor_centers = anchor_centers.reshape([1, h, w, 2])
|
|
|
|
|
+ bbox_pred = self._reg_grid_sample(reg_bbox, reg_offset,
|
|
|
|
|
+ anchor_centers)
|
|
|
bbox_pred = bbox_pred.flatten(2).transpose([0, 2, 1])
|
|
bbox_pred = bbox_pred.flatten(2).transpose([0, 2, 1])
|
|
|
else:
|
|
else:
|
|
|
bbox_pred = reg_bbox
|
|
bbox_pred = reg_bbox
|