diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py | 368 |
1 files changed, 368 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py new file mode 100644 index 00000000..3ceec2db --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py @@ -0,0 +1,368 @@ +import json +import time +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2 +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse, Usage + +from ..common_utils import ModelResponseIterator as CohereModelResponseIterator +from ..common_utils import validate_environment as cohere_validate_environment + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class CohereError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + headers: Optional[httpx.Headers] = None, + ): + self.status_code = status_code + self.message = message + self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat") + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + status_code=status_code, + message=message, + headers=headers, + ) + + +class CohereChatConfig(BaseConfig): + """ + Configuration class for Cohere's API interface. + + Args: + preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one. + chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model. + generation_id (str, optional): Unique identifier for the generated reply. + response_id (str, optional): Unique identifier for the response. + conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation. + prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'. + connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply. + search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries. + documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite. + temperature (float, optional): A non-negative float that tunes the degree of randomness in generation. + max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response. + k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step. + p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation. + frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens. + presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens. + tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking. + tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools. + seed (int, optional): A seed to assist reproducibility of the model's response. + """ + + preamble: Optional[str] = None + chat_history: Optional[list] = None + generation_id: Optional[str] = None + response_id: Optional[str] = None + conversation_id: Optional[str] = None + prompt_truncation: Optional[str] = None + connectors: Optional[list] = None + search_queries_only: Optional[bool] = None + documents: Optional[list] = None + temperature: Optional[int] = None + max_tokens: Optional[int] = None + k: Optional[int] = None + p: Optional[int] = None + frequency_penalty: Optional[int] = None + presence_penalty: Optional[int] = None + tools: Optional[list] = None + tool_results: Optional[list] = None + seed: Optional[int] = None + + def __init__( + self, + preamble: Optional[str] = None, + chat_history: Optional[list] = None, + generation_id: Optional[str] = None, + response_id: Optional[str] = None, + conversation_id: Optional[str] = None, + prompt_truncation: Optional[str] = None, + connectors: Optional[list] = None, + search_queries_only: Optional[bool] = None, + documents: Optional[list] = None, + temperature: Optional[int] = None, + max_tokens: Optional[int] = None, + k: Optional[int] = None, + p: Optional[int] = None, + frequency_penalty: Optional[int] = None, + presence_penalty: Optional[int] = None, + tools: Optional[list] = None, + tool_results: Optional[list] = None, + seed: Optional[int] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + return cohere_validate_environment( + headers=headers, + model=model, + messages=messages, + optional_params=optional_params, + api_key=api_key, + ) + + def get_supported_openai_params(self, model: str) -> List[str]: + return [ + "stream", + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "n", + "tools", + "tool_choice", + "seed", + "extra_headers", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "stream": + optional_params["stream"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "n": + optional_params["num_generations"] = value + if param == "top_p": + optional_params["p"] = value + if param == "frequency_penalty": + optional_params["frequency_penalty"] = value + if param == "presence_penalty": + optional_params["presence_penalty"] = value + if param == "stop": + optional_params["stop_sequences"] = value + if param == "tools": + optional_params["tools"] = value + if param == "seed": + optional_params["seed"] = value + return optional_params + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + + ## Load Config + for k, v in litellm.CohereChatConfig.get_config().items(): + if ( + k not in optional_params + ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + most_recent_message, chat_history = cohere_messages_pt_v2( + messages=messages, model=model, llm_provider="cohere_chat" + ) + + ## Handle Tool Calling + if "tools" in optional_params: + _is_function_call = True + cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"]) + optional_params["tools"] = cohere_tools + if isinstance(most_recent_message, dict): + optional_params["tool_results"] = [most_recent_message] + elif isinstance(most_recent_message, str): + optional_params["message"] = most_recent_message + + ## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails + if len(chat_history) > 0 and chat_history[-1]["role"] == "USER": + optional_params["force_single_step"] = True + + return optional_params + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + + try: + raw_response_json = raw_response.json() + model_response.choices[0].message.content = raw_response_json["text"] # type: ignore + except Exception: + raise CohereError( + message=raw_response.text, status_code=raw_response.status_code + ) + + ## ADD CITATIONS + if "citations" in raw_response_json: + setattr(model_response, "citations", raw_response_json["citations"]) + + ## Tool calling response + cohere_tools_response = raw_response_json.get("tool_calls", None) + if cohere_tools_response is not None and cohere_tools_response != []: + # convert cohere_tools_response to OpenAI response format + tool_calls = [] + for tool in cohere_tools_response: + function_name = tool.get("name", "") + generation_id = tool.get("generation_id", "") + parameters = tool.get("parameters", {}) + tool_call = { + "id": f"call_{generation_id}", + "type": "function", + "function": { + "name": function_name, + "arguments": json.dumps(parameters), + }, + } + tool_calls.append(tool_call) + _message = litellm.Message( + tool_calls=tool_calls, + content=None, + ) + model_response.choices[0].message = _message # type: ignore + + ## CALCULATING USAGE - use cohere `billed_units` for returning usage + billed_units = raw_response_json.get("meta", {}).get("billed_units", {}) + + prompt_tokens = billed_units.get("input_tokens", 0) + completion_tokens = billed_units.get("output_tokens", 0) + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + return model_response + + def _construct_cohere_tool( + self, + tools: Optional[list] = None, + ): + if tools is None: + tools = [] + cohere_tools = [] + for tool in tools: + cohere_tool = self._translate_openai_tool_to_cohere(tool) + cohere_tools.append(cohere_tool) + return cohere_tools + + def _translate_openai_tool_to_cohere( + self, + openai_tool: dict, + ): + # cohere tools look like this + """ + { + "name": "query_daily_sales_report", + "description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.", + "parameter_definitions": { + "day": { + "description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", + "type": "str", + "required": True + } + } + } + """ + + # OpenAI tools look like this + """ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + """ + cohere_tool = { + "name": openai_tool["function"]["name"], + "description": openai_tool["function"]["description"], + "parameter_definitions": {}, + } + + for param_name, param_def in openai_tool["function"]["parameters"][ + "properties" + ].items(): + required_params = ( + openai_tool.get("function", {}) + .get("parameters", {}) + .get("required", []) + ) + cohere_param_def = { + "description": param_def.get("description", ""), + "type": param_def.get("type", ""), + "required": param_name in required_params, + } + cohere_tool["parameter_definitions"][param_name] = cohere_param_def + + return cohere_tool + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ): + return CohereModelResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return CohereError(status_code=status_code, message=error_message) |