mixin.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from abc import abstractmethod
  15. import json
  16. from pathlib import Path
  17. import numpy as np
  18. from PIL import Image
  19. from ....utils import logging
  20. from ...utils.io import (
  21. JsonWriter,
  22. ImageReader,
  23. ImageWriter,
  24. CSVWriter,
  25. HtmlWriter,
  26. XlsxWriter,
  27. )
  28. def _save_list_data(save_func, save_path, data, *args, **kwargs):
  29. save_path = Path(save_path)
  30. if data is None:
  31. return
  32. if isinstance(data, list):
  33. for idx, single in enumerate(data):
  34. save_func(
  35. (
  36. save_path.parent / f"{save_path.stem}_{idx}.{save_path.suffix}"
  37. ).as_posix(),
  38. single,
  39. *args,
  40. **kwargs,
  41. )
  42. save_func(save_path.as_posix(), data, *args, **kwargs)
  43. logging.info(f"The result has been saved in {save_path}.")
  44. class StrMixin:
  45. def __init__(self):
  46. self._show_func_register()(self.print)
  47. @property
  48. def str(self):
  49. return self._to_str()
  50. def _to_str(self):
  51. return str(self)
  52. def print(self, json_format=False, indent=4, ensure_ascii=False):
  53. str_ = self._to_str()
  54. if json_format:
  55. str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
  56. logging.info(str_)
  57. class JsonMixin:
  58. def __init__(self):
  59. self._json_writer = JsonWriter()
  60. self._show_func_register()(self.save_to_json)
  61. def _to_json(self):
  62. def _format_data(obj):
  63. if isinstance(obj, np.float32):
  64. return float(obj)
  65. if isinstance(obj, np.ndarray):
  66. return [_format_data(item) for item in obj.tolist()]
  67. elif isinstance(obj, dict):
  68. return type(obj)({k: _format_data(v) for k, v in obj.items()})
  69. elif isinstance(obj, (list, tuple)):
  70. return [_format_data(i) for i in obj]
  71. else:
  72. return obj
  73. return _format_data(self)
  74. @property
  75. def json(self):
  76. return self._to_json()
  77. def save_to_json(self, save_path, indent=4, ensure_ascii=False, *args, **kwargs):
  78. if not save_path.endswith(".json"):
  79. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.json"
  80. _save_list_data(
  81. self._json_writer.write,
  82. save_path,
  83. self.json,
  84. indent=indent,
  85. ensure_ascii=ensure_ascii,
  86. *args,
  87. **kwargs,
  88. )
  89. class ImgMixin:
  90. def __init__(self, backend="pillow", *args, **kwargs):
  91. self._img_writer = ImageWriter(backend=backend, *args, **kwargs)
  92. self._show_func_register()(self.save_to_img)
  93. @abstractmethod
  94. def _to_img(self):
  95. raise NotImplementedError
  96. @property
  97. def img(self):
  98. image = self._to_img()
  99. # The img must be a PIL.Image obj
  100. if isinstance(image, np.ndarray):
  101. return Image.fromarray(image)
  102. return image
  103. def save_to_img(self, save_path, *args, **kwargs):
  104. if not save_path.lower().endswith((".jpg", ".png")):
  105. fp = Path(self["input_path"])
  106. save_path = Path(save_path) / f"{fp.stem}.{fp.suffix}"
  107. _save_list_data(self._img_writer.write, save_path, self.img, *args, **kwargs)
  108. class CSVMixin:
  109. def __init__(self, backend="pandas", *args, **kwargs):
  110. self._csv_writer = CSVWriter(backend=backend, *args, **kwargs)
  111. self._show_func_register()(self.save_to_csv)
  112. @abstractmethod
  113. def _to_csv(self):
  114. raise NotImplementedError
  115. def save_to_csv(self, save_path, *args, **kwargs):
  116. if not save_path.endswith(".csv"):
  117. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv"
  118. _save_list_data(
  119. self._csv_writer.write, save_path, self._to_csv(), *args, **kwargs
  120. )
  121. class HtmlMixin:
  122. def __init__(self, *args, **kwargs):
  123. self._html_writer = HtmlWriter(*args, **kwargs)
  124. self._show_func_register()(self.save_to_html)
  125. @property
  126. def html(self):
  127. return self._to_html()
  128. def _to_html(self):
  129. return self["html"]
  130. def save_to_html(self, save_path, *args, **kwargs):
  131. if not save_path.endswith(".html"):
  132. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html"
  133. _save_list_data(self._html_writer.write, save_path, self.html, *args, **kwargs)
  134. class XlsxMixin:
  135. def __init__(self, *args, **kwargs):
  136. self._xlsx_writer = XlsxWriter(*args, **kwargs)
  137. self._show_func_register()(self.save_to_xlsx)
  138. def _to_xlsx(self):
  139. return self["html"]
  140. def save_to_xlsx(self, save_path, *args, **kwargs):
  141. if not save_path.endswith(".xlsx"):
  142. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.xlsx"
  143. _save_list_data(self._xlsx_writer.write, save_path, self.html, *args, **kwargs)