aboutsummaryrefslogtreecommitdiff
# mypy: ignore-errors
from __future__ import annotations

import queue
import asyncio
from typing import Any, Union, Callable, AsyncGenerator, cast
from typing_extensions import TYPE_CHECKING

from .. import _legacy_response
from .._extras import numpy as np, sounddevice as sd
from .._response import StreamedBinaryAPIResponse, AsyncStreamedBinaryAPIResponse

if TYPE_CHECKING:
    import numpy.typing as npt

SAMPLE_RATE = 24000


class LocalAudioPlayer:
    def __init__(
        self,
        should_stop: Union[Callable[[], bool], None] = None,
    ):
        self.channels = 1
        self.dtype = np.float32
        self.should_stop = should_stop

    async def _tts_response_to_buffer(
        self,
        response: Union[
            _legacy_response.HttpxBinaryResponseContent,
            AsyncStreamedBinaryAPIResponse,
            StreamedBinaryAPIResponse,
        ],
    ) -> npt.NDArray[np.float32]:
        chunks: list[bytes] = []
        if isinstance(response, _legacy_response.HttpxBinaryResponseContent) or isinstance(
            response, StreamedBinaryAPIResponse
        ):
            for chunk in response.iter_bytes(chunk_size=1024):
                if chunk:
                    chunks.append(chunk)
        else:
            async for chunk in response.iter_bytes(chunk_size=1024):
                if chunk:
                    chunks.append(chunk)

        audio_bytes = b"".join(chunks)
        audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0
        audio_np = audio_np.reshape(-1, 1)
        return audio_np

    async def play(
        self,
        input: Union[
            npt.NDArray[np.int16],
            npt.NDArray[np.float32],
            _legacy_response.HttpxBinaryResponseContent,
            AsyncStreamedBinaryAPIResponse,
            StreamedBinaryAPIResponse,
        ],
    ) -> None:
        audio_content: npt.NDArray[np.float32]
        if isinstance(input, np.ndarray):
            if input.dtype == np.int16 and self.dtype == np.float32:
                audio_content = (input.astype(np.float32) / 32767.0).reshape(-1, self.channels)
            elif input.dtype == np.float32:
                audio_content = cast('npt.NDArray[np.float32]', input)
            else:
                raise ValueError(f"Unsupported dtype: {input.dtype}")
        else:
            audio_content = await self._tts_response_to_buffer(input)

        loop = asyncio.get_event_loop()
        event = asyncio.Event()
        idx = 0

        def callback(
            outdata: npt.NDArray[np.float32],
            frame_count: int,
            _time_info: Any,
            _status: Any,
        ):
            nonlocal idx

            remainder = len(audio_content) - idx
            if remainder == 0 or (callable(self.should_stop) and self.should_stop()):
                loop.call_soon_threadsafe(event.set)
                raise sd.CallbackStop
            valid_frames = frame_count if remainder >= frame_count else remainder
            outdata[:valid_frames] = audio_content[idx : idx + valid_frames]
            outdata[valid_frames:] = 0
            idx += valid_frames

        stream = sd.OutputStream(
            samplerate=SAMPLE_RATE,
            callback=callback,
            dtype=audio_content.dtype,
            channels=audio_content.shape[1],
        )
        with stream:
            await event.wait()

    async def play_stream(
        self,
        buffer_stream: AsyncGenerator[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None], None],
    ) -> None:
        loop = asyncio.get_event_loop()
        event = asyncio.Event()
        buffer_queue: queue.Queue[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None]] = queue.Queue(maxsize=50)

        async def buffer_producer():
            async for buffer in buffer_stream:
                if buffer is None:
                    break
                await loop.run_in_executor(None, buffer_queue.put, buffer)
            await loop.run_in_executor(None, buffer_queue.put, None)  # Signal completion

        def callback(
            outdata: npt.NDArray[np.float32],
            frame_count: int,
            _time_info: Any,
            _status: Any,
        ):
            nonlocal current_buffer, buffer_pos

            frames_written = 0
            while frames_written < frame_count:
                if current_buffer is None or buffer_pos >= len(current_buffer):
                    try:
                        current_buffer = buffer_queue.get(timeout=0.1)
                        if current_buffer is None:
                            loop.call_soon_threadsafe(event.set)
                            raise sd.CallbackStop
                        buffer_pos = 0

                        if current_buffer.dtype == np.int16 and self.dtype == np.float32:
                            current_buffer = (current_buffer.astype(np.float32) / 32767.0).reshape(-1, self.channels)

                    except queue.Empty:
                        outdata[frames_written:] = 0
                        return

                remaining_frames = len(current_buffer) - buffer_pos
                frames_to_write = min(frame_count - frames_written, remaining_frames)
                outdata[frames_written : frames_written + frames_to_write] = current_buffer[
                    buffer_pos : buffer_pos + frames_to_write
                ]
                buffer_pos += frames_to_write
                frames_written += frames_to_write

        current_buffer = None
        buffer_pos = 0

        producer_task = asyncio.create_task(buffer_producer())

        with sd.OutputStream(
            samplerate=SAMPLE_RATE,
            channels=self.channels,
            dtype=self.dtype,
            callback=callback,
        ):
            await event.wait()

        await producer_task