mixin.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from abc import abstractmethod
  15. import json
  16. from pathlib import Path
  17. import numpy as np
  18. from PIL import Image
  19. import pandas as pd
  20. from .....utils import logging
  21. from ....utils.io import (
  22. JsonWriter,
  23. ImageReader,
  24. ImageWriter,
  25. CSVWriter,
  26. HtmlWriter,
  27. XlsxWriter,
  28. TextWriter,
  29. )
  30. def _save_list_data(save_func, save_path, data, *args, **kwargs):
  31. save_path = Path(save_path)
  32. if data is None:
  33. return
  34. if isinstance(data, list):
  35. for idx, single in enumerate(data):
  36. save_func(
  37. (
  38. save_path.parent / f"{save_path.stem}_{idx}{save_path.suffix}"
  39. ).as_posix(),
  40. single,
  41. *args,
  42. **kwargs,
  43. )
  44. save_func(save_path.as_posix(), data, *args, **kwargs)
  45. logging.info(f"The result has been saved in {save_path}.")
  46. class StrMixin:
  47. @property
  48. def str(self):
  49. return self._to_str()
  50. def _to_str(self, data, json_format=False, indent=4, ensure_ascii=False):
  51. if json_format:
  52. return json.dumps(data.json, indent=indent, ensure_ascii=ensure_ascii)
  53. else:
  54. return str(data)
  55. def print(self, json_format=False, indent=4, ensure_ascii=False):
  56. str_ = self._to_str(
  57. self, json_format=json_format, indent=indent, ensure_ascii=ensure_ascii
  58. )
  59. logging.info(str_)
  60. class JsonMixin:
  61. def __init__(self):
  62. self._json_writer = JsonWriter()
  63. self._show_funcs.append(self.save_to_json)
  64. def _to_json(self):
  65. def _format_data(obj):
  66. if isinstance(obj, np.float32):
  67. return float(obj)
  68. elif isinstance(obj, np.ndarray):
  69. return [_format_data(item) for item in obj.tolist()]
  70. elif isinstance(obj, pd.DataFrame):
  71. return obj.to_json(orient="records", force_ascii=False)
  72. elif isinstance(obj, Path):
  73. return obj.as_posix()
  74. elif isinstance(obj, dict):
  75. return type(obj)({k: _format_data(v) for k, v in obj.items()})
  76. elif isinstance(obj, (list, tuple)):
  77. return [_format_data(i) for i in obj]
  78. else:
  79. return obj
  80. return _format_data(self)
  81. @property
  82. def json(self):
  83. return self._to_json()
  84. def save_to_json(self, save_path, indent=4, ensure_ascii=False, *args, **kwargs):
  85. if not str(save_path).endswith(".json"):
  86. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.json"
  87. _save_list_data(
  88. self._json_writer.write,
  89. save_path,
  90. self.json,
  91. indent=indent,
  92. ensure_ascii=ensure_ascii,
  93. *args,
  94. **kwargs,
  95. )
  96. class Base64Mixin:
  97. def __init__(self, *args, **kwargs):
  98. self._base64_writer = TextWriter(*args, **kwargs)
  99. self._show_funcs.append(self.save_to_base64)
  100. @abstractmethod
  101. def _to_base64(self):
  102. raise NotImplementedError
  103. @property
  104. def base64(self):
  105. return self._to_base64()
  106. def save_to_base64(self, save_path, *args, **kwargs):
  107. if not str(save_path).lower().endswith((".b64")):
  108. fp = Path(self["input_path"])
  109. save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
  110. _save_list_data(
  111. self._base64_writer.write, save_path, self.base64, *args, **kwargs
  112. )
  113. class ImgMixin:
  114. def __init__(self, backend="pillow", *args, **kwargs):
  115. self._img_writer = ImageWriter(backend=backend, *args, **kwargs)
  116. self._show_funcs.append(self.save_to_img)
  117. @abstractmethod
  118. def _to_img(self):
  119. raise NotImplementedError
  120. @property
  121. def img(self):
  122. image = self._to_img()
  123. # The img must be a PIL.Image obj
  124. if isinstance(image, np.ndarray):
  125. return Image.fromarray(image)
  126. return image
  127. def save_to_img(self, save_path, *args, **kwargs):
  128. if not str(save_path).lower().endswith((".jpg", ".png")):
  129. fp = Path(self["input_path"])
  130. save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
  131. _save_list_data(self._img_writer.write, save_path, self.img, *args, **kwargs)
  132. class CSVMixin:
  133. def __init__(self, backend="pandas", *args, **kwargs):
  134. self._csv_writer = CSVWriter(backend=backend, *args, **kwargs)
  135. self._show_funcs.append(self.save_to_csv)
  136. @abstractmethod
  137. def _to_csv(self):
  138. raise NotImplementedError
  139. def save_to_csv(self, save_path, *args, **kwargs):
  140. if not str(save_path).endswith(".csv"):
  141. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv"
  142. _save_list_data(
  143. self._csv_writer.write, save_path, self._to_csv(), *args, **kwargs
  144. )
  145. class HtmlMixin:
  146. def __init__(self, *args, **kwargs):
  147. self._html_writer = HtmlWriter(*args, **kwargs)
  148. self._show_funcs.append(self.save_to_html)
  149. @property
  150. def html(self):
  151. return self._to_html()
  152. def _to_html(self):
  153. return self["html"]
  154. def save_to_html(self, save_path, *args, **kwargs):
  155. if not str(save_path).endswith(".html"):
  156. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html"
  157. _save_list_data(self._html_writer.write, save_path, self.html, *args, **kwargs)
  158. class XlsxMixin:
  159. def __init__(self, *args, **kwargs):
  160. self._xlsx_writer = XlsxWriter(*args, **kwargs)
  161. self._show_funcs.append(self.save_to_xlsx)
  162. def _to_xlsx(self):
  163. return self["html"]
  164. def save_to_xlsx(self, save_path, *args, **kwargs):
  165. if not str(save_path).endswith(".xlsx"):
  166. save_path = Path(save_path) / f"{Path(self['input_path']).stem}.xlsx"
  167. _save_list_data(self._xlsx_writer.write, save_path, self.html, *args, **kwargs)