writers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import enum
  15. import json
  16. from pathlib import Path
  17. import numpy as np
  18. import pandas as pd
  19. import yaml
  20. from PIL import Image
  21. from ....utils.deps import class_requires_deps, is_dep_available
  22. from .tablepyxl import document_to_xl
  23. if is_dep_available("opencv-contrib-python"):
  24. import cv2
  25. __all__ = [
  26. "WriterType",
  27. "ImageWriter",
  28. "TextWriter",
  29. "JsonWriter",
  30. "CSVWriter",
  31. "HtmlWriter",
  32. "XlsxWriter",
  33. "YAMLWriter",
  34. "VideoWriter",
  35. "MarkdownWriter",
  36. ]
  37. class WriterType(enum.Enum):
  38. """WriterType"""
  39. IMAGE = 1
  40. VIDEO = 2
  41. TEXT = 3
  42. JSON = 4
  43. HTML = 5
  44. XLSX = 6
  45. CSV = 7
  46. YAML = 8
  47. class _BaseWriter(object):
  48. """_BaseWriter"""
  49. def __init__(self, backend, **bk_args):
  50. super().__init__()
  51. if len(bk_args) == 0:
  52. bk_args = self.get_default_backend_args()
  53. self.bk_type = backend
  54. self.bk_args = bk_args
  55. self._backend = self.get_backend()
  56. def write(self, out_path, obj):
  57. """write"""
  58. raise NotImplementedError
  59. def get_backend(self, bk_args=None):
  60. """get backend"""
  61. if bk_args is None:
  62. bk_args = self.bk_args
  63. return self._init_backend(self.bk_type, bk_args)
  64. def set_backend(self, backend, **bk_args):
  65. self.bk_type = backend
  66. self.bk_args = bk_args
  67. self._backend = self.get_backend()
  68. def _init_backend(self, bk_type, bk_args):
  69. """init backend"""
  70. raise NotImplementedError
  71. def get_type(self):
  72. """get type"""
  73. raise NotImplementedError
  74. def get_default_backend_args(self):
  75. """get default backend arguments"""
  76. return {}
  77. class ImageWriter(_BaseWriter):
  78. """ImageWriter"""
  79. def __init__(self, backend="opencv", **bk_args):
  80. super().__init__(backend=backend, **bk_args)
  81. def write(self, out_path, obj):
  82. """write"""
  83. return self._backend.write_obj(str(out_path), obj)
  84. def _init_backend(self, bk_type, bk_args):
  85. """init backend"""
  86. if bk_type == "opencv":
  87. return OpenCVImageWriterBackend(**bk_args)
  88. elif bk_type == "pil" or bk_type == "pillow":
  89. return PILImageWriterBackend(**bk_args)
  90. else:
  91. raise ValueError("Unsupported backend type")
  92. def get_type(self):
  93. """get type"""
  94. return WriterType.IMAGE
  95. class VideoWriter(_BaseWriter):
  96. """VideoWriter"""
  97. def __init__(self, backend="opencv", **bk_args):
  98. super().__init__(backend=backend, **bk_args)
  99. def write(self, out_path, obj):
  100. """write"""
  101. return self._backend.write_obj(str(out_path), obj)
  102. def _init_backend(self, bk_type, bk_args):
  103. """init backend"""
  104. if bk_type == "opencv":
  105. return OpenCVVideoWriterBackend(**bk_args)
  106. else:
  107. raise ValueError("Unsupported backend type")
  108. def get_type(self):
  109. """get type"""
  110. return WriterType.VIDEO
  111. class TextWriter(_BaseWriter):
  112. """TextWriter"""
  113. def __init__(self, backend="python", **bk_args):
  114. super().__init__(backend=backend, **bk_args)
  115. def write(self, out_path, obj):
  116. """write"""
  117. return self._backend.write_obj(str(out_path), obj)
  118. def _init_backend(self, bk_type, bk_args):
  119. """init backend"""
  120. if bk_type == "python":
  121. return TextWriterBackend(**bk_args)
  122. else:
  123. raise ValueError("Unsupported backend type")
  124. def get_type(self):
  125. """get type"""
  126. return WriterType.TEXT
  127. class JsonWriter(_BaseWriter):
  128. def __init__(self, backend="json", **bk_args):
  129. super().__init__(backend=backend, **bk_args)
  130. def write(self, out_path, obj, **bk_args):
  131. return self._backend.write_obj(str(out_path), obj, **bk_args)
  132. def _init_backend(self, bk_type, bk_args):
  133. if bk_type == "json":
  134. return JsonWriterBackend(**bk_args)
  135. elif bk_type == "ujson":
  136. return UJsonWriterBackend(**bk_args)
  137. else:
  138. raise ValueError("Unsupported backend type")
  139. def get_type(self):
  140. """get type"""
  141. return WriterType.JSON
  142. class HtmlWriter(_BaseWriter):
  143. def __init__(self, backend="html", **bk_args):
  144. super().__init__(backend=backend, **bk_args)
  145. def write(self, out_path, obj, **bk_args):
  146. return self._backend.write_obj(str(out_path), obj, **bk_args)
  147. def _init_backend(self, bk_type, bk_args):
  148. if bk_type == "html":
  149. return HtmlWriterBackend(**bk_args)
  150. else:
  151. raise ValueError("Unsupported backend type")
  152. def get_type(self):
  153. """get type"""
  154. return WriterType.HTML
  155. class XlsxWriter(_BaseWriter):
  156. def __init__(self, backend="xlsx", **bk_args):
  157. super().__init__(backend=backend, **bk_args)
  158. def write(self, out_path, obj, **bk_args):
  159. return self._backend.write_obj(str(out_path), obj, **bk_args)
  160. def _init_backend(self, bk_type, bk_args):
  161. if bk_type == "xlsx":
  162. return XlsxWriterBackend(**bk_args)
  163. else:
  164. raise ValueError("Unsupported backend type")
  165. def get_type(self):
  166. """get type"""
  167. return WriterType.XLSX
  168. class YAMLWriter(_BaseWriter):
  169. def __init__(self, backend="PyYAML", **bk_args):
  170. super().__init__(backend=backend, **bk_args)
  171. def write(self, out_path, obj, **bk_args):
  172. return self._backend.write_obj(str(out_path), obj, **bk_args)
  173. def _init_backend(self, bk_type, bk_args):
  174. if bk_type == "PyYAML":
  175. return YAMLWriterBackend(**bk_args)
  176. else:
  177. raise ValueError("Unsupported backend type")
  178. def get_type(self):
  179. """get type"""
  180. return WriterType.YAML
  181. class MarkdownWriter(_BaseWriter):
  182. """MarkdownWriter"""
  183. def __init__(self, backend="markdown", **bk_args):
  184. super().__init__(backend=backend, **bk_args)
  185. def write(self, out_path, obj):
  186. """write"""
  187. return self._backend.write_obj(str(out_path), obj)
  188. def _init_backend(self, bk_type, bk_args):
  189. """init backend"""
  190. if bk_type == "markdown":
  191. return MarkdownWriterBackend(**bk_args)
  192. else:
  193. raise ValueError("Unsupported backend type")
  194. def get_type(self):
  195. """get type"""
  196. return WriterType.MARKDOWN
  197. class _BaseWriterBackend(object):
  198. """_BaseWriterBackend"""
  199. def write_obj(self, out_path, obj, **bk_args):
  200. """write object"""
  201. Path(out_path).parent.mkdir(parents=True, exist_ok=True)
  202. return self._write_obj(out_path, obj, **bk_args)
  203. def _write_obj(self, out_path, obj, **bk_args):
  204. """write object"""
  205. raise NotImplementedError
  206. class TextWriterBackend(_BaseWriterBackend):
  207. """TextWriterBackend"""
  208. def __init__(self, mode="w", encoding="utf-8"):
  209. super().__init__()
  210. self.mode = mode
  211. self.encoding = encoding
  212. def _write_obj(self, out_path, obj):
  213. """write text object"""
  214. with open(out_path, mode=self.mode, encoding=self.encoding) as f:
  215. f.write(obj)
  216. class HtmlWriterBackend(_BaseWriterBackend):
  217. def __init__(self, mode="w", encoding="utf-8"):
  218. super().__init__()
  219. self.mode = mode
  220. self.encoding = encoding
  221. def _write_obj(self, out_path, obj, **bk_args):
  222. with open(out_path, mode=self.mode, encoding=self.encoding) as f:
  223. f.write(obj)
  224. class XlsxWriterBackend(_BaseWriterBackend):
  225. def _write_obj(self, out_path, obj, **bk_args):
  226. document_to_xl(obj, out_path)
  227. class _ImageWriterBackend(_BaseWriterBackend):
  228. """_ImageWriterBackend"""
  229. @class_requires_deps("opencv-contrib-python")
  230. class OpenCVImageWriterBackend(_ImageWriterBackend):
  231. """OpenCVImageWriterBackend"""
  232. def _write_obj(self, out_path, obj):
  233. """write image object by OpenCV"""
  234. if isinstance(obj, Image.Image):
  235. # Assuming the channel order is RGB.
  236. arr = np.asarray(obj)[:, :, ::-1]
  237. elif isinstance(obj, np.ndarray):
  238. arr = obj
  239. else:
  240. raise TypeError("Unsupported object type")
  241. return cv2.imwrite(out_path, arr)
  242. class PILImageWriterBackend(_ImageWriterBackend):
  243. """PILImageWriterBackend"""
  244. def __init__(self, format_=None):
  245. super().__init__()
  246. self.format = format_
  247. def _write_obj(self, out_path, obj):
  248. """write image object by PIL"""
  249. if isinstance(obj, Image.Image):
  250. img = obj
  251. elif isinstance(obj, np.ndarray):
  252. img = Image.fromarray(obj)
  253. else:
  254. raise TypeError("Unsupported object type")
  255. if len(img.getbands()) == 4:
  256. self.format = "PNG"
  257. return img.save(out_path, format=self.format)
  258. class _VideoWriterBackend(_BaseWriterBackend):
  259. """_VideoWriterBackend"""
  260. @class_requires_deps("opencv-contrib-python")
  261. class OpenCVVideoWriterBackend(_VideoWriterBackend):
  262. """OpenCVImageWriterBackend"""
  263. def _write_obj(self, out_path, obj):
  264. """write video object by OpenCV"""
  265. obj, fps = obj
  266. if isinstance(obj, np.ndarray):
  267. vr = obj
  268. width, height = vr[0].shape[1], vr[0].shape[0]
  269. fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Alternatively, use 'XVID'
  270. out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  271. for frame in vr:
  272. out.write(frame)
  273. out.release()
  274. else:
  275. raise TypeError("Unsupported object type")
  276. class _BaseJsonWriterBackend(object):
  277. def __init__(self, indent=4, ensure_ascii=False):
  278. super().__init__()
  279. self.indent = indent
  280. self.ensure_ascii = ensure_ascii
  281. def write_obj(self, out_path, obj, **bk_args):
  282. Path(out_path).parent.mkdir(parents=True, exist_ok=True)
  283. return self._write_obj(out_path, obj, **bk_args)
  284. def _write_obj(self, out_path, obj):
  285. raise NotImplementedError
  286. class JsonWriterBackend(_BaseJsonWriterBackend):
  287. def _write_obj(self, out_path, obj, **bk_args):
  288. with open(out_path, "w", encoding="utf-8") as f:
  289. json.dump(obj, f, **bk_args)
  290. class UJsonWriterBackend(_BaseJsonWriterBackend):
  291. # TODO
  292. def _write_obj(self, out_path, obj, **bk_args):
  293. raise NotImplementedError
  294. class YAMLWriterBackend(_BaseWriterBackend):
  295. def __init__(self, mode="w", encoding="utf-8"):
  296. super().__init__()
  297. self.mode = mode
  298. self.encoding = encoding
  299. def _write_obj(self, out_path, obj, **bk_args):
  300. """write text object"""
  301. with open(out_path, mode=self.mode, encoding=self.encoding) as f:
  302. yaml.dump(obj, f, **bk_args)
  303. class CSVWriter(_BaseWriter):
  304. """CSVWriter"""
  305. def __init__(self, backend="pandas", **bk_args):
  306. super().__init__(backend=backend, **bk_args)
  307. def write(self, out_path, obj):
  308. """write"""
  309. return self._backend.write_obj(str(out_path), obj)
  310. def _init_backend(self, bk_type, bk_args):
  311. """init backend"""
  312. if bk_type == "pandas":
  313. return PandasCSVWriterBackend(**bk_args)
  314. else:
  315. raise ValueError("Unsupported backend type")
  316. def get_type(self):
  317. """get type"""
  318. return WriterType.CSV
  319. class _CSVWriterBackend(_BaseWriterBackend):
  320. """_CSVWriterBackend"""
  321. class PandasCSVWriterBackend(_CSVWriterBackend):
  322. """PILImageWriterBackend"""
  323. def __init__(self):
  324. super().__init__()
  325. def _write_obj(self, out_path, obj):
  326. """write image object by PIL"""
  327. if isinstance(obj, pd.DataFrame):
  328. ts = obj
  329. else:
  330. raise TypeError("Unsupported object type")
  331. return ts.to_csv(out_path)
  332. class MarkdownWriterBackend(_BaseWriterBackend):
  333. """MarkdownWriterBackend"""
  334. def __init__(self):
  335. super().__init__()
  336. def _write_obj(self, out_path, obj):
  337. """write markdown obj"""
  338. with open(out_path, mode="w", encoding="utf-8", errors="replace") as f:
  339. f.write(obj)