microphone.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # mypy: ignore-errors
  2. from __future__ import annotations
  3. import io
  4. import time
  5. import wave
  6. import asyncio
  7. from typing import Any, Type, Union, Generic, TypeVar, Callable, overload
  8. from typing_extensions import TYPE_CHECKING, Literal
  9. from .._types import FileTypes, FileContent
  10. from .._extras import numpy as np, sounddevice as sd
  11. if TYPE_CHECKING:
  12. import numpy.typing as npt
  13. SAMPLE_RATE = 24000
  14. DType = TypeVar("DType", bound=np.generic)
  15. class Microphone(Generic[DType]):
  16. def __init__(
  17. self,
  18. channels: int = 1,
  19. dtype: Type[DType] = np.int16,
  20. should_record: Union[Callable[[], bool], None] = None,
  21. timeout: Union[float, None] = None,
  22. ):
  23. self.channels = channels
  24. self.dtype = dtype
  25. self.should_record = should_record
  26. self.buffer_chunks = []
  27. self.timeout = timeout
  28. self.has_record_function = callable(should_record)
  29. def _ndarray_to_wav(self, audio_data: npt.NDArray[DType]) -> FileTypes:
  30. buffer: FileContent = io.BytesIO()
  31. with wave.open(buffer, "w") as wav_file:
  32. wav_file.setnchannels(self.channels)
  33. wav_file.setsampwidth(np.dtype(self.dtype).itemsize)
  34. wav_file.setframerate(SAMPLE_RATE)
  35. wav_file.writeframes(audio_data.tobytes())
  36. buffer.seek(0)
  37. return ("audio.wav", buffer, "audio/wav")
  38. @overload
  39. async def record(self, return_ndarray: Literal[True]) -> npt.NDArray[DType]: ...
  40. @overload
  41. async def record(self, return_ndarray: Literal[False]) -> FileTypes: ...
  42. @overload
  43. async def record(self, return_ndarray: None = ...) -> FileTypes: ...
  44. async def record(self, return_ndarray: Union[bool, None] = False) -> Union[npt.NDArray[DType], FileTypes]:
  45. loop = asyncio.get_event_loop()
  46. event = asyncio.Event()
  47. self.buffer_chunks: list[npt.NDArray[DType]] = []
  48. start_time = time.perf_counter()
  49. def callback(
  50. indata: npt.NDArray[DType],
  51. _frame_count: int,
  52. _time_info: Any,
  53. _status: Any,
  54. ):
  55. execution_time = time.perf_counter() - start_time
  56. reached_recording_timeout = execution_time > self.timeout if self.timeout is not None else False
  57. if reached_recording_timeout:
  58. loop.call_soon_threadsafe(event.set)
  59. raise sd.CallbackStop
  60. should_be_recording = self.should_record() if callable(self.should_record) else True
  61. if not should_be_recording:
  62. loop.call_soon_threadsafe(event.set)
  63. raise sd.CallbackStop
  64. self.buffer_chunks.append(indata.copy())
  65. stream = sd.InputStream(
  66. callback=callback,
  67. dtype=self.dtype,
  68. samplerate=SAMPLE_RATE,
  69. channels=self.channels,
  70. )
  71. with stream:
  72. await event.wait()
  73. # Concatenate all chunks into a single buffer, handle empty case
  74. concatenated_chunks: npt.NDArray[DType] = (
  75. np.concatenate(self.buffer_chunks, axis=0)
  76. if len(self.buffer_chunks) > 0
  77. else np.array([], dtype=self.dtype)
  78. )
  79. if return_ndarray:
  80. return concatenated_chunks
  81. else:
  82. return self._ndarray_to_wav(concatenated_chunks)