about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py368
1 files changed, 368 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py
new file mode 100644
index 00000000..3ceec2db
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py
@@ -0,0 +1,368 @@
+import json
+import time
+from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
+from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse, Usage
+
+from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
+from ..common_utils import validate_environment as cohere_validate_environment
+
+if TYPE_CHECKING:
+    from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+    LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+    LiteLLMLoggingObj = Any
+
+
+class CohereError(BaseLLMException):
+    def __init__(
+        self,
+        status_code: int,
+        message: str,
+        headers: Optional[httpx.Headers] = None,
+    ):
+        self.status_code = status_code
+        self.message = message
+        self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat")
+        self.response = httpx.Response(status_code=status_code, request=self.request)
+        super().__init__(
+            status_code=status_code,
+            message=message,
+            headers=headers,
+        )
+
+
+class CohereChatConfig(BaseConfig):
+    """
+    Configuration class for Cohere's API interface.
+
+    Args:
+        preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
+        chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
+        generation_id (str, optional): Unique identifier for the generated reply.
+        response_id (str, optional): Unique identifier for the response.
+        conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
+        prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
+        connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
+        search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
+        documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
+        temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
+        max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
+        k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
+        p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
+        frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
+        presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
+        tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
+        tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
+        seed (int, optional): A seed to assist reproducibility of the model's response.
+    """
+
+    preamble: Optional[str] = None
+    chat_history: Optional[list] = None
+    generation_id: Optional[str] = None
+    response_id: Optional[str] = None
+    conversation_id: Optional[str] = None
+    prompt_truncation: Optional[str] = None
+    connectors: Optional[list] = None
+    search_queries_only: Optional[bool] = None
+    documents: Optional[list] = None
+    temperature: Optional[int] = None
+    max_tokens: Optional[int] = None
+    k: Optional[int] = None
+    p: Optional[int] = None
+    frequency_penalty: Optional[int] = None
+    presence_penalty: Optional[int] = None
+    tools: Optional[list] = None
+    tool_results: Optional[list] = None
+    seed: Optional[int] = None
+
+    def __init__(
+        self,
+        preamble: Optional[str] = None,
+        chat_history: Optional[list] = None,
+        generation_id: Optional[str] = None,
+        response_id: Optional[str] = None,
+        conversation_id: Optional[str] = None,
+        prompt_truncation: Optional[str] = None,
+        connectors: Optional[list] = None,
+        search_queries_only: Optional[bool] = None,
+        documents: Optional[list] = None,
+        temperature: Optional[int] = None,
+        max_tokens: Optional[int] = None,
+        k: Optional[int] = None,
+        p: Optional[int] = None,
+        frequency_penalty: Optional[int] = None,
+        presence_penalty: Optional[int] = None,
+        tools: Optional[list] = None,
+        tool_results: Optional[list] = None,
+        seed: Optional[int] = None,
+    ) -> None:
+        locals_ = locals().copy()
+        for key, value in locals_.items():
+            if key != "self" and value is not None:
+                setattr(self.__class__, key, value)
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> dict:
+        return cohere_validate_environment(
+            headers=headers,
+            model=model,
+            messages=messages,
+            optional_params=optional_params,
+            api_key=api_key,
+        )
+
+    def get_supported_openai_params(self, model: str) -> List[str]:
+        return [
+            "stream",
+            "temperature",
+            "max_tokens",
+            "top_p",
+            "frequency_penalty",
+            "presence_penalty",
+            "stop",
+            "n",
+            "tools",
+            "tool_choice",
+            "seed",
+            "extra_headers",
+        ]
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        for param, value in non_default_params.items():
+            if param == "stream":
+                optional_params["stream"] = value
+            if param == "temperature":
+                optional_params["temperature"] = value
+            if param == "max_tokens":
+                optional_params["max_tokens"] = value
+            if param == "n":
+                optional_params["num_generations"] = value
+            if param == "top_p":
+                optional_params["p"] = value
+            if param == "frequency_penalty":
+                optional_params["frequency_penalty"] = value
+            if param == "presence_penalty":
+                optional_params["presence_penalty"] = value
+            if param == "stop":
+                optional_params["stop_sequences"] = value
+            if param == "tools":
+                optional_params["tools"] = value
+            if param == "seed":
+                optional_params["seed"] = value
+        return optional_params
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+
+        ## Load Config
+        for k, v in litellm.CohereChatConfig.get_config().items():
+            if (
+                k not in optional_params
+            ):  # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
+                optional_params[k] = v
+
+        most_recent_message, chat_history = cohere_messages_pt_v2(
+            messages=messages, model=model, llm_provider="cohere_chat"
+        )
+
+        ## Handle Tool Calling
+        if "tools" in optional_params:
+            _is_function_call = True
+            cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
+            optional_params["tools"] = cohere_tools
+        if isinstance(most_recent_message, dict):
+            optional_params["tool_results"] = [most_recent_message]
+        elif isinstance(most_recent_message, str):
+            optional_params["message"] = most_recent_message
+
+        ## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
+        if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
+            optional_params["force_single_step"] = True
+
+        return optional_params
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: dict,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+
+        try:
+            raw_response_json = raw_response.json()
+            model_response.choices[0].message.content = raw_response_json["text"]  # type: ignore
+        except Exception:
+            raise CohereError(
+                message=raw_response.text, status_code=raw_response.status_code
+            )
+
+        ## ADD CITATIONS
+        if "citations" in raw_response_json:
+            setattr(model_response, "citations", raw_response_json["citations"])
+
+        ## Tool calling response
+        cohere_tools_response = raw_response_json.get("tool_calls", None)
+        if cohere_tools_response is not None and cohere_tools_response != []:
+            # convert cohere_tools_response to OpenAI response format
+            tool_calls = []
+            for tool in cohere_tools_response:
+                function_name = tool.get("name", "")
+                generation_id = tool.get("generation_id", "")
+                parameters = tool.get("parameters", {})
+                tool_call = {
+                    "id": f"call_{generation_id}",
+                    "type": "function",
+                    "function": {
+                        "name": function_name,
+                        "arguments": json.dumps(parameters),
+                    },
+                }
+                tool_calls.append(tool_call)
+            _message = litellm.Message(
+                tool_calls=tool_calls,
+                content=None,
+            )
+            model_response.choices[0].message = _message  # type: ignore
+
+        ## CALCULATING USAGE - use cohere `billed_units` for returning usage
+        billed_units = raw_response_json.get("meta", {}).get("billed_units", {})
+
+        prompt_tokens = billed_units.get("input_tokens", 0)
+        completion_tokens = billed_units.get("output_tokens", 0)
+
+        model_response.created = int(time.time())
+        model_response.model = model
+        usage = Usage(
+            prompt_tokens=prompt_tokens,
+            completion_tokens=completion_tokens,
+            total_tokens=prompt_tokens + completion_tokens,
+        )
+        setattr(model_response, "usage", usage)
+        return model_response
+
+    def _construct_cohere_tool(
+        self,
+        tools: Optional[list] = None,
+    ):
+        if tools is None:
+            tools = []
+        cohere_tools = []
+        for tool in tools:
+            cohere_tool = self._translate_openai_tool_to_cohere(tool)
+            cohere_tools.append(cohere_tool)
+        return cohere_tools
+
+    def _translate_openai_tool_to_cohere(
+        self,
+        openai_tool: dict,
+    ):
+        # cohere tools look like this
+        """
+        {
+        "name": "query_daily_sales_report",
+        "description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
+        "parameter_definitions": {
+            "day": {
+                "description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
+                "type": "str",
+                "required": True
+            }
+        }
+        }
+        """
+
+        # OpenAI tools look like this
+        """
+        {
+            "type": "function",
+            "function": {
+                "name": "get_current_weather",
+                "description": "Get the current weather in a given location",
+                "parameters": {
+                    "type": "object",
+                    "properties": {
+                        "location": {
+                            "type": "string",
+                            "description": "The city and state, e.g. San Francisco, CA",
+                        },
+                        "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
+                    },
+                    "required": ["location"],
+                },
+            },
+        }
+        """
+        cohere_tool = {
+            "name": openai_tool["function"]["name"],
+            "description": openai_tool["function"]["description"],
+            "parameter_definitions": {},
+        }
+
+        for param_name, param_def in openai_tool["function"]["parameters"][
+            "properties"
+        ].items():
+            required_params = (
+                openai_tool.get("function", {})
+                .get("parameters", {})
+                .get("required", [])
+            )
+            cohere_param_def = {
+                "description": param_def.get("description", ""),
+                "type": param_def.get("type", ""),
+                "required": param_name in required_params,
+            }
+            cohere_tool["parameter_definitions"][param_name] = cohere_param_def
+
+        return cohere_tool
+
+    def get_model_response_iterator(
+        self,
+        streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
+        sync_stream: bool,
+        json_mode: Optional[bool] = False,
+    ):
+        return CohereModelResponseIterator(
+            streaming_response=streaming_response,
+            sync_stream=sync_stream,
+            json_mode=json_mode,
+        )
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+    ) -> BaseLLMException:
+        return CohereError(status_code=status_code, message=error_message)