aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_non_gemini.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_non_gemini.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_non_gemini.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_non_gemini.py784
1 files changed, 784 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_non_gemini.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_non_gemini.py
new file mode 100644
index 00000000..744e1eb3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_non_gemini.py
@@ -0,0 +1,784 @@
+import json
+import os
+import time
+from typing import Any, Callable, Optional, cast
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.core_helpers import map_finish_reason
+from litellm.llms.bedrock.common_utils import ModelResponseIterator
+from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
+from litellm.types.llms.vertex_ai import *
+from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
+
+
+class VertexAIError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(
+ method="POST", url=" https://cloud.google.com/vertex-ai/"
+ )
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ self.message
+ ) # Call the base class constructor with the parameters it needs
+
+
+class TextStreamer:
+ """
+ Fake streaming iterator for Vertex AI Model Garden calls
+ """
+
+ def __init__(self, text):
+ self.text = text.split() # let's assume words as a streaming unit
+ self.index = 0
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.index < len(self.text):
+ result = self.text[self.index]
+ self.index += 1
+ return result
+ else:
+ raise StopIteration
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ if self.index < len(self.text):
+ result = self.text[self.index]
+ self.index += 1
+ return result
+ else:
+ raise StopAsyncIteration # once we run out of data to stream, we raise this error
+
+
+def _get_client_cache_key(
+ model: str, vertex_project: Optional[str], vertex_location: Optional[str]
+):
+ _cache_key = f"{model}-{vertex_project}-{vertex_location}"
+ return _cache_key
+
+
+def _get_client_from_cache(client_cache_key: str):
+ return litellm.in_memory_llm_clients_cache.get_cache(client_cache_key)
+
+
+def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
+ litellm.in_memory_llm_clients_cache.set_cache(
+ key=client_cache_key,
+ value=vertex_llm_model,
+ ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
+ )
+
+
+def completion( # noqa: PLR0915
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ logging_obj,
+ optional_params: dict,
+ vertex_project=None,
+ vertex_location=None,
+ vertex_credentials=None,
+ litellm_params=None,
+ logger_fn=None,
+ acompletion: bool = False,
+):
+ """
+ NON-GEMINI/ANTHROPIC CALLS.
+
+ This is the handler for OLDER PALM MODELS and VERTEX AI MODEL GARDEN
+
+ For Vertex AI Anthropic: `vertex_anthropic.py`
+ For Gemini: `vertex_httpx.py`
+ """
+ try:
+ import vertexai
+ except Exception:
+ raise VertexAIError(
+ status_code=400,
+ message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM",
+ )
+
+ if not (
+ hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
+ ):
+ raise VertexAIError(
+ status_code=400,
+ message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
+ )
+ try:
+ import google.auth # type: ignore
+ from google.cloud import aiplatform # type: ignore
+ from google.cloud.aiplatform_v1beta1.types import (
+ content as gapic_content_types, # type: ignore
+ )
+ from google.protobuf import json_format # type: ignore
+ from google.protobuf.struct_pb2 import Value # type: ignore
+ from vertexai.language_models import CodeGenerationModel, TextGenerationModel
+ from vertexai.preview.generative_models import GenerativeModel
+ from vertexai.preview.language_models import ChatModel, CodeChatModel
+
+ ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
+ print_verbose(
+ f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
+ )
+
+ _cache_key = _get_client_cache_key(
+ model=model, vertex_project=vertex_project, vertex_location=vertex_location
+ )
+ _vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
+
+ if _vertex_llm_model_object is None:
+ from google.auth.credentials import Credentials
+
+ if vertex_credentials is not None and isinstance(vertex_credentials, str):
+ import google.oauth2.service_account
+
+ json_obj = json.loads(vertex_credentials)
+
+ creds = (
+ google.oauth2.service_account.Credentials.from_service_account_info(
+ json_obj,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ )
+ else:
+ creds, _ = google.auth.default(quota_project_id=vertex_project)
+ print_verbose(
+ f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
+ )
+ vertexai.init(
+ project=vertex_project,
+ location=vertex_location,
+ credentials=cast(Credentials, creds),
+ )
+
+ ## Load Config
+ config = litellm.VertexAIConfig.get_config()
+ for k, v in config.items():
+ if k not in optional_params:
+ optional_params[k] = v
+
+ ## Process safety settings into format expected by vertex AI
+ safety_settings = None
+ if "safety_settings" in optional_params:
+ safety_settings = optional_params.pop("safety_settings")
+ if not isinstance(safety_settings, list):
+ raise ValueError("safety_settings must be a list")
+ if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
+ raise ValueError("safety_settings must be a list of dicts")
+ safety_settings = [
+ gapic_content_types.SafetySetting(x) for x in safety_settings
+ ]
+
+ # vertexai does not use an API key, it looks for credentials.json in the environment
+
+ prompt = " ".join(
+ [
+ message.get("content")
+ for message in messages
+ if isinstance(message.get("content", None), str)
+ ]
+ )
+
+ mode = ""
+
+ request_str = ""
+ response_obj = None
+ instances = None
+ client_options = {
+ "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
+ }
+ fake_stream = False
+ if (
+ model in litellm.vertex_language_models
+ or model in litellm.vertex_vision_models
+ ):
+ llm_model: Any = _vertex_llm_model_object or GenerativeModel(model)
+ mode = "vision"
+ request_str += f"llm_model = GenerativeModel({model})\n"
+ elif model in litellm.vertex_chat_models:
+ llm_model = _vertex_llm_model_object or ChatModel.from_pretrained(model)
+ mode = "chat"
+ request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
+ elif model in litellm.vertex_text_models:
+ llm_model = _vertex_llm_model_object or TextGenerationModel.from_pretrained(
+ model
+ )
+ mode = "text"
+ request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
+ elif model in litellm.vertex_code_text_models:
+ llm_model = _vertex_llm_model_object or CodeGenerationModel.from_pretrained(
+ model
+ )
+ mode = "text"
+ request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
+ fake_stream = True
+ elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
+ llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model)
+ mode = "chat"
+ request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
+ elif model == "private":
+ mode = "private"
+ model = optional_params.pop("model_id", None)
+ # private endpoint requires a dict instead of JSON
+ instances = [optional_params.copy()]
+ instances[0]["prompt"] = prompt
+ llm_model = aiplatform.PrivateEndpoint(
+ endpoint_name=model,
+ project=vertex_project,
+ location=vertex_location,
+ )
+ request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n"
+ else: # assume vertex model garden on public endpoint
+ mode = "custom"
+
+ instances = [optional_params.copy()]
+ instances[0]["prompt"] = prompt
+ instances = [
+ json_format.ParseDict(instance_dict, Value())
+ for instance_dict in instances
+ ]
+ # Will determine the API used based on async parameter
+ llm_model = None
+
+ # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
+ if acompletion is True:
+ data = {
+ "llm_model": llm_model,
+ "mode": mode,
+ "prompt": prompt,
+ "logging_obj": logging_obj,
+ "request_str": request_str,
+ "model": model,
+ "model_response": model_response,
+ "encoding": encoding,
+ "messages": messages,
+ "print_verbose": print_verbose,
+ "client_options": client_options,
+ "instances": instances,
+ "vertex_location": vertex_location,
+ "vertex_project": vertex_project,
+ "safety_settings": safety_settings,
+ **optional_params,
+ }
+ if optional_params.get("stream", False) is True:
+ # async streaming
+ return async_streaming(**data)
+
+ return async_completion(**data)
+
+ completion_response = None
+
+ stream = optional_params.pop(
+ "stream", None
+ ) # See note above on handling streaming for vertex ai
+ if mode == "chat":
+ chat = llm_model.start_chat()
+ request_str += "chat = llm_model.start_chat()\n"
+
+ if fake_stream is not True and stream is True:
+ # NOTE: VertexAI does not accept stream=True as a param and raises an error,
+ # we handle this by removing 'stream' from optional params and sending the request
+ # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
+ optional_params.pop(
+ "stream", None
+ ) # vertex ai raises an error when passing stream in optional params
+
+ request_str += (
+ f"chat.send_message_streaming({prompt}, **{optional_params})\n"
+ )
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ model_response = chat.send_message_streaming(prompt, **optional_params)
+
+ return model_response
+
+ request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ completion_response = chat.send_message(prompt, **optional_params).text
+ elif mode == "text":
+
+ if fake_stream is not True and stream is True:
+ request_str += (
+ f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
+ )
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ model_response = llm_model.predict_streaming(prompt, **optional_params)
+
+ return model_response
+
+ request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ completion_response = llm_model.predict(prompt, **optional_params).text
+ elif mode == "custom":
+ """
+ Vertex AI Model Garden
+ """
+
+ if vertex_project is None or vertex_location is None:
+ raise ValueError(
+ "Vertex project and location are required for custom endpoint"
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ llm_model = aiplatform.gapic.PredictionServiceClient(
+ client_options=client_options
+ )
+ request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options})\n"
+ endpoint_path = llm_model.endpoint_path(
+ project=vertex_project, location=vertex_location, endpoint=model
+ )
+ request_str += (
+ f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
+ )
+ response = llm_model.predict(
+ endpoint=endpoint_path, instances=instances
+ ).predictions
+
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+ if stream is True:
+ response = TextStreamer(completion_response)
+ return response
+ elif mode == "private":
+ """
+ Vertex AI Model Garden deployed on private endpoint
+ """
+ if instances is None:
+ raise ValueError("instances are required for private endpoint")
+ if llm_model is None:
+ raise ValueError("Unable to pick client for private endpoint")
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ request_str += f"llm_model.predict(instances={instances})\n"
+ response = llm_model.predict(instances=instances).predictions
+
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+ if stream is True:
+ response = TextStreamer(completion_response)
+ return response
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=prompt, api_key=None, original_response=completion_response
+ )
+
+ ## RESPONSE OBJECT
+ if isinstance(completion_response, litellm.Message):
+ model_response.choices[0].message = completion_response # type: ignore
+ elif len(str(completion_response)) > 0:
+ model_response.choices[0].message.content = str(completion_response) # type: ignore
+ model_response.created = int(time.time())
+ model_response.model = model
+ ## CALCULATING USAGE
+ if model in litellm.vertex_language_models and response_obj is not None:
+ model_response.choices[0].finish_reason = map_finish_reason(
+ response_obj.candidates[0].finish_reason.name
+ )
+ usage = Usage(
+ prompt_tokens=response_obj.usage_metadata.prompt_token_count,
+ completion_tokens=response_obj.usage_metadata.candidates_token_count,
+ total_tokens=response_obj.usage_metadata.total_token_count,
+ )
+ else:
+ # init prompt tokens
+ # this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
+ prompt_tokens, completion_tokens, _ = 0, 0, 0
+ if response_obj is not None:
+ if hasattr(response_obj, "usage_metadata") and hasattr(
+ response_obj.usage_metadata, "prompt_token_count"
+ ):
+ prompt_tokens = response_obj.usage_metadata.prompt_token_count
+ completion_tokens = (
+ response_obj.usage_metadata.candidates_token_count
+ )
+ else:
+ prompt_tokens = len(encoding.encode(prompt))
+ completion_tokens = len(
+ encoding.encode(
+ model_response["choices"][0]["message"].get("content", "")
+ )
+ )
+
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ setattr(model_response, "usage", usage)
+
+ if fake_stream is True and stream is True:
+ return ModelResponseIterator(model_response)
+ return model_response
+ except Exception as e:
+ if isinstance(e, VertexAIError):
+ raise e
+ raise litellm.APIConnectionError(
+ message=str(e), llm_provider="vertex_ai", model=model
+ )
+
+
+async def async_completion( # noqa: PLR0915
+ llm_model,
+ mode: str,
+ prompt: str,
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ request_str: str,
+ print_verbose: Callable,
+ logging_obj,
+ encoding,
+ client_options=None,
+ instances=None,
+ vertex_project=None,
+ vertex_location=None,
+ safety_settings=None,
+ **optional_params,
+):
+ """
+ Add support for acompletion calls for gemini-pro
+ """
+ try:
+
+ response_obj = None
+ completion_response = None
+ if mode == "chat":
+ # chat-bison etc.
+ chat = llm_model.start_chat()
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ response_obj = await chat.send_message_async(prompt, **optional_params)
+ completion_response = response_obj.text
+ elif mode == "text":
+ # gecko etc.
+ request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ response_obj = await llm_model.predict_async(prompt, **optional_params)
+ completion_response = response_obj.text
+ elif mode == "custom":
+ """
+ Vertex AI Model Garden
+ """
+ from google.cloud import aiplatform # type: ignore
+
+ if vertex_project is None or vertex_location is None:
+ raise ValueError(
+ "Vertex project and location are required for custom endpoint"
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
+ client_options=client_options
+ )
+ request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
+ endpoint_path = llm_model.endpoint_path(
+ project=vertex_project, location=vertex_location, endpoint=model
+ )
+ request_str += (
+ f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
+ )
+ response_obj = await llm_model.predict(
+ endpoint=endpoint_path,
+ instances=instances,
+ )
+ response = response_obj.predictions
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+
+ elif mode == "private":
+ request_str += f"llm_model.predict_async(instances={instances})\n"
+ response_obj = await llm_model.predict_async(
+ instances=instances,
+ )
+
+ response = response_obj.predictions
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=prompt, api_key=None, original_response=completion_response
+ )
+
+ ## RESPONSE OBJECT
+ if isinstance(completion_response, litellm.Message):
+ model_response.choices[0].message = completion_response # type: ignore
+ elif len(str(completion_response)) > 0:
+ model_response.choices[0].message.content = str( # type: ignore
+ completion_response
+ )
+ model_response.created = int(time.time())
+ model_response.model = model
+ ## CALCULATING USAGE
+ if model in litellm.vertex_language_models and response_obj is not None:
+ model_response.choices[0].finish_reason = map_finish_reason(
+ response_obj.candidates[0].finish_reason.name
+ )
+ usage = Usage(
+ prompt_tokens=response_obj.usage_metadata.prompt_token_count,
+ completion_tokens=response_obj.usage_metadata.candidates_token_count,
+ total_tokens=response_obj.usage_metadata.total_token_count,
+ )
+ else:
+ # init prompt tokens
+ # this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
+ prompt_tokens, completion_tokens, _ = 0, 0, 0
+ if response_obj is not None and (
+ hasattr(response_obj, "usage_metadata")
+ and hasattr(response_obj.usage_metadata, "prompt_token_count")
+ ):
+ prompt_tokens = response_obj.usage_metadata.prompt_token_count
+ completion_tokens = response_obj.usage_metadata.candidates_token_count
+ else:
+ prompt_tokens = len(encoding.encode(prompt))
+ completion_tokens = len(
+ encoding.encode(
+ model_response["choices"][0]["message"].get("content", "")
+ )
+ )
+
+ # set usage
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ setattr(model_response, "usage", usage)
+ return model_response
+ except Exception as e:
+ raise VertexAIError(status_code=500, message=str(e))
+
+
+async def async_streaming( # noqa: PLR0915
+ llm_model,
+ mode: str,
+ prompt: str,
+ model: str,
+ model_response: ModelResponse,
+ messages: list,
+ print_verbose: Callable,
+ logging_obj,
+ request_str: str,
+ encoding=None,
+ client_options=None,
+ instances=None,
+ vertex_project=None,
+ vertex_location=None,
+ safety_settings=None,
+ **optional_params,
+):
+ """
+ Add support for async streaming calls for gemini-pro
+ """
+ response: Any = None
+ if mode == "chat":
+ chat = llm_model.start_chat()
+ optional_params.pop(
+ "stream", None
+ ) # vertex ai raises an error when passing stream in optional params
+ request_str += (
+ f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
+ )
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ response = chat.send_message_streaming_async(prompt, **optional_params)
+
+ elif mode == "text":
+ optional_params.pop(
+ "stream", None
+ ) # See note above on handling streaming for vertex ai
+ request_str += (
+ f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
+ )
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ response = llm_model.predict_streaming_async(prompt, **optional_params)
+ elif mode == "custom":
+ from google.cloud import aiplatform # type: ignore
+
+ if vertex_project is None or vertex_location is None:
+ raise ValueError(
+ "Vertex project and location are required for custom endpoint"
+ )
+
+ stream = optional_params.pop("stream", None)
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
+ client_options=client_options
+ )
+ request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
+ endpoint_path = llm_model.endpoint_path(
+ project=vertex_project, location=vertex_location, endpoint=model
+ )
+ request_str += (
+ f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
+ )
+ response_obj = await llm_model.predict(
+ endpoint=endpoint_path,
+ instances=instances,
+ )
+
+ response = response_obj.predictions
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+ if stream:
+ response = TextStreamer(completion_response)
+
+ elif mode == "private":
+ if instances is None:
+ raise ValueError("Instances are required for private endpoint")
+ stream = optional_params.pop("stream", None)
+ _ = instances[0].pop("stream", None)
+ request_str += f"llm_model.predict_async(instances={instances})\n"
+ response_obj = await llm_model.predict_async(
+ instances=instances,
+ )
+ response = response_obj.predictions
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+ if stream:
+ response = TextStreamer(completion_response)
+
+ if response is None:
+ raise ValueError("Unable to generate response")
+
+ logging_obj.post_call(input=prompt, api_key=None, original_response=response)
+
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=response,
+ model=model,
+ custom_llm_provider="vertex_ai",
+ logging_obj=logging_obj,
+ )
+
+ return streamwrapper