readers.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import enum
  15. import itertools
  16. import cv2
  17. import pandas as pd
  18. from PIL import Image, ImageOps
  19. __all__ = ["ImageReader", "VideoReader", "ReaderType", "TSReader"]
  20. class ReaderType(enum.Enum):
  21. """ReaderType"""
  22. IMAGE = 1
  23. GENERATIVE = 2
  24. POINT_CLOUD = 3
  25. TS = 4
  26. class _BaseReader(object):
  27. """_BaseReader"""
  28. def __init__(self, backend, **bk_args):
  29. super().__init__()
  30. if len(bk_args) == 0:
  31. bk_args = self.get_default_backend_args()
  32. self.bk_type = backend
  33. self.bk_args = bk_args
  34. self._backend = self.get_backend()
  35. def read(self, in_path):
  36. """read file from path"""
  37. raise NotImplementedError
  38. def get_backend(self, bk_args=None):
  39. """get the backend"""
  40. if bk_args is None:
  41. bk_args = self.bk_args
  42. return self._init_backend(self.bk_type, bk_args)
  43. def _init_backend(self, bk_type, bk_args):
  44. """init backend"""
  45. raise NotImplementedError
  46. def get_type(self):
  47. """get type"""
  48. raise NotImplementedError
  49. def get_default_backend_args(self):
  50. """get default backend arguments"""
  51. return {}
  52. class ImageReader(_BaseReader):
  53. """ImageReader"""
  54. def __init__(self, backend="opencv", **bk_args):
  55. super().__init__(backend=backend, **bk_args)
  56. def read(self, in_path):
  57. """read the image file from path"""
  58. arr = self._backend.read_file(in_path)
  59. return arr
  60. def _init_backend(self, bk_type, bk_args):
  61. """init backend"""
  62. if bk_type == "opencv":
  63. return OpenCVImageReaderBackend(**bk_args)
  64. elif bk_type == "pil":
  65. return PILImageReaderBackend(**bk_args)
  66. else:
  67. raise ValueError("Unsupported backend type")
  68. def get_type(self):
  69. """get type"""
  70. return ReaderType.IMAGE
  71. class _GenerativeReader(_BaseReader):
  72. """_GenerativeReader"""
  73. def get_type(self):
  74. """get type"""
  75. return ReaderType.GENERATIVE
  76. def is_generative_reader(reader):
  77. """is_generative_reader"""
  78. return isinstance(reader, _GenerativeReader)
  79. class VideoReader(_GenerativeReader):
  80. """VideoReader"""
  81. def __init__(
  82. self,
  83. backend="opencv",
  84. st_frame_id=0,
  85. max_num_frames=None,
  86. auto_close=True,
  87. **bk_args,
  88. ):
  89. super().__init__(backend=backend, **bk_args)
  90. self.st_frame_id = st_frame_id
  91. self.max_num_frames = max_num_frames
  92. self.auto_close = auto_close
  93. def read(self, in_path):
  94. """read vide file from path"""
  95. self._backend.set_pos(self.st_frame_id)
  96. gen = self._backend.read_file(in_path)
  97. if self.num_frames is not None:
  98. gen = itertools.islice(gen, self.num_frames)
  99. yield from gen
  100. if self.auto_close:
  101. self._backend.close()
  102. def _init_backend(self, bk_type, bk_args):
  103. """init backend"""
  104. if bk_type == "opencv":
  105. return OpenCVVideoReaderBackend(**bk_args)
  106. else:
  107. raise ValueError("Unsupported backend type")
  108. class _BaseReaderBackend(object):
  109. """_BaseReaderBackend"""
  110. def read_file(self, in_path):
  111. """read file from path"""
  112. raise NotImplementedError
  113. class _ImageReaderBackend(_BaseReaderBackend):
  114. """_ImageReaderBackend"""
  115. pass
  116. class OpenCVImageReaderBackend(_ImageReaderBackend):
  117. """OpenCVImageReaderBackend"""
  118. def __init__(self, flags=cv2.IMREAD_COLOR):
  119. super().__init__()
  120. self.flags = flags
  121. def read_file(self, in_path):
  122. """read image file from path by OpenCV"""
  123. return cv2.imread(in_path, flags=self.flags)
  124. class PILImageReaderBackend(_ImageReaderBackend):
  125. """PILImageReaderBackend"""
  126. def __init__(self):
  127. super().__init__()
  128. def read_file(self, in_path):
  129. """read image file from path by PIL"""
  130. return ImageOps.exif_transpose(Image.open(in_path))
  131. class _VideoReaderBackend(_BaseReaderBackend):
  132. """_VideoReaderBackend"""
  133. def set_pos(self, pos):
  134. """set pos"""
  135. raise NotImplementedError
  136. def close(self):
  137. """close io"""
  138. raise NotImplementedError
  139. class OpenCVVideoReaderBackend(_VideoReaderBackend):
  140. """OpenCVVideoReaderBackend"""
  141. def __init__(self, **bk_args):
  142. super().__init__()
  143. self.cap_init_args = bk_args
  144. self._cap = None
  145. self._pos = 0
  146. self._max_num_frames = None
  147. def read_file(self, in_path):
  148. """read vidio file from path"""
  149. if self._cap is not None:
  150. self._cap_release()
  151. self._cap = self._cap_open(in_path)
  152. if self._pos is not None:
  153. self._cap_set_pos()
  154. return self._read_frames(self._cap)
  155. def _read_frames(self, cap):
  156. """read frames"""
  157. while True:
  158. ret, frame = cap.read()
  159. if not ret:
  160. break
  161. yield frame
  162. self._cap_release()
  163. def _cap_open(self, video_path):
  164. self._cap = cv2.VideoCapture(video_path, **self.cap_init_args)
  165. if not self._cap.isOpened():
  166. raise RuntimeError(f"Failed to open {video_path}")
  167. return self._cap
  168. def _cap_release(self):
  169. self._cap.release()
  170. def _cap_set_pos(self):
  171. self._cap.set(cv2.CAP_PROP_POS_FRAMES, self._pos)
  172. def set_pos(self, pos):
  173. self._pos = pos
  174. def close(self):
  175. if self._cap is not None:
  176. self._cap_release()
  177. self._cap = None
  178. class TSReader(_BaseReader):
  179. """TSReader"""
  180. def __init__(self, backend="pandas", **bk_args):
  181. super().__init__(backend=backend, **bk_args)
  182. def read(self, in_path):
  183. """read the image file from path"""
  184. arr = self._backend.read_file(in_path)
  185. return arr
  186. def _init_backend(self, bk_type, bk_args):
  187. """init backend"""
  188. if bk_type == "pandas":
  189. return PandasTSReaderBackend(**bk_args)
  190. else:
  191. raise ValueError("Unsupported backend type")
  192. def get_type(self):
  193. """get type"""
  194. return ReaderType.TS
  195. class _TSReaderBackend(_BaseReaderBackend):
  196. """_TSReaderBackend"""
  197. pass
  198. class PandasTSReaderBackend(_TSReaderBackend):
  199. """PandasTSReaderBackend"""
  200. def __init__(self):
  201. super().__init__()
  202. def read_file(self, in_path):
  203. """read image file from path by OpenCV"""
  204. return pd.read_csv(in_path)