meta_arch.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import paddle
  5. import paddle.nn as nn
  6. from paddlex.ppdet.core.workspace import register
  7. __all__ = ['BaseArch']
  8. @register
  9. class BaseArch(nn.Layer):
  10. def __init__(self, data_format='NCHW'):
  11. super(BaseArch, self).__init__()
  12. self.data_format = data_format
  13. self.inputs = {}
  14. self.fuse_norm = False
  15. def load_meanstd(self, cfg_transform):
  16. self.scale = 1.
  17. self.mean = paddle.to_tensor([0.485, 0.456, 0.406]).reshape(
  18. (1, 3, 1, 1))
  19. self.std = paddle.to_tensor([0.229, 0.224, 0.225]).reshape(
  20. (1, 3, 1, 1))
  21. for item in cfg_transform:
  22. if 'NormalizeImage' in item:
  23. self.mean = paddle.to_tensor(item['NormalizeImage'][
  24. 'mean']).reshape((1, 3, 1, 1))
  25. self.std = paddle.to_tensor(item['NormalizeImage'][
  26. 'std']).reshape((1, 3, 1, 1))
  27. if item['NormalizeImage']['is_scale']:
  28. self.scale = 1. / 255.
  29. break
  30. if self.data_format == 'NHWC':
  31. self.mean = self.mean.reshape(1, 1, 1, 3)
  32. self.std = self.std.reshape(1, 1, 1, 3)
  33. def forward(self, inputs):
  34. if self.data_format == 'NHWC':
  35. image = inputs['image']
  36. inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
  37. if self.fuse_norm:
  38. image = inputs['image']
  39. self.inputs['image'] = (image * self.scale - self.mean) / self.std
  40. self.inputs['im_shape'] = inputs['im_shape']
  41. self.inputs['scale_factor'] = inputs['scale_factor']
  42. else:
  43. self.inputs = inputs
  44. self.model_arch()
  45. if self.training:
  46. out = self.get_loss()
  47. else:
  48. out = self.get_pred()
  49. return out
  50. def build_inputs(self, data, input_def):
  51. inputs = {}
  52. for i, k in enumerate(input_def):
  53. inputs[k] = data[i]
  54. return inputs
  55. def model_arch(self, ):
  56. pass
  57. def get_loss(self, ):
  58. raise NotImplementedError("Should implement get_loss method!")
  59. def get_pred(self, ):
  60. raise NotImplementedError("Should implement get_pred method!")