|
|
@@ -76,14 +76,8 @@ class BaseClassifier(BaseModel):
|
|
|
def _fix_transforms_shape(self, image_shape):
|
|
|
if hasattr(self, 'test_transforms'):
|
|
|
if self.test_transforms is not None:
|
|
|
- normalize_op_idx = len(self.test_transforms.transforms)
|
|
|
- for idx, op in enumerate(self.test_transforms.transforms):
|
|
|
- name = op.__class__.__name__
|
|
|
- if name == 'Normalize':
|
|
|
- normalize_op_idx = idx
|
|
|
-
|
|
|
- self.test_transforms.transforms.insert(
|
|
|
- normalize_op_idx, Resize(target_size=image_shape))
|
|
|
+ self.test_transforms.transforms.append(
|
|
|
+ Resize(target_size=image_shape))
|
|
|
|
|
|
def _get_test_inputs(self, image_shape):
|
|
|
if image_shape is not None:
|