about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/cohere
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/cohere')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py368
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py146
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/handler.py5
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/transformation.py264
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/handler.py178
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/transformation.py153
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/handler.py5
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py151
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank_v2/transformation.py80
9 files changed, 1350 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py
new file mode 100644
index 00000000..3ceec2db
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py
@@ -0,0 +1,368 @@
+import json
+import time
+from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
+from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse, Usage
+
+from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
+from ..common_utils import validate_environment as cohere_validate_environment
+
+if TYPE_CHECKING:
+    from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+    LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+    LiteLLMLoggingObj = Any
+
+
+class CohereError(BaseLLMException):
+    def __init__(
+        self,
+        status_code: int,
+        message: str,
+        headers: Optional[httpx.Headers] = None,
+    ):
+        self.status_code = status_code
+        self.message = message
+        self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat")
+        self.response = httpx.Response(status_code=status_code, request=self.request)
+        super().__init__(
+            status_code=status_code,
+            message=message,
+            headers=headers,
+        )
+
+
+class CohereChatConfig(BaseConfig):
+    """
+    Configuration class for Cohere's API interface.
+
+    Args:
+        preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
+        chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
+        generation_id (str, optional): Unique identifier for the generated reply.
+        response_id (str, optional): Unique identifier for the response.
+        conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
+        prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
+        connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
+        search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
+        documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
+        temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
+        max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
+        k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
+        p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
+        frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
+        presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
+        tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
+        tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
+        seed (int, optional): A seed to assist reproducibility of the model's response.
+    """
+
+    preamble: Optional[str] = None
+    chat_history: Optional[list] = None
+    generation_id: Optional[str] = None
+    response_id: Optional[str] = None
+    conversation_id: Optional[str] = None
+    prompt_truncation: Optional[str] = None
+    connectors: Optional[list] = None
+    search_queries_only: Optional[bool] = None
+    documents: Optional[list] = None
+    temperature: Optional[int] = None
+    max_tokens: Optional[int] = None
+    k: Optional[int] = None
+    p: Optional[int] = None
+    frequency_penalty: Optional[int] = None
+    presence_penalty: Optional[int] = None
+    tools: Optional[list] = None
+    tool_results: Optional[list] = None
+    seed: Optional[int] = None
+
+    def __init__(
+        self,
+        preamble: Optional[str] = None,
+        chat_history: Optional[list] = None,
+        generation_id: Optional[str] = None,
+        response_id: Optional[str] = None,
+        conversation_id: Optional[str] = None,
+        prompt_truncation: Optional[str] = None,
+        connectors: Optional[list] = None,
+        search_queries_only: Optional[bool] = None,
+        documents: Optional[list] = None,
+        temperature: Optional[int] = None,
+        max_tokens: Optional[int] = None,
+        k: Optional[int] = None,
+        p: Optional[int] = None,
+        frequency_penalty: Optional[int] = None,
+        presence_penalty: Optional[int] = None,
+        tools: Optional[list] = None,
+        tool_results: Optional[list] = None,
+        seed: Optional[int] = None,
+    ) -> None:
+        locals_ = locals().copy()
+        for key, value in locals_.items():
+            if key != "self" and value is not None:
+                setattr(self.__class__, key, value)
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> dict:
+        return cohere_validate_environment(
+            headers=headers,
+            model=model,
+            messages=messages,
+            optional_params=optional_params,
+            api_key=api_key,
+        )
+
+    def get_supported_openai_params(self, model: str) -> List[str]:
+        return [
+            "stream",
+            "temperature",
+            "max_tokens",
+            "top_p",
+            "frequency_penalty",
+            "presence_penalty",
+            "stop",
+            "n",
+            "tools",
+            "tool_choice",
+            "seed",
+            "extra_headers",
+        ]
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        for param, value in non_default_params.items():
+            if param == "stream":
+                optional_params["stream"] = value
+            if param == "temperature":
+                optional_params["temperature"] = value
+            if param == "max_tokens":
+                optional_params["max_tokens"] = value
+            if param == "n":
+                optional_params["num_generations"] = value
+            if param == "top_p":
+                optional_params["p"] = value
+            if param == "frequency_penalty":
+                optional_params["frequency_penalty"] = value
+            if param == "presence_penalty":
+                optional_params["presence_penalty"] = value
+            if param == "stop":
+                optional_params["stop_sequences"] = value
+            if param == "tools":
+                optional_params["tools"] = value
+            if param == "seed":
+                optional_params["seed"] = value
+        return optional_params
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+
+        ## Load Config
+        for k, v in litellm.CohereChatConfig.get_config().items():
+            if (
+                k not in optional_params
+            ):  # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
+                optional_params[k] = v
+
+        most_recent_message, chat_history = cohere_messages_pt_v2(
+            messages=messages, model=model, llm_provider="cohere_chat"
+        )
+
+        ## Handle Tool Calling
+        if "tools" in optional_params:
+            _is_function_call = True
+            cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
+            optional_params["tools"] = cohere_tools
+        if isinstance(most_recent_message, dict):
+            optional_params["tool_results"] = [most_recent_message]
+        elif isinstance(most_recent_message, str):
+            optional_params["message"] = most_recent_message
+
+        ## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
+        if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
+            optional_params["force_single_step"] = True
+
+        return optional_params
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: dict,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+
+        try:
+            raw_response_json = raw_response.json()
+            model_response.choices[0].message.content = raw_response_json["text"]  # type: ignore
+        except Exception:
+            raise CohereError(
+                message=raw_response.text, status_code=raw_response.status_code
+            )
+
+        ## ADD CITATIONS
+        if "citations" in raw_response_json:
+            setattr(model_response, "citations", raw_response_json["citations"])
+
+        ## Tool calling response
+        cohere_tools_response = raw_response_json.get("tool_calls", None)
+        if cohere_tools_response is not None and cohere_tools_response != []:
+            # convert cohere_tools_response to OpenAI response format
+            tool_calls = []
+            for tool in cohere_tools_response:
+                function_name = tool.get("name", "")
+                generation_id = tool.get("generation_id", "")
+                parameters = tool.get("parameters", {})
+                tool_call = {
+                    "id": f"call_{generation_id}",
+                    "type": "function",
+                    "function": {
+                        "name": function_name,
+                        "arguments": json.dumps(parameters),
+                    },
+                }
+                tool_calls.append(tool_call)
+            _message = litellm.Message(
+                tool_calls=tool_calls,
+                content=None,
+            )
+            model_response.choices[0].message = _message  # type: ignore
+
+        ## CALCULATING USAGE - use cohere `billed_units` for returning usage
+        billed_units = raw_response_json.get("meta", {}).get("billed_units", {})
+
+        prompt_tokens = billed_units.get("input_tokens", 0)
+        completion_tokens = billed_units.get("output_tokens", 0)
+
+        model_response.created = int(time.time())
+        model_response.model = model
+        usage = Usage(
+            prompt_tokens=prompt_tokens,
+            completion_tokens=completion_tokens,
+            total_tokens=prompt_tokens + completion_tokens,
+        )
+        setattr(model_response, "usage", usage)
+        return model_response
+
+    def _construct_cohere_tool(
+        self,
+        tools: Optional[list] = None,
+    ):
+        if tools is None:
+            tools = []
+        cohere_tools = []
+        for tool in tools:
+            cohere_tool = self._translate_openai_tool_to_cohere(tool)
+            cohere_tools.append(cohere_tool)
+        return cohere_tools
+
+    def _translate_openai_tool_to_cohere(
+        self,
+        openai_tool: dict,
+    ):
+        # cohere tools look like this
+        """
+        {
+        "name": "query_daily_sales_report",
+        "description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
+        "parameter_definitions": {
+            "day": {
+                "description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
+                "type": "str",
+                "required": True
+            }
+        }
+        }
+        """
+
+        # OpenAI tools look like this
+        """
+        {
+            "type": "function",
+            "function": {
+                "name": "get_current_weather",
+                "description": "Get the current weather in a given location",
+                "parameters": {
+                    "type": "object",
+                    "properties": {
+                        "location": {
+                            "type": "string",
+                            "description": "The city and state, e.g. San Francisco, CA",
+                        },
+                        "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
+                    },
+                    "required": ["location"],
+                },
+            },
+        }
+        """
+        cohere_tool = {
+            "name": openai_tool["function"]["name"],
+            "description": openai_tool["function"]["description"],
+            "parameter_definitions": {},
+        }
+
+        for param_name, param_def in openai_tool["function"]["parameters"][
+            "properties"
+        ].items():
+            required_params = (
+                openai_tool.get("function", {})
+                .get("parameters", {})
+                .get("required", [])
+            )
+            cohere_param_def = {
+                "description": param_def.get("description", ""),
+                "type": param_def.get("type", ""),
+                "required": param_name in required_params,
+            }
+            cohere_tool["parameter_definitions"][param_name] = cohere_param_def
+
+        return cohere_tool
+
+    def get_model_response_iterator(
+        self,
+        streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
+        sync_stream: bool,
+        json_mode: Optional[bool] = False,
+    ):
+        return CohereModelResponseIterator(
+            streaming_response=streaming_response,
+            sync_stream=sync_stream,
+            json_mode=json_mode,
+        )
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+    ) -> BaseLLMException:
+        return CohereError(status_code=status_code, message=error_message)
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py
new file mode 100644
index 00000000..11ff73ef
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py
@@ -0,0 +1,146 @@
+import json
+from typing import List, Optional
+
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import (
+    ChatCompletionToolCallChunk,
+    ChatCompletionUsageBlock,
+    GenericStreamingChunk,
+)
+
+
+class CohereError(BaseLLMException):
+    def __init__(self, status_code, message):
+        super().__init__(status_code=status_code, message=message)
+
+
+def validate_environment(
+    headers: dict,
+    model: str,
+    messages: List[AllMessageValues],
+    optional_params: dict,
+    api_key: Optional[str] = None,
+) -> dict:
+    """
+    Return headers to use for cohere chat completion request
+
+    Cohere API Ref: https://docs.cohere.com/reference/chat
+    Expected headers:
+    {
+        "Request-Source": "unspecified:litellm",
+        "accept": "application/json",
+        "content-type": "application/json",
+        "Authorization": "bearer $CO_API_KEY"
+    }
+    """
+    headers.update(
+        {
+            "Request-Source": "unspecified:litellm",
+            "accept": "application/json",
+            "content-type": "application/json",
+        }
+    )
+    if api_key:
+        headers["Authorization"] = f"bearer {api_key}"
+    return headers
+
+
+class ModelResponseIterator:
+    def __init__(
+        self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
+    ):
+        self.streaming_response = streaming_response
+        self.response_iterator = self.streaming_response
+        self.content_blocks: List = []
+        self.tool_index = -1
+        self.json_mode = json_mode
+
+    def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
+        try:
+            text = ""
+            tool_use: Optional[ChatCompletionToolCallChunk] = None
+            is_finished = False
+            finish_reason = ""
+            usage: Optional[ChatCompletionUsageBlock] = None
+            provider_specific_fields = None
+
+            index = int(chunk.get("index", 0))
+
+            if "text" in chunk:
+                text = chunk["text"]
+            elif "is_finished" in chunk and chunk["is_finished"] is True:
+                is_finished = chunk["is_finished"]
+                finish_reason = chunk["finish_reason"]
+
+            if "citations" in chunk:
+                provider_specific_fields = {"citations": chunk["citations"]}
+
+            returned_chunk = GenericStreamingChunk(
+                text=text,
+                tool_use=tool_use,
+                is_finished=is_finished,
+                finish_reason=finish_reason,
+                usage=usage,
+                index=index,
+                provider_specific_fields=provider_specific_fields,
+            )
+
+            return returned_chunk
+
+        except json.JSONDecodeError:
+            raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
+
+    # Sync iterator
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        try:
+            chunk = self.response_iterator.__next__()
+        except StopIteration:
+            raise StopIteration
+        except ValueError as e:
+            raise RuntimeError(f"Error receiving chunk from stream: {e}")
+
+        try:
+            str_line = chunk
+            if isinstance(chunk, bytes):  # Handle binary data
+                str_line = chunk.decode("utf-8")  # Convert bytes to string
+                index = str_line.find("data:")
+                if index != -1:
+                    str_line = str_line[index:]
+            data_json = json.loads(str_line)
+            return self.chunk_parser(chunk=data_json)
+        except StopIteration:
+            raise StopIteration
+        except ValueError as e:
+            raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
+
+    # Async iterator
+    def __aiter__(self):
+        self.async_response_iterator = self.streaming_response.__aiter__()
+        return self
+
+    async def __anext__(self):
+        try:
+            chunk = await self.async_response_iterator.__anext__()
+        except StopAsyncIteration:
+            raise StopAsyncIteration
+        except ValueError as e:
+            raise RuntimeError(f"Error receiving chunk from stream: {e}")
+
+        try:
+            str_line = chunk
+            if isinstance(chunk, bytes):  # Handle binary data
+                str_line = chunk.decode("utf-8")  # Convert bytes to string
+                index = str_line.find("data:")
+                if index != -1:
+                    str_line = str_line[index:]
+
+            data_json = json.loads(str_line)
+            return self.chunk_parser(chunk=data_json)
+        except StopAsyncIteration:
+            raise StopAsyncIteration
+        except ValueError as e:
+            raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/handler.py
new file mode 100644
index 00000000..6a779511
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/handler.py
@@ -0,0 +1,5 @@
+"""
+Cohere /generate API - uses `llm_http_handler.py` to make httpx requests
+
+Request/Response transformation is handled in `transformation.py`
+"""
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/transformation.py
new file mode 100644
index 00000000..bdfcda02
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/transformation.py
@@ -0,0 +1,264 @@
+import time
+from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.prompt_templates.common_utils import (
+    convert_content_list_to_str,
+)
+from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import Choices, Message, ModelResponse, Usage
+
+from ..common_utils import CohereError
+from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
+from ..common_utils import validate_environment as cohere_validate_environment
+
+if TYPE_CHECKING:
+    from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+    LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+    LiteLLMLoggingObj = Any
+
+
+class CohereTextConfig(BaseConfig):
+    """
+    Reference: https://docs.cohere.com/reference/generate
+
+    The class `CohereConfig` provides configuration for the Cohere's API interface. Below are the parameters:
+
+    - `num_generations` (integer): Maximum number of generations returned. Default is 1, with a minimum value of 1 and a maximum value of 5.
+
+    - `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default value is 20.
+
+    - `truncate` (string): Specifies how the API handles inputs longer than maximum token length. Options include NONE, START, END. Default is END.
+
+    - `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.75.
+
+    - `preset` (string): Identifier of a custom preset, a combination of parameters such as prompt, temperature etc.
+
+    - `end_sequences` (array of strings): The generated text gets cut at the beginning of the earliest occurrence of an end sequence, which will be excluded from the text.
+
+    - `stop_sequences` (array of strings): The generated text gets cut at the end of the earliest occurrence of a stop sequence, which will be included in the text.
+
+    - `k` (integer): Limits generation at each step to top `k` most likely tokens. Default is 0.
+
+    - `p` (number): Limits generation at each step to most likely tokens with total probability mass of `p`. Default is 0.
+
+    - `frequency_penalty` (number): Reduces repetitiveness of generated tokens. Higher values apply stronger penalties to previously occurred tokens.
+
+    - `presence_penalty` (number): Reduces repetitiveness of generated tokens. Similar to frequency_penalty, but this penalty applies equally to all tokens that have already appeared.
+
+    - `return_likelihoods` (string): Specifies how and if token likelihoods are returned with the response. Options include GENERATION, ALL and NONE.
+
+    - `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233}
+    """
+
+    num_generations: Optional[int] = None
+    max_tokens: Optional[int] = None
+    truncate: Optional[str] = None
+    temperature: Optional[int] = None
+    preset: Optional[str] = None
+    end_sequences: Optional[list] = None
+    stop_sequences: Optional[list] = None
+    k: Optional[int] = None
+    p: Optional[int] = None
+    frequency_penalty: Optional[int] = None
+    presence_penalty: Optional[int] = None
+    return_likelihoods: Optional[str] = None
+    logit_bias: Optional[dict] = None
+
+    def __init__(
+        self,
+        num_generations: Optional[int] = None,
+        max_tokens: Optional[int] = None,
+        truncate: Optional[str] = None,
+        temperature: Optional[int] = None,
+        preset: Optional[str] = None,
+        end_sequences: Optional[list] = None,
+        stop_sequences: Optional[list] = None,
+        k: Optional[int] = None,
+        p: Optional[int] = None,
+        frequency_penalty: Optional[int] = None,
+        presence_penalty: Optional[int] = None,
+        return_likelihoods: Optional[str] = None,
+        logit_bias: Optional[dict] = None,
+    ) -> None:
+        locals_ = locals().copy()
+        for key, value in locals_.items():
+            if key != "self" and value is not None:
+                setattr(self.__class__, key, value)
+
+    @classmethod
+    def get_config(cls):
+        return super().get_config()
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> dict:
+        return cohere_validate_environment(
+            headers=headers,
+            model=model,
+            messages=messages,
+            optional_params=optional_params,
+            api_key=api_key,
+        )
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+    ) -> BaseLLMException:
+        return CohereError(status_code=status_code, message=error_message)
+
+    def get_supported_openai_params(self, model: str) -> List:
+        return [
+            "stream",
+            "temperature",
+            "max_tokens",
+            "logit_bias",
+            "top_p",
+            "frequency_penalty",
+            "presence_penalty",
+            "stop",
+            "n",
+            "extra_headers",
+        ]
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        for param, value in non_default_params.items():
+            if param == "stream":
+                optional_params["stream"] = value
+            elif param == "temperature":
+                optional_params["temperature"] = value
+            elif param == "max_tokens":
+                optional_params["max_tokens"] = value
+            elif param == "n":
+                optional_params["num_generations"] = value
+            elif param == "logit_bias":
+                optional_params["logit_bias"] = value
+            elif param == "top_p":
+                optional_params["p"] = value
+            elif param == "frequency_penalty":
+                optional_params["frequency_penalty"] = value
+            elif param == "presence_penalty":
+                optional_params["presence_penalty"] = value
+            elif param == "stop":
+                optional_params["stop_sequences"] = value
+        return optional_params
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+        prompt = " ".join(
+            convert_content_list_to_str(message=message) for message in messages
+        )
+
+        ## Load Config
+        config = litellm.CohereConfig.get_config()
+        for k, v in config.items():
+            if (
+                k not in optional_params
+            ):  # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
+                optional_params[k] = v
+
+        ## Handle Tool Calling
+        if "tools" in optional_params:
+            _is_function_call = True
+            tool_calling_system_prompt = self._construct_cohere_tool_for_completion_api(
+                tools=optional_params["tools"]
+            )
+            optional_params["tools"] = tool_calling_system_prompt
+
+        data = {
+            "model": model,
+            "prompt": prompt,
+            **optional_params,
+        }
+
+        return data
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: dict,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+        prompt = " ".join(
+            convert_content_list_to_str(message=message) for message in messages
+        )
+        completion_response = raw_response.json()
+        choices_list = []
+        for idx, item in enumerate(completion_response["generations"]):
+            if len(item["text"]) > 0:
+                message_obj = Message(content=item["text"])
+            else:
+                message_obj = Message(content=None)
+            choice_obj = Choices(
+                finish_reason=item["finish_reason"],
+                index=idx + 1,
+                message=message_obj,
+            )
+            choices_list.append(choice_obj)
+        model_response.choices = choices_list  # type: ignore
+
+        ## CALCULATING USAGE
+        prompt_tokens = len(encoding.encode(prompt))
+        completion_tokens = len(
+            encoding.encode(model_response["choices"][0]["message"].get("content", ""))
+        )
+
+        model_response.created = int(time.time())
+        model_response.model = model
+        usage = Usage(
+            prompt_tokens=prompt_tokens,
+            completion_tokens=completion_tokens,
+            total_tokens=prompt_tokens + completion_tokens,
+        )
+        setattr(model_response, "usage", usage)
+        return model_response
+
+    def _construct_cohere_tool_for_completion_api(
+        self,
+        tools: Optional[List] = None,
+    ) -> dict:
+        if tools is None:
+            tools = []
+        return {"tools": tools}
+
+    def get_model_response_iterator(
+        self,
+        streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
+        sync_stream: bool,
+        json_mode: Optional[bool] = False,
+    ):
+        return CohereModelResponseIterator(
+            streaming_response=streaming_response,
+            sync_stream=sync_stream,
+            json_mode=json_mode,
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/handler.py
new file mode 100644
index 00000000..e7f22ea7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/handler.py
@@ -0,0 +1,178 @@
+import json
+from typing import Any, Callable, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.custom_httpx.http_handler import (
+    AsyncHTTPHandler,
+    HTTPHandler,
+    get_async_httpx_client,
+)
+from litellm.types.llms.bedrock import CohereEmbeddingRequest
+from litellm.types.utils import EmbeddingResponse
+
+from .transformation import CohereEmbeddingConfig
+
+
+def validate_environment(api_key, headers: dict):
+    headers.update(
+        {
+            "Request-Source": "unspecified:litellm",
+            "accept": "application/json",
+            "content-type": "application/json",
+        }
+    )
+    if api_key:
+        headers["Authorization"] = f"Bearer {api_key}"
+    return headers
+
+
+class CohereError(Exception):
+    def __init__(self, status_code, message):
+        self.status_code = status_code
+        self.message = message
+        self.request = httpx.Request(
+            method="POST", url="https://api.cohere.ai/v1/generate"
+        )
+        self.response = httpx.Response(status_code=status_code, request=self.request)
+        super().__init__(
+            self.message
+        )  # Call the base class constructor with the parameters it needs
+
+
+async def async_embedding(
+    model: str,
+    data: Union[dict, CohereEmbeddingRequest],
+    input: list,
+    model_response: litellm.utils.EmbeddingResponse,
+    timeout: Optional[Union[float, httpx.Timeout]],
+    logging_obj: LiteLLMLoggingObj,
+    optional_params: dict,
+    api_base: str,
+    api_key: Optional[str],
+    headers: dict,
+    encoding: Callable,
+    client: Optional[AsyncHTTPHandler] = None,
+):
+
+    ## LOGGING
+    logging_obj.pre_call(
+        input=input,
+        api_key=api_key,
+        additional_args={
+            "complete_input_dict": data,
+            "headers": headers,
+            "api_base": api_base,
+        },
+    )
+    ## COMPLETION CALL
+
+    if client is None:
+        client = get_async_httpx_client(
+            llm_provider=litellm.LlmProviders.COHERE,
+            params={"timeout": timeout},
+        )
+
+    try:
+        response = await client.post(api_base, headers=headers, data=json.dumps(data))
+    except httpx.HTTPStatusError as e:
+        ## LOGGING
+        logging_obj.post_call(
+            input=input,
+            api_key=api_key,
+            additional_args={"complete_input_dict": data},
+            original_response=e.response.text,
+        )
+        raise e
+    except Exception as e:
+        ## LOGGING
+        logging_obj.post_call(
+            input=input,
+            api_key=api_key,
+            additional_args={"complete_input_dict": data},
+            original_response=str(e),
+        )
+        raise e
+
+    ## PROCESS RESPONSE ##
+    return CohereEmbeddingConfig()._transform_response(
+        response=response,
+        api_key=api_key,
+        logging_obj=logging_obj,
+        data=data,
+        model_response=model_response,
+        model=model,
+        encoding=encoding,
+        input=input,
+    )
+
+
+def embedding(
+    model: str,
+    input: list,
+    model_response: EmbeddingResponse,
+    logging_obj: LiteLLMLoggingObj,
+    optional_params: dict,
+    headers: dict,
+    encoding: Any,
+    data: Optional[Union[dict, CohereEmbeddingRequest]] = None,
+    complete_api_base: Optional[str] = None,
+    api_key: Optional[str] = None,
+    aembedding: Optional[bool] = None,
+    timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None),
+    client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+):
+    headers = validate_environment(api_key, headers=headers)
+    embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
+    model = model
+
+    data = data or CohereEmbeddingConfig()._transform_request(
+        model=model, input=input, inference_params=optional_params
+    )
+
+    ## ROUTING
+    if aembedding is True:
+        return async_embedding(
+            model=model,
+            data=data,
+            input=input,
+            model_response=model_response,
+            timeout=timeout,
+            logging_obj=logging_obj,
+            optional_params=optional_params,
+            api_base=embed_url,
+            api_key=api_key,
+            headers=headers,
+            encoding=encoding,
+            client=(
+                client
+                if client is not None and isinstance(client, AsyncHTTPHandler)
+                else None
+            ),
+        )
+
+    ## LOGGING
+    logging_obj.pre_call(
+        input=input,
+        api_key=api_key,
+        additional_args={"complete_input_dict": data},
+    )
+
+    ## COMPLETION CALL
+    if client is None or not isinstance(client, HTTPHandler):
+        client = HTTPHandler(concurrent_limit=1)
+
+    response = client.post(embed_url, headers=headers, data=json.dumps(data))
+
+    return CohereEmbeddingConfig()._transform_response(
+        response=response,
+        api_key=api_key,
+        logging_obj=logging_obj,
+        data=data,
+        model_response=model_response,
+        model=model,
+        encoding=encoding,
+        input=input,
+    )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/transformation.py
new file mode 100644
index 00000000..22e157a0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/transformation.py
@@ -0,0 +1,153 @@
+"""
+Transformation logic from OpenAI /v1/embeddings format to Cohere's /v1/embed format.
+
+Why separate file? Make it easy to see how transformation works
+
+Convers
+- v3 embedding models
+- v2 embedding models
+
+Docs - https://docs.cohere.com/v2/reference/embed
+"""
+
+from typing import Any, List, Optional, Union
+
+import httpx
+
+from litellm import COHERE_DEFAULT_EMBEDDING_INPUT_TYPE
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.types.llms.bedrock import (
+    CohereEmbeddingRequest,
+    CohereEmbeddingRequestWithModel,
+)
+from litellm.types.utils import EmbeddingResponse, PromptTokensDetailsWrapper, Usage
+from litellm.utils import is_base64_encoded
+
+
+class CohereEmbeddingConfig:
+    """
+    Reference: https://docs.cohere.com/v2/reference/embed
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def get_supported_openai_params(self) -> List[str]:
+        return ["encoding_format"]
+
+    def map_openai_params(
+        self, non_default_params: dict, optional_params: dict
+    ) -> dict:
+        for k, v in non_default_params.items():
+            if k == "encoding_format":
+                optional_params["embedding_types"] = v
+        return optional_params
+
+    def _is_v3_model(self, model: str) -> bool:
+        return "3" in model
+
+    def _transform_request(
+        self, model: str, input: List[str], inference_params: dict
+    ) -> CohereEmbeddingRequestWithModel:
+        is_encoded = False
+        for input_str in input:
+            is_encoded = is_base64_encoded(input_str)
+
+        if is_encoded:  # check if string is b64 encoded image or not
+            transformed_request = CohereEmbeddingRequestWithModel(
+                model=model,
+                images=input,
+                input_type="image",
+            )
+        else:
+            transformed_request = CohereEmbeddingRequestWithModel(
+                model=model,
+                texts=input,
+                input_type=COHERE_DEFAULT_EMBEDDING_INPUT_TYPE,
+            )
+
+        for k, v in inference_params.items():
+            transformed_request[k] = v  # type: ignore
+
+        return transformed_request
+
+    def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
+
+        input_tokens = 0
+
+        text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
+
+        image_tokens: Optional[int] = meta.get("billed_units", {}).get("images")
+
+        prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
+        if image_tokens is None and text_tokens is None:
+            for text in input:
+                input_tokens += len(encoding.encode(text))
+        else:
+            prompt_tokens_details = PromptTokensDetailsWrapper(
+                image_tokens=image_tokens,
+                text_tokens=text_tokens,
+            )
+            if image_tokens:
+                input_tokens += image_tokens
+            if text_tokens:
+                input_tokens += text_tokens
+
+        return Usage(
+            prompt_tokens=input_tokens,
+            completion_tokens=0,
+            total_tokens=input_tokens,
+            prompt_tokens_details=prompt_tokens_details,
+        )
+
+    def _transform_response(
+        self,
+        response: httpx.Response,
+        api_key: Optional[str],
+        logging_obj: LiteLLMLoggingObj,
+        data: Union[dict, CohereEmbeddingRequest],
+        model_response: EmbeddingResponse,
+        model: str,
+        encoding: Any,
+        input: list,
+    ) -> EmbeddingResponse:
+
+        response_json = response.json()
+        ## LOGGING
+        logging_obj.post_call(
+            input=input,
+            api_key=api_key,
+            additional_args={"complete_input_dict": data},
+            original_response=response_json,
+        )
+        """
+            response 
+            {
+                'object': "list",
+                'data': [
+                
+                ]
+                'model', 
+                'usage'
+            }
+        """
+        embeddings = response_json["embeddings"]
+        output_data = []
+        for idx, embedding in enumerate(embeddings):
+            output_data.append(
+                {"object": "embedding", "index": idx, "embedding": embedding}
+            )
+        model_response.object = "list"
+        model_response.data = output_data
+        model_response.model = model
+        input_tokens = 0
+        for text in input:
+            input_tokens += len(encoding.encode(text))
+
+        setattr(
+            model_response,
+            "usage",
+            self._calculate_usage(input, encoding, response_json.get("meta", {})),
+        )
+
+        return model_response
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/handler.py
new file mode 100644
index 00000000..e94f1859
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/handler.py
@@ -0,0 +1,5 @@
+"""
+Cohere Rerank - uses `llm_http_handler.py` to make httpx requests
+
+Request/Response transformation is handled in `transformation.py`
+"""
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py
new file mode 100644
index 00000000..f3624d92
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py
@@ -0,0 +1,151 @@
+from typing import Any, Dict, List, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.rerank import OptionalRerankParams, RerankRequest
+from litellm.types.utils import RerankResponse
+
+from ..common_utils import CohereError
+
+
+class CohereRerankConfig(BaseRerankConfig):
+    """
+    Reference: https://docs.cohere.com/v2/reference/rerank
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def get_complete_url(self, api_base: Optional[str], model: str) -> str:
+        if api_base:
+            # Remove trailing slashes and ensure clean base URL
+            api_base = api_base.rstrip("/")
+            if not api_base.endswith("/v1/rerank"):
+                api_base = f"{api_base}/v1/rerank"
+            return api_base
+        return "https://api.cohere.ai/v1/rerank"
+
+    def get_supported_cohere_rerank_params(self, model: str) -> list:
+        return [
+            "query",
+            "documents",
+            "top_n",
+            "max_chunks_per_doc",
+            "rank_fields",
+            "return_documents",
+        ]
+
+    def map_cohere_rerank_params(
+        self,
+        non_default_params: Optional[dict],
+        model: str,
+        drop_params: bool,
+        query: str,
+        documents: List[Union[str, Dict[str, Any]]],
+        custom_llm_provider: Optional[str] = None,
+        top_n: Optional[int] = None,
+        rank_fields: Optional[List[str]] = None,
+        return_documents: Optional[bool] = True,
+        max_chunks_per_doc: Optional[int] = None,
+        max_tokens_per_doc: Optional[int] = None,
+    ) -> OptionalRerankParams:
+        """
+        Map Cohere rerank params
+
+        No mapping required - returns all supported params
+        """
+        return OptionalRerankParams(
+            query=query,
+            documents=documents,
+            top_n=top_n,
+            rank_fields=rank_fields,
+            return_documents=return_documents,
+            max_chunks_per_doc=max_chunks_per_doc,
+        )
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        api_key: Optional[str] = None,
+    ) -> dict:
+        if api_key is None:
+            api_key = (
+                get_secret_str("COHERE_API_KEY")
+                or get_secret_str("CO_API_KEY")
+                or litellm.cohere_key
+            )
+
+        if api_key is None:
+            raise ValueError(
+                "Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'"
+            )
+
+        default_headers = {
+            "Authorization": f"bearer {api_key}",
+            "accept": "application/json",
+            "content-type": "application/json",
+        }
+
+        # If 'Authorization' is provided in headers, it overrides the default.
+        if "Authorization" in headers:
+            default_headers["Authorization"] = headers["Authorization"]
+
+        # Merge other headers, overriding any default ones except Authorization
+        return {**default_headers, **headers}
+
+    def transform_rerank_request(
+        self,
+        model: str,
+        optional_rerank_params: OptionalRerankParams,
+        headers: dict,
+    ) -> dict:
+        if "query" not in optional_rerank_params:
+            raise ValueError("query is required for Cohere rerank")
+        if "documents" not in optional_rerank_params:
+            raise ValueError("documents is required for Cohere rerank")
+        rerank_request = RerankRequest(
+            model=model,
+            query=optional_rerank_params["query"],
+            documents=optional_rerank_params["documents"],
+            top_n=optional_rerank_params.get("top_n", None),
+            rank_fields=optional_rerank_params.get("rank_fields", None),
+            return_documents=optional_rerank_params.get("return_documents", None),
+            max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None),
+        )
+        return rerank_request.model_dump(exclude_none=True)
+
+    def transform_rerank_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: RerankResponse,
+        logging_obj: LiteLLMLoggingObj,
+        api_key: Optional[str] = None,
+        request_data: dict = {},
+        optional_params: dict = {},
+        litellm_params: dict = {},
+    ) -> RerankResponse:
+        """
+        Transform Cohere rerank response
+
+        No transformation required, litellm follows cohere API response format
+        """
+        try:
+            raw_response_json = raw_response.json()
+        except Exception:
+            raise CohereError(
+                message=raw_response.text, status_code=raw_response.status_code
+            )
+
+        return RerankResponse(**raw_response_json)
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+    ) -> BaseLLMException:
+        return CohereError(message=error_message, status_code=status_code)
\ No newline at end of file
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank_v2/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank_v2/transformation.py
new file mode 100644
index 00000000..a93cb982
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank_v2/transformation.py
@@ -0,0 +1,80 @@
+from typing import Any, Dict, List, Optional, Union
+
+from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
+from litellm.types.rerank import OptionalRerankParams, RerankRequest
+
+class CohereRerankV2Config(CohereRerankConfig):
+    """
+    Reference: https://docs.cohere.com/v2/reference/rerank
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def get_complete_url(self, api_base: Optional[str], model: str) -> str:
+        if api_base:
+            # Remove trailing slashes and ensure clean base URL
+            api_base = api_base.rstrip("/")
+            if not api_base.endswith("/v2/rerank"):
+                api_base = f"{api_base}/v2/rerank"
+            return api_base
+        return "https://api.cohere.ai/v2/rerank"
+
+    def get_supported_cohere_rerank_params(self, model: str) -> list:
+        return [
+            "query",
+            "documents",
+            "top_n",
+            "max_tokens_per_doc",
+            "rank_fields",
+            "return_documents",
+        ]
+
+    def map_cohere_rerank_params(
+        self,
+        non_default_params: Optional[dict],
+        model: str,
+        drop_params: bool,
+        query: str,
+        documents: List[Union[str, Dict[str, Any]]],
+        custom_llm_provider: Optional[str] = None,
+        top_n: Optional[int] = None,
+        rank_fields: Optional[List[str]] = None,
+        return_documents: Optional[bool] = True,
+        max_chunks_per_doc: Optional[int] = None,
+        max_tokens_per_doc: Optional[int] = None,
+    ) -> OptionalRerankParams:
+        """
+        Map Cohere rerank params
+
+        No mapping required - returns all supported params
+        """
+        return OptionalRerankParams(
+            query=query,
+            documents=documents,
+            top_n=top_n,
+            rank_fields=rank_fields,
+            return_documents=return_documents,
+            max_tokens_per_doc=max_tokens_per_doc,
+        )
+
+    def transform_rerank_request(
+        self,
+        model: str,
+        optional_rerank_params: OptionalRerankParams,
+        headers: dict,
+    ) -> dict:
+        if "query" not in optional_rerank_params:
+            raise ValueError("query is required for Cohere rerank")
+        if "documents" not in optional_rerank_params:
+            raise ValueError("documents is required for Cohere rerank")
+        rerank_request = RerankRequest(
+            model=model,
+            query=optional_rerank_params["query"],
+            documents=optional_rerank_params["documents"],
+            top_n=optional_rerank_params.get("top_n", None),
+            rank_fields=optional_rerank_params.get("rank_fields", None),
+            return_documents=optional_rerank_params.get("return_documents", None),
+            max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None),
+        )
+        return rerank_request.model_dump(exclude_none=True)
\ No newline at end of file