meta_arch.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import numpy as np
  5. import paddle
  6. import paddle.nn as nn
  7. from paddlex.ppdet.core.workspace import register
  8. __all__ = ['BaseArch']
  9. @register
  10. class BaseArch(nn.Layer):
  11. def __init__(self, data_format='NCHW'):
  12. super(BaseArch, self).__init__()
  13. self.data_format = data_format
  14. def forward(self, inputs):
  15. if self.data_format == 'NHWC':
  16. image = inputs['image']
  17. inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
  18. self.inputs = inputs
  19. self.model_arch()
  20. if self.training:
  21. out = self.get_loss()
  22. else:
  23. out = self.get_pred()
  24. return out
  25. def build_inputs(self, data, input_def):
  26. inputs = {}
  27. for i, k in enumerate(input_def):
  28. inputs[k] = data[i]
  29. return inputs
  30. def model_arch(self, ):
  31. pass
  32. def get_loss(self, ):
  33. raise NotImplementedError("Should implement get_loss method!")
  34. def get_pred(self, ):
  35. raise NotImplementedError("Should implement get_pred method!")