intracl.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from torch import nn
  2. class IntraCLBlock(nn.Module):
  3. def __init__(self, in_channels=96, reduce_factor=4):
  4. super(IntraCLBlock, self).__init__()
  5. self.channels = in_channels
  6. self.rf = reduce_factor
  7. self.conv1x1_reduce_channel = nn.Conv2d(
  8. self.channels, self.channels // self.rf, kernel_size=1, stride=1, padding=0
  9. )
  10. self.conv1x1_return_channel = nn.Conv2d(
  11. self.channels // self.rf, self.channels, kernel_size=1, stride=1, padding=0
  12. )
  13. self.v_layer_7x1 = nn.Conv2d(
  14. self.channels // self.rf,
  15. self.channels // self.rf,
  16. kernel_size=(7, 1),
  17. stride=(1, 1),
  18. padding=(3, 0),
  19. )
  20. self.v_layer_5x1 = nn.Conv2d(
  21. self.channels // self.rf,
  22. self.channels // self.rf,
  23. kernel_size=(5, 1),
  24. stride=(1, 1),
  25. padding=(2, 0),
  26. )
  27. self.v_layer_3x1 = nn.Conv2d(
  28. self.channels // self.rf,
  29. self.channels // self.rf,
  30. kernel_size=(3, 1),
  31. stride=(1, 1),
  32. padding=(1, 0),
  33. )
  34. self.q_layer_1x7 = nn.Conv2d(
  35. self.channels // self.rf,
  36. self.channels // self.rf,
  37. kernel_size=(1, 7),
  38. stride=(1, 1),
  39. padding=(0, 3),
  40. )
  41. self.q_layer_1x5 = nn.Conv2d(
  42. self.channels // self.rf,
  43. self.channels // self.rf,
  44. kernel_size=(1, 5),
  45. stride=(1, 1),
  46. padding=(0, 2),
  47. )
  48. self.q_layer_1x3 = nn.Conv2d(
  49. self.channels // self.rf,
  50. self.channels // self.rf,
  51. kernel_size=(1, 3),
  52. stride=(1, 1),
  53. padding=(0, 1),
  54. )
  55. # base
  56. self.c_layer_7x7 = nn.Conv2d(
  57. self.channels // self.rf,
  58. self.channels // self.rf,
  59. kernel_size=(7, 7),
  60. stride=(1, 1),
  61. padding=(3, 3),
  62. )
  63. self.c_layer_5x5 = nn.Conv2d(
  64. self.channels // self.rf,
  65. self.channels // self.rf,
  66. kernel_size=(5, 5),
  67. stride=(1, 1),
  68. padding=(2, 2),
  69. )
  70. self.c_layer_3x3 = nn.Conv2d(
  71. self.channels // self.rf,
  72. self.channels // self.rf,
  73. kernel_size=(3, 3),
  74. stride=(1, 1),
  75. padding=(1, 1),
  76. )
  77. self.bn = nn.BatchNorm2d(self.channels)
  78. self.relu = nn.ReLU()
  79. def forward(self, x):
  80. x_new = self.conv1x1_reduce_channel(x)
  81. x_7_c = self.c_layer_7x7(x_new)
  82. x_7_v = self.v_layer_7x1(x_new)
  83. x_7_q = self.q_layer_1x7(x_new)
  84. x_7 = x_7_c + x_7_v + x_7_q
  85. x_5_c = self.c_layer_5x5(x_7)
  86. x_5_v = self.v_layer_5x1(x_7)
  87. x_5_q = self.q_layer_1x5(x_7)
  88. x_5 = x_5_c + x_5_v + x_5_q
  89. x_3_c = self.c_layer_3x3(x_5)
  90. x_3_v = self.v_layer_3x1(x_5)
  91. x_3_q = self.q_layer_1x3(x_5)
  92. x_3 = x_3_c + x_3_v + x_3_q
  93. x_relation = self.conv1x1_return_channel(x_3)
  94. x_relation = self.bn(x_relation)
  95. x_relation = self.relu(x_relation)
  96. return x + x_relation
  97. def build_intraclblock_list(num_block):
  98. IntraCLBlock_list = nn.ModuleList()
  99. for i in range(num_block):
  100. IntraCLBlock_list.append(IntraCLBlock())
  101. return IntraCLBlock_list