|
|
@@ -116,20 +116,7 @@ def fix_input_shape(info, fixed_input_shape=None):
|
|
|
resize = {'ResizeByShort': {}}
|
|
|
padding = {'Padding': {}}
|
|
|
if info['_Attributes']['model_type'] == 'classifier':
|
|
|
- crop_size = 0
|
|
|
- for transform in info['Transforms']:
|
|
|
- if 'CenterCrop' in transform:
|
|
|
- crop_size = transform['CenterCrop']['crop_size']
|
|
|
- break
|
|
|
- assert crop_size == fixed_input_shape[
|
|
|
- 0], "fixed_input_shape must == CenterCrop:crop_size:{}".format(
|
|
|
- crop_size)
|
|
|
- assert crop_size == fixed_input_shape[
|
|
|
- 1], "fixed_input_shape must == CenterCrop:crop_size:{}".format(
|
|
|
- crop_size)
|
|
|
- if crop_size == 0:
|
|
|
- logging.warning(
|
|
|
- "fixed_input_shape must == input shape when trainning")
|
|
|
+ pass
|
|
|
else:
|
|
|
resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
|
|
|
resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
|