diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/cohere | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/cohere')
9 files changed, 1350 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py new file mode 100644 index 00000000..3ceec2db --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/chat/transformation.py @@ -0,0 +1,368 @@ +import json +import time +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2 +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse, Usage + +from ..common_utils import ModelResponseIterator as CohereModelResponseIterator +from ..common_utils import validate_environment as cohere_validate_environment + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class CohereError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + headers: Optional[httpx.Headers] = None, + ): + self.status_code = status_code + self.message = message + self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat") + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + status_code=status_code, + message=message, + headers=headers, + ) + + +class CohereChatConfig(BaseConfig): + """ + Configuration class for Cohere's API interface. + + Args: + preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one. + chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model. + generation_id (str, optional): Unique identifier for the generated reply. + response_id (str, optional): Unique identifier for the response. + conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation. + prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'. + connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply. + search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries. + documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite. + temperature (float, optional): A non-negative float that tunes the degree of randomness in generation. + max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response. + k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step. + p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation. + frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens. + presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens. + tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking. + tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools. + seed (int, optional): A seed to assist reproducibility of the model's response. + """ + + preamble: Optional[str] = None + chat_history: Optional[list] = None + generation_id: Optional[str] = None + response_id: Optional[str] = None + conversation_id: Optional[str] = None + prompt_truncation: Optional[str] = None + connectors: Optional[list] = None + search_queries_only: Optional[bool] = None + documents: Optional[list] = None + temperature: Optional[int] = None + max_tokens: Optional[int] = None + k: Optional[int] = None + p: Optional[int] = None + frequency_penalty: Optional[int] = None + presence_penalty: Optional[int] = None + tools: Optional[list] = None + tool_results: Optional[list] = None + seed: Optional[int] = None + + def __init__( + self, + preamble: Optional[str] = None, + chat_history: Optional[list] = None, + generation_id: Optional[str] = None, + response_id: Optional[str] = None, + conversation_id: Optional[str] = None, + prompt_truncation: Optional[str] = None, + connectors: Optional[list] = None, + search_queries_only: Optional[bool] = None, + documents: Optional[list] = None, + temperature: Optional[int] = None, + max_tokens: Optional[int] = None, + k: Optional[int] = None, + p: Optional[int] = None, + frequency_penalty: Optional[int] = None, + presence_penalty: Optional[int] = None, + tools: Optional[list] = None, + tool_results: Optional[list] = None, + seed: Optional[int] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + return cohere_validate_environment( + headers=headers, + model=model, + messages=messages, + optional_params=optional_params, + api_key=api_key, + ) + + def get_supported_openai_params(self, model: str) -> List[str]: + return [ + "stream", + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "n", + "tools", + "tool_choice", + "seed", + "extra_headers", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "stream": + optional_params["stream"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "n": + optional_params["num_generations"] = value + if param == "top_p": + optional_params["p"] = value + if param == "frequency_penalty": + optional_params["frequency_penalty"] = value + if param == "presence_penalty": + optional_params["presence_penalty"] = value + if param == "stop": + optional_params["stop_sequences"] = value + if param == "tools": + optional_params["tools"] = value + if param == "seed": + optional_params["seed"] = value + return optional_params + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + + ## Load Config + for k, v in litellm.CohereChatConfig.get_config().items(): + if ( + k not in optional_params + ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + most_recent_message, chat_history = cohere_messages_pt_v2( + messages=messages, model=model, llm_provider="cohere_chat" + ) + + ## Handle Tool Calling + if "tools" in optional_params: + _is_function_call = True + cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"]) + optional_params["tools"] = cohere_tools + if isinstance(most_recent_message, dict): + optional_params["tool_results"] = [most_recent_message] + elif isinstance(most_recent_message, str): + optional_params["message"] = most_recent_message + + ## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails + if len(chat_history) > 0 and chat_history[-1]["role"] == "USER": + optional_params["force_single_step"] = True + + return optional_params + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + + try: + raw_response_json = raw_response.json() + model_response.choices[0].message.content = raw_response_json["text"] # type: ignore + except Exception: + raise CohereError( + message=raw_response.text, status_code=raw_response.status_code + ) + + ## ADD CITATIONS + if "citations" in raw_response_json: + setattr(model_response, "citations", raw_response_json["citations"]) + + ## Tool calling response + cohere_tools_response = raw_response_json.get("tool_calls", None) + if cohere_tools_response is not None and cohere_tools_response != []: + # convert cohere_tools_response to OpenAI response format + tool_calls = [] + for tool in cohere_tools_response: + function_name = tool.get("name", "") + generation_id = tool.get("generation_id", "") + parameters = tool.get("parameters", {}) + tool_call = { + "id": f"call_{generation_id}", + "type": "function", + "function": { + "name": function_name, + "arguments": json.dumps(parameters), + }, + } + tool_calls.append(tool_call) + _message = litellm.Message( + tool_calls=tool_calls, + content=None, + ) + model_response.choices[0].message = _message # type: ignore + + ## CALCULATING USAGE - use cohere `billed_units` for returning usage + billed_units = raw_response_json.get("meta", {}).get("billed_units", {}) + + prompt_tokens = billed_units.get("input_tokens", 0) + completion_tokens = billed_units.get("output_tokens", 0) + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + return model_response + + def _construct_cohere_tool( + self, + tools: Optional[list] = None, + ): + if tools is None: + tools = [] + cohere_tools = [] + for tool in tools: + cohere_tool = self._translate_openai_tool_to_cohere(tool) + cohere_tools.append(cohere_tool) + return cohere_tools + + def _translate_openai_tool_to_cohere( + self, + openai_tool: dict, + ): + # cohere tools look like this + """ + { + "name": "query_daily_sales_report", + "description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.", + "parameter_definitions": { + "day": { + "description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", + "type": "str", + "required": True + } + } + } + """ + + # OpenAI tools look like this + """ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + """ + cohere_tool = { + "name": openai_tool["function"]["name"], + "description": openai_tool["function"]["description"], + "parameter_definitions": {}, + } + + for param_name, param_def in openai_tool["function"]["parameters"][ + "properties" + ].items(): + required_params = ( + openai_tool.get("function", {}) + .get("parameters", {}) + .get("required", []) + ) + cohere_param_def = { + "description": param_def.get("description", ""), + "type": param_def.get("type", ""), + "required": param_name in required_params, + } + cohere_tool["parameter_definitions"][param_name] = cohere_param_def + + return cohere_tool + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ): + return CohereModelResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return CohereError(status_code=status_code, message=error_message) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py new file mode 100644 index 00000000..11ff73ef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/common_utils.py @@ -0,0 +1,146 @@ +import json +from typing import List, Optional + +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + GenericStreamingChunk, +) + + +class CohereError(BaseLLMException): + def __init__(self, status_code, message): + super().__init__(status_code=status_code, message=message) + + +def validate_environment( + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, +) -> dict: + """ + Return headers to use for cohere chat completion request + + Cohere API Ref: https://docs.cohere.com/reference/chat + Expected headers: + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + "Authorization": "bearer $CO_API_KEY" + } + """ + headers.update( + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + } + ) + if api_key: + headers["Authorization"] = f"bearer {api_key}" + return headers + + +class ModelResponseIterator: + def __init__( + self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False + ): + self.streaming_response = streaming_response + self.response_iterator = self.streaming_response + self.content_blocks: List = [] + self.tool_index = -1 + self.json_mode = json_mode + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + provider_specific_fields = None + + index = int(chunk.get("index", 0)) + + if "text" in chunk: + text = chunk["text"] + elif "is_finished" in chunk and chunk["is_finished"] is True: + is_finished = chunk["is_finished"] + finish_reason = chunk["finish_reason"] + + if "citations" in chunk: + provider_specific_fields = {"citations": chunk["citations"]} + + returned_chunk = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + provider_specific_fields=provider_specific_fields, + ) + + return returned_chunk + + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + try: + chunk = self.response_iterator.__next__() + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + data_json = json.loads(str_line) + return self.chunk_parser(chunk=data_json) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + + data_json = json.loads(str_line) + return self.chunk_parser(chunk=data_json) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/handler.py new file mode 100644 index 00000000..6a779511 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/handler.py @@ -0,0 +1,5 @@ +""" +Cohere /generate API - uses `llm_http_handler.py` to make httpx requests + +Request/Response transformation is handled in `transformation.py` +""" diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/transformation.py new file mode 100644 index 00000000..bdfcda02 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/completion/transformation.py @@ -0,0 +1,264 @@ +import time +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + convert_content_list_to_str, +) +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import Choices, Message, ModelResponse, Usage + +from ..common_utils import CohereError +from ..common_utils import ModelResponseIterator as CohereModelResponseIterator +from ..common_utils import validate_environment as cohere_validate_environment + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class CohereTextConfig(BaseConfig): + """ + Reference: https://docs.cohere.com/reference/generate + + The class `CohereConfig` provides configuration for the Cohere's API interface. Below are the parameters: + + - `num_generations` (integer): Maximum number of generations returned. Default is 1, with a minimum value of 1 and a maximum value of 5. + + - `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default value is 20. + + - `truncate` (string): Specifies how the API handles inputs longer than maximum token length. Options include NONE, START, END. Default is END. + + - `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.75. + + - `preset` (string): Identifier of a custom preset, a combination of parameters such as prompt, temperature etc. + + - `end_sequences` (array of strings): The generated text gets cut at the beginning of the earliest occurrence of an end sequence, which will be excluded from the text. + + - `stop_sequences` (array of strings): The generated text gets cut at the end of the earliest occurrence of a stop sequence, which will be included in the text. + + - `k` (integer): Limits generation at each step to top `k` most likely tokens. Default is 0. + + - `p` (number): Limits generation at each step to most likely tokens with total probability mass of `p`. Default is 0. + + - `frequency_penalty` (number): Reduces repetitiveness of generated tokens. Higher values apply stronger penalties to previously occurred tokens. + + - `presence_penalty` (number): Reduces repetitiveness of generated tokens. Similar to frequency_penalty, but this penalty applies equally to all tokens that have already appeared. + + - `return_likelihoods` (string): Specifies how and if token likelihoods are returned with the response. Options include GENERATION, ALL and NONE. + + - `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233} + """ + + num_generations: Optional[int] = None + max_tokens: Optional[int] = None + truncate: Optional[str] = None + temperature: Optional[int] = None + preset: Optional[str] = None + end_sequences: Optional[list] = None + stop_sequences: Optional[list] = None + k: Optional[int] = None + p: Optional[int] = None + frequency_penalty: Optional[int] = None + presence_penalty: Optional[int] = None + return_likelihoods: Optional[str] = None + logit_bias: Optional[dict] = None + + def __init__( + self, + num_generations: Optional[int] = None, + max_tokens: Optional[int] = None, + truncate: Optional[str] = None, + temperature: Optional[int] = None, + preset: Optional[str] = None, + end_sequences: Optional[list] = None, + stop_sequences: Optional[list] = None, + k: Optional[int] = None, + p: Optional[int] = None, + frequency_penalty: Optional[int] = None, + presence_penalty: Optional[int] = None, + return_likelihoods: Optional[str] = None, + logit_bias: Optional[dict] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + return cohere_validate_environment( + headers=headers, + model=model, + messages=messages, + optional_params=optional_params, + api_key=api_key, + ) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return CohereError(status_code=status_code, message=error_message) + + def get_supported_openai_params(self, model: str) -> List: + return [ + "stream", + "temperature", + "max_tokens", + "logit_bias", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "n", + "extra_headers", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "stream": + optional_params["stream"] = value + elif param == "temperature": + optional_params["temperature"] = value + elif param == "max_tokens": + optional_params["max_tokens"] = value + elif param == "n": + optional_params["num_generations"] = value + elif param == "logit_bias": + optional_params["logit_bias"] = value + elif param == "top_p": + optional_params["p"] = value + elif param == "frequency_penalty": + optional_params["frequency_penalty"] = value + elif param == "presence_penalty": + optional_params["presence_penalty"] = value + elif param == "stop": + optional_params["stop_sequences"] = value + return optional_params + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + prompt = " ".join( + convert_content_list_to_str(message=message) for message in messages + ) + + ## Load Config + config = litellm.CohereConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## Handle Tool Calling + if "tools" in optional_params: + _is_function_call = True + tool_calling_system_prompt = self._construct_cohere_tool_for_completion_api( + tools=optional_params["tools"] + ) + optional_params["tools"] = tool_calling_system_prompt + + data = { + "model": model, + "prompt": prompt, + **optional_params, + } + + return data + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + prompt = " ".join( + convert_content_list_to_str(message=message) for message in messages + ) + completion_response = raw_response.json() + choices_list = [] + for idx, item in enumerate(completion_response["generations"]): + if len(item["text"]) > 0: + message_obj = Message(content=item["text"]) + else: + message_obj = Message(content=None) + choice_obj = Choices( + finish_reason=item["finish_reason"], + index=idx + 1, + message=message_obj, + ) + choices_list.append(choice_obj) + model_response.choices = choices_list # type: ignore + + ## CALCULATING USAGE + prompt_tokens = len(encoding.encode(prompt)) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + return model_response + + def _construct_cohere_tool_for_completion_api( + self, + tools: Optional[List] = None, + ) -> dict: + if tools is None: + tools = [] + return {"tools": tools} + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ): + return CohereModelResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/handler.py new file mode 100644 index 00000000..e7f22ea7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/handler.py @@ -0,0 +1,178 @@ +import json +from typing import Any, Callable, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) +from litellm.types.llms.bedrock import CohereEmbeddingRequest +from litellm.types.utils import EmbeddingResponse + +from .transformation import CohereEmbeddingConfig + + +def validate_environment(api_key, headers: dict): + headers.update( + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + } + ) + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + +class CohereError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url="https://api.cohere.ai/v1/generate" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +async def async_embedding( + model: str, + data: Union[dict, CohereEmbeddingRequest], + input: list, + model_response: litellm.utils.EmbeddingResponse, + timeout: Optional[Union[float, httpx.Timeout]], + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + api_base: str, + api_key: Optional[str], + headers: dict, + encoding: Callable, + client: Optional[AsyncHTTPHandler] = None, +): + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": api_base, + }, + ) + ## COMPLETION CALL + + if client is None: + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.COHERE, + params={"timeout": timeout}, + ) + + try: + response = await client.post(api_base, headers=headers, data=json.dumps(data)) + except httpx.HTTPStatusError as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=e.response.text, + ) + raise e + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + raise e + + ## PROCESS RESPONSE ## + return CohereEmbeddingConfig()._transform_response( + response=response, + api_key=api_key, + logging_obj=logging_obj, + data=data, + model_response=model_response, + model=model, + encoding=encoding, + input=input, + ) + + +def embedding( + model: str, + input: list, + model_response: EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + headers: dict, + encoding: Any, + data: Optional[Union[dict, CohereEmbeddingRequest]] = None, + complete_api_base: Optional[str] = None, + api_key: Optional[str] = None, + aembedding: Optional[bool] = None, + timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None), + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, +): + headers = validate_environment(api_key, headers=headers) + embed_url = complete_api_base or "https://api.cohere.ai/v1/embed" + model = model + + data = data or CohereEmbeddingConfig()._transform_request( + model=model, input=input, inference_params=optional_params + ) + + ## ROUTING + if aembedding is True: + return async_embedding( + model=model, + data=data, + input=input, + model_response=model_response, + timeout=timeout, + logging_obj=logging_obj, + optional_params=optional_params, + api_base=embed_url, + api_key=api_key, + headers=headers, + encoding=encoding, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), + ) + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + + ## COMPLETION CALL + if client is None or not isinstance(client, HTTPHandler): + client = HTTPHandler(concurrent_limit=1) + + response = client.post(embed_url, headers=headers, data=json.dumps(data)) + + return CohereEmbeddingConfig()._transform_response( + response=response, + api_key=api_key, + logging_obj=logging_obj, + data=data, + model_response=model_response, + model=model, + encoding=encoding, + input=input, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/transformation.py new file mode 100644 index 00000000..22e157a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/embed/transformation.py @@ -0,0 +1,153 @@ +""" +Transformation logic from OpenAI /v1/embeddings format to Cohere's /v1/embed format. + +Why separate file? Make it easy to see how transformation works + +Convers +- v3 embedding models +- v2 embedding models + +Docs - https://docs.cohere.com/v2/reference/embed +""" + +from typing import Any, List, Optional, Union + +import httpx + +from litellm import COHERE_DEFAULT_EMBEDDING_INPUT_TYPE +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.types.llms.bedrock import ( + CohereEmbeddingRequest, + CohereEmbeddingRequestWithModel, +) +from litellm.types.utils import EmbeddingResponse, PromptTokensDetailsWrapper, Usage +from litellm.utils import is_base64_encoded + + +class CohereEmbeddingConfig: + """ + Reference: https://docs.cohere.com/v2/reference/embed + """ + + def __init__(self) -> None: + pass + + def get_supported_openai_params(self) -> List[str]: + return ["encoding_format"] + + def map_openai_params( + self, non_default_params: dict, optional_params: dict + ) -> dict: + for k, v in non_default_params.items(): + if k == "encoding_format": + optional_params["embedding_types"] = v + return optional_params + + def _is_v3_model(self, model: str) -> bool: + return "3" in model + + def _transform_request( + self, model: str, input: List[str], inference_params: dict + ) -> CohereEmbeddingRequestWithModel: + is_encoded = False + for input_str in input: + is_encoded = is_base64_encoded(input_str) + + if is_encoded: # check if string is b64 encoded image or not + transformed_request = CohereEmbeddingRequestWithModel( + model=model, + images=input, + input_type="image", + ) + else: + transformed_request = CohereEmbeddingRequestWithModel( + model=model, + texts=input, + input_type=COHERE_DEFAULT_EMBEDDING_INPUT_TYPE, + ) + + for k, v in inference_params.items(): + transformed_request[k] = v # type: ignore + + return transformed_request + + def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage: + + input_tokens = 0 + + text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens") + + image_tokens: Optional[int] = meta.get("billed_units", {}).get("images") + + prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None + if image_tokens is None and text_tokens is None: + for text in input: + input_tokens += len(encoding.encode(text)) + else: + prompt_tokens_details = PromptTokensDetailsWrapper( + image_tokens=image_tokens, + text_tokens=text_tokens, + ) + if image_tokens: + input_tokens += image_tokens + if text_tokens: + input_tokens += text_tokens + + return Usage( + prompt_tokens=input_tokens, + completion_tokens=0, + total_tokens=input_tokens, + prompt_tokens_details=prompt_tokens_details, + ) + + def _transform_response( + self, + response: httpx.Response, + api_key: Optional[str], + logging_obj: LiteLLMLoggingObj, + data: Union[dict, CohereEmbeddingRequest], + model_response: EmbeddingResponse, + model: str, + encoding: Any, + input: list, + ) -> EmbeddingResponse: + + response_json = response.json() + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response_json, + ) + """ + response + { + 'object': "list", + 'data': [ + + ] + 'model', + 'usage' + } + """ + embeddings = response_json["embeddings"] + output_data = [] + for idx, embedding in enumerate(embeddings): + output_data.append( + {"object": "embedding", "index": idx, "embedding": embedding} + ) + model_response.object = "list" + model_response.data = output_data + model_response.model = model + input_tokens = 0 + for text in input: + input_tokens += len(encoding.encode(text)) + + setattr( + model_response, + "usage", + self._calculate_usage(input, encoding, response_json.get("meta", {})), + ) + + return model_response diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/handler.py new file mode 100644 index 00000000..e94f1859 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/handler.py @@ -0,0 +1,5 @@ +""" +Cohere Rerank - uses `llm_http_handler.py` to make httpx requests + +Request/Response transformation is handled in `transformation.py` +""" diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py new file mode 100644 index 00000000..f3624d92 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py @@ -0,0 +1,151 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.rerank import OptionalRerankParams, RerankRequest +from litellm.types.utils import RerankResponse + +from ..common_utils import CohereError + + +class CohereRerankConfig(BaseRerankConfig): + """ + Reference: https://docs.cohere.com/v2/reference/rerank + """ + + def __init__(self) -> None: + pass + + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + if api_base: + # Remove trailing slashes and ensure clean base URL + api_base = api_base.rstrip("/") + if not api_base.endswith("/v1/rerank"): + api_base = f"{api_base}/v1/rerank" + return api_base + return "https://api.cohere.ai/v1/rerank" + + def get_supported_cohere_rerank_params(self, model: str) -> list: + return [ + "query", + "documents", + "top_n", + "max_chunks_per_doc", + "rank_fields", + "return_documents", + ] + + def map_cohere_rerank_params( + self, + non_default_params: Optional[dict], + model: str, + drop_params: bool, + query: str, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[str] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + max_tokens_per_doc: Optional[int] = None, + ) -> OptionalRerankParams: + """ + Map Cohere rerank params + + No mapping required - returns all supported params + """ + return OptionalRerankParams( + query=query, + documents=documents, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + ) + + def validate_environment( + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + ) -> dict: + if api_key is None: + api_key = ( + get_secret_str("COHERE_API_KEY") + or get_secret_str("CO_API_KEY") + or litellm.cohere_key + ) + + if api_key is None: + raise ValueError( + "Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'" + ) + + default_headers = { + "Authorization": f"bearer {api_key}", + "accept": "application/json", + "content-type": "application/json", + } + + # If 'Authorization' is provided in headers, it overrides the default. + if "Authorization" in headers: + default_headers["Authorization"] = headers["Authorization"] + + # Merge other headers, overriding any default ones except Authorization + return {**default_headers, **headers} + + def transform_rerank_request( + self, + model: str, + optional_rerank_params: OptionalRerankParams, + headers: dict, + ) -> dict: + if "query" not in optional_rerank_params: + raise ValueError("query is required for Cohere rerank") + if "documents" not in optional_rerank_params: + raise ValueError("documents is required for Cohere rerank") + rerank_request = RerankRequest( + model=model, + query=optional_rerank_params["query"], + documents=optional_rerank_params["documents"], + top_n=optional_rerank_params.get("top_n", None), + rank_fields=optional_rerank_params.get("rank_fields", None), + return_documents=optional_rerank_params.get("return_documents", None), + max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None), + ) + return rerank_request.model_dump(exclude_none=True) + + def transform_rerank_response( + self, + model: str, + raw_response: httpx.Response, + model_response: RerankResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> RerankResponse: + """ + Transform Cohere rerank response + + No transformation required, litellm follows cohere API response format + """ + try: + raw_response_json = raw_response.json() + except Exception: + raise CohereError( + message=raw_response.text, status_code=raw_response.status_code + ) + + return RerankResponse(**raw_response_json) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return CohereError(message=error_message, status_code=status_code)
\ No newline at end of file diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank_v2/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank_v2/transformation.py new file mode 100644 index 00000000..a93cb982 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank_v2/transformation.py @@ -0,0 +1,80 @@ +from typing import Any, Dict, List, Optional, Union + +from litellm.llms.cohere.rerank.transformation import CohereRerankConfig +from litellm.types.rerank import OptionalRerankParams, RerankRequest + +class CohereRerankV2Config(CohereRerankConfig): + """ + Reference: https://docs.cohere.com/v2/reference/rerank + """ + + def __init__(self) -> None: + pass + + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + if api_base: + # Remove trailing slashes and ensure clean base URL + api_base = api_base.rstrip("/") + if not api_base.endswith("/v2/rerank"): + api_base = f"{api_base}/v2/rerank" + return api_base + return "https://api.cohere.ai/v2/rerank" + + def get_supported_cohere_rerank_params(self, model: str) -> list: + return [ + "query", + "documents", + "top_n", + "max_tokens_per_doc", + "rank_fields", + "return_documents", + ] + + def map_cohere_rerank_params( + self, + non_default_params: Optional[dict], + model: str, + drop_params: bool, + query: str, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[str] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + max_tokens_per_doc: Optional[int] = None, + ) -> OptionalRerankParams: + """ + Map Cohere rerank params + + No mapping required - returns all supported params + """ + return OptionalRerankParams( + query=query, + documents=documents, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_tokens_per_doc=max_tokens_per_doc, + ) + + def transform_rerank_request( + self, + model: str, + optional_rerank_params: OptionalRerankParams, + headers: dict, + ) -> dict: + if "query" not in optional_rerank_params: + raise ValueError("query is required for Cohere rerank") + if "documents" not in optional_rerank_params: + raise ValueError("documents is required for Cohere rerank") + rerank_request = RerankRequest( + model=model, + query=optional_rerank_params["query"], + documents=optional_rerank_params["documents"], + top_n=optional_rerank_params.get("top_n", None), + rank_fields=optional_rerank_params.get("rank_fields", None), + return_documents=optional_rerank_params.get("return_documents", None), + max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None), + ) + return rerank_request.model_dump(exclude_none=True)
\ No newline at end of file |