Przeglądaj źródła

do not build infernet for detectors

will-jl944 4 lat temu
rodzic
commit
91478e87a2

+ 2 - 1
paddlex/cv/models/base.py

@@ -584,7 +584,8 @@ class BaseModel:
         return pipeline_info
         return pipeline_info
 
 
     def _build_inference_net(self):
     def _build_inference_net(self):
-        infer_net = InferNet(self.net, self.model_type)
+        infer_net = self.net if self.model_type == 'detector' else InferNet(
+            self.net, self.model_type)
         infer_net.eval()
         infer_net.eval()
         return infer_net
         return infer_net
 
 

+ 6 - 6
paddlex/cv/models/utils/infer_nets.py

@@ -23,13 +23,13 @@ class PostProcessor(paddle.nn.Layer):
     def forward(self, net_outputs):
     def forward(self, net_outputs):
         if self.model_type == 'classifier':
         if self.model_type == 'classifier':
             outputs = paddle.nn.functional.softmax(net_outputs, axis=1)
             outputs = paddle.nn.functional.softmax(net_outputs, axis=1)
-        elif self.model_type == 'segmenter':
+        else:
             # score_map, label_map
             # score_map, label_map
-            outputs = paddle.transpose(paddle.nn.functional.softmax(net_outputs, axis=1), perm=[0, 2, 3, 1]), \
-                      paddle.transpose(paddle.argmax(net_outputs, axis=1, keepdim=True, dtype='int32'),
+            logit = net_outputs[0]
+            outputs = paddle.transpose(paddle.nn.functional.softmax(logit, axis=1), perm=[0, 2, 3, 1]), \
+                      paddle.transpose(paddle.argmax(logit, axis=1, keepdim=True, dtype='int32'),
                                        perm=[0, 2, 3, 1])
                                        perm=[0, 2, 3, 1])
-        else:
-            outputs = net_outputs
+
         return outputs
         return outputs
 
 
 
 
@@ -40,7 +40,7 @@ class InferNet(paddle.nn.Layer):
         self.postprocessor = PostProcessor(model_type)
         self.postprocessor = PostProcessor(model_type)
 
 
     def forward(self, x):
     def forward(self, x):
-        net_outputs = self.net(x)[0]
+        net_outputs = self.net(x)
         outputs = self.postprocessor(net_outputs)
         outputs = self.postprocessor(net_outputs)
 
 
         return outputs
         return outputs