| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from . import cls_transforms
- from . import det_transforms
- from . import seg_transforms
- from . import visualize
- visualize = visualize.visualize
- def build_transforms(model_type, transforms_info, to_rgb=True):
- if model_type == "classifier":
- from . import cls_transforms as T
- elif model_type == "detector":
- from . import det_transforms as T
- elif model_type == "segmenter":
- from . import seg_transforms as T
- transforms = list()
- for op_info in transforms_info:
- op_name = list(op_info.keys())[0]
- op_attr = op_info[op_name]
- if not hasattr(T, op_name):
- raise Exception(
- "There's no operator named '{}' in transforms of {}".format(
- op_name, model_type))
- transforms.append(getattr(T, op_name)(**op_attr))
- eval_transforms = T.Compose(transforms)
- eval_transforms.to_rgb = to_rgb
- return eval_transforms
- def build_transforms_v1(model_type, transforms_info, batch_transforms_info):
- """ 老版本模型加载,仅支持PaddleX前端导出的模型
- """
- logging.debug("Use build_transforms_v1 to reconstruct transforms")
- if model_type == "classifier":
- from . import cls_transforms as T
- elif model_type == "detector":
- from . import det_transforms as T
- elif model_type == "segmenter":
- from . import seg_transforms as T
- transforms = list()
- for op_info in transforms_info:
- op_name = op_info[0]
- op_attr = op_info[1]
- if op_name == 'DecodeImage':
- continue
- if op_name == 'Permute':
- continue
- if op_name == 'ResizeByShort':
- op_attr_new = dict()
- if 'short_size' in op_attr:
- op_attr_new['short_size'] = op_attr['short_size']
- else:
- op_attr_new['short_size'] = op_attr['target_size']
- op_attr_new['max_size'] = op_attr.get('max_size', -1)
- op_attr = op_attr_new
- if op_name.startswith('Arrange'):
- continue
- if not hasattr(T, op_name):
- raise Exception(
- "There's no operator named '{}' in transforms of {}".format(
- op_name, model_type))
- transforms.append(getattr(T, op_name)(**op_attr))
- if model_type == "detector" and len(batch_transforms_info) > 0:
- op_name = batch_transforms_info[0][0]
- op_attr = batch_transforms_info[0][1]
- assert op_name == "PaddingMiniBatch", "Only PaddingMiniBatch transform is supported for batch transform"
- padding = T.Padding(coarsest_stride=op_attr['coarsest_stride'])
- transforms.append(padding)
- eval_transforms = T.Compose(transforms)
- return eval_transforms
- def arrange_transforms(model_type, class_name, transforms, mode='train'):
- # 给transforms添加arrange操作
- if model_type == 'classifier':
- arrange_transform = cls_transforms.ArrangeClassifier
- elif model_type == 'segmenter':
- arrange_transform = seg_transforms.ArrangeSegmenter
- elif model_type == 'detector':
- if class_name == "PPYOLO":
- arrange_name = 'ArrangeYOLOv3'
- else:
- arrange_name = 'Arrange{}'.format(class_name)
- arrange_transform = getattr(det_transforms, arrange_name)
- else:
- raise Exception("Unrecognized model type: {}".format(self.model_type))
- if type(transforms.transforms[-1]).__name__.startswith('Arrange'):
- transforms.transforms[-1] = arrange_transform(mode=mode)
- else:
- transforms.transforms.append(arrange_transform(mode=mode))
|