about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/ai21/transformation.py62
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py114
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/llama3/transformation.py73
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py242
4 files changed, 491 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/ai21/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/ai21/transformation.py
new file mode 100644
index 00000000..d87b2e03
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/ai21/transformation.py
@@ -0,0 +1,62 @@
+import types
+from typing import Optional
+
+import litellm
+
+
+class VertexAIAi21Config:
+    """
+    Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
+
+    The class `VertexAIAi21Config` provides configuration for the VertexAI's AI21 API interface
+
+    -> Supports all OpenAI parameters
+    """
+
+    def __init__(
+        self,
+        max_tokens: Optional[int] = None,
+    ) -> None:
+        locals_ = locals().copy()
+        for key, value in locals_.items():
+            if key != "self" and value is not None:
+                setattr(self.__class__, key, value)
+
+    @classmethod
+    def get_config(cls):
+        return {
+            k: v
+            for k, v in cls.__dict__.items()
+            if not k.startswith("__")
+            and not isinstance(
+                v,
+                (
+                    types.FunctionType,
+                    types.BuiltinFunctionType,
+                    classmethod,
+                    staticmethod,
+                ),
+            )
+            and v is not None
+        }
+
+    def get_supported_openai_params(self):
+        return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ):
+        if "max_completion_tokens" in non_default_params:
+            non_default_params["max_tokens"] = non_default_params.pop(
+                "max_completion_tokens"
+            )
+        return litellm.OpenAIConfig().map_openai_params(
+            non_default_params=non_default_params,
+            optional_params=optional_params,
+            model=model,
+            drop_params=drop_params,
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py
new file mode 100644
index 00000000..ab0555b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py
@@ -0,0 +1,114 @@
+# What is this?
+## Handler file for calling claude-3 on vertex ai
+from typing import Any, List, Optional
+
+import httpx
+
+import litellm
+from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse
+
+from ....anthropic.chat.transformation import AnthropicConfig
+
+
+class VertexAIError(Exception):
+    def __init__(self, status_code, message):
+        self.status_code = status_code
+        self.message = message
+        self.request = httpx.Request(
+            method="POST", url=" https://cloud.google.com/vertex-ai/"
+        )
+        self.response = httpx.Response(status_code=status_code, request=self.request)
+        super().__init__(
+            self.message
+        )  # Call the base class constructor with the parameters it needs
+
+
+class VertexAIAnthropicConfig(AnthropicConfig):
+    """
+    Reference:https://docs.anthropic.com/claude/reference/messages_post
+
+    Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
+
+    - `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
+    - `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
+
+    The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
+
+    - `max_tokens` Required (integer) max tokens,
+    - `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
+    - `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
+    - `temperature` Optional (float) The amount of randomness injected into the response
+    - `top_p` Optional (float) Use nucleus sampling.
+    - `top_k` Optional (int) Only sample from the top K options for each subsequent token
+    - `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
+
+    Note: Please make sure to modify the default parameters as required for your use case.
+    """
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+        data = super().transform_request(
+            model=model,
+            messages=messages,
+            optional_params=optional_params,
+            litellm_params=litellm_params,
+            headers=headers,
+        )
+
+        data.pop("model", None)  # vertex anthropic doesn't accept 'model' parameter
+        return data
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: dict,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+        response = super().transform_response(
+            model,
+            raw_response,
+            model_response,
+            logging_obj,
+            request_data,
+            messages,
+            optional_params,
+            litellm_params,
+            encoding,
+            api_key,
+            json_mode,
+        )
+        response.model = model
+
+        return response
+
+    @classmethod
+    def is_supported_model(cls, model: str, custom_llm_provider: str) -> bool:
+        """
+        Check if the model is supported by the VertexAI Anthropic API.
+        """
+        if (
+            custom_llm_provider != "vertex_ai"
+            and custom_llm_provider != "vertex_ai_beta"
+        ):
+            return False
+        if "claude" in model.lower():
+            return True
+        elif model in litellm.vertex_anthropic_models:
+            return True
+        return False
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/llama3/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/llama3/transformation.py
new file mode 100644
index 00000000..cf46f4a7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/llama3/transformation.py
@@ -0,0 +1,73 @@
+import types
+from typing import Optional
+
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
+
+
+class VertexAILlama3Config(OpenAIGPTConfig):
+    """
+    Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
+
+    The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
+
+    - `max_tokens` Required (integer) max tokens,
+
+    Note: Please make sure to modify the default parameters as required for your use case.
+    """
+
+    max_tokens: Optional[int] = None
+
+    def __init__(
+        self,
+        max_tokens: Optional[int] = None,
+    ) -> None:
+        locals_ = locals().copy()
+        for key, value in locals_.items():
+            if key == "max_tokens" and value is None:
+                value = self.max_tokens
+            if key != "self" and value is not None:
+                setattr(self.__class__, key, value)
+
+    @classmethod
+    def get_config(cls):
+        return {
+            k: v
+            for k, v in cls.__dict__.items()
+            if not k.startswith("__")
+            and not isinstance(
+                v,
+                (
+                    types.FunctionType,
+                    types.BuiltinFunctionType,
+                    classmethod,
+                    staticmethod,
+                ),
+            )
+            and v is not None
+        }
+
+    def get_supported_openai_params(self, model: str):
+        supported_params = super().get_supported_openai_params(model=model)
+        try:
+            supported_params.remove("max_retries")
+        except KeyError:
+            pass
+        return supported_params
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ):
+        if "max_completion_tokens" in non_default_params:
+            non_default_params["max_tokens"] = non_default_params.pop(
+                "max_completion_tokens"
+            )
+        return super().map_openai_params(
+            non_default_params=non_default_params,
+            optional_params=optional_params,
+            model=model,
+            drop_params=drop_params,
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py
new file mode 100644
index 00000000..fb239363
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py
@@ -0,0 +1,242 @@
+# What is this?
+## API Handler for calling Vertex AI Partner Models
+from enum import Enum
+from typing import Callable, Optional, Union
+
+import httpx  # type: ignore
+
+import litellm
+from litellm import LlmProviders
+from litellm.utils import ModelResponse
+
+from ..vertex_llm_base import VertexBase
+
+
+class VertexPartnerProvider(str, Enum):
+    mistralai = "mistralai"
+    llama = "llama"
+    ai21 = "ai21"
+    claude = "claude"
+
+
+class VertexAIError(Exception):
+    def __init__(self, status_code, message):
+        self.status_code = status_code
+        self.message = message
+        self.request = httpx.Request(
+            method="POST", url=" https://cloud.google.com/vertex-ai/"
+        )
+        self.response = httpx.Response(status_code=status_code, request=self.request)
+        super().__init__(
+            self.message
+        )  # Call the base class constructor with the parameters it needs
+
+
+def create_vertex_url(
+    vertex_location: str,
+    vertex_project: str,
+    partner: VertexPartnerProvider,
+    stream: Optional[bool],
+    model: str,
+    api_base: Optional[str] = None,
+) -> str:
+    """Return the base url for the vertex partner models"""
+    if partner == VertexPartnerProvider.llama:
+        return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi/chat/completions"
+    elif partner == VertexPartnerProvider.mistralai:
+        if stream:
+            return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
+        else:
+            return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
+    elif partner == VertexPartnerProvider.ai21:
+        if stream:
+            return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
+        else:
+            return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
+    elif partner == VertexPartnerProvider.claude:
+        if stream:
+            return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
+        else:
+            return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
+
+
+class VertexAIPartnerModels(VertexBase):
+    def __init__(self) -> None:
+        pass
+
+    def completion(
+        self,
+        model: str,
+        messages: list,
+        model_response: ModelResponse,
+        print_verbose: Callable,
+        encoding,
+        logging_obj,
+        api_base: Optional[str],
+        optional_params: dict,
+        custom_prompt_dict: dict,
+        headers: Optional[dict],
+        timeout: Union[float, httpx.Timeout],
+        litellm_params: dict,
+        vertex_project=None,
+        vertex_location=None,
+        vertex_credentials=None,
+        logger_fn=None,
+        acompletion: bool = False,
+        client=None,
+    ):
+        try:
+            import vertexai
+
+            from litellm.llms.anthropic.chat import AnthropicChatCompletion
+            from litellm.llms.codestral.completion.handler import (
+                CodestralTextCompletion,
+            )
+            from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
+            from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
+                VertexLLM,
+            )
+        except Exception as e:
+            raise VertexAIError(
+                status_code=400,
+                message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
+            )
+
+        if not (
+            hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
+        ):
+            raise VertexAIError(
+                status_code=400,
+                message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
+            )
+        try:
+
+            vertex_httpx_logic = VertexLLM()
+
+            access_token, project_id = vertex_httpx_logic._ensure_access_token(
+                credentials=vertex_credentials,
+                project_id=vertex_project,
+                custom_llm_provider="vertex_ai",
+            )
+
+            openai_like_chat_completions = OpenAILikeChatHandler()
+            codestral_fim_completions = CodestralTextCompletion()
+            anthropic_chat_completions = AnthropicChatCompletion()
+
+            ## CONSTRUCT API BASE
+            stream: bool = optional_params.get("stream", False) or False
+
+            optional_params["stream"] = stream
+
+            if "llama" in model:
+                partner = VertexPartnerProvider.llama
+            elif "mistral" in model or "codestral" in model:
+                partner = VertexPartnerProvider.mistralai
+            elif "jamba" in model:
+                partner = VertexPartnerProvider.ai21
+            elif "claude" in model:
+                partner = VertexPartnerProvider.claude
+
+            default_api_base = create_vertex_url(
+                vertex_location=vertex_location or "us-central1",
+                vertex_project=vertex_project or project_id,
+                partner=partner,  # type: ignore
+                stream=stream,
+                model=model,
+            )
+
+            if len(default_api_base.split(":")) > 1:
+                endpoint = default_api_base.split(":")[-1]
+            else:
+                endpoint = ""
+
+            _, api_base = self._check_custom_proxy(
+                api_base=api_base,
+                custom_llm_provider="vertex_ai",
+                gemini_api_key=None,
+                endpoint=endpoint,
+                stream=stream,
+                auth_header=None,
+                url=default_api_base,
+            )
+
+            if "codestral" in model or "mistral" in model:
+                model = model.split("@")[0]
+
+            if "codestral" in model and litellm_params.get("text_completion") is True:
+                optional_params["model"] = model
+                text_completion_model_response = litellm.TextCompletionResponse(
+                    stream=stream
+                )
+                return codestral_fim_completions.completion(
+                    model=model,
+                    messages=messages,
+                    api_base=api_base,
+                    api_key=access_token,
+                    custom_prompt_dict=custom_prompt_dict,
+                    model_response=text_completion_model_response,
+                    print_verbose=print_verbose,
+                    logging_obj=logging_obj,
+                    optional_params=optional_params,
+                    acompletion=acompletion,
+                    litellm_params=litellm_params,
+                    logger_fn=logger_fn,
+                    timeout=timeout,
+                    encoding=encoding,
+                )
+            elif "claude" in model:
+                if headers is None:
+                    headers = {}
+                headers.update({"Authorization": "Bearer {}".format(access_token)})
+
+                optional_params.update(
+                    {
+                        "anthropic_version": "vertex-2023-10-16",
+                        "is_vertex_request": True,
+                    }
+                )
+
+                return anthropic_chat_completions.completion(
+                    model=model,
+                    messages=messages,
+                    api_base=api_base,
+                    acompletion=acompletion,
+                    custom_prompt_dict=litellm.custom_prompt_dict,
+                    model_response=model_response,
+                    print_verbose=print_verbose,
+                    optional_params=optional_params,
+                    litellm_params=litellm_params,
+                    logger_fn=logger_fn,
+                    encoding=encoding,  # for calculating input/output tokens
+                    api_key=access_token,
+                    logging_obj=logging_obj,
+                    headers=headers,
+                    timeout=timeout,
+                    client=client,
+                    custom_llm_provider=LlmProviders.VERTEX_AI.value,
+                )
+
+            return openai_like_chat_completions.completion(
+                model=model,
+                messages=messages,
+                api_base=api_base,
+                api_key=access_token,
+                custom_prompt_dict=custom_prompt_dict,
+                model_response=model_response,
+                print_verbose=print_verbose,
+                logging_obj=logging_obj,
+                optional_params=optional_params,
+                acompletion=acompletion,
+                litellm_params=litellm_params,
+                logger_fn=logger_fn,
+                client=client,
+                timeout=timeout,
+                encoding=encoding,
+                custom_llm_provider="vertex_ai",
+                custom_endpoint=True,
+            )
+
+        except Exception as e:
+            if hasattr(e, "status_code"):
+                raise e
+            raise VertexAIError(status_code=500, message=str(e))