about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/snowflake
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/snowflake')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/snowflake/chat/transformation.py167
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/snowflake/common_utils.py34
2 files changed, 201 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/snowflake/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/snowflake/chat/transformation.py
new file mode 100644
index 00000000..d3634e79
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/snowflake/chat/transformation.py
@@ -0,0 +1,167 @@
+"""
+Support for Snowflake REST API 
+"""
+
+from typing import TYPE_CHECKING, Any, List, Optional, Tuple
+
+import httpx
+
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse
+
+from ...openai_like.chat.transformation import OpenAIGPTConfig
+
+if TYPE_CHECKING:
+    from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+    LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+    LiteLLMLoggingObj = Any
+
+
+class SnowflakeConfig(OpenAIGPTConfig):
+    """
+    source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
+    """
+
+    @classmethod
+    def get_config(cls):
+        return super().get_config()
+
+    def get_supported_openai_params(self, model: str) -> List:
+        return ["temperature", "max_tokens", "top_p", "response_format"]
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        """
+        If any supported_openai_params are in non_default_params, add them to optional_params, so they are used in API call
+
+        Args:
+            non_default_params (dict): Non-default parameters to filter.
+            optional_params (dict): Optional parameters to update.
+            model (str): Model name for parameter support check.
+
+        Returns:
+            dict: Updated optional_params with supported non-default parameters.
+        """
+        supported_openai_params = self.get_supported_openai_params(model)
+        for param, value in non_default_params.items():
+            if param in supported_openai_params:
+                optional_params[param] = value
+        return optional_params
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: dict,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+        response_json = raw_response.json()
+        logging_obj.post_call(
+            input=messages,
+            api_key="",
+            original_response=response_json,
+            additional_args={"complete_input_dict": request_data},
+        )
+
+        returned_response = ModelResponse(**response_json)
+
+        returned_response.model = "snowflake/" + (returned_response.model or "")
+
+        if model is not None:
+            returned_response._hidden_params["model"] = model
+        return returned_response
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> dict:
+        """
+        Return headers to use for Snowflake completion request
+
+        Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference
+        Expected headers:
+        {
+            "Content-Type": "application/json",
+            "Accept": "application/json",
+            "Authorization": "Bearer " + <JWT>,
+            "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
+        }
+        """
+
+        if api_key is None:
+            raise ValueError("Missing Snowflake JWT key")
+
+        headers.update(
+            {
+                "Content-Type": "application/json",
+                "Accept": "application/json",
+                "Authorization": "Bearer " + api_key,
+                "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
+            }
+        )
+        return headers
+
+    def _get_openai_compatible_provider_info(
+        self, api_base: Optional[str], api_key: Optional[str]
+    ) -> Tuple[Optional[str], Optional[str]]:
+        api_base = (
+            api_base
+            or f"""https://{get_secret_str("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete"""
+            or get_secret_str("SNOWFLAKE_API_BASE")
+        )
+        dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT")
+        return api_base, dynamic_api_key
+
+    def get_complete_url(
+        self,
+        api_base: Optional[str],
+        model: str,
+        optional_params: dict,
+        litellm_params: dict,
+        stream: Optional[bool] = None,
+    ) -> str:
+        """
+        If api_base is not provided, use the default DeepSeek /chat/completions endpoint.
+        """
+        if not api_base:
+            api_base = f"""https://{get_secret_str("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete"""
+
+        return api_base
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+        stream: bool = optional_params.pop("stream", None) or False
+        extra_body = optional_params.pop("extra_body", {})
+        return {
+            "model": model,
+            "messages": messages,
+            "stream": stream,
+            **optional_params,
+            **extra_body,
+        }
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/snowflake/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/snowflake/common_utils.py
new file mode 100644
index 00000000..40c8270f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/snowflake/common_utils.py
@@ -0,0 +1,34 @@
+from typing import Optional
+
+
+class SnowflakeBase:
+    def validate_environment(
+        self,
+        headers: dict,
+        JWT: Optional[str] = None,
+    ) -> dict:
+        """
+        Return headers to use for Snowflake completion request
+
+        Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference
+        Expected headers:
+        {
+            "Content-Type": "application/json",
+            "Accept": "application/json",
+            "Authorization": "Bearer " + <JWT>,
+            "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
+        }
+        """
+
+        if JWT is None:
+            raise ValueError("Missing Snowflake JWT key")
+
+        headers.update(
+            {
+                "Content-Type": "application/json",
+                "Accept": "application/json",
+                "Authorization": "Bearer " + JWT,
+                "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
+            }
+        )
+        return headers