Browse Source

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleX into develop

jiangjiajun 5 years ago
parent
commit
cfdf5e234a
1 changed files with 2 additions and 2 deletions
  1. 2 2
      paddlex/cv/nets/hrnet.py

+ 2 - 2
paddlex/cv/nets/hrnet.py

@@ -71,7 +71,7 @@ class HRNet(object):
         self.end_points = []
         return
 
-    def net(self, input, class_dim=1000):
+    def net(self, input):
         width = self.width
         channels_2, channels_3, channels_4 = self.channels[width]
         num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
@@ -125,7 +125,7 @@ class HRNet(object):
             stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
             out = fluid.layers.fc(
                 input=pool,
-                size=class_dim,
+                size=self.num_classes,
                 param_attr=ParamAttr(
                     name='fc_weights',
                     initializer=fluid.initializer.Uniform(-stdv, stdv)),