meta_arch.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import paddle
  5. import paddle.nn as nn
  6. from paddlex.ppdet.core.workspace import register
  7. __all__ = ['BaseArch']
  8. @register
  9. class BaseArch(nn.Layer):
  10. def __init__(self, data_format='NCHW'):
  11. super(BaseArch, self).__init__()
  12. self.data_format = data_format
  13. def forward(self, inputs):
  14. if self.data_format == 'NHWC':
  15. image = inputs['image']
  16. inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
  17. self.inputs = inputs
  18. self.model_arch()
  19. if self.training:
  20. out = self.get_loss()
  21. else:
  22. out = self.get_pred()
  23. return out
  24. def build_inputs(self, data, input_def):
  25. inputs = {}
  26. for i, k in enumerate(input_def):
  27. inputs[k] = data[i]
  28. return inputs
  29. def model_arch(self, ):
  30. pass
  31. def get_loss(self, ):
  32. raise NotImplementedError("Should implement get_loss method!")
  33. def get_pred(self, ):
  34. raise NotImplementedError("Should implement get_pred method!")