瀏覽代碼

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleX into develop

jiangjiajun 5 年之前
父節點
當前提交
cfdf5e234a
共有 1 個文件被更改,包括 2 次插入2 次删除
  1. 2 2
      paddlex/cv/nets/hrnet.py

+ 2 - 2
paddlex/cv/nets/hrnet.py

@@ -71,7 +71,7 @@ class HRNet(object):
         self.end_points = []
         return
 
-    def net(self, input, class_dim=1000):
+    def net(self, input):
         width = self.width
         channels_2, channels_3, channels_4 = self.channels[width]
         num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
@@ -125,7 +125,7 @@ class HRNet(object):
             stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
             out = fluid.layers.fc(
                 input=pool,
-                size=class_dim,
+                size=self.num_classes,
                 param_attr=ParamAttr(
                     name='fc_weights',
                     initializer=fluid.initializer.Uniform(-stdv, stdv)),