소스 검색

add NotImplementedError for _fix_transforms_shape

FlyingQianMM 4 년 전
부모
커밋
4edf96ad1b
2개의 변경된 파일3개의 추가작업 그리고 9개의 파일을 삭제
  1. 2 8
      dygraph/paddlex/cv/models/classifier.py
  2. 1 1
      dygraph/paddlex/cv/models/detector.py

+ 2 - 8
dygraph/paddlex/cv/models/classifier.py

@@ -76,14 +76,8 @@ class BaseClassifier(BaseModel):
     def _fix_transforms_shape(self, image_shape):
         if hasattr(self, 'test_transforms'):
             if self.test_transforms is not None:
-                normalize_op_idx = len(self.test_transforms.transforms)
-                for idx, op in enumerate(self.test_transforms.transforms):
-                    name = op.__class__.__name__
-                    if name == 'Normalize':
-                        normalize_op_idx = idx
-
-                self.test_transforms.transforms.insert(
-                    normalize_op_idx, Resize(target_size=image_shape))
+                self.test_transforms.transforms.append(
+                    Resize(target_size=image_shape))
 
     def _get_test_inputs(self, image_shape):
         if image_shape is not None:

+ 1 - 1
dygraph/paddlex/cv/models/detector.py

@@ -58,7 +58,7 @@ class BaseDetector(BaseModel):
         return net
 
     def _fix_transforms_shape(self, image_shape):
-        pass
+        raise NotImplementedError("_fix_transforms_shape: not implemented!")
 
     def _get_test_inputs(self, image_shape):
         if image_shape is not None: