mixin.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from abc import abstractmethod
  15. import json
  16. from pathlib import Path
  17. import numpy as np
  18. from PIL import Image
  19. import pandas as pd
  20. from ....utils import logging
  21. from ...utils.io import (
  22. JsonWriter,
  23. ImageReader,
  24. ImageWriter,
  25. CSVWriter,
  26. HtmlWriter,
  27. XlsxWriter,
  28. )
  29. def _save_list_data(save_func, save_path, data, *args, **kwargs):
  30. save_path = Path(save_path)
  31. if data is None:
  32. return
  33. if isinstance(data, list):
  34. for idx, single in enumerate(data):
  35. save_func(
  36. (
  37. save_path.parent / f"{save_path.stem}_{idx}{save_path.suffix}"
  38. ).as_posix(),
  39. single,
  40. *args,
  41. **kwargs,
  42. )
  43. save_func(save_path.as_posix(), data, *args, **kwargs)
  44. logging.info(f"The result has been saved in {save_path}.")
  45. class StrMixin:
  46. @property
  47. def str(self):
  48. return self._to_str()
  49. def _to_str(self):
  50. return str(self)
  51. def print(self, json_format=False, indent=4, ensure_ascii=False):
  52. str_ = self._to_str()
  53. if json_format:
  54. str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
  55. logging.info(str_)
  56. class JsonMixin:
  57. def __init__(self):
  58. self._json_writer = JsonWriter()
  59. self._show_func_register()(self.save_to_json)
  60. def _to_json(self):
  61. def _format_data(obj):
  62. if isinstance(obj, np.float32):
  63. return float(obj)
  64. elif isinstance(obj, np.ndarray):
  65. return [_format_data(item) for item in obj.tolist()]
  66. elif isinstance(obj, pd.DataFrame):
  67. return obj.to_json(orient="records", force_ascii=False)
  68. elif isinstance(obj, Path):
  69. return obj.as_posix()
  70. elif isinstance(obj, dict):
  71. return type(obj)({k: _format_data(v) for k, v in obj.items()})
  72. elif isinstance(obj, (list, tuple)):
  73. return [_format_data(i) for i in obj]
  74. else:
  75. return obj
  76. return _format_data(self)
  77. @property
  78. def json(self):
  79. return self._to_json()
  80. def save_to_json(self, save_path, indent=4, ensure_ascii=False, *args, **kwargs):
  81. if not str(save_path).endswith(".json"):
  82. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.json"
  83. _save_list_data(
  84. self._json_writer.write,
  85. save_path,
  86. self.json,
  87. indent=indent,
  88. ensure_ascii=ensure_ascii,
  89. *args,
  90. **kwargs,
  91. )
  92. class ImgMixin:
  93. def __init__(self, backend="pillow", *args, **kwargs):
  94. self._img_writer = ImageWriter(backend=backend, *args, **kwargs)
  95. self._show_func_register()(self.save_to_img)
  96. @abstractmethod
  97. def _to_img(self):
  98. raise NotImplementedError
  99. @property
  100. def img(self):
  101. image = self._to_img()
  102. # The img must be a PIL.Image obj
  103. if isinstance(image, np.ndarray):
  104. return Image.fromarray(image)
  105. return image
  106. def save_to_img(self, save_path, *args, **kwargs):
  107. if not str(save_path).lower().endswith((".jpg", ".png")):
  108. fp = Path(self["input_path"])
  109. save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
  110. _save_list_data(self._img_writer.write, save_path, self.img, *args, **kwargs)
  111. class CSVMixin:
  112. def __init__(self, backend="pandas", *args, **kwargs):
  113. self._csv_writer = CSVWriter(backend=backend, *args, **kwargs)
  114. self._show_func_register()(self.save_to_csv)
  115. @abstractmethod
  116. def _to_csv(self):
  117. raise NotImplementedError
  118. def save_to_csv(self, save_path, *args, **kwargs):
  119. if not str(save_path).endswith(".csv"):
  120. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv"
  121. _save_list_data(
  122. self._csv_writer.write, save_path, self._to_csv(), *args, **kwargs
  123. )
  124. class HtmlMixin:
  125. def __init__(self, *args, **kwargs):
  126. self._html_writer = HtmlWriter(*args, **kwargs)
  127. self._show_func_register()(self.save_to_html)
  128. @property
  129. def html(self):
  130. return self._to_html()
  131. def _to_html(self):
  132. return self["html"]
  133. def save_to_html(self, save_path, *args, **kwargs):
  134. if not str(save_path).endswith(".html"):
  135. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html"
  136. _save_list_data(self._html_writer.write, save_path, self.html, *args, **kwargs)
  137. class XlsxMixin:
  138. def __init__(self, *args, **kwargs):
  139. self._xlsx_writer = XlsxWriter(*args, **kwargs)
  140. self._show_func_register()(self.save_to_xlsx)
  141. def _to_xlsx(self):
  142. return self["html"]
  143. def save_to_xlsx(self, save_path, *args, **kwargs):
  144. if not str(save_path).endswith(".xlsx"):
  145. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.xlsx"
  146. _save_list_data(self._xlsx_writer.write, save_path, self.html, *args, **kwargs)