yolo_head.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import paddle
  2. import paddle.nn as nn
  3. import paddle.nn.functional as F
  4. from paddle import ParamAttr
  5. from paddle.regularizer import L2Decay
  6. from paddlex.ppdet.core.workspace import register
  7. def _de_sigmoid(x, eps=1e-7):
  8. x = paddle.clip(x, eps, 1. / eps)
  9. x = paddle.clip(1. / x - 1., eps, 1. / eps)
  10. x = -paddle.log(x)
  11. return x
  12. @register
  13. class YOLOv3Head(nn.Layer):
  14. __shared__ = ['num_classes', 'data_format']
  15. __inject__ = ['loss']
  16. def __init__(self,
  17. in_channels=[1024, 512, 256],
  18. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  19. [59, 119], [116, 90], [156, 198], [373, 326]],
  20. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  21. num_classes=80,
  22. loss='YOLOv3Loss',
  23. iou_aware=False,
  24. iou_aware_factor=0.4,
  25. data_format='NCHW'):
  26. """
  27. Head for YOLOv3 network
  28. Args:
  29. num_classes (int): number of foreground classes
  30. anchors (list): anchors
  31. anchor_masks (list): anchor masks
  32. loss (object): YOLOv3Loss instance
  33. iou_aware (bool): whether to use iou_aware
  34. iou_aware_factor (float): iou aware factor
  35. data_format (str): data format, NCHW or NHWC
  36. """
  37. super(YOLOv3Head, self).__init__()
  38. assert len(in_channels) > 0, "in_channels length should > 0"
  39. self.in_channels = in_channels
  40. self.num_classes = num_classes
  41. self.loss = loss
  42. self.iou_aware = iou_aware
  43. self.iou_aware_factor = iou_aware_factor
  44. self.parse_anchor(anchors, anchor_masks)
  45. self.num_outputs = len(self.anchors)
  46. self.data_format = data_format
  47. self.yolo_outputs = []
  48. for i in range(len(self.anchors)):
  49. if self.iou_aware:
  50. num_filters = len(self.anchors[i]) * (self.num_classes + 6)
  51. else:
  52. num_filters = len(self.anchors[i]) * (self.num_classes + 5)
  53. name = 'yolo_output.{}'.format(i)
  54. conv = nn.Conv2D(
  55. in_channels=self.in_channels[i],
  56. out_channels=num_filters,
  57. kernel_size=1,
  58. stride=1,
  59. padding=0,
  60. data_format=data_format,
  61. bias_attr=ParamAttr(regularizer=L2Decay(0.)))
  62. conv.skip_quant = True
  63. yolo_output = self.add_sublayer(name, conv)
  64. self.yolo_outputs.append(yolo_output)
  65. def parse_anchor(self, anchors, anchor_masks):
  66. self.anchors = [[anchors[i] for i in mask] for mask in anchor_masks]
  67. self.mask_anchors = []
  68. anchor_num = len(anchors)
  69. for masks in anchor_masks:
  70. self.mask_anchors.append([])
  71. for mask in masks:
  72. assert mask < anchor_num, "anchor mask index overflow"
  73. self.mask_anchors[-1].extend(anchors[mask])
  74. def forward(self, feats, targets=None):
  75. assert len(feats) == len(self.anchors)
  76. yolo_outputs = []
  77. for i, feat in enumerate(feats):
  78. yolo_output = self.yolo_outputs[i](feat)
  79. if self.data_format == 'NHWC':
  80. yolo_output = paddle.transpose(yolo_output, [0, 3, 1, 2])
  81. yolo_outputs.append(yolo_output)
  82. if self.training:
  83. return self.loss(yolo_outputs, targets, self.anchors)
  84. else:
  85. if self.iou_aware:
  86. y = []
  87. for i, out in enumerate(yolo_outputs):
  88. na = len(self.anchors[i])
  89. ioup, x = out[:, 0:na, :, :], out[:, na:, :, :]
  90. b, c, h, w = x.shape
  91. no = c // na
  92. x = x.reshape((b, na, no, h * w))
  93. ioup = ioup.reshape((b, na, 1, h * w))
  94. obj = x[:, :, 4:5, :]
  95. ioup = F.sigmoid(ioup)
  96. obj = F.sigmoid(obj)
  97. obj_t = (obj**(1 - self.iou_aware_factor)) * (
  98. ioup**self.iou_aware_factor)
  99. obj_t = _de_sigmoid(obj_t)
  100. loc_t = x[:, :, :4, :]
  101. cls_t = x[:, :, 5:, :]
  102. y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2)
  103. y_t = y_t.reshape((b, c, h, w))
  104. y.append(y_t)
  105. return y
  106. else:
  107. return yolo_outputs
  108. @classmethod
  109. def from_config(cls, cfg, input_shape):
  110. return {'in_channels': [i.channels for i in input_shape], }