|
|
@@ -51,15 +51,38 @@ class HRNet(object):
|
|
|
|
|
|
self.width = width
|
|
|
self.has_se = has_se
|
|
|
+ self.num_modules = {
|
|
|
+ '18_small_v1': [1, 1, 1, 1],
|
|
|
+ '18': [1, 1, 4, 3],
|
|
|
+ '30': [1, 1, 4, 3],
|
|
|
+ '32': [1, 1, 4, 3],
|
|
|
+ '40': [1, 1, 4, 3],
|
|
|
+ '44': [1, 1, 4, 3],
|
|
|
+ '48': [1, 1, 4, 3],
|
|
|
+ '60': [1, 1, 4, 3],
|
|
|
+ '64': [1, 1, 4, 3]
|
|
|
+ }
|
|
|
+ self.num_blocks = {
|
|
|
+ '18_small_v1': [[1], [2, 2], [2, 2, 2], [2, 2, 2, 2]],
|
|
|
+ '18': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
|
|
|
+ '30': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
|
|
|
+ '32': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
|
|
|
+ '40': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
|
|
|
+ '44': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
|
|
|
+ '48': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
|
|
|
+ '60': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
|
|
|
+ '64': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]]
|
|
|
+ }
|
|
|
self.channels = {
|
|
|
- 18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
|
|
|
- 30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
|
|
|
- 32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
|
|
|
- 40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
|
|
|
- 44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
|
|
|
- 48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
|
|
|
- 60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
|
|
|
- 64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]],
|
|
|
+ '18_small_v1': [[32], [16, 32], [16, 32, 64], [16, 32, 64, 128]],
|
|
|
+ '18': [[64], [18, 36], [18, 36, 72], [18, 36, 72, 144]],
|
|
|
+ '30': [[64], [30, 60], [30, 60, 120], [30, 60, 120, 240]],
|
|
|
+ '32': [[64], [32, 64], [32, 64, 128], [32, 64, 128, 256]],
|
|
|
+ '40': [[64], [40, 80], [40, 80, 160], [40, 80, 160, 320]],
|
|
|
+ '44': [[64], [44, 88], [44, 88, 176], [44, 88, 176, 352]],
|
|
|
+ '48': [[64], [48, 96], [48, 96, 192], [48, 96, 192, 384]],
|
|
|
+ '60': [[64], [60, 120], [60, 120, 240], [60, 120, 240, 480]],
|
|
|
+ '64': [[64], [64, 128], [64, 128, 256], [64, 128, 256, 512]],
|
|
|
}
|
|
|
|
|
|
self.freeze_at = freeze_at
|
|
|
@@ -73,31 +96,38 @@ class HRNet(object):
|
|
|
|
|
|
def net(self, input):
|
|
|
width = self.width
|
|
|
- channels_2, channels_3, channels_4 = self.channels[width]
|
|
|
- num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
|
|
|
+ channels_1, channels_2, channels_3, channels_4 = self.channels[str(
|
|
|
+ width)]
|
|
|
+ num_modules_1, num_modules_2, num_modules_3, num_modules_4 = self.num_modules[
|
|
|
+ str(width)]
|
|
|
+ num_blocks_1, num_blocks_2, num_blocks_3, num_blocks_4 = self.num_blocks[
|
|
|
+ str(width)]
|
|
|
|
|
|
x = self.conv_bn_layer(
|
|
|
input=input,
|
|
|
filter_size=3,
|
|
|
- num_filters=64,
|
|
|
+ num_filters=channels_1[0],
|
|
|
stride=2,
|
|
|
if_act=True,
|
|
|
name='layer1_1')
|
|
|
x = self.conv_bn_layer(
|
|
|
input=x,
|
|
|
filter_size=3,
|
|
|
- num_filters=64,
|
|
|
+ num_filters=channels_1[0],
|
|
|
stride=2,
|
|
|
if_act=True,
|
|
|
name='layer1_2')
|
|
|
|
|
|
- la1 = self.layer1(x, name='layer2')
|
|
|
+ la1 = self.layer1(x, num_blocks_1, channels_1, name='layer2')
|
|
|
tr1 = self.transition_layer([la1], [256], channels_2, name='tr1')
|
|
|
- st2 = self.stage(tr1, num_modules_2, channels_2, name='st2')
|
|
|
+ st2 = self.stage(
|
|
|
+ tr1, num_modules_2, num_blocks_2, channels_2, name='st2')
|
|
|
tr2 = self.transition_layer(st2, channels_2, channels_3, name='tr2')
|
|
|
- st3 = self.stage(tr2, num_modules_3, channels_3, name='st3')
|
|
|
+ st3 = self.stage(
|
|
|
+ tr2, num_modules_3, num_blocks_3, channels_3, name='st3')
|
|
|
tr3 = self.transition_layer(st3, channels_3, channels_4, name='tr3')
|
|
|
- st4 = self.stage(tr3, num_modules_4, channels_4, name='st4')
|
|
|
+ st4 = self.stage(
|
|
|
+ tr3, num_modules_4, num_blocks_4, channels_4, name='st4')
|
|
|
|
|
|
# classification
|
|
|
if self.num_classes:
|
|
|
@@ -139,12 +169,12 @@ class HRNet(object):
|
|
|
self.end_points = st4
|
|
|
return st4[-1]
|
|
|
|
|
|
- def layer1(self, input, name=None):
|
|
|
+ def layer1(self, input, num_blocks, channels, name=None):
|
|
|
conv = input
|
|
|
- for i in range(4):
|
|
|
+ for i in range(num_blocks[0]):
|
|
|
conv = self.bottleneck_block(
|
|
|
conv,
|
|
|
- num_filters=64,
|
|
|
+ num_filters=channels[0],
|
|
|
downsample=True if i == 0 else False,
|
|
|
name=name + '_' + str(i + 1))
|
|
|
return conv
|
|
|
@@ -178,7 +208,7 @@ class HRNet(object):
|
|
|
out = []
|
|
|
for i in range(len(channels)):
|
|
|
residual = x[i]
|
|
|
- for j in range(block_num):
|
|
|
+ for j in range(block_num[i]):
|
|
|
residual = self.basic_block(
|
|
|
residual,
|
|
|
channels[i],
|
|
|
@@ -240,10 +270,11 @@ class HRNet(object):
|
|
|
|
|
|
def high_resolution_module(self,
|
|
|
x,
|
|
|
+ num_blocks,
|
|
|
channels,
|
|
|
multi_scale_output=True,
|
|
|
name=None):
|
|
|
- residual = self.branches(x, 4, channels, name=name)
|
|
|
+ residual = self.branches(x, num_blocks, channels, name=name)
|
|
|
out = self.fuse_layers(
|
|
|
residual,
|
|
|
channels,
|
|
|
@@ -254,6 +285,7 @@ class HRNet(object):
|
|
|
def stage(self,
|
|
|
x,
|
|
|
num_modules,
|
|
|
+ num_blocks,
|
|
|
channels,
|
|
|
multi_scale_output=True,
|
|
|
name=None):
|
|
|
@@ -262,12 +294,13 @@ class HRNet(object):
|
|
|
if i == num_modules - 1 and multi_scale_output == False:
|
|
|
out = self.high_resolution_module(
|
|
|
out,
|
|
|
+ num_blocks,
|
|
|
channels,
|
|
|
multi_scale_output=False,
|
|
|
name=name + '_' + str(i + 1))
|
|
|
else:
|
|
|
out = self.high_resolution_module(
|
|
|
- out, channels, name=name + '_' + str(i + 1))
|
|
|
+ out, num_blocks, channels, name=name + '_' + str(i + 1))
|
|
|
|
|
|
return out
|
|
|
|