diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/groq/chat')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/groq/chat/handler.py | 76 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/groq/chat/transformation.py | 158 |
2 files changed, 234 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/groq/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/groq/chat/handler.py new file mode 100644 index 00000000..dc4c3222 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/groq/chat/handler.py @@ -0,0 +1,76 @@ +""" +Handles the chat completion request for groq +""" + +from typing import Callable, List, Optional, Union, cast + +from httpx._config import Timeout + +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import CustomStreamingDecoder +from litellm.utils import ModelResponse + +from ...groq.chat.transformation import GroqChatConfig +from ...openai_like.chat.handler import OpenAILikeChatHandler + + +class GroqChatCompletion(OpenAILikeChatHandler): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def completion( + self, + *, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key: Optional[str], + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers: Optional[dict] = None, + timeout: Optional[Union[float, Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + custom_endpoint: Optional[bool] = None, + streaming_decoder: Optional[CustomStreamingDecoder] = None, + fake_stream: bool = False, + ): + messages = GroqChatConfig()._transform_messages( + messages=cast(List[AllMessageValues], messages), model=model + ) + + if optional_params.get("stream") is True: + fake_stream = GroqChatConfig()._should_fake_stream(optional_params) + else: + fake_stream = False + + return super().completion( + model=model, + messages=messages, + api_base=api_base, + custom_llm_provider=custom_llm_provider, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + custom_endpoint=custom_endpoint, + streaming_decoder=streaming_decoder, + fake_stream=fake_stream, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/groq/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/groq/chat/transformation.py new file mode 100644 index 00000000..5b24f7d1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/groq/chat/transformation.py @@ -0,0 +1,158 @@ +""" +Translate from OpenAI's `/v1/chat/completions` to Groq's `/v1/chat/completions` +""" + +from typing import List, Optional, Tuple, Union + +from pydantic import BaseModel + +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import ( + AllMessageValues, + ChatCompletionAssistantMessage, + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, +) + +from ...openai.chat.gpt_transformation import OpenAIGPTConfig + + +class GroqChatConfig(OpenAIGPTConfig): + + frequency_penalty: Optional[int] = None + function_call: Optional[Union[str, dict]] = None + functions: Optional[list] = None + logit_bias: Optional[dict] = None + max_tokens: Optional[int] = None + n: Optional[int] = None + presence_penalty: Optional[int] = None + stop: Optional[Union[str, list]] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + response_format: Optional[dict] = None + tools: Optional[list] = None + tool_choice: Optional[Union[str, dict]] = None + + def __init__( + self, + frequency_penalty: Optional[int] = None, + function_call: Optional[Union[str, dict]] = None, + functions: Optional[list] = None, + logit_bias: Optional[dict] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[int] = None, + stop: Optional[Union[str, list]] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + response_format: Optional[dict] = None, + tools: Optional[list] = None, + tool_choice: Optional[Union[str, dict]] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def _transform_messages(self, messages: List[AllMessageValues], model: str) -> List: + for idx, message in enumerate(messages): + """ + 1. Don't pass 'null' function_call assistant message to groq - https://github.com/BerriAI/litellm/issues/5839 + """ + if isinstance(message, BaseModel): + _message = message.model_dump() + else: + _message = message + assistant_message = _message.get("role") == "assistant" + if assistant_message: + new_message = ChatCompletionAssistantMessage(role="assistant") + for k, v in _message.items(): + if v is not None: + new_message[k] = v # type: ignore + messages[idx] = new_message + + return messages + + def _get_openai_compatible_provider_info( + self, api_base: Optional[str], api_key: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: + # groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1 + api_base = ( + api_base + or get_secret_str("GROQ_API_BASE") + or "https://api.groq.com/openai/v1" + ) # type: ignore + dynamic_api_key = api_key or get_secret_str("GROQ_API_KEY") + return api_base, dynamic_api_key + + def _should_fake_stream(self, optional_params: dict) -> bool: + """ + Groq doesn't support 'response_format' while streaming + """ + if optional_params.get("response_format") is not None: + return True + + return False + + def _create_json_tool_call_for_response_format( + self, + json_schema: dict, + ): + """ + Handles creating a tool call for getting responses in JSON format. + + Args: + json_schema (Optional[dict]): The JSON schema the response should be in + + Returns: + AnthropicMessagesTool: The tool call to send to Anthropic API to get responses in JSON format + """ + return ChatCompletionToolParam( + type="function", + function=ChatCompletionToolParamFunctionChunk( + name="json_tool_call", + parameters=json_schema, + ), + ) + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool = False, + ) -> dict: + _response_format = non_default_params.get("response_format") + if _response_format is not None and isinstance(_response_format, dict): + json_schema: Optional[dict] = None + if "response_schema" in _response_format: + json_schema = _response_format["response_schema"] + elif "json_schema" in _response_format: + json_schema = _response_format["json_schema"]["schema"] + """ + When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode + - You usually want to provide a single tool + - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool + - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. + """ + if json_schema is not None: + _tool_choice = { + "type": "function", + "function": {"name": "json_tool_call"}, + } + _tool = self._create_json_tool_call_for_response_format( + json_schema=json_schema, + ) + optional_params["tools"] = [_tool] + optional_params["tool_choice"] = _tool_choice + optional_params["json_mode"] = True + non_default_params.pop( + "response_format", None + ) # only remove if it's a json_schema - handled via using groq's tool calling params. + return super().map_openai_params( + non_default_params, optional_params, model, drop_params + ) |