Эх сурвалжийг харах

delete self.arrange_transforms

FlyingQianMM 5 жил өмнө
parent
commit
e5b030126e

+ 5 - 2
paddlex/cv/models/classifier.py

@@ -223,8 +223,11 @@ class BaseClassifier(BaseAPI):
           tuple (metrics, eval_details): 当return_details为True时,增加返回dict,
               包含关键字:'true_labels'、'pred_scores',分别代表真实类别id、每个类别的预测得分。
         """
-        self.arrange_transforms(
-            transforms=eval_dataset.transforms, mode='eval')
+        arrange_transforms(
+            model_type=self.model_type,
+            class_name=self.__class__.__name__,
+            transforms=eval_dataset.transforms,
+            mode='eval')
         data_generator = eval_dataset.generator(
             batch_size=batch_size, drop_last=False)
         k = min(5, self.num_classes)

+ 5 - 2
paddlex/cv/models/faster_rcnn.py

@@ -312,8 +312,11 @@ class FasterRCNN(BaseAPI):
                 eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、
                 预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
         """
-        self.arrange_transforms(
-            transforms=eval_dataset.transforms, mode='eval')
+        arrange_transforms(
+            model_type=self.model_type,
+            class_name=self.__class__.__name__,
+            transforms=eval_dataset.transforms,
+            mode='eval')
         if metric is None:
             if hasattr(self, 'metric') and self.metric is not None:
                 metric = self.metric

+ 5 - 2
paddlex/cv/models/mask_rcnn.py

@@ -254,8 +254,11 @@ class MaskRCNN(FasterRCNN):
                 预测框坐标、预测框得分;'mask',对应元素预测区域结果列表,每个预测结果由图像id、
                 预测区域类别id、预测区域坐标、预测区域得分;’gt‘:真实标注框和标注区域相关信息。
         """
-        self.arrange_transforms(
-            transforms=eval_dataset.transforms, mode='eval')
+        arrange_transforms(
+            model_type=self.model_type,
+            class_name=self.__class__.__name__,
+            transforms=eval_dataset.transforms,
+            mode='eval')
         if metric is None:
             if hasattr(self, 'metric') and self.metric is not None:
                 metric = self.metric

+ 5 - 2
paddlex/cv/models/yolo_v3.py

@@ -289,8 +289,11 @@ class YOLOv3(BaseAPI):
                 eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、
                 预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
         """
-        self.arrange_transforms(
-            transforms=eval_dataset.transforms, mode='eval')
+        arrange_transforms(
+            model_type=self.model_type,
+            class_name=self.__class__.__name__,
+            transforms=eval_dataset.transforms,
+            mode='eval')
         if metric is None:
             if hasattr(self, 'metric') and self.metric is not None:
                 metric = self.metric