diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py | 207 |
1 files changed, 207 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py new file mode 100644 index 00000000..9884f420 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/common_utils.py @@ -0,0 +1,207 @@ +import json +from typing import AsyncIterator, Iterator, List, Optional, Union + +import httpx + +import litellm +from litellm import verbose_logger +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.types.utils import GenericStreamingChunk as GChunk +from litellm.types.utils import StreamingChatCompletionChunk + +_response_stream_shape_cache = None + + +class SagemakerError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + headers: Optional[Union[dict, httpx.Headers]] = None, + ): + super().__init__(status_code=status_code, message=message, headers=headers) + + +class AWSEventStreamDecoder: + def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None: + from botocore.parsers import EventStreamJSONParser + + self.model = model + self.parser = EventStreamJSONParser() + self.content_blocks: List = [] + self.is_messages_api = is_messages_api + + def _chunk_parser_messages_api( + self, chunk_data: dict + ) -> StreamingChatCompletionChunk: + + openai_chunk = StreamingChatCompletionChunk(**chunk_data) + + return openai_chunk + + def _chunk_parser(self, chunk_data: dict) -> GChunk: + verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data) + _token = chunk_data.get("token", {}) or {} + _index = chunk_data.get("index", None) or 0 + is_finished = False + finish_reason = "" + + _text = _token.get("text", "") + if _text == "<|endoftext|>": + return GChunk( + text="", + index=_index, + is_finished=True, + finish_reason="stop", + usage=None, + ) + + return GChunk( + text=_text, + index=_index, + is_finished=is_finished, + finish_reason=finish_reason, + usage=None, + ) + + def iter_bytes( + self, iterator: Iterator[bytes] + ) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: + """Given an iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + accumulated_json = "" + + for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + # remove data: prefix and "\n\n" at the end + message = ( + litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message) + or "" + ) + message = message.replace("\n\n", "") + + # Accumulate JSON data + accumulated_json += message + + # Try to parse the accumulated JSON + try: + _data = json.loads(accumulated_json) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) + # Reset accumulated_json after successful parsing + accumulated_json = "" + except json.JSONDecodeError: + # If it's not valid JSON yet, continue to the next event + continue + + # Handle any remaining data after the iterator is exhausted + if accumulated_json: + try: + _data = json.loads(accumulated_json) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) + except json.JSONDecodeError: + # Handle or log any unparseable data at the end + verbose_logger.error( + f"Warning: Unparseable JSON data remained: {accumulated_json}" + ) + yield None + + async def aiter_bytes( + self, iterator: AsyncIterator[bytes] + ) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: + """Given an async iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + accumulated_json = "" + + async for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + verbose_logger.debug("sagemaker parsed chunk bytes %s", message) + # remove data: prefix and "\n\n" at the end + message = ( + litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message) + or "" + ) + message = message.replace("\n\n", "") + + # Accumulate JSON data + accumulated_json += message + + # Try to parse the accumulated JSON + try: + _data = json.loads(accumulated_json) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) + # Reset accumulated_json after successful parsing + accumulated_json = "" + except json.JSONDecodeError: + # If it's not valid JSON yet, continue to the next event + continue + + # Handle any remaining data after the iterator is exhausted + if accumulated_json: + try: + _data = json.loads(accumulated_json) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) + except json.JSONDecodeError: + # Handle or log any unparseable data at the end + verbose_logger.error( + f"Warning: Unparseable JSON data remained: {accumulated_json}" + ) + yield None + + def _parse_message_from_event(self, event) -> Optional[str]: + response_dict = event.to_response_dict() + parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) + + if response_dict["status_code"] != 200: + raise ValueError(f"Bad response code, expected 200: {response_dict}") + + if "chunk" in parsed_response: + chunk = parsed_response.get("chunk") + if not chunk: + return None + return chunk.get("bytes").decode() # type: ignore[no-any-return] + else: + chunk = response_dict.get("body") + if not chunk: + return None + + return chunk.decode() # type: ignore[no-any-return] + + +def get_response_stream_shape(): + global _response_stream_shape_cache + if _response_stream_shape_cache is None: + + from botocore.loaders import Loader + from botocore.model import ServiceModel + + loader = Loader() + sagemaker_service_dict = loader.load_service_model( + "sagemaker-runtime", "service-2" + ) + sagemaker_service_model = ServiceModel(sagemaker_service_dict) + _response_stream_shape_cache = sagemaker_service_model.shape_for( + "InvokeEndpointWithResponseStreamOutput" + ) + return _response_stream_shape_cache |