writers.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import enum
  16. import cv2
  17. import numpy as np
  18. from PIL import Image
  19. __all__ = ['ImageWriter', 'TextWriter', 'WriterType']
  20. class WriterType(enum.Enum):
  21. """ WriterType """
  22. IMAGE = 1
  23. VIDEO = 2
  24. TEXT = 3
  25. class _BaseWriter(object):
  26. """ _BaseWriter """
  27. def __init__(self, backend, **bk_args):
  28. super().__init__()
  29. if len(bk_args) == 0:
  30. bk_args = self.get_default_backend_args()
  31. self.bk_type = backend
  32. self.bk_args = bk_args
  33. self._backend = self.get_backend()
  34. def write(self, out_path, obj):
  35. """ write """
  36. raise NotImplementedError
  37. def get_backend(self, bk_args=None):
  38. """ get backend """
  39. if bk_args is None:
  40. bk_args = self.bk_args
  41. return self._init_backend(self.bk_type, bk_args)
  42. def _init_backend(self, bk_type, bk_args):
  43. """ init backend """
  44. raise NotImplementedError
  45. def get_type(self):
  46. """ get type """
  47. raise NotImplementedError
  48. def get_default_backend_args(self):
  49. """ get default backend arguments """
  50. return {}
  51. class ImageWriter(_BaseWriter):
  52. """ ImageWriter """
  53. def __init__(self, backend='opencv', **bk_args):
  54. super().__init__(backend=backend, **bk_args)
  55. def write(self, out_path, obj):
  56. """ write """
  57. return self._backend.write_obj(out_path, obj)
  58. def _init_backend(self, bk_type, bk_args):
  59. """ init backend """
  60. if bk_type == 'opencv':
  61. return OpenCVImageWriterBackend(**bk_args)
  62. elif bk_type == 'pillow':
  63. return PILImageWriterBackend(**bk_args)
  64. else:
  65. raise ValueError("Unsupported backend type")
  66. def get_type(self):
  67. """ get type """
  68. return WriterType.IMAGE
  69. class TextWriter(_BaseWriter):
  70. """ TextWriter """
  71. def __init__(self, backend='python', **bk_args):
  72. super().__init__(backend=backend, **bk_args)
  73. def write(self, out_path, obj):
  74. """ write """
  75. return self._backend.write_obj(out_path, obj)
  76. def _init_backend(self, bk_type, bk_args):
  77. """ init backend """
  78. if bk_type == 'python':
  79. return TextWriterBackend(**bk_args)
  80. else:
  81. raise ValueError("Unsupported backend type")
  82. def get_type(self):
  83. """ get type """
  84. return WriterType.TEXT
  85. class _BaseWriterBackend(object):
  86. """ _BaseWriterBackend """
  87. def write_obj(self, out_path, obj):
  88. """ write object """
  89. out_dir = os.path.dirname(out_path)
  90. os.makedirs(out_dir, exist_ok=True)
  91. return self._write_obj(out_path, obj)
  92. def _write_obj(self, out_path, obj):
  93. """ write object """
  94. raise NotImplementedError
  95. class TextWriterBackend(_BaseWriterBackend):
  96. """ TextWriterBackend """
  97. def __init__(self, mode='w', encoding='utf-8'):
  98. super().__init__()
  99. self.mode = mode
  100. self.encoding = encoding
  101. def _write_obj(self, out_path, obj):
  102. """ write text object """
  103. with open(out_path, mode=self.mode, encoding=self.encoding) as f:
  104. f.write(obj)
  105. class _ImageWriterBackend(_BaseWriterBackend):
  106. """ _ImageWriterBackend """
  107. pass
  108. class OpenCVImageWriterBackend(_ImageWriterBackend):
  109. """ OpenCVImageWriterBackend """
  110. def _write_obj(self, out_path, obj):
  111. """ write image object by OpenCV """
  112. if isinstance(obj, Image.Image):
  113. arr = np.asarray(obj)
  114. elif isinstance(obj, np.ndarray):
  115. arr = obj
  116. else:
  117. raise TypeError("Unsupported object type")
  118. return cv2.imwrite(out_path, arr)
  119. class PILImageWriterBackend(_ImageWriterBackend):
  120. """ PILImageWriterBackend """
  121. def __init__(self, format_=None):
  122. super().__init__()
  123. self.format = format_
  124. def _write_obj(self, out_path, obj):
  125. """ write image object by PIL """
  126. if isinstance(obj, Image.Image):
  127. img = obj
  128. elif isinstance(obj, np.ndarray):
  129. img = Image.fromarray(obj)
  130. else:
  131. raise TypeError("Unsupported object type")
  132. return img.save(out_path, format=self.format)