writers.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import enum
  16. import json
  17. from pathlib import Path
  18. import cv2
  19. import numpy as np
  20. from PIL import Image
  21. from .tablepyxl import document_to_xl
  22. __all__ = [
  23. "ImageWriter",
  24. "TextWriter",
  25. "JsonWriter",
  26. "WriterType",
  27. "HtmlWriter",
  28. "XlsxWriter",
  29. ]
  30. class WriterType(enum.Enum):
  31. """WriterType"""
  32. IMAGE = 1
  33. VIDEO = 2
  34. TEXT = 3
  35. JSON = 4
  36. HTML = 5
  37. XLSX = 6
  38. class _BaseWriter(object):
  39. """_BaseWriter"""
  40. def __init__(self, backend, **bk_args):
  41. super().__init__()
  42. if len(bk_args) == 0:
  43. bk_args = self.get_default_backend_args()
  44. self.bk_type = backend
  45. self.bk_args = bk_args
  46. self._backend = self.get_backend()
  47. def write(self, out_path, obj):
  48. """write"""
  49. raise NotImplementedError
  50. def get_backend(self, bk_args=None):
  51. """get backend"""
  52. if bk_args is None:
  53. bk_args = self.bk_args
  54. return self._init_backend(self.bk_type, bk_args)
  55. def set_backend(self, backend, **bk_args):
  56. self.bk_type = backend
  57. self.bk_args = bk_args
  58. self._backend = self.get_backend()
  59. def _init_backend(self, bk_type, bk_args):
  60. """init backend"""
  61. raise NotImplementedError
  62. def get_type(self):
  63. """get type"""
  64. raise NotImplementedError
  65. def get_default_backend_args(self):
  66. """get default backend arguments"""
  67. return {}
  68. class ImageWriter(_BaseWriter):
  69. """ImageWriter"""
  70. def __init__(self, backend="opencv", **bk_args):
  71. super().__init__(backend=backend, **bk_args)
  72. def write(self, out_path, obj):
  73. """write"""
  74. return self._backend.write_obj(out_path, obj)
  75. def _init_backend(self, bk_type, bk_args):
  76. """init backend"""
  77. if bk_type == "opencv":
  78. return OpenCVImageWriterBackend(**bk_args)
  79. elif bk_type == "pil" or bk_type == "pillow":
  80. return PILImageWriterBackend(**bk_args)
  81. else:
  82. raise ValueError("Unsupported backend type")
  83. def get_type(self):
  84. """get type"""
  85. return WriterType.IMAGE
  86. class TextWriter(_BaseWriter):
  87. """TextWriter"""
  88. def __init__(self, backend="python", **bk_args):
  89. super().__init__(backend=backend, **bk_args)
  90. def write(self, out_path, obj):
  91. """write"""
  92. return self._backend.write_obj(out_path, obj)
  93. def _init_backend(self, bk_type, bk_args):
  94. """init backend"""
  95. if bk_type == "python":
  96. return TextWriterBackend(**bk_args)
  97. else:
  98. raise ValueError("Unsupported backend type")
  99. def get_type(self):
  100. """get type"""
  101. return WriterType.TEXT
  102. class JsonWriter(_BaseWriter):
  103. def __init__(self, backend="json", **bk_args):
  104. super().__init__(backend=backend, **bk_args)
  105. def write(self, out_path, obj, **bk_args):
  106. return self._backend.write_obj(out_path, obj, **bk_args)
  107. def _init_backend(self, bk_type, bk_args):
  108. if bk_type == "json":
  109. return JsonWriterBackend(**bk_args)
  110. elif bk_type == "ujson":
  111. return UJsonWriterBackend(**bk_args)
  112. else:
  113. raise ValueError("Unsupported backend type")
  114. def get_type(self):
  115. """get type"""
  116. return WriterType.JSON
  117. class HtmlWriter(_BaseWriter):
  118. def __init__(self, backend="html", **bk_args):
  119. super().__init__(backend=backend, **bk_args)
  120. def write(self, out_path, obj, **bk_args):
  121. return self._backend.write_obj(out_path, obj, **bk_args)
  122. def _init_backend(self, bk_type, bk_args):
  123. if bk_type == "html":
  124. return HtmlWriterBackend(**bk_args)
  125. else:
  126. raise ValueError("Unsupported backend type")
  127. def get_type(self):
  128. """get type"""
  129. return WriterType.HTML
  130. class XlsxWriter(_BaseWriter):
  131. def __init__(self, backend="xlsx", **bk_args):
  132. super().__init__(backend=backend, **bk_args)
  133. def write(self, out_path, obj, **bk_args):
  134. return self._backend.write_obj(out_path, obj, **bk_args)
  135. def _init_backend(self, bk_type, bk_args):
  136. if bk_type == "xlsx":
  137. return XlsxWriterBackend(**bk_args)
  138. else:
  139. raise ValueError("Unsupported backend type")
  140. def get_type(self):
  141. """get type"""
  142. return WriterType.XLSX
  143. class _BaseWriterBackend(object):
  144. """_BaseWriterBackend"""
  145. def write_obj(self, out_path, obj):
  146. """write object"""
  147. out_dir = os.path.dirname(out_path)
  148. os.makedirs(out_dir, exist_ok=True)
  149. return self._write_obj(out_path, obj)
  150. def _write_obj(self, out_path, obj):
  151. """write object"""
  152. raise NotImplementedError
  153. class TextWriterBackend(_BaseWriterBackend):
  154. """TextWriterBackend"""
  155. def __init__(self, mode="w", encoding="utf-8"):
  156. super().__init__()
  157. self.mode = mode
  158. self.encoding = encoding
  159. def _write_obj(self, out_path, obj):
  160. """write text object"""
  161. with open(out_path, mode=self.mode, encoding=self.encoding) as f:
  162. f.write(obj)
  163. class HtmlWriterBackend(_BaseWriterBackend):
  164. def __init__(self, mode="w", encoding="utf-8"):
  165. super().__init__()
  166. self.mode = mode
  167. self.encoding = encoding
  168. def _write_obj(self, out_path, obj, **bk_args):
  169. with open(out_path, mode=self.mode, encoding=self.encoding) as f:
  170. f.write(obj)
  171. class XlsxWriterBackend(_BaseWriterBackend):
  172. def _write_obj(self, out_path, obj, **bk_args):
  173. document_to_xl(obj, out_path)
  174. class _ImageWriterBackend(_BaseWriterBackend):
  175. """_ImageWriterBackend"""
  176. pass
  177. class OpenCVImageWriterBackend(_ImageWriterBackend):
  178. """OpenCVImageWriterBackend"""
  179. def _write_obj(self, out_path, obj):
  180. """write image object by OpenCV"""
  181. if isinstance(obj, Image.Image):
  182. arr = np.asarray(obj)
  183. elif isinstance(obj, np.ndarray):
  184. arr = obj
  185. else:
  186. raise TypeError("Unsupported object type")
  187. return cv2.imwrite(out_path, arr)
  188. class PILImageWriterBackend(_ImageWriterBackend):
  189. """PILImageWriterBackend"""
  190. def __init__(self, format_=None):
  191. super().__init__()
  192. self.format = format_
  193. def _write_obj(self, out_path, obj):
  194. """write image object by PIL"""
  195. if isinstance(obj, Image.Image):
  196. img = obj
  197. elif isinstance(obj, np.ndarray):
  198. img = Image.fromarray(obj)
  199. else:
  200. raise TypeError("Unsupported object type")
  201. return img.save(out_path, format=self.format)
  202. class _BaseJsonWriterBackend(object):
  203. def __init__(self, indent=4, ensure_ascii=False):
  204. super().__init__()
  205. self.indent = indent
  206. self.ensure_ascii = ensure_ascii
  207. def write_obj(self, out_path, obj, **bk_args):
  208. Path(out_path).parent.mkdir(parents=True, exist_ok=True)
  209. return self._write_obj(out_path, obj, **bk_args)
  210. def _write_obj(self, out_path, obj):
  211. raise NotImplementedError
  212. class JsonWriterBackend(_BaseJsonWriterBackend):
  213. def _write_obj(self, out_path, obj, **bk_args):
  214. with open(out_path, "w") as f:
  215. json.dump(obj, f, **bk_args)
  216. class UJsonWriterBackend(_BaseJsonWriterBackend):
  217. # TODO
  218. def _write_obj(self, out_path, obj, **bk_args):
  219. raise NotImplementedError