__init__.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from . import cls_transforms
  15. from . import det_transforms
  16. from . import seg_transforms
  17. from . import visualize
  18. visualize = visualize.visualize
  19. def build_transforms(model_type, transforms_info, to_rgb=True):
  20. if model_type == "classifier":
  21. from . import cls_transforms as T
  22. elif model_type == "detector":
  23. from . import det_transforms as T
  24. elif model_type == "segmenter":
  25. from . import seg_transforms as T
  26. transforms = list()
  27. for op_info in transforms_info:
  28. op_name = list(op_info.keys())[0]
  29. op_attr = op_info[op_name]
  30. if not hasattr(T, op_name):
  31. raise Exception(
  32. "There's no operator named '{}' in transforms of {}".format(
  33. op_name, model_type))
  34. transforms.append(getattr(T, op_name)(**op_attr))
  35. eval_transforms = T.Compose(transforms)
  36. eval_transforms.to_rgb = to_rgb
  37. return eval_transforms
  38. def build_transforms_v1(model_type, transforms_info, batch_transforms_info):
  39. """ 老版本模型加载,仅支持PaddleX前端导出的模型
  40. """
  41. logging.debug("Use build_transforms_v1 to reconstruct transforms")
  42. if model_type == "classifier":
  43. from . import cls_transforms as T
  44. elif model_type == "detector":
  45. from . import det_transforms as T
  46. elif model_type == "segmenter":
  47. from . import seg_transforms as T
  48. transforms = list()
  49. for op_info in transforms_info:
  50. op_name = op_info[0]
  51. op_attr = op_info[1]
  52. if op_name == 'DecodeImage':
  53. continue
  54. if op_name == 'Permute':
  55. continue
  56. if op_name == 'ResizeByShort':
  57. op_attr_new = dict()
  58. if 'short_size' in op_attr:
  59. op_attr_new['short_size'] = op_attr['short_size']
  60. else:
  61. op_attr_new['short_size'] = op_attr['target_size']
  62. op_attr_new['max_size'] = op_attr.get('max_size', -1)
  63. op_attr = op_attr_new
  64. if op_name.startswith('Arrange'):
  65. continue
  66. if not hasattr(T, op_name):
  67. raise Exception(
  68. "There's no operator named '{}' in transforms of {}".format(
  69. op_name, model_type))
  70. transforms.append(getattr(T, op_name)(**op_attr))
  71. if model_type == "detector" and len(batch_transforms_info) > 0:
  72. op_name = batch_transforms_info[0][0]
  73. op_attr = batch_transforms_info[0][1]
  74. assert op_name == "PaddingMiniBatch", "Only PaddingMiniBatch transform is supported for batch transform"
  75. padding = T.Padding(coarsest_stride=op_attr['coarsest_stride'])
  76. transforms.append(padding)
  77. eval_transforms = T.Compose(transforms)
  78. return eval_transforms
  79. def arrange_transforms(model_type, class_name, transforms, mode='train'):
  80. # 给transforms添加arrange操作
  81. if model_type == 'classifier':
  82. arrange_transform = cls_transforms.ArrangeClassifier
  83. elif model_type == 'segmenter':
  84. arrange_transform = seg_transforms.ArrangeSegmenter
  85. elif model_type == 'detector':
  86. if class_name == "PPYOLO":
  87. arrange_name = 'ArrangeYOLOv3'
  88. else:
  89. arrange_name = 'Arrange{}'.format(class_name)
  90. arrange_transform = getattr(det_transforms, arrange_name)
  91. else:
  92. raise Exception("Unrecognized model type: {}".format(self.model_type))
  93. if type(transforms.transforms[-1]).__name__.startswith('Arrange'):
  94. transforms.transforms[-1] = arrange_transform(mode=mode)
  95. else:
  96. transforms.transforms.append(arrange_transform(mode=mode))