diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/cerebras')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/cerebras/chat.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cerebras/chat.py b/.venv/lib/python3.12/site-packages/litellm/llms/cerebras/chat.py new file mode 100644 index 00000000..4e9c6811 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cerebras/chat.py @@ -0,0 +1,83 @@ +""" +Cerebras Chat Completions API + +this is OpenAI compatible - no translation needed / occurs +""" + +from typing import Optional + +from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig + + +class CerebrasConfig(OpenAIGPTConfig): + """ + Reference: https://inference-docs.cerebras.ai/api-reference/chat-completions + + Below are the parameters: + """ + + max_tokens: Optional[int] = None + response_format: Optional[dict] = None + seed: Optional[int] = None + stream: Optional[bool] = None + top_p: Optional[int] = None + tool_choice: Optional[str] = None + tools: Optional[list] = None + user: Optional[str] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + response_format: Optional[dict] = None, + seed: Optional[int] = None, + stop: Optional[str] = None, + stream: Optional[bool] = None, + temperature: Optional[float] = None, + top_p: Optional[int] = None, + tool_choice: Optional[str] = None, + tools: Optional[list] = None, + user: Optional[str] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_supported_openai_params(self, model: str) -> list: + """ + Get the supported OpenAI params for the given model + + """ + + return [ + "max_tokens", + "max_completion_tokens", + "response_format", + "seed", + "stop", + "stream", + "temperature", + "top_p", + "tool_choice", + "tools", + "user", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + supported_openai_params = self.get_supported_openai_params(model=model) + for param, value in non_default_params.items(): + if param == "max_completion_tokens": + optional_params["max_tokens"] = value + elif param in supported_openai_params: + optional_params[param] = value + return optional_params |