writers.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import enum
  13. import cv2
  14. import numpy as np
  15. from PIL import Image
  16. __all__ = ['ImageWriter', 'TextWriter', 'WriterType']
  17. class WriterType(enum.Enum):
  18. """ WriterType """
  19. IMAGE = 1
  20. VIDEO = 2
  21. TEXT = 3
  22. class _BaseWriter(object):
  23. """ _BaseWriter """
  24. def __init__(self, backend, **bk_args):
  25. super().__init__()
  26. if len(bk_args) == 0:
  27. bk_args = self.get_default_backend_args()
  28. self.bk_type = backend
  29. self.bk_args = bk_args
  30. self._backend = self.get_backend()
  31. def write(self, out_path, obj):
  32. """ write """
  33. raise NotImplementedError
  34. def get_backend(self, bk_args=None):
  35. """ get backend """
  36. if bk_args is None:
  37. bk_args = self.bk_args
  38. return self._init_backend(self.bk_type, bk_args)
  39. def _init_backend(self, bk_type, bk_args):
  40. """ init backend """
  41. raise NotImplementedError
  42. def get_type(self):
  43. """ get type """
  44. raise NotImplementedError
  45. def get_default_backend_args(self):
  46. """ get default backend arguments """
  47. return {}
  48. class ImageWriter(_BaseWriter):
  49. """ ImageWriter """
  50. def __init__(self, backend='opencv', **bk_args):
  51. super().__init__(backend=backend, **bk_args)
  52. def write(self, out_path, obj):
  53. """ write """
  54. return self._backend.write_obj(out_path, obj)
  55. def _init_backend(self, bk_type, bk_args):
  56. """ init backend """
  57. if bk_type == 'opencv':
  58. return OpenCVImageWriterBackend(**bk_args)
  59. elif bk_type == 'pillow':
  60. return PILImageWriterBackend(**bk_args)
  61. else:
  62. raise ValueError("Unsupported backend type")
  63. def get_type(self):
  64. """ get type """
  65. return WriterType.IMAGE
  66. class TextWriter(_BaseWriter):
  67. """ TextWriter """
  68. def __init__(self, backend='python', **bk_args):
  69. super().__init__(backend=backend, **bk_args)
  70. def write(self, out_path, obj):
  71. """ write """
  72. return self._backend.write_obj(out_path, obj)
  73. def _init_backend(self, bk_type, bk_args):
  74. """ init backend """
  75. if bk_type == 'python':
  76. return TextWriterBackend(**bk_args)
  77. else:
  78. raise ValueError("Unsupported backend type")
  79. def get_type(self):
  80. """ get type """
  81. return WriterType.TEXT
  82. class _BaseWriterBackend(object):
  83. """ _BaseWriterBackend """
  84. def write_obj(self, out_path, obj):
  85. """ write object """
  86. out_dir = os.path.dirname(out_path)
  87. os.makedirs(out_dir, exist_ok=True)
  88. return self._write_obj(out_path, obj)
  89. def _write_obj(self, out_path, obj):
  90. """ write object """
  91. raise NotImplementedError
  92. class TextWriterBackend(_BaseWriterBackend):
  93. """ TextWriterBackend """
  94. def __init__(self, mode='w', encoding='utf-8'):
  95. super().__init__()
  96. self.mode = mode
  97. self.encoding = encoding
  98. def _write_obj(self, out_path, obj):
  99. """ write text object """
  100. with open(out_path, mode=self.mode, encoding=self.encoding) as f:
  101. f.write(obj)
  102. class _ImageWriterBackend(_BaseWriterBackend):
  103. """ _ImageWriterBackend """
  104. pass
  105. class OpenCVImageWriterBackend(_ImageWriterBackend):
  106. """ OpenCVImageWriterBackend """
  107. def _write_obj(self, out_path, obj):
  108. """ write image object by OpenCV """
  109. if isinstance(obj, Image.Image):
  110. arr = np.asarray(obj)
  111. elif isinstance(obj, np.ndarray):
  112. arr = obj
  113. else:
  114. raise TypeError("Unsupported object type")
  115. return cv2.imwrite(out_path, arr)
  116. class PILImageWriterBackend(_ImageWriterBackend):
  117. """ PILImageWriterBackend """
  118. def __init__(self, format_=None):
  119. super().__init__()
  120. self.format = format_
  121. def _write_obj(self, out_path, obj):
  122. """ write image object by PIL """
  123. if isinstance(obj, Image.Image):
  124. img = obj
  125. elif isinstance(obj, np.ndarray):
  126. img = Image.fromarray(obj)
  127. else:
  128. raise TypeError("Unsupported object type")
  129. return img.save(out_path, format=self.format)