transforms.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. from .keys import ShiTuRecKeys as K
  17. from ...base import BaseTransform
  18. from ....utils import logging
  19. __all__ = [
  20. "NormalizeFeatures",
  21. "PrintShiTuRecResult"
  22. ]
  23. class NormalizeFeatures(BaseTransform):
  24. """Normalize Features Transform"""
  25. def apply(self, data):
  26. """apply"""
  27. x = data[K.SHITU_REC_PRED]
  28. feas_norm = np.sqrt(np.sum(np.square(x), axis=0, keepdims=True))
  29. x = np.divide(x, feas_norm)
  30. data[K.SHITU_REC_RESULT] = x
  31. return data
  32. @classmethod
  33. def get_input_keys(cls):
  34. """get input keys"""
  35. return [K.IM_PATH, K.SHITU_REC_PRED]
  36. @classmethod
  37. def get_output_keys(cls):
  38. """get output keys"""
  39. return [K.SHITU_REC_RESULT]
  40. class PrintShiTuRecResult(BaseTransform):
  41. """Print Result Transform"""
  42. def apply(self, data):
  43. """apply"""
  44. logging.info("The prediction result is:")
  45. logging.info(data[K.SHITU_REC_RESULT])
  46. return data
  47. @classmethod
  48. def get_input_keys(cls):
  49. """get input keys"""
  50. return [K.SHITU_REC_RESULT]
  51. @classmethod
  52. def get_output_keys(cls):
  53. """get output keys"""
  54. return []