|
|
@@ -223,8 +223,11 @@ class BaseClassifier(BaseAPI):
|
|
|
tuple (metrics, eval_details): 当return_details为True时,增加返回dict,
|
|
|
包含关键字:'true_labels'、'pred_scores',分别代表真实类别id、每个类别的预测得分。
|
|
|
"""
|
|
|
- self.arrange_transforms(
|
|
|
- transforms=eval_dataset.transforms, mode='eval')
|
|
|
+ arrange_transforms(
|
|
|
+ model_type=self.model_type,
|
|
|
+ class_name=self.__class__.__name__,
|
|
|
+ transforms=eval_dataset.transforms,
|
|
|
+ mode='eval')
|
|
|
data_generator = eval_dataset.generator(
|
|
|
batch_size=batch_size, drop_last=False)
|
|
|
k = min(5, self.num_classes)
|