readers.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import enum
  15. import itertools
  16. import cv2
  17. from PIL import Image, ImageOps
  18. __all__ = ["ImageReader", "VideoReader", "ReaderType"]
  19. class ReaderType(enum.Enum):
  20. """ReaderType"""
  21. IMAGE = 1
  22. GENERATIVE = 2
  23. POINT_CLOUD = 3
  24. class _BaseReader(object):
  25. """_BaseReader"""
  26. def __init__(self, backend, **bk_args):
  27. super().__init__()
  28. if len(bk_args) == 0:
  29. bk_args = self.get_default_backend_args()
  30. self.bk_type = backend
  31. self.bk_args = bk_args
  32. self._backend = self.get_backend()
  33. def read(self, in_path):
  34. """read file from path"""
  35. raise NotImplementedError
  36. def get_backend(self, bk_args=None):
  37. """get the backend"""
  38. if bk_args is None:
  39. bk_args = self.bk_args
  40. return self._init_backend(self.bk_type, bk_args)
  41. def set_backend(self, backend, **bk_args):
  42. self.bk_type = backend
  43. self.bk_args = bk_args
  44. self._backend = self.get_backend()
  45. def _init_backend(self, bk_type, bk_args):
  46. """init backend"""
  47. raise NotImplementedError
  48. def get_type(self):
  49. """get type"""
  50. raise NotImplementedError
  51. def get_default_backend_args(self):
  52. """get default backend arguments"""
  53. return {}
  54. class ImageReader(_BaseReader):
  55. """ImageReader"""
  56. def __init__(self, backend="opencv", **bk_args):
  57. super().__init__(backend=backend, **bk_args)
  58. def read(self, in_path):
  59. """read the image file from path"""
  60. arr = self._backend.read_file(in_path)
  61. return arr
  62. def _init_backend(self, bk_type, bk_args):
  63. """init backend"""
  64. if bk_type == "opencv":
  65. return OpenCVImageReaderBackend(**bk_args)
  66. elif bk_type == "pil" or bk_type == "pillow":
  67. return PILImageReaderBackend(**bk_args)
  68. else:
  69. raise ValueError("Unsupported backend type")
  70. def get_type(self):
  71. """get type"""
  72. return ReaderType.IMAGE
  73. class _GenerativeReader(_BaseReader):
  74. """_GenerativeReader"""
  75. def get_type(self):
  76. """get type"""
  77. return ReaderType.GENERATIVE
  78. def is_generative_reader(reader):
  79. """is_generative_reader"""
  80. return isinstance(reader, _GenerativeReader)
  81. class VideoReader(_GenerativeReader):
  82. """VideoReader"""
  83. def __init__(
  84. self,
  85. backend="opencv",
  86. st_frame_id=0,
  87. max_num_frames=None,
  88. auto_close=True,
  89. **bk_args,
  90. ):
  91. super().__init__(backend=backend, **bk_args)
  92. self.st_frame_id = st_frame_id
  93. self.max_num_frames = max_num_frames
  94. self.auto_close = auto_close
  95. def read(self, in_path):
  96. """read vide file from path"""
  97. self._backend.set_pos(self.st_frame_id)
  98. gen = self._backend.read_file(in_path)
  99. if self.num_frames is not None:
  100. gen = itertools.islice(gen, self.num_frames)
  101. yield from gen
  102. if self.auto_close:
  103. self._backend.close()
  104. def _init_backend(self, bk_type, bk_args):
  105. """init backend"""
  106. if bk_type == "opencv":
  107. return OpenCVVideoReaderBackend(**bk_args)
  108. else:
  109. raise ValueError("Unsupported backend type")
  110. class _BaseReaderBackend(object):
  111. """_BaseReaderBackend"""
  112. def read_file(self, in_path):
  113. """read file from path"""
  114. raise NotImplementedError
  115. class _ImageReaderBackend(_BaseReaderBackend):
  116. """_ImageReaderBackend"""
  117. pass
  118. class OpenCVImageReaderBackend(_ImageReaderBackend):
  119. """OpenCVImageReaderBackend"""
  120. def __init__(self, flags=cv2.IMREAD_COLOR):
  121. super().__init__()
  122. self.flags = flags
  123. def read_file(self, in_path):
  124. """read image file from path by OpenCV"""
  125. return cv2.imread(in_path, flags=self.flags)
  126. class PILImageReaderBackend(_ImageReaderBackend):
  127. """PILImageReaderBackend"""
  128. def __init__(self):
  129. super().__init__()
  130. def read_file(self, in_path):
  131. """read image file from path by PIL"""
  132. return ImageOps.exif_transpose(Image.open(in_path))
  133. class _VideoReaderBackend(_BaseReaderBackend):
  134. """_VideoReaderBackend"""
  135. def set_pos(self, pos):
  136. """set pos"""
  137. raise NotImplementedError
  138. def close(self):
  139. """close io"""
  140. raise NotImplementedError
  141. class OpenCVVideoReaderBackend(_VideoReaderBackend):
  142. """OpenCVVideoReaderBackend"""
  143. def __init__(self, **bk_args):
  144. super().__init__()
  145. self.cap_init_args = bk_args
  146. self._cap = None
  147. self._pos = 0
  148. self._max_num_frames = None
  149. def read_file(self, in_path):
  150. """read vidio file from path"""
  151. if self._cap is not None:
  152. self._cap_release()
  153. self._cap = self._cap_open(in_path)
  154. if self._pos is not None:
  155. self._cap_set_pos()
  156. return self._read_frames(self._cap)
  157. def _read_frames(self, cap):
  158. """read frames"""
  159. while True:
  160. ret, frame = cap.read()
  161. if not ret:
  162. break
  163. yield frame
  164. self._cap_release()
  165. def _cap_open(self, video_path):
  166. self._cap = cv2.VideoCapture(video_path, **self.cap_init_args)
  167. if not self._cap.isOpened():
  168. raise RuntimeError(f"Failed to open {video_path}")
  169. return self._cap
  170. def _cap_release(self):
  171. self._cap.release()
  172. def _cap_set_pos(self):
  173. self._cap.set(cv2.CAP_PROP_POS_FRAMES, self._pos)
  174. def set_pos(self, pos):
  175. self._pos = pos
  176. def close(self):
  177. if self._cap is not None:
  178. self._cap_release()
  179. self._cap = None