From 4a52a71956a8d46fcb7294ac71734504bb09bcc2 Mon Sep 17 00:00:00 2001 From: S. Solomon Darnell Date: Fri, 28 Mar 2025 21:52:21 -0500 Subject: two version of R2R are here --- .../litellm/llms/cohere/common_utils.py | 146 +++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 .venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py (limited to '.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py') diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py new file mode 100644 index 00000000..11ff73ef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py @@ -0,0 +1,146 @@ +import json +from typing import List, Optional + +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + GenericStreamingChunk, +) + + +class CohereError(BaseLLMException): + def __init__(self, status_code, message): + super().__init__(status_code=status_code, message=message) + + +def validate_environment( + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, +) -> dict: + """ + Return headers to use for cohere chat completion request + + Cohere API Ref: https://docs.cohere.com/reference/chat + Expected headers: + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + "Authorization": "bearer $CO_API_KEY" + } + """ + headers.update( + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + } + ) + if api_key: + headers["Authorization"] = f"bearer {api_key}" + return headers + + +class ModelResponseIterator: + def __init__( + self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False + ): + self.streaming_response = streaming_response + self.response_iterator = self.streaming_response + self.content_blocks: List = [] + self.tool_index = -1 + self.json_mode = json_mode + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + provider_specific_fields = None + + index = int(chunk.get("index", 0)) + + if "text" in chunk: + text = chunk["text"] + elif "is_finished" in chunk and chunk["is_finished"] is True: + is_finished = chunk["is_finished"] + finish_reason = chunk["finish_reason"] + + if "citations" in chunk: + provider_specific_fields = {"citations": chunk["citations"]} + + returned_chunk = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + provider_specific_fields=provider_specific_fields, + ) + + return returned_chunk + + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + try: + chunk = self.response_iterator.__next__() + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + data_json = json.loads(str_line) + return self.chunk_parser(chunk=data_json) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + + data_json = json.loads(str_line) + return self.chunk_parser(chunk=data_json) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") -- cgit v1.2.3