readers.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import enum
  15. import itertools
  16. from pathlib import Path
  17. import cv2
  18. from .....utils.download import download
  19. from .....utils.cache import CACHE_DIR
  20. __all__ = ['ImageReader', 'VideoReader', 'ReaderType']
  21. class ReaderType(enum.Enum):
  22. """ ReaderType """
  23. IMAGE = 1
  24. GENERATIVE = 2
  25. POINT_CLOUD = 3
  26. class _BaseReader(object):
  27. """ _BaseReader """
  28. def __init__(self, backend, **bk_args):
  29. super().__init__()
  30. if len(bk_args) == 0:
  31. bk_args = self.get_default_backend_args()
  32. self.bk_type = backend
  33. self.bk_args = bk_args
  34. self._backend = self.get_backend()
  35. def read(self, in_path):
  36. """ read file from path """
  37. raise NotImplementedError
  38. def get_backend(self, bk_args=None):
  39. """ get the backend """
  40. if bk_args is None:
  41. bk_args = self.bk_args
  42. return self._init_backend(self.bk_type, bk_args)
  43. def _init_backend(self, bk_type, bk_args):
  44. """ init backend """
  45. raise NotImplementedError
  46. def get_type(self):
  47. """ get type """
  48. raise NotImplementedError
  49. def get_default_backend_args(self):
  50. """ get default backend arguments """
  51. return {}
  52. def _download_from_url(self, in_path):
  53. if in_path.startswith("http"):
  54. file_name = Path(in_path).name
  55. save_path = Path(CACHE_DIR) / "predict_input" / file_name
  56. download(in_path, save_path, overwrite=True)
  57. return save_path.as_posix()
  58. return in_path
  59. class ImageReader(_BaseReader):
  60. """ ImageReader """
  61. def __init__(self, backend='opencv', **bk_args):
  62. super().__init__(backend=backend, **bk_args)
  63. def read(self, in_path):
  64. """ read the image file from path """
  65. # XXX: auto download for url
  66. in_path = self._download_from_url(in_path)
  67. arr = self._backend.read_file(in_path)
  68. return arr
  69. def _init_backend(self, bk_type, bk_args):
  70. """ init backend """
  71. if bk_type == 'opencv':
  72. return OpenCVImageReaderBackend(**bk_args)
  73. else:
  74. raise ValueError("Unsupported backend type")
  75. def get_type(self):
  76. """ get type """
  77. return ReaderType.IMAGE
  78. class _GenerativeReader(_BaseReader):
  79. """ _GenerativeReader """
  80. def get_type(self):
  81. """ get type """
  82. return ReaderType.GENERATIVE
  83. def is_generative_reader(reader):
  84. """ is_generative_reader """
  85. return isinstance(reader, _GenerativeReader)
  86. class VideoReader(_GenerativeReader):
  87. """ VideoReader """
  88. def __init__(self,
  89. backend='opencv',
  90. st_frame_id=0,
  91. max_num_frames=None,
  92. auto_close=True,
  93. **bk_args):
  94. super().__init__(backend=backend, **bk_args)
  95. self.st_frame_id = st_frame_id
  96. self.max_num_frames = max_num_frames
  97. self.auto_close = auto_close
  98. def read(self, in_path):
  99. """ read vide file from path """
  100. self._backend.set_pos(self.st_frame_id)
  101. gen = self._backend.read_file(in_path)
  102. if self.num_frames is not None:
  103. gen = itertools.islice(gen, self.num_frames)
  104. yield from gen
  105. if self.auto_close:
  106. self._backend.close()
  107. def _init_backend(self, bk_type, bk_args):
  108. """ init backend """
  109. if bk_type == 'opencv':
  110. return OpenCVVideoReaderBackend(**bk_args)
  111. else:
  112. raise ValueError("Unsupported backend type")
  113. class _BaseReaderBackend(object):
  114. """ _BaseReaderBackend """
  115. def read_file(self, in_path):
  116. """ read file from path """
  117. raise NotImplementedError
  118. class _ImageReaderBackend(_BaseReaderBackend):
  119. """ _ImageReaderBackend """
  120. pass
  121. class OpenCVImageReaderBackend(_ImageReaderBackend):
  122. """ OpenCVImageReaderBackend """
  123. def __init__(self, flags=cv2.IMREAD_COLOR):
  124. super().__init__()
  125. self.flags = flags
  126. def read_file(self, in_path):
  127. """ read image file from path by OpenCV """
  128. return cv2.imread(in_path, flags=self.flags)
  129. class _VideoReaderBackend(_BaseReaderBackend):
  130. """ _VideoReaderBackend """
  131. def set_pos(self, pos):
  132. """ set pos """
  133. raise NotImplementedError
  134. def close(self):
  135. """ close io """
  136. raise NotImplementedError
  137. class OpenCVVideoReaderBackend(_VideoReaderBackend):
  138. """ OpenCVVideoReaderBackend """
  139. def __init__(self, **bk_args):
  140. super().__init__()
  141. self.cap_init_args = bk_args
  142. self._cap = None
  143. self._pos = 0
  144. self._max_num_frames = None
  145. def read_file(self, in_path):
  146. """ read vidio file from path """
  147. if self._cap is not None:
  148. self._cap_release()
  149. self._cap = self._cap_open(in_path)
  150. if self._pos is not None:
  151. self._cap_set_pos()
  152. return self._read_frames(self._cap)
  153. def _read_frames(self, cap):
  154. """ read frames """
  155. while True:
  156. ret, frame = cap.read()
  157. if not ret:
  158. break
  159. yield frame
  160. self._cap_release()
  161. def _cap_open(self, video_path):
  162. self._cap = cv2.VideoCapture(video_path, **self.cap_init_args)
  163. if not self._cap.isOpened():
  164. raise RuntimeError(f"Failed to open {video_path}")
  165. return self._cap
  166. def _cap_release(self):
  167. self._cap.release()
  168. def _cap_set_pos(self):
  169. self._cap.set(cv2.CAP_PROP_POS_FRAMES, self._pos)
  170. def set_pos(self, pos):
  171. self._pos = pos
  172. def close(self):
  173. if self._cap is not None:
  174. self._cap_release()
  175. self._cap = None