瀏覽代碼

Merge pull request #824 from will-jl944/develop_jf

Refine multi-card evaluation logistics
FlyingQianMM 4 年之前
父節點
當前提交
bcc2173797

+ 2 - 1
dygraph/paddlex/cv/models/base.py

@@ -333,12 +333,13 @@ class BaseModel:
             eval_epoch_tic = time.time()
             if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
                 if eval_dataset is not None and eval_dataset.num_samples > 0:
-                    self.eval_metrics, self.eval_details = self.evaluate(
+                    eval_result = self.evaluate(
                         eval_dataset,
                         batch_size=eval_batch_size,
                         return_details=True)
                     # 保存最优模型
                     if local_rank == 0:
+                        self.eval_metrics, self.eval_details = eval_result
                         logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
                             i + 1, dict2str(self.eval_metrics)))
                         best_accuracy_key = list(self.eval_metrics.keys())[0]

+ 6 - 2
dygraph/paddlex/cv/models/classifier.py

@@ -90,6 +90,7 @@ class BaseClassifier(BaseModel):
             acc1 = paddle.metric.accuracy(softmax_out, label=labels)
             k = min(5, self.num_classes)
             acck = paddle.metric.accuracy(softmax_out, label=labels, k=k)
+            prediction = softmax_out
             # multi cards eval
             if paddle.distributed.get_world_size() > 1:
                 acc1 = paddle.distributed.all_reduce(
@@ -98,9 +99,12 @@ class BaseClassifier(BaseModel):
                 acck = paddle.distributed.all_reduce(
                     acck, op=paddle.distributed.ReduceOp.
                     SUM) / paddle.distributed.get_world_size()
+                prediction = []
+                paddle.distributed.all_gather(prediction, softmax_out)
+                prediction = paddle.concat(prediction, axis=0)
 
             outputs = OrderedDict([('acc1', acc1), ('acc{}'.format(k), acck),
-                                   ('prediction', softmax_out)])
+                                   ('prediction', prediction)])
 
         else:
             # mode == 'train'
@@ -347,7 +351,7 @@ class BaseClassifier(BaseModel):
             for step, data in enumerate(self.eval_data_loader()):
                 outputs = self.run(self.net, data, mode='eval')
                 if return_details:
-                    eval_details.append(outputs['prediction'].numpy())
+                    eval_details.append(outputs['prediction'].tolist())
                 outputs.pop('prediction')
                 eval_metrics.update(outputs)
         if return_details:

+ 0 - 3
dygraph/paddlex/cv/models/detector.py

@@ -419,9 +419,6 @@ class BaseDetector(BaseModel):
             if return_details:
                 return scores, self.eval_details
             return scores
-        if return_details:
-            return None, None
-        return None
 
     def predict(self, img_file, transforms=None):
         """