aboutsummaryrefslogtreecommitdiff
# mypy: ignore-errors
from __future__ import annotations

import io
import time
import wave
import asyncio
from typing import Any, Type, Union, Generic, TypeVar, Callable, overload
from typing_extensions import TYPE_CHECKING, Literal

from .._types import FileTypes, FileContent
from .._extras import numpy as np, sounddevice as sd

if TYPE_CHECKING:
    import numpy.typing as npt

SAMPLE_RATE = 24000

DType = TypeVar("DType", bound=np.generic)


class Microphone(Generic[DType]):
    def __init__(
        self,
        channels: int = 1,
        dtype: Type[DType] = np.int16,
        should_record: Union[Callable[[], bool], None] = None,
        timeout: Union[float, None] = None,
    ):
        self.channels = channels
        self.dtype = dtype
        self.should_record = should_record
        self.buffer_chunks = []
        self.timeout = timeout
        self.has_record_function = callable(should_record)

    def _ndarray_to_wav(self, audio_data: npt.NDArray[DType]) -> FileTypes:
        buffer: FileContent = io.BytesIO()
        with wave.open(buffer, "w") as wav_file:
            wav_file.setnchannels(self.channels)
            wav_file.setsampwidth(np.dtype(self.dtype).itemsize)
            wav_file.setframerate(SAMPLE_RATE)
            wav_file.writeframes(audio_data.tobytes())
        buffer.seek(0)
        return ("audio.wav", buffer, "audio/wav")

    @overload
    async def record(self, return_ndarray: Literal[True]) -> npt.NDArray[DType]: ...

    @overload
    async def record(self, return_ndarray: Literal[False]) -> FileTypes: ...

    @overload
    async def record(self, return_ndarray: None = ...) -> FileTypes: ...

    async def record(self, return_ndarray: Union[bool, None] = False) -> Union[npt.NDArray[DType], FileTypes]:
        loop = asyncio.get_event_loop()
        event = asyncio.Event()
        self.buffer_chunks: list[npt.NDArray[DType]] = []
        start_time = time.perf_counter()

        def callback(
            indata: npt.NDArray[DType],
            _frame_count: int,
            _time_info: Any,
            _status: Any,
        ):
            execution_time = time.perf_counter() - start_time
            reached_recording_timeout = execution_time > self.timeout if self.timeout is not None else False
            if reached_recording_timeout:
                loop.call_soon_threadsafe(event.set)
                raise sd.CallbackStop

            should_be_recording = self.should_record() if callable(self.should_record) else True
            if not should_be_recording:
                loop.call_soon_threadsafe(event.set)
                raise sd.CallbackStop

            self.buffer_chunks.append(indata.copy())

        stream = sd.InputStream(
            callback=callback,
            dtype=self.dtype,
            samplerate=SAMPLE_RATE,
            channels=self.channels,
        )
        with stream:
            await event.wait()

        # Concatenate all chunks into a single buffer, handle empty case
        concatenated_chunks: npt.NDArray[DType] = (
            np.concatenate(self.buffer_chunks, axis=0)
            if len(self.buffer_chunks) > 0
            else np.array([], dtype=self.dtype)
        )

        if return_ndarray:
            return concatenated_chunks
        else:
            return self._ndarray_to_wav(concatenated_chunks)