readers.py 5.4 KB


  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import enum
  12. import itertools
  13. import cv2
  14. __all__ = ['ImageReader', 'VideoReader', 'ReaderType']
  15. class ReaderType(enum.Enum):
  16. """ ReaderType """
  17. IMAGE = 1
  18. GENERATIVE = 2
  19. POINT_CLOUD = 3
  20. class _BaseReader(object):
  21. """ _BaseReader """
  22. def __init__(self, backend, **bk_args):
  23. super().__init__()
  24. if len(bk_args) == 0:
  25. bk_args = self.get_default_backend_args()
  26. self.bk_type = backend
  27. self.bk_args = bk_args
  28. self._backend = self.get_backend()
  29. def read(self, in_path):
  30. """ read file from path """
  31. raise NotImplementedError
  32. def get_backend(self, bk_args=None):
  33. """ get the backend """
  34. if bk_args is None:
  35. bk_args = self.bk_args
  36. return self._init_backend(self.bk_type, bk_args)
  37. def _init_backend(self, bk_type, bk_args):
  38. """ init backend """
  39. raise NotImplementedError
  40. def get_type(self):
  41. """ get type """
  42. raise NotImplementedError
  43. def get_default_backend_args(self):
  44. """ get default backend arguments """
  45. return {}
  46. class ImageReader(_BaseReader):
  47. """ ImageReader """
  48. def __init__(self, backend='opencv', **bk_args):
  49. super().__init__(backend=backend, **bk_args)
  50. def read(self, in_path):
  51. """ read the image file from path """
  52. arr = self._backend.read_file(in_path)
  53. return arr
  54. def _init_backend(self, bk_type, bk_args):
  55. """ init backend """
  56. if bk_type == 'opencv':
  57. return OpenCVImageReaderBackend(**bk_args)
  58. else:
  59. raise ValueError("Unsupported backend type")
  60. def get_type(self):
  61. """ get type """
  62. return ReaderType.IMAGE
  63. class _GenerativeReader(_BaseReader):
  64. """ _GenerativeReader """
  65. def get_type(self):
  66. """ get type """
  67. return ReaderType.GENERATIVE
  68. def is_generative_reader(reader):
  69. """ is_generative_reader """
  70. return isinstance(reader, _GenerativeReader)
  71. class VideoReader(_GenerativeReader):
  72. """ VideoReader """
  73. def __init__(self,
  74. backend='opencv',
  75. st_frame_id=0,
  76. max_num_frames=None,
  77. auto_close=True,
  78. **bk_args):
  79. super().__init__(backend=backend, **bk_args)
  80. self.st_frame_id = st_frame_id
  81. self.max_num_frames = max_num_frames
  82. self.auto_close = auto_close
  83. def read(self, in_path):
  84. """ read vide file from path """
  85. self._backend.set_pos(self.st_frame_id)
  86. gen = self._backend.read_file(in_path)
  87. if self.num_frames is not None:
  88. gen = itertools.islice(gen, self.num_frames)
  89. yield from gen
  90. if self.auto_close:
  91. self._backend.close()
  92. def _init_backend(self, bk_type, bk_args):
  93. """ init backend """
  94. if bk_type == 'opencv':
  95. return OpenCVVideoReaderBackend(**bk_args)
  96. else:
  97. raise ValueError("Unsupported backend type")
  98. class _BaseReaderBackend(object):
  99. """ _BaseReaderBackend """
  100. def read_file(self, in_path):
  101. """ read file from path """
  102. raise NotImplementedError
  103. class _ImageReaderBackend(_BaseReaderBackend):
  104. """ _ImageReaderBackend """
  105. pass
  106. class OpenCVImageReaderBackend(_ImageReaderBackend):
  107. """ OpenCVImageReaderBackend """
  108. def __init__(self, flags=cv2.IMREAD_COLOR):
  109. super().__init__()
  110. self.flags = flags
  111. def read_file(self, in_path):
  112. """ read image file from path by OpenCV """
  113. return cv2.imread(in_path, flags=self.flags)
  114. class _VideoReaderBackend(_BaseReaderBackend):
  115. """ _VideoReaderBackend """
  116. def set_pos(self, pos):
  117. """ set pos """
  118. raise NotImplementedError
  119. def close(self):
  120. """ close io """
  121. raise NotImplementedError
  122. class OpenCVVideoReaderBackend(_VideoReaderBackend):
  123. """ OpenCVVideoReaderBackend """
  124. def __init__(self, **bk_args):
  125. super().__init__()
  126. self.cap_init_args = bk_args
  127. self._cap = None
  128. self._pos = 0
  129. self._max_num_frames = None
  130. def read_file(self, in_path):
  131. """ read vidio file from path """
  132. if self._cap is not None:
  133. self._cap_release()
  134. self._cap = self._cap_open(in_path)
  135. if self._pos is not None:
  136. self._cap_set_pos()
  137. return self._read_frames(self._cap)
  138. def _read_frames(self, cap):
  139. """ read frames """
  140. while True:
  141. ret, frame = cap.read()
  142. if not ret:
  143. break
  144. yield frame
  145. self._cap_release()
  146. def _cap_open(self, video_path):
  147. self._cap = cv2.VideoCapture(video_path, **self.cap_init_args)
  148. if not self._cap.isOpened():
  149. raise RuntimeError(f"Failed to open {video_path}")
  150. return self._cap
  151. def _cap_release(self):
  152. self._cap.release()
  153. def _cap_set_pos(self):
  154. self._cap.set(cv2.CAP_PROP_POS_FRAMES, self._pos)
  155. def set_pos(self, pos):
  156. self._pos = pos
  157. def close(self):
  158. if self._cap is not None:
  159. self._cap_release()
  160. self._cap = None