base_result.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import random
  16. import time
  17. from pathlib import Path
  18. import numpy as np
  19. from ....utils import logging
  20. from .mixin import JsonMixin, StrMixin
  21. class BaseResult(dict, JsonMixin, StrMixin):
  22. """Base class for result objects that can save themselves.
  23. This class inherits from dict and provides properties and methods for handling result.
  24. """
  25. def __init__(self, data: dict) -> None:
  26. """Initializes the BaseResult with the given data.
  27. Args:
  28. data (dict): The initial data.
  29. """
  30. super().__init__(data)
  31. self._save_funcs = []
  32. StrMixin.__init__(self)
  33. JsonMixin.__init__(self)
  34. np.set_printoptions(threshold=1, edgeitems=1)
  35. self._rand_fn = None
  36. def save_all(self, save_path: str) -> None:
  37. """Calls all registered save methods with the given save path.
  38. Args:
  39. save_path (str): The path to save the result to.
  40. """
  41. for func in self._save_funcs:
  42. signature = inspect.signature(func)
  43. if "save_path" in signature.parameters:
  44. func(save_path=save_path)
  45. else:
  46. func()
  47. def _get_input_fn(self):
  48. if self.get("input_path", None) is None:
  49. if self._rand_fn:
  50. return self._rand_fn
  51. timestamp = int(time.time())
  52. random_number = random.randint(1000, 9999)
  53. fp = f"{timestamp}_{random_number}"
  54. logging.warning(
  55. f"There is not input file name as reference for name of saved result file. So the saved result file would be named with timestamp and random number: `{fp}`."
  56. )
  57. self._rand_fn = Path(fp).name
  58. return self._rand_fn
  59. fp = self["input_path"]
  60. return Path(fp).name