writers.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import enum
  16. import cv2
  17. import pandas as pd
  18. import numpy as np
  19. from PIL import Image
  20. __all__ = ["ImageWriter", "TextWriter", "WriterType", "TSWriter"]
  21. class WriterType(enum.Enum):
  22. """WriterType"""
  23. IMAGE = 1
  24. VIDEO = 2
  25. TEXT = 3
  26. TS = 4
  27. class _BaseWriter(object):
  28. """_BaseWriter"""
  29. def __init__(self, backend, **bk_args):
  30. super().__init__()
  31. if len(bk_args) == 0:
  32. bk_args = self.get_default_backend_args()
  33. self.bk_type = backend
  34. self.bk_args = bk_args
  35. self._backend = self.get_backend()
  36. def write(self, out_path, obj):
  37. """write"""
  38. raise NotImplementedError
  39. def get_backend(self, bk_args=None):
  40. """get backend"""
  41. if bk_args is None:
  42. bk_args = self.bk_args
  43. return self._init_backend(self.bk_type, bk_args)
  44. def _init_backend(self, bk_type, bk_args):
  45. """init backend"""
  46. raise NotImplementedError
  47. def get_type(self):
  48. """get type"""
  49. raise NotImplementedError
  50. def get_default_backend_args(self):
  51. """get default backend arguments"""
  52. return {}
  53. class ImageWriter(_BaseWriter):
  54. """ImageWriter"""
  55. def __init__(self, backend="opencv", **bk_args):
  56. super().__init__(backend=backend, **bk_args)
  57. def write(self, out_path, obj):
  58. """write"""
  59. return self._backend.write_obj(out_path, obj)
  60. def _init_backend(self, bk_type, bk_args):
  61. """init backend"""
  62. if bk_type == "opencv":
  63. return OpenCVImageWriterBackend(**bk_args)
  64. elif bk_type == "pillow":
  65. return PILImageWriterBackend(**bk_args)
  66. else:
  67. raise ValueError("Unsupported backend type")
  68. def get_type(self):
  69. """get type"""
  70. return WriterType.IMAGE
  71. class TextWriter(_BaseWriter):
  72. """TextWriter"""
  73. def __init__(self, backend="python", **bk_args):
  74. super().__init__(backend=backend, **bk_args)
  75. def write(self, out_path, obj):
  76. """write"""
  77. return self._backend.write_obj(out_path, obj)
  78. def _init_backend(self, bk_type, bk_args):
  79. """init backend"""
  80. if bk_type == "python":
  81. return TextWriterBackend(**bk_args)
  82. else:
  83. raise ValueError("Unsupported backend type")
  84. def get_type(self):
  85. """get type"""
  86. return WriterType.TEXT
  87. class _BaseWriterBackend(object):
  88. """_BaseWriterBackend"""
  89. def write_obj(self, out_path, obj):
  90. """write object"""
  91. out_dir = os.path.dirname(out_path)
  92. os.makedirs(out_dir, exist_ok=True)
  93. return self._write_obj(out_path, obj)
  94. def _write_obj(self, out_path, obj):
  95. """write object"""
  96. raise NotImplementedError
  97. class TextWriterBackend(_BaseWriterBackend):
  98. """TextWriterBackend"""
  99. def __init__(self, mode="w", encoding="utf-8"):
  100. super().__init__()
  101. self.mode = mode
  102. self.encoding = encoding
  103. def _write_obj(self, out_path, obj):
  104. """write text object"""
  105. with open(out_path, mode=self.mode, encoding=self.encoding) as f:
  106. f.write(obj)
  107. class _ImageWriterBackend(_BaseWriterBackend):
  108. """_ImageWriterBackend"""
  109. pass
  110. class OpenCVImageWriterBackend(_ImageWriterBackend):
  111. """OpenCVImageWriterBackend"""
  112. def _write_obj(self, out_path, obj):
  113. """write image object by OpenCV"""
  114. if isinstance(obj, Image.Image):
  115. arr = np.asarray(obj)
  116. elif isinstance(obj, np.ndarray):
  117. arr = obj
  118. else:
  119. raise TypeError("Unsupported object type")
  120. return cv2.imwrite(out_path, arr)
  121. class PILImageWriterBackend(_ImageWriterBackend):
  122. """PILImageWriterBackend"""
  123. def __init__(self, format_=None):
  124. super().__init__()
  125. self.format = format_
  126. def _write_obj(self, out_path, obj):
  127. """write image object by PIL"""
  128. if isinstance(obj, Image.Image):
  129. img = obj
  130. elif isinstance(obj, np.ndarray):
  131. img = Image.fromarray(obj)
  132. else:
  133. raise TypeError("Unsupported object type")
  134. return img.save(out_path, format=self.format)
  135. class TSWriter(_BaseWriter):
  136. """TSWriter"""
  137. def __init__(self, backend="pandas", **bk_args):
  138. super().__init__(backend=backend, **bk_args)
  139. def write(self, out_path, obj):
  140. """write"""
  141. return self._backend.write_obj(out_path, obj)
  142. def _init_backend(self, bk_type, bk_args):
  143. """init backend"""
  144. if bk_type == "pandas":
  145. return PandasTSWriterBackend(**bk_args)
  146. else:
  147. raise ValueError("Unsupported backend type")
  148. def get_type(self):
  149. """get type"""
  150. return WriterType.TS
  151. class _TSWriterBackend(_BaseWriterBackend):
  152. """_TSWriterBackend"""
  153. pass
  154. class PandasTSWriterBackend(_TSWriterBackend):
  155. """PILImageWriterBackend"""
  156. def __init__(self):
  157. super().__init__()
  158. def _write_obj(self, out_path, obj):
  159. """write image object by PIL"""
  160. if isinstance(obj, pd.DataFrame):
  161. ts = obj
  162. else:
  163. raise TypeError("Unsupported object type")
  164. return ts.to_csv(out_path)