|
|
@@ -256,7 +256,7 @@ class LearnableRepLayer(nn.Module):
|
|
|
input_dim = self.in_channels // self.groups
|
|
|
kernel_value = torch.zeros(
|
|
|
(self.in_channels, input_dim, self.kernel_size, self.kernel_size),
|
|
|
- dtype=branch.weight.dtype, device=branch.weight.device,
|
|
|
+ dtype=branch.weight.dtype, device=branch.weight.device,
|
|
|
)
|
|
|
for i in range(self.in_channels):
|
|
|
kernel_value[
|