about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/xai
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/xai')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/xai/chat/transformation.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/xai/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/xai/chat/transformation.py
new file mode 100644
index 00000000..734c6eb2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/xai/chat/transformation.py
@@ -0,0 +1,53 @@
+from typing import Optional, Tuple
+
+from litellm.secret_managers.main import get_secret_str
+
+from ...openai.chat.gpt_transformation import OpenAIGPTConfig
+
+XAI_API_BASE = "https://api.x.ai/v1"
+
+
+class XAIChatConfig(OpenAIGPTConfig):
+    def _get_openai_compatible_provider_info(
+        self, api_base: Optional[str], api_key: Optional[str]
+    ) -> Tuple[Optional[str], Optional[str]]:
+        api_base = api_base or get_secret_str("XAI_API_BASE") or XAI_API_BASE  # type: ignore
+        dynamic_api_key = api_key or get_secret_str("XAI_API_KEY")
+        return api_base, dynamic_api_key
+
+    def get_supported_openai_params(self, model: str) -> list:
+        return [
+            "frequency_penalty",
+            "logit_bias",
+            "logprobs",
+            "max_tokens",
+            "n",
+            "presence_penalty",
+            "response_format",
+            "seed",
+            "stop",
+            "stream",
+            "stream_options",
+            "temperature",
+            "tool_choice",
+            "tools",
+            "top_logprobs",
+            "top_p",
+            "user",
+        ]
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool = False,
+    ) -> dict:
+        supported_openai_params = self.get_supported_openai_params(model=model)
+        for param, value in non_default_params.items():
+            if param == "max_completion_tokens":
+                optional_params["max_tokens"] = value
+            elif param in supported_openai_params:
+                if value is not None:
+                    optional_params[param] = value
+        return optional_params