aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py218
1 files changed, 218 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py
new file mode 100644
index 00000000..b82268be
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py
@@ -0,0 +1,218 @@
+import json
+from typing import Any, Coroutine, Dict, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.llms.custom_httpx.http_handler import (
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
+from litellm.types.llms.openai import CreateBatchRequest
+from litellm.types.llms.vertex_ai import (
+ VERTEX_CREDENTIALS_TYPES,
+ VertexAIBatchPredictionJob,
+)
+from litellm.types.utils import LiteLLMBatch
+
+from .transformation import VertexAIBatchTransformation
+
+
+class VertexAIBatchPrediction(VertexLLM):
+ def __init__(self, gcs_bucket_name: str, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.gcs_bucket_name = gcs_bucket_name
+
+ def create_batch(
+ self,
+ _is_async: bool,
+ create_batch_data: CreateBatchRequest,
+ api_base: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
+
+ sync_handler = _get_httpx_client()
+
+ access_token, project_id = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai",
+ )
+
+ default_api_base = self.create_vertex_url(
+ vertex_location=vertex_location or "us-central1",
+ vertex_project=vertex_project or project_id,
+ )
+
+ if len(default_api_base.split(":")) > 1:
+ endpoint = default_api_base.split(":")[-1]
+ else:
+ endpoint = ""
+
+ _, api_base = self._check_custom_proxy(
+ api_base=api_base,
+ custom_llm_provider="vertex_ai",
+ gemini_api_key=None,
+ endpoint=endpoint,
+ stream=None,
+ auth_header=None,
+ url=default_api_base,
+ )
+
+ headers = {
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {access_token}",
+ }
+
+ vertex_batch_request: VertexAIBatchPredictionJob = (
+ VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request(
+ request=create_batch_data
+ )
+ )
+
+ if _is_async is True:
+ return self._async_create_batch(
+ vertex_batch_request=vertex_batch_request,
+ api_base=api_base,
+ headers=headers,
+ )
+
+ response = sync_handler.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(vertex_batch_request),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+ response=_json_response
+ )
+ return vertex_batch_response
+
+ async def _async_create_batch(
+ self,
+ vertex_batch_request: VertexAIBatchPredictionJob,
+ api_base: str,
+ headers: Dict[str, str],
+ ) -> LiteLLMBatch:
+ client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ )
+ response = await client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(vertex_batch_request),
+ )
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+ response=_json_response
+ )
+ return vertex_batch_response
+
+ def create_vertex_url(
+ self,
+ vertex_location: str,
+ vertex_project: str,
+ ) -> str:
+ """Return the base url for the vertex garden models"""
+ # POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
+
+ def retrieve_batch(
+ self,
+ _is_async: bool,
+ batch_id: str,
+ api_base: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
+ sync_handler = _get_httpx_client()
+
+ access_token, project_id = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai",
+ )
+
+ default_api_base = self.create_vertex_url(
+ vertex_location=vertex_location or "us-central1",
+ vertex_project=vertex_project or project_id,
+ )
+
+ # Append batch_id to the URL
+ default_api_base = f"{default_api_base}/{batch_id}"
+
+ if len(default_api_base.split(":")) > 1:
+ endpoint = default_api_base.split(":")[-1]
+ else:
+ endpoint = ""
+
+ _, api_base = self._check_custom_proxy(
+ api_base=api_base,
+ custom_llm_provider="vertex_ai",
+ gemini_api_key=None,
+ endpoint=endpoint,
+ stream=None,
+ auth_header=None,
+ url=default_api_base,
+ )
+
+ headers = {
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {access_token}",
+ }
+
+ if _is_async is True:
+ return self._async_retrieve_batch(
+ api_base=api_base,
+ headers=headers,
+ )
+
+ response = sync_handler.get(
+ url=api_base,
+ headers=headers,
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+ response=_json_response
+ )
+ return vertex_batch_response
+
+ async def _async_retrieve_batch(
+ self,
+ api_base: str,
+ headers: Dict[str, str],
+ ) -> LiteLLMBatch:
+ client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ )
+ response = await client.get(
+ url=api_base,
+ headers=headers,
+ )
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+ response=_json_response
+ )
+ return vertex_batch_response