about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py146
1 files changed, 146 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py
new file mode 100644
index 00000000..11ff73ef
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py
@@ -0,0 +1,146 @@
+import json
+from typing import List, Optional
+
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import (
+    ChatCompletionToolCallChunk,
+    ChatCompletionUsageBlock,
+    GenericStreamingChunk,
+)
+
+
+class CohereError(BaseLLMException):
+    def __init__(self, status_code, message):
+        super().__init__(status_code=status_code, message=message)
+
+
+def validate_environment(
+    headers: dict,
+    model: str,
+    messages: List[AllMessageValues],
+    optional_params: dict,
+    api_key: Optional[str] = None,
+) -> dict:
+    """
+    Return headers to use for cohere chat completion request
+
+    Cohere API Ref: https://docs.cohere.com/reference/chat
+    Expected headers:
+    {
+        "Request-Source": "unspecified:litellm",
+        "accept": "application/json",
+        "content-type": "application/json",
+        "Authorization": "bearer $CO_API_KEY"
+    }
+    """
+    headers.update(
+        {
+            "Request-Source": "unspecified:litellm",
+            "accept": "application/json",
+            "content-type": "application/json",
+        }
+    )
+    if api_key:
+        headers["Authorization"] = f"bearer {api_key}"
+    return headers
+
+
+class ModelResponseIterator:
+    def __init__(
+        self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
+    ):
+        self.streaming_response = streaming_response
+        self.response_iterator = self.streaming_response
+        self.content_blocks: List = []
+        self.tool_index = -1
+        self.json_mode = json_mode
+
+    def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
+        try:
+            text = ""
+            tool_use: Optional[ChatCompletionToolCallChunk] = None
+            is_finished = False
+            finish_reason = ""
+            usage: Optional[ChatCompletionUsageBlock] = None
+            provider_specific_fields = None
+
+            index = int(chunk.get("index", 0))
+
+            if "text" in chunk:
+                text = chunk["text"]
+            elif "is_finished" in chunk and chunk["is_finished"] is True:
+                is_finished = chunk["is_finished"]
+                finish_reason = chunk["finish_reason"]
+
+            if "citations" in chunk:
+                provider_specific_fields = {"citations": chunk["citations"]}
+
+            returned_chunk = GenericStreamingChunk(
+                text=text,
+                tool_use=tool_use,
+                is_finished=is_finished,
+                finish_reason=finish_reason,
+                usage=usage,
+                index=index,
+                provider_specific_fields=provider_specific_fields,
+            )
+
+            return returned_chunk
+
+        except json.JSONDecodeError:
+            raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
+
+    # Sync iterator
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        try:
+            chunk = self.response_iterator.__next__()
+        except StopIteration:
+            raise StopIteration
+        except ValueError as e:
+            raise RuntimeError(f"Error receiving chunk from stream: {e}")
+
+        try:
+            str_line = chunk
+            if isinstance(chunk, bytes):  # Handle binary data
+                str_line = chunk.decode("utf-8")  # Convert bytes to string
+                index = str_line.find("data:")
+                if index != -1:
+                    str_line = str_line[index:]
+            data_json = json.loads(str_line)
+            return self.chunk_parser(chunk=data_json)
+        except StopIteration:
+            raise StopIteration
+        except ValueError as e:
+            raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
+
+    # Async iterator
+    def __aiter__(self):
+        self.async_response_iterator = self.streaming_response.__aiter__()
+        return self
+
+    async def __anext__(self):
+        try:
+            chunk = await self.async_response_iterator.__anext__()
+        except StopAsyncIteration:
+            raise StopAsyncIteration
+        except ValueError as e:
+            raise RuntimeError(f"Error receiving chunk from stream: {e}")
+
+        try:
+            str_line = chunk
+            if isinstance(chunk, bytes):  # Handle binary data
+                str_line = chunk.decode("utf-8")  # Convert bytes to string
+                index = str_line.find("data:")
+                if index != -1:
+                    str_line = str_line[index:]
+
+            data_json = json.loads(str_line)
+            return self.chunk_parser(chunk=data_json)
+        except StopAsyncIteration:
+            raise StopAsyncIteration
+        except ValueError as e:
+            raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")