about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/anthropic/_streaming.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/anthropic/_streaming.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/anthropic/_streaming.py')
-rw-r--r--.venv/lib/python3.12/site-packages/anthropic/_streaming.py443
1 files changed, 443 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/anthropic/_streaming.py b/.venv/lib/python3.12/site-packages/anthropic/_streaming.py
new file mode 100644
index 00000000..d43e2e6a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/anthropic/_streaming.py
@@ -0,0 +1,443 @@
+# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
+from __future__ import annotations
+
+import abc
+import json
+import inspect
+import warnings
+from types import TracebackType
+from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
+from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
+
+import httpx
+
+from ._utils import is_dict, extract_type_var_from_base
+
+if TYPE_CHECKING:
+    from ._client import Anthropic, AsyncAnthropic
+
+
+_T = TypeVar("_T")
+
+
+class _SyncStreamMeta(abc.ABCMeta):
+    @override
+    def __instancecheck__(self, instance: Any) -> bool:
+        # we override the `isinstance()` check for `Stream`
+        # as a previous version of the `MessageStream` class
+        # inherited from `Stream` & without this workaround,
+        # changing it to not inherit would be a breaking change.
+
+        from .lib.streaming import MessageStream
+
+        if isinstance(instance, MessageStream):
+            warnings.warn(
+                "Using `isinstance()` to check if a `MessageStream` object is an instance of `Stream` is deprecated & will be removed in the next major version",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            return True
+
+        return False
+
+
+class Stream(Generic[_T], metaclass=_SyncStreamMeta):
+    """Provides the core interface to iterate over a synchronous stream response."""
+
+    response: httpx.Response
+
+    _decoder: SSEBytesDecoder
+
+    def __init__(
+        self,
+        *,
+        cast_to: type[_T],
+        response: httpx.Response,
+        client: Anthropic,
+    ) -> None:
+        self.response = response
+        self._cast_to = cast_to
+        self._client = client
+        self._decoder = client._make_sse_decoder()
+        self._iterator = self.__stream__()
+
+    def __next__(self) -> _T:
+        return self._iterator.__next__()
+
+    def __iter__(self) -> Iterator[_T]:
+        for item in self._iterator:
+            yield item
+
+    def _iter_events(self) -> Iterator[ServerSentEvent]:
+        yield from self._decoder.iter_bytes(self.response.iter_bytes())
+
+    def __stream__(self) -> Iterator[_T]:
+        cast_to = cast(Any, self._cast_to)
+        response = self.response
+        process_data = self._client._process_response_data
+        iterator = self._iter_events()
+
+        for sse in iterator:
+            if sse.event == "completion":
+                yield process_data(data=sse.json(), cast_to=cast_to, response=response)
+
+            if (
+                sse.event == "message_start"
+                or sse.event == "message_delta"
+                or sse.event == "message_stop"
+                or sse.event == "content_block_start"
+                or sse.event == "content_block_delta"
+                or sse.event == "content_block_stop"
+            ):
+                data = sse.json()
+                if is_dict(data) and "type" not in data:
+                    data["type"] = sse.event
+
+                yield process_data(data=data, cast_to=cast_to, response=response)
+
+            if sse.event == "ping":
+                continue
+
+            if sse.event == "error":
+                body = sse.data
+
+                try:
+                    body = sse.json()
+                    err_msg = f"{body}"
+                except Exception:
+                    err_msg = sse.data or f"Error code: {response.status_code}"
+
+                raise self._client._make_status_error(
+                    err_msg,
+                    body=body,
+                    response=self.response,
+                )
+
+        # Ensure the entire stream is consumed
+        for _sse in iterator:
+            ...
+
+    def __enter__(self) -> Self:
+        return self
+
+    def __exit__(
+        self,
+        exc_type: type[BaseException] | None,
+        exc: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> None:
+        self.close()
+
+    def close(self) -> None:
+        """
+        Close the response and release the connection.
+
+        Automatically called if the response body is read to completion.
+        """
+        self.response.close()
+
+
+class _AsyncStreamMeta(abc.ABCMeta):
+    @override
+    def __instancecheck__(self, instance: Any) -> bool:
+        # we override the `isinstance()` check for `AsyncStream`
+        # as a previous version of the `AsyncMessageStream` class
+        # inherited from `AsyncStream` & without this workaround,
+        # changing it to not inherit would be a breaking change.
+
+        from .lib.streaming import AsyncMessageStream
+
+        if isinstance(instance, AsyncMessageStream):
+            warnings.warn(
+                "Using `isinstance()` to check if a `AsyncMessageStream` object is an instance of `AsyncStream` is deprecated & will be removed in the next major version",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            return True
+
+        return False
+
+
+class AsyncStream(Generic[_T], metaclass=_AsyncStreamMeta):
+    """Provides the core interface to iterate over an asynchronous stream response."""
+
+    response: httpx.Response
+
+    _decoder: SSEDecoder | SSEBytesDecoder
+
+    def __init__(
+        self,
+        *,
+        cast_to: type[_T],
+        response: httpx.Response,
+        client: AsyncAnthropic,
+    ) -> None:
+        self.response = response
+        self._cast_to = cast_to
+        self._client = client
+        self._decoder = client._make_sse_decoder()
+        self._iterator = self.__stream__()
+
+    async def __anext__(self) -> _T:
+        return await self._iterator.__anext__()
+
+    async def __aiter__(self) -> AsyncIterator[_T]:
+        async for item in self._iterator:
+            yield item
+
+    async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
+        async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
+            yield sse
+
+    async def __stream__(self) -> AsyncIterator[_T]:
+        cast_to = cast(Any, self._cast_to)
+        response = self.response
+        process_data = self._client._process_response_data
+        iterator = self._iter_events()
+
+        async for sse in iterator:
+            if sse.event == "completion":
+                yield process_data(data=sse.json(), cast_to=cast_to, response=response)
+
+            if (
+                sse.event == "message_start"
+                or sse.event == "message_delta"
+                or sse.event == "message_stop"
+                or sse.event == "content_block_start"
+                or sse.event == "content_block_delta"
+                or sse.event == "content_block_stop"
+            ):
+                data = sse.json()
+                if is_dict(data) and "type" not in data:
+                    data["type"] = sse.event
+
+                yield process_data(data=data, cast_to=cast_to, response=response)
+
+            if sse.event == "ping":
+                continue
+
+            if sse.event == "error":
+                body = sse.data
+
+                try:
+                    body = sse.json()
+                    err_msg = f"{body}"
+                except Exception:
+                    err_msg = sse.data or f"Error code: {response.status_code}"
+
+                raise self._client._make_status_error(
+                    err_msg,
+                    body=body,
+                    response=self.response,
+                )
+
+        # Ensure the entire stream is consumed
+        async for _sse in iterator:
+            ...
+
+    async def __aenter__(self) -> Self:
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: type[BaseException] | None,
+        exc: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> None:
+        await self.close()
+
+    async def close(self) -> None:
+        """
+        Close the response and release the connection.
+
+        Automatically called if the response body is read to completion.
+        """
+        await self.response.aclose()
+
+
+class ServerSentEvent:
+    def __init__(
+        self,
+        *,
+        event: str | None = None,
+        data: str | None = None,
+        id: str | None = None,
+        retry: int | None = None,
+    ) -> None:
+        if data is None:
+            data = ""
+
+        self._id = id
+        self._data = data
+        self._event = event or None
+        self._retry = retry
+
+    @property
+    def event(self) -> str | None:
+        return self._event
+
+    @property
+    def id(self) -> str | None:
+        return self._id
+
+    @property
+    def retry(self) -> int | None:
+        return self._retry
+
+    @property
+    def data(self) -> str:
+        return self._data
+
+    def json(self) -> Any:
+        return json.loads(self.data)
+
+    @override
+    def __repr__(self) -> str:
+        return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
+
+
+class SSEDecoder:
+    _data: list[str]
+    _event: str | None
+    _retry: int | None
+    _last_event_id: str | None
+
+    def __init__(self) -> None:
+        self._event = None
+        self._data = []
+        self._last_event_id = None
+        self._retry = None
+
+    def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
+        """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
+        for chunk in self._iter_chunks(iterator):
+            # Split before decoding so splitlines() only uses \r and \n
+            for raw_line in chunk.splitlines():
+                line = raw_line.decode("utf-8")
+                sse = self.decode(line)
+                if sse:
+                    yield sse
+
+    def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
+        """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
+        data = b""
+        for chunk in iterator:
+            for line in chunk.splitlines(keepends=True):
+                data += line
+                if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
+                    yield data
+                    data = b""
+        if data:
+            yield data
+
+    async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
+        """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
+        async for chunk in self._aiter_chunks(iterator):
+            # Split before decoding so splitlines() only uses \r and \n
+            for raw_line in chunk.splitlines():
+                line = raw_line.decode("utf-8")
+                sse = self.decode(line)
+                if sse:
+                    yield sse
+
+    async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
+        """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
+        data = b""
+        async for chunk in iterator:
+            for line in chunk.splitlines(keepends=True):
+                data += line
+                if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
+                    yield data
+                    data = b""
+        if data:
+            yield data
+
+    def decode(self, line: str) -> ServerSentEvent | None:
+        # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation  # noqa: E501
+
+        if not line:
+            if not self._event and not self._data and not self._last_event_id and self._retry is None:
+                return None
+
+            sse = ServerSentEvent(
+                event=self._event,
+                data="\n".join(self._data),
+                id=self._last_event_id,
+                retry=self._retry,
+            )
+
+            # NOTE: as per the SSE spec, do not reset last_event_id.
+            self._event = None
+            self._data = []
+            self._retry = None
+
+            return sse
+
+        if line.startswith(":"):
+            return None
+
+        fieldname, _, value = line.partition(":")
+
+        if value.startswith(" "):
+            value = value[1:]
+
+        if fieldname == "event":
+            self._event = value
+        elif fieldname == "data":
+            self._data.append(value)
+        elif fieldname == "id":
+            if "\0" in value:
+                pass
+            else:
+                self._last_event_id = value
+        elif fieldname == "retry":
+            try:
+                self._retry = int(value)
+            except (TypeError, ValueError):
+                pass
+        else:
+            pass  # Field is ignored.
+
+        return None
+
+
+@runtime_checkable
+class SSEBytesDecoder(Protocol):
+    def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
+        """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
+        ...
+
+    def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
+        """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
+        ...
+
+
+def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
+    """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
+    origin = get_origin(typ) or typ
+    return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))
+
+
+def extract_stream_chunk_type(
+    stream_cls: type,
+    *,
+    failure_message: str | None = None,
+) -> type:
+    """Given a type like `Stream[T]`, returns the generic type variable `T`.
+
+    This also handles the case where a concrete subclass is given, e.g.
+    ```py
+    class MyStream(Stream[bytes]):
+        ...
+
+    extract_stream_chunk_type(MyStream) -> bytes
+    ```
+    """
+    from ._base_client import Stream, AsyncStream
+
+    return extract_type_var_from_base(
+        stream_cls,
+        index=0,
+        generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
+        failure_message=failure_message,
+    )