about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py207
1 files changed, 207 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py
new file mode 100644
index 00000000..9884f420
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py
@@ -0,0 +1,207 @@
+import json
+from typing import AsyncIterator, Iterator, List, Optional, Union
+
+import httpx
+
+import litellm
+from litellm import verbose_logger
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.types.utils import GenericStreamingChunk as GChunk
+from litellm.types.utils import StreamingChatCompletionChunk
+
+_response_stream_shape_cache = None
+
+
+class SagemakerError(BaseLLMException):
+    def __init__(
+        self,
+        status_code: int,
+        message: str,
+        headers: Optional[Union[dict, httpx.Headers]] = None,
+    ):
+        super().__init__(status_code=status_code, message=message, headers=headers)
+
+
+class AWSEventStreamDecoder:
+    def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
+        from botocore.parsers import EventStreamJSONParser
+
+        self.model = model
+        self.parser = EventStreamJSONParser()
+        self.content_blocks: List = []
+        self.is_messages_api = is_messages_api
+
+    def _chunk_parser_messages_api(
+        self, chunk_data: dict
+    ) -> StreamingChatCompletionChunk:
+
+        openai_chunk = StreamingChatCompletionChunk(**chunk_data)
+
+        return openai_chunk
+
+    def _chunk_parser(self, chunk_data: dict) -> GChunk:
+        verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
+        _token = chunk_data.get("token", {}) or {}
+        _index = chunk_data.get("index", None) or 0
+        is_finished = False
+        finish_reason = ""
+
+        _text = _token.get("text", "")
+        if _text == "<|endoftext|>":
+            return GChunk(
+                text="",
+                index=_index,
+                is_finished=True,
+                finish_reason="stop",
+                usage=None,
+            )
+
+        return GChunk(
+            text=_text,
+            index=_index,
+            is_finished=is_finished,
+            finish_reason=finish_reason,
+            usage=None,
+        )
+
+    def iter_bytes(
+        self, iterator: Iterator[bytes]
+    ) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
+        """Given an iterator that yields lines, iterate over it & yield every event encountered"""
+        from botocore.eventstream import EventStreamBuffer
+
+        event_stream_buffer = EventStreamBuffer()
+        accumulated_json = ""
+
+        for chunk in iterator:
+            event_stream_buffer.add_data(chunk)
+            for event in event_stream_buffer:
+                message = self._parse_message_from_event(event)
+                if message:
+                    # remove data: prefix and "\n\n" at the end
+                    message = (
+                        litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message)
+                        or ""
+                    )
+                    message = message.replace("\n\n", "")
+
+                    # Accumulate JSON data
+                    accumulated_json += message
+
+                    # Try to parse the accumulated JSON
+                    try:
+                        _data = json.loads(accumulated_json)
+                        if self.is_messages_api:
+                            yield self._chunk_parser_messages_api(chunk_data=_data)
+                        else:
+                            yield self._chunk_parser(chunk_data=_data)
+                        # Reset accumulated_json after successful parsing
+                        accumulated_json = ""
+                    except json.JSONDecodeError:
+                        # If it's not valid JSON yet, continue to the next event
+                        continue
+
+        # Handle any remaining data after the iterator is exhausted
+        if accumulated_json:
+            try:
+                _data = json.loads(accumulated_json)
+                if self.is_messages_api:
+                    yield self._chunk_parser_messages_api(chunk_data=_data)
+                else:
+                    yield self._chunk_parser(chunk_data=_data)
+            except json.JSONDecodeError:
+                # Handle or log any unparseable data at the end
+                verbose_logger.error(
+                    f"Warning: Unparseable JSON data remained: {accumulated_json}"
+                )
+                yield None
+
+    async def aiter_bytes(
+        self, iterator: AsyncIterator[bytes]
+    ) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
+        """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
+        from botocore.eventstream import EventStreamBuffer
+
+        event_stream_buffer = EventStreamBuffer()
+        accumulated_json = ""
+
+        async for chunk in iterator:
+            event_stream_buffer.add_data(chunk)
+            for event in event_stream_buffer:
+                message = self._parse_message_from_event(event)
+                if message:
+                    verbose_logger.debug("sagemaker  parsed chunk bytes %s", message)
+                    # remove data: prefix and "\n\n" at the end
+                    message = (
+                        litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message)
+                        or ""
+                    )
+                    message = message.replace("\n\n", "")
+
+                    # Accumulate JSON data
+                    accumulated_json += message
+
+                    # Try to parse the accumulated JSON
+                    try:
+                        _data = json.loads(accumulated_json)
+                        if self.is_messages_api:
+                            yield self._chunk_parser_messages_api(chunk_data=_data)
+                        else:
+                            yield self._chunk_parser(chunk_data=_data)
+                        # Reset accumulated_json after successful parsing
+                        accumulated_json = ""
+                    except json.JSONDecodeError:
+                        # If it's not valid JSON yet, continue to the next event
+                        continue
+
+        # Handle any remaining data after the iterator is exhausted
+        if accumulated_json:
+            try:
+                _data = json.loads(accumulated_json)
+                if self.is_messages_api:
+                    yield self._chunk_parser_messages_api(chunk_data=_data)
+                else:
+                    yield self._chunk_parser(chunk_data=_data)
+            except json.JSONDecodeError:
+                # Handle or log any unparseable data at the end
+                verbose_logger.error(
+                    f"Warning: Unparseable JSON data remained: {accumulated_json}"
+                )
+                yield None
+
+    def _parse_message_from_event(self, event) -> Optional[str]:
+        response_dict = event.to_response_dict()
+        parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
+
+        if response_dict["status_code"] != 200:
+            raise ValueError(f"Bad response code, expected 200: {response_dict}")
+
+        if "chunk" in parsed_response:
+            chunk = parsed_response.get("chunk")
+            if not chunk:
+                return None
+            return chunk.get("bytes").decode()  # type: ignore[no-any-return]
+        else:
+            chunk = response_dict.get("body")
+            if not chunk:
+                return None
+
+            return chunk.decode()  # type: ignore[no-any-return]
+
+
+def get_response_stream_shape():
+    global _response_stream_shape_cache
+    if _response_stream_shape_cache is None:
+
+        from botocore.loaders import Loader
+        from botocore.model import ServiceModel
+
+        loader = Loader()
+        sagemaker_service_dict = loader.load_service_model(
+            "sagemaker-runtime", "service-2"
+        )
+        sagemaker_service_model = ServiceModel(sagemaker_service_dict)
+        _response_stream_shape_cache = sagemaker_service_model.shape_for(
+            "InvokeEndpointWithResponseStreamOutput"
+        )
+    return _response_stream_shape_cache