diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/google/genai | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/google/genai')
20 files changed, 24595 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/google/genai/__init__.py b/.venv/lib/python3.12/site-packages/google/genai/__init__.py new file mode 100644 index 00000000..358b726f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Google Gen AI SDK""" + +from .client import Client +from . import version + +__version__ = version.__version__ + +__all__ = ['Client'] diff --git a/.venv/lib/python3.12/site-packages/google/genai/_api_client.py b/.venv/lib/python3.12/site-packages/google/genai/_api_client.py new file mode 100644 index 00000000..cc0dada1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/_api_client.py @@ -0,0 +1,697 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +"""Base client for calling HTTP APIs sending and receiving JSON.""" + +import asyncio +import copy +from dataclasses import dataclass +import datetime +import io +import json +import logging +import os +import sys +from typing import Any, Optional, Tuple, TypedDict, Union +from urllib.parse import urlparse, urlunparse + +import google.auth +import google.auth.credentials +from google.auth.transport.requests import AuthorizedSession +from pydantic import BaseModel, ConfigDict, Field, ValidationError +import requests + +from . import errors +from . import version + + +class HttpOptions(BaseModel): + """HTTP options for the api client.""" + model_config = ConfigDict(extra='forbid') + + base_url: Optional[str] = Field( + default=None, + description="""The base URL for the AI platform service endpoint.""", + ) + api_version: Optional[str] = Field( + default=None, + description="""Specifies the version of the API to use.""", + ) + headers: Optional[dict[str, str]] = Field( + default=None, + description="""Additional HTTP headers to be sent with the request.""", + ) + response_payload: Optional[dict] = Field( + default=None, + description="""If set, the response payload will be returned int the supplied dict.""", + ) + timeout: Optional[Union[float, Tuple[float, float]]] = Field( + default=None, + description="""Timeout for the request in seconds.""", + ) + skip_project_and_location_in_path: bool = Field( + default=False, + description="""If set to True, the project and location will not be appended to the path.""", + ) + + +class HttpOptionsDict(TypedDict): + """HTTP options for the api client.""" + + base_url: Optional[str] = None + """The base URL for the AI platform service endpoint.""" + api_version: Optional[str] = None + """Specifies the version of the API to use.""" + headers: Optional[dict[str, str]] = None + """Additional HTTP headers to be sent with the request.""" + response_payload: Optional[dict] = None + """If set, the response payload will be returned int the supplied dict.""" + timeout: Optional[Union[float, Tuple[float, float]]] = None + """Timeout for the request in seconds.""" + skip_project_and_location_in_path: bool = False + """If set to True, the project and location will not be appended to the path.""" + +HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict] + + +def _append_library_version_headers(headers: dict[str, str]) -> None: + """Appends the telemetry header to the headers dict.""" + library_label = f'google-genai-sdk/{version.__version__}' + language_label = 'gl-python/' + sys.version.split()[0] + version_header_value = f'{library_label} {language_label}' + if ( + 'user-agent' in headers + and version_header_value not in headers['user-agent'] + ): + headers['user-agent'] += f' {version_header_value}' + elif 'user-agent' not in headers: + headers['user-agent'] = version_header_value + if ( + 'x-goog-api-client' in headers + and version_header_value not in headers['x-goog-api-client'] + ): + headers['x-goog-api-client'] += f' {version_header_value}' + elif 'x-goog-api-client' not in headers: + headers['x-goog-api-client'] = version_header_value + + +def _patch_http_options( + options: HttpOptionsDict, patch_options: HttpOptionsDict +) -> HttpOptionsDict: + # use shallow copy so we don't override the original objects. + copy_option = HttpOptionsDict() + copy_option.update(options) + for patch_key, patch_value in patch_options.items(): + # if both are dicts, update the copy. + # This is to handle cases like merging headers. + if isinstance(patch_value, dict) and isinstance( + copy_option.get(patch_key, None), dict + ): + copy_option[patch_key] = {} + copy_option[patch_key].update( + options[patch_key] + ) # shallow copy from original options. + copy_option[patch_key].update(patch_value) + elif patch_value is not None: # Accept empty values. + copy_option[patch_key] = patch_value + _append_library_version_headers(copy_option['headers']) + return copy_option + + +def _join_url_path(base_url: str, path: str) -> str: + parsed_base = urlparse(base_url) + base_path = parsed_base.path[:-1] if parsed_base.path.endswith('/') else parsed_base.path + path = path[1:] if path.startswith('/') else path + return urlunparse(parsed_base._replace(path=base_path + '/' + path)) + + +@dataclass +class HttpRequest: + headers: dict[str, str] + url: str + method: str + data: Union[dict[str, object], bytes] + timeout: Optional[Union[float, Tuple[float, float]]] = None + + +class HttpResponse: + + def __init__( + self, + headers: dict[str, str], + response_stream: Union[Any, str] = None, + byte_stream: Union[Any, bytes] = None, + ): + self.status_code = 200 + self.headers = headers + self.response_stream = response_stream + self.byte_stream = byte_stream + + @property + def text(self) -> str: + if not self.response_stream[0]: # Empty response + return '' + return json.loads(self.response_stream[0]) + + def segments(self): + if isinstance(self.response_stream, list): + # list of objects retrieved from replay or from non-streaming API. + for chunk in self.response_stream: + yield json.loads(chunk) if chunk else {} + elif self.response_stream is None: + yield from [] + else: + # Iterator of objects retrieved from the API. + for chunk in self.response_stream.iter_lines(): + if chunk: + # In streaming mode, the chunk of JSON is prefixed with "data:" which + # we must strip before parsing. + if chunk.startswith(b'data: '): + chunk = chunk[len(b'data: ') :] + yield json.loads(str(chunk, 'utf-8')) + + def byte_segments(self): + if isinstance(self.byte_stream, list): + # list of objects retrieved from replay or from non-streaming API. + yield from self.byte_stream + elif self.byte_stream is None: + yield from [] + else: + raise ValueError( + 'Byte segments are not supported for streaming responses.' + ) + + def copy_to_dict(self, response_payload: dict[str, object]): + for attribute in dir(self): + response_payload[attribute] = copy.deepcopy(getattr(self, attribute)) + + +class ApiClient: + """Client for calling HTTP APIs sending and receiving JSON.""" + + def __init__( + self, + vertexai: Union[bool, None] = None, + api_key: Union[str, None] = None, + credentials: google.auth.credentials.Credentials = None, + project: Union[str, None] = None, + location: Union[str, None] = None, + http_options: HttpOptionsOrDict = None, + ): + self.vertexai = vertexai + if self.vertexai is None: + if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [ + 'true', + '1', + ]: + self.vertexai = True + + # Validate explicitly set initializer values. + if (project or location) and api_key: + # API cannot consume both project/location and api_key. + raise ValueError( + 'Project/location and API key are mutually exclusive in the client initializer.' + ) + elif credentials and api_key: + # API cannot consume both credentials and api_key. + raise ValueError( + 'Credentials and API key are mutually exclusive in the client initializer.' + ) + + # Validate http_options if a dict is provided. + if isinstance(http_options, dict): + try: + HttpOptions.model_validate(http_options) + except ValidationError as e: + raise ValueError(f'Invalid http_options: {e}') + elif(isinstance(http_options, HttpOptions)): + http_options = http_options.model_dump() + + # Retrieve implicitly set values from the environment. + env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None) + env_location = os.environ.get('GOOGLE_CLOUD_LOCATION', None) + env_api_key = os.environ.get('GOOGLE_API_KEY', None) + self.project = project or env_project + self.location = location or env_location + self.api_key = api_key or env_api_key + + self._credentials = credentials + self._http_options = HttpOptionsDict() + + # Handle when to use Vertex AI in express mode (api key). + # Explicit initializer arguments are already validated above. + if self.vertexai: + if credentials: + # Explicit credentials take precedence over implicit api_key. + logging.info( + 'The user provided Google Cloud credentials will take precedence' + + ' over the API key from the environment variable.' + ) + self.api_key = None + elif (env_location or env_project) and api_key: + # Explicit api_key takes precedence over implicit project/location. + logging.info( + 'The user provided Vertex AI API key will take precedence over the' + + ' project/location from the environment variables.' + ) + self.project = None + self.location = None + elif (project or location) and env_api_key: + # Explicit project/location takes precedence over implicit api_key. + logging.info( + 'The user provided project/location will take precedence over the' + + ' Vertex AI API key from the environment variable.' + ) + self.api_key = None + elif (env_location or env_project) and env_api_key: + # Implicit project/location takes precedence over implicit api_key. + logging.info( + 'The project/location from the environment variables will take' + + ' precedence over the API key from the environment variables.' + ) + self.api_key = None + if not self.project and not self.api_key: + self.project = google.auth.default()[1] + if not ((self.project and self.location) or self.api_key): + raise ValueError( + 'Project and location or API key must be set when using the Vertex ' + 'AI API.' + ) + if self.api_key: + self._http_options['base_url'] = ( + f'https://aiplatform.googleapis.com/' + ) + else: + self._http_options['base_url'] = ( + f'https://{self.location}-aiplatform.googleapis.com/' + ) + self._http_options['api_version'] = 'v1beta1' + else: # ML Dev API + if not self.api_key: + raise ValueError('API key must be set when using the Google AI API.') + self._http_options['base_url'] = ( + 'https://generativelanguage.googleapis.com/' + ) + self._http_options['api_version'] = 'v1beta' + # Default options for both clients. + self._http_options['headers'] = {'Content-Type': 'application/json'} + if self.api_key and not self.vertexai: + self._http_options['headers']['x-goog-api-key'] = self.api_key + # Update the http options with the user provided http options. + if http_options: + self._http_options = _patch_http_options(self._http_options, http_options) + else: + _append_library_version_headers(self._http_options['headers']) + + def _websocket_base_url(self): + url_parts = urlparse(self._http_options['base_url']) + return url_parts._replace(scheme='wss').geturl() + + def _build_request( + self, + http_method: str, + path: str, + request_dict: dict[str, object], + http_options: HttpOptionsDict = None, + ) -> HttpRequest: + # Remove all special dict keys such as _url and _query. + keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')] + for key in keys_to_delete: + del request_dict[key] + # patch the http options with the user provided settings. + if http_options: + patched_http_options = _patch_http_options( + self._http_options, http_options + ) + else: + patched_http_options = self._http_options + skip_project_and_location_in_path_val = patched_http_options.get( + 'skip_project_and_location_in_path', False + ) + if ( + self.vertexai + and not path.startswith('projects/') + and not skip_project_and_location_in_path_val + and not self.api_key + ): + path = f'projects/{self.project}/locations/{self.location}/' + path + elif self.vertexai and self.api_key: + path = f'{path}?key={self.api_key}' + url = _join_url_path( + patched_http_options['base_url'], + patched_http_options['api_version'] + '/' + path, + ) + return HttpRequest( + method=http_method, + url=url, + headers=patched_http_options['headers'], + data=request_dict, + timeout=patched_http_options.get('timeout', None), + ) + + def _request( + self, + http_request: HttpRequest, + stream: bool = False, + ) -> HttpResponse: + if self.vertexai and not self.api_key: + if not self._credentials: + self._credentials, _ = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + authed_session = AuthorizedSession(self._credentials) + authed_session.stream = stream + response = authed_session.request( + http_request.method.upper(), + http_request.url, + headers=http_request.headers, + data=json.dumps(http_request.data) + if http_request.data + else None, + timeout=http_request.timeout, + ) + errors.APIError.raise_for_response(response) + return HttpResponse( + response.headers, response if stream else [response.text] + ) + else: + return self._request_unauthorized(http_request, stream) + + def _request_unauthorized( + self, + http_request: HttpRequest, + stream: bool = False, + ) -> HttpResponse: + data = None + if http_request.data: + if not isinstance(http_request.data, bytes): + data = json.dumps(http_request.data) + else: + data = http_request.data + + http_session = requests.Session() + response = http_session.request( + method=http_request.method, + url=http_request.url, + headers=http_request.headers, + data=data, + timeout=http_request.timeout, + stream=stream, + ) + errors.APIError.raise_for_response(response) + return HttpResponse( + response.headers, response if stream else [response.text] + ) + + async def _async_request( + self, http_request: HttpRequest, stream: bool = False + ): + if self.vertexai: + if not self._credentials: + self._credentials, _ = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + return await asyncio.to_thread( + self._request, + http_request, + stream=stream, + ) + else: + return await asyncio.to_thread( + self._request, + http_request, + stream=stream, + ) + + def get_read_only_http_options(self) -> HttpOptionsDict: + copied = HttpOptionsDict() + if isinstance(self._http_options, BaseModel): + self._http_options = self._http_options.model_dump() + copied.update(self._http_options) + return copied + + def request( + self, + http_method: str, + path: str, + request_dict: dict[str, object], + http_options: HttpOptionsDict = None, + ): + http_request = self._build_request( + http_method, path, request_dict, http_options + ) + response = self._request(http_request, stream=False) + if http_options and 'response_payload' in http_options: + response.copy_to_dict(http_options['response_payload']) + return response.text + + def request_streamed( + self, + http_method: str, + path: str, + request_dict: dict[str, object], + http_options: HttpOptionsDict = None, + ): + http_request = self._build_request( + http_method, path, request_dict, http_options + ) + + session_response = self._request(http_request, stream=True) + if http_options and 'response_payload' in http_options: + session_response.copy_to_dict(http_options['response_payload']) + for chunk in session_response.segments(): + yield chunk + + async def async_request( + self, + http_method: str, + path: str, + request_dict: dict[str, object], + http_options: HttpOptionsDict = None, + ) -> dict[str, object]: + http_request = self._build_request( + http_method, path, request_dict, http_options + ) + + result = await self._async_request(http_request=http_request, stream=False) + if http_options and 'response_payload' in http_options: + result.copy_to_dict(http_options['response_payload']) + return result.text + + async def async_request_streamed( + self, + http_method: str, + path: str, + request_dict: dict[str, object], + http_options: HttpOptionsDict = None, + ): + http_request = self._build_request( + http_method, path, request_dict, http_options + ) + + response = await self._async_request(http_request=http_request, stream=True) + + for chunk in response.segments(): + yield chunk + if http_options and 'response_payload' in http_options: + response.copy_to_dict(http_options['response_payload']) + + def upload_file( + self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int + ) -> str: + """Transfers a file to the given URL. + + Args: + file_path: The full path to the file or a file like object inherited from + io.BytesIO. If the local file path is not found, an error will be + raised. + upload_url: The URL to upload the file to. + upload_size: The size of file content to be uploaded, this will have to + match the size requested in the resumable upload request. + + returns: + The response json object from the finalize request. + """ + if isinstance(file_path, io.IOBase): + return self._upload_fd(file_path, upload_url, upload_size) + else: + with open(file_path, 'rb') as file: + return self._upload_fd(file, upload_url, upload_size) + + def _upload_fd( + self, file: io.IOBase, upload_url: str, upload_size: int + ) -> str: + """Transfers a file to the given URL. + + Args: + file: A file like object inherited from io.BytesIO. + upload_url: The URL to upload the file to. + upload_size: The size of file content to be uploaded, this will have to + match the size requested in the resumable upload request. + + returns: + The response json object from the finalize request. + """ + offset = 0 + # Upload the file in chunks + while True: + file_chunk = file.read(1024 * 1024 * 8) # 8 MB chunk size + chunk_size = 0 + if file_chunk: + chunk_size = len(file_chunk) + upload_command = 'upload' + # If last chunk, finalize the upload. + if chunk_size + offset >= upload_size: + upload_command += ', finalize' + request = HttpRequest( + method='POST', + url=upload_url, + headers={ + 'X-Goog-Upload-Command': upload_command, + 'X-Goog-Upload-Offset': str(offset), + 'Content-Length': str(chunk_size), + }, + data=file_chunk, + ) + + response = self._request_unauthorized(request, stream=False) + offset += chunk_size + if response.headers['X-Goog-Upload-Status'] != 'active': + break # upload is complete or it has been interrupted. + + if upload_size <= offset: # Status is not finalized. + raise ValueError( + 'All content has been uploaded, but the upload status is not' + f' finalized. {response.headers}, body: {response.text}' + ) + + if response.headers['X-Goog-Upload-Status'] != 'final': + raise ValueError( + 'Failed to upload file: Upload status is not finalized. headers:' + f' {response.headers}, body: {response.text}' + ) + return response.text + + def download_file(self, path: str, http_options): + """Downloads the file data. + + Args: + path: The request path with query params. + http_options: The http options to use for the request. + + returns: + The file bytes + """ + http_request = self._build_request( + 'get', path=path, request_dict={}, http_options=http_options + ) + return self._download_file_request(http_request).byte_stream[0] + + def _download_file_request( + self, + http_request: HttpRequest, + ) -> HttpResponse: + data = None + if http_request.data: + if not isinstance(http_request.data, bytes): + data = json.dumps(http_request.data, cls=RequestJsonEncoder) + else: + data = http_request.data + + http_session = requests.Session() + response = http_session.request( + method=http_request.method, + url=http_request.url, + headers=http_request.headers, + data=data, + timeout=http_request.timeout, + stream=False, + ) + + errors.APIError.raise_for_response(response) + return HttpResponse(response.headers, byte_stream=[response.content]) + + + async def async_upload_file( + self, + file_path: Union[str, io.IOBase], + upload_url: str, + upload_size: int, + ) -> str: + """Transfers a file asynchronously to the given URL. + + Args: + file_path: The full path to the file. If the local file path is not found, + an error will be raised. + upload_url: The URL to upload the file to. + upload_size: The size of file content to be uploaded, this will have to + match the size requested in the resumable upload request. + + returns: + The response json object from the finalize request. + """ + return await asyncio.to_thread( + self.upload_file, + file_path, + upload_url, + upload_size, + ) + + async def _async_upload_fd( + self, + file: io.IOBase, + upload_url: str, + upload_size: int, + ) -> str: + """Transfers a file asynchronously to the given URL. + + Args: + file: A file like object inherited from io.BytesIO. + upload_url: The URL to upload the file to. + upload_size: The size of file content to be uploaded, this will have to + match the size requested in the resumable upload request. + + returns: + The response json object from the finalize request. + """ + return await asyncio.to_thread( + self._upload_fd, + file, + upload_url, + upload_size, + ) + + async def async_download_file(self, path: str, http_options): + """Downloads the file data. + + Args: + path: The request path with query params. + http_options: The http options to use for the request. + + returns: + The file bytes + """ + return await asyncio.to_thread( + self.download_file, + path, + http_options, + ) + + # This method does nothing in the real api client. It is used in the + # replay_api_client to verify the response from the SDK method matches the + # recorded response. + def _verify_response(self, response_model: BaseModel): + pass diff --git a/.venv/lib/python3.12/site-packages/google/genai/_automatic_function_calling_util.py b/.venv/lib/python3.12/site-packages/google/genai/_automatic_function_calling_util.py new file mode 100644 index 00000000..12d1df7c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/_automatic_function_calling_util.py @@ -0,0 +1,294 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import inspect +import types as typing_types +from typing import Any, Callable, Literal, Union, _GenericAlias, get_args, get_origin +import pydantic +from . import types + +_py_builtin_type_to_schema_type = { + str: 'STRING', + int: 'INTEGER', + float: 'NUMBER', + bool: 'BOOLEAN', + list: 'ARRAY', + dict: 'OBJECT', +} + + +def _is_builtin_primitive_or_compound( + annotation: inspect.Parameter.annotation, +) -> bool: + return annotation in _py_builtin_type_to_schema_type.keys() + + +def _raise_for_any_of_if_mldev(schema: types.Schema): + if schema.any_of: + raise ValueError( + 'AnyOf is not supported in function declaration schema for Google AI.' + ) + + +def _raise_for_default_if_mldev(schema: types.Schema): + if schema.default is not None: + raise ValueError( + 'Default value is not supported in function declaration schema for' + ' Google AI.' + ) + + +def _raise_for_nullable_if_mldev(schema: types.Schema): + if schema.nullable: + raise ValueError( + 'Nullable is not supported in function declaration schema for' + ' Google AI.' + ) + + +def _raise_if_schema_unsupported(variant: str, schema: types.Schema): + if not variant == 'VERTEX_AI': + _raise_for_any_of_if_mldev(schema) + _raise_for_default_if_mldev(schema) + _raise_for_nullable_if_mldev(schema) + + +def _is_default_value_compatible( + default_value: Any, annotation: inspect.Parameter.annotation +) -> bool: + # None type is expected to be handled external to this function + if _is_builtin_primitive_or_compound(annotation): + return isinstance(default_value, annotation) + + if ( + isinstance(annotation, _GenericAlias) + or isinstance(annotation, typing_types.GenericAlias) + or isinstance(annotation, typing_types.UnionType) + ): + origin = get_origin(annotation) + if origin in (Union, typing_types.UnionType): + return any( + _is_default_value_compatible(default_value, arg) + for arg in get_args(annotation) + ) + + if origin is dict: + return isinstance(default_value, dict) + + if origin is list: + if not isinstance(default_value, list): + return False + # most tricky case, element in list is union type + # need to apply any logic within all + # see test case test_generic_alias_complex_array_with_default_value + # a: typing.List[int | str | float | bool] + # default_value: [1, 'a', 1.1, True] + return all( + any( + _is_default_value_compatible(item, arg) + for arg in get_args(annotation) + ) + for item in default_value + ) + + if origin is Literal: + return default_value in get_args(annotation) + + # return False for any other unrecognized annotation + # let caller handle the raise + return False + + +def _parse_schema_from_parameter( + variant: str, param: inspect.Parameter, func_name: str +) -> types.Schema: + """parse schema from parameter. + + from the simplest case to the most complex case. + """ + schema = types.Schema() + default_value_error_msg = ( + f'Default value {param.default} of parameter {param} of function' + f' {func_name} is not compatible with the parameter annotation' + f' {param.annotation}.' + ) + if _is_builtin_primitive_or_compound(param.annotation): + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + schema.type = _py_builtin_type_to_schema_type[param.annotation] + _raise_if_schema_unsupported(variant, schema) + return schema + if ( + isinstance(param.annotation, typing_types.UnionType) + # only parse simple UnionType, example int | str | float | bool + # complex types.UnionType will be invoked in raise branch + and all( + (_is_builtin_primitive_or_compound(arg) or arg is type(None)) + for arg in get_args(param.annotation) + ) + ): + schema.type = 'OBJECT' + schema.any_of = [] + unique_types = set() + for arg in get_args(param.annotation): + if arg.__name__ == 'NoneType': # Optional type + schema.nullable = True + continue + schema_in_any_of = _parse_schema_from_parameter( + variant, + inspect.Parameter( + 'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg + ), + func_name, + ) + if ( + schema_in_any_of.model_dump_json(exclude_none=True) + not in unique_types + ): + schema.any_of.append(schema_in_any_of) + unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True)) + if len(schema.any_of) == 1: # param: list | None -> Array + schema.type = schema.any_of[0].type + schema.any_of = None + if ( + param.default is not inspect.Parameter.empty + and param.default is not None + ): + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if isinstance(param.annotation, _GenericAlias) or isinstance( + param.annotation, typing_types.GenericAlias + ): + origin = get_origin(param.annotation) + args = get_args(param.annotation) + if origin is dict: + schema.type = 'OBJECT' + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if origin is Literal: + if not all(isinstance(arg, str) for arg in args): + raise ValueError( + f'Literal type {param.annotation} must be a list of strings.' + ) + schema.type = 'STRING' + schema.enum = list(args) + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if origin is list: + schema.type = 'ARRAY' + schema.items = _parse_schema_from_parameter( + variant, + inspect.Parameter( + 'item', + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=args[0], + ), + func_name, + ) + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if origin is Union: + schema.any_of = [] + schema.type = 'OBJECT' + unique_types = set() + for arg in args: + if arg.__name__ == 'NoneType': # Optional type + schema.nullable = True + continue + schema_in_any_of = _parse_schema_from_parameter( + variant, + inspect.Parameter( + 'item', + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=arg, + ), + func_name, + ) + if ( + schema_in_any_of.model_dump_json(exclude_none=True) + not in unique_types + ): + schema.any_of.append(schema_in_any_of) + unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True)) + if len(schema.any_of) == 1: # param: Union[List, None] -> Array + schema.type = schema.any_of[0].type + schema.any_of = None + if ( + param.default is not None + and param.default is not inspect.Parameter.empty + ): + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + # all other generic alias will be invoked in raise branch + if ( + inspect.isclass(param.annotation) + # for user defined class, we only support pydantic model + and issubclass(param.annotation, pydantic.BaseModel) + ): + if ( + param.default is not inspect.Parameter.empty + and param.default is not None + ): + schema.default = param.default + schema.type = 'OBJECT' + schema.properties = {} + for field_name, field_info in param.annotation.model_fields.items(): + schema.properties[field_name] = _parse_schema_from_parameter( + variant, + inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=field_info.annotation, + ), + func_name, + ) + _raise_if_schema_unsupported(variant, schema) + return schema + raise ValueError( + f'Failed to parse the parameter {param} of function {func_name} for' + ' automatic function calling.Automatic function calling works best with' + ' simpler function signature schema,consider manually parse your' + f' function declaration for function {func_name}.' + ) + + +def _get_required_fields(schema: types.Schema) -> list[str]: + if not schema.properties: + return + return [ + field_name + for field_name, field_schema in schema.properties.items() + if not field_schema.nullable and field_schema.default is None + ] diff --git a/.venv/lib/python3.12/site-packages/google/genai/_common.py b/.venv/lib/python3.12/site-packages/google/genai/_common.py new file mode 100644 index 00000000..3211fd40 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/_common.py @@ -0,0 +1,272 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Common utilities for the SDK.""" + +import base64 +import datetime +import enum +import typing +from typing import Union +import uuid + +import pydantic +from pydantic import alias_generators + +from . import _api_client + + +def set_value_by_path(data, keys, value): + """Examples: + + set_value_by_path({}, ['a', 'b'], v) + -> {'a': {'b': v}} + set_value_by_path({}, ['a', 'b[]', c], [v1, v2]) + -> {'a': {'b': [{'c': v1}, {'c': v2}]}} + set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3) + -> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}} + """ + if value is None: + return + for i, key in enumerate(keys[:-1]): + if key.endswith('[]'): + key_name = key[:-2] + if key_name not in data: + if isinstance(value, list): + data[key_name] = [{} for _ in range(len(value))] + else: + raise ValueError( + f'value {value} must be a list given an array path {key}' + ) + if isinstance(value, list): + for j, d in enumerate(data[key_name]): + set_value_by_path(d, keys[i + 1 :], value[j]) + else: + for d in data[key_name]: + set_value_by_path(d, keys[i + 1 :], value) + return + + data = data.setdefault(key, {}) + + existing_data = data.get(keys[-1]) + # If there is an existing value, merge, not overwrite. + if existing_data is not None: + # Don't overwrite existing non-empty value with new empty value. + # This is triggered when handling tuning datasets. + if not value: + pass + # Don't fail when overwriting value with same value + elif value == existing_data: + pass + # Instead of overwriting dictionary with another dictionary, merge them. + # This is important for handling training and validation datasets in tuning. + elif isinstance(existing_data, dict) and isinstance(value, dict): + # Merging dictionaries. Consider deep merging in the future. + existing_data.update(value) + else: + raise ValueError( + f'Cannot set value for an existing key. Key: {keys[-1]};' + f' Existing value: {existing_data}; New value: {value}.' + ) + else: + data[keys[-1]] = value + + +def get_value_by_path(data: object, keys: list[str]): + """Examples: + + get_value_by_path({'a': {'b': v}}, ['a', 'b']) + -> v + get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c']) + -> [v1, v2] + """ + if keys == ['_self']: + return data + for i, key in enumerate(keys): + if not data: + return None + if key.endswith('[]'): + key_name = key[:-2] + if key_name in data: + return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]] + else: + return None + else: + if key in data: + data = data[key] + elif isinstance(data, BaseModel) and hasattr(data, key): + data = getattr(data, key) + else: + return None + return data + + +class BaseModule: + + def __init__(self, api_client_: _api_client.ApiClient): + self._api_client = api_client_ + + +def convert_to_dict(obj: dict[str, object]) -> dict[str, object]: + """Recursively converts a given object to a dictionary. + + If the object is a Pydantic model, it uses the model's `model_dump()` method. + + Args: + obj: The object to convert. + + Returns: + A dictionary representation of the object. + """ + if isinstance(obj, pydantic.BaseModel): + return obj.model_dump(exclude_none=True) + elif isinstance(obj, dict): + return {key: convert_to_dict(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [convert_to_dict(item) for item in obj] + else: + return obj + + +def _remove_extra_fields( + model: pydantic.BaseModel, response: dict[str, object] +) -> None: + """Removes extra fields from the response that are not in the model. + + Mutates the response in place. + """ + + key_values = list(response.items()) + + for key, value in key_values: + # Need to convert to snake case to match model fields names + # ex: UsageMetadata + alias_map = { + field_info.alias: key for key, field_info in model.model_fields.items() + } + + if key not in model.model_fields and key not in alias_map: + response.pop(key) + continue + + key = alias_map.get(key, key) + + annotation = model.model_fields[key].annotation + + # Get the BaseModel if Optional + if typing.get_origin(annotation) is Union: + annotation = typing.get_args(annotation)[0] + + # if dict, assume BaseModel but also check that field type is not dict + # example: FunctionCall.args + if isinstance(value, dict) and typing.get_origin(annotation) is not dict: + _remove_extra_fields(annotation, value) + elif isinstance(value, list): + for item in value: + # assume a list of dict is list of BaseModel + if isinstance(item, dict): + _remove_extra_fields(typing.get_args(annotation)[0], item) + + +class BaseModel(pydantic.BaseModel): + + model_config = pydantic.ConfigDict( + alias_generator=alias_generators.to_camel, + populate_by_name=True, + from_attributes=True, + protected_namespaces=(), + extra='forbid', + # This allows us to use arbitrary types in the model. E.g. PIL.Image. + arbitrary_types_allowed=True, + ser_json_bytes='base64', + val_json_bytes='base64', + ) + + @classmethod + def _from_response( + cls, response: dict[str, object], kwargs: dict[str, object] + ) -> 'BaseModel': + # To maintain forward compatibility, we need to remove extra fields from + # the response. + # We will provide another mechanism to allow users to access these fields. + _remove_extra_fields(cls, response) + validated_response = cls.model_validate(response) + return validated_response + + def to_json_dict(self) -> dict[str, object]: + return self.model_dump(exclude_none=True, mode='json') + + +class CaseInSensitiveEnum(str, enum.Enum): + """Case insensitive enum.""" + + @classmethod + def _missing_(cls, value): + try: + return cls[value.upper()] # Try to access directly with uppercase + except KeyError: + try: + return cls[value.lower()] # Try to access directly with lowercase + except KeyError as e: + raise ValueError(f"{value} is not a valid {cls.__name__}") from e + + +def timestamped_unique_name() -> str: + """Composes a timestamped unique name. + + Returns: + A string representing a unique name. + """ + timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + unique_id = uuid.uuid4().hex[0:5] + return f'{timestamp}_{unique_id}' + + +def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]: + """Converts unserializable types in dict to json.dumps() compatible types. + + This function is called in models.py after calling convert_to_dict(). The + convert_to_dict() can convert pydantic object to dict. However, the input to + convert_to_dict() is dict mixed of pydantic object and nested dict(the output + of converters). So they may be bytes in the dict and they are out of + `ser_json_bytes` control in model_dump(mode='json') called in + `convert_to_dict`, as well as datetime deserialization in Pydantic json mode. + + Returns: + A dictionary with json.dumps() incompatible type (e.g. bytes datetime) + to compatible type (e.g. base64 encoded string, isoformat date string). + """ + processed_data = {} + if not isinstance(data, dict): + return data + for key, value in data.items(): + if isinstance(value, bytes): + processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii') + elif isinstance(value, datetime.datetime): + processed_data[key] = value.isoformat() + elif isinstance(value, dict): + processed_data[key] = encode_unserializable_types(value) + elif isinstance(value, list): + if all(isinstance(v, bytes) for v in value): + processed_data[key] = [ + base64.urlsafe_b64encode(v).decode('ascii') for v in value + ] + if all(isinstance(v, datetime.datetime) for v in value): + processed_data[key] = [v.isoformat() for v in value] + else: + processed_data[key] = [encode_unserializable_types(v) for v in value] + else: + processed_data[key] = value + return processed_data diff --git a/.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py b/.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py new file mode 100644 index 00000000..db8d377b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py @@ -0,0 +1,310 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Extra utils depending on types that are shared between sync and async modules. +""" + +import inspect +import logging +from typing import Any, Callable, Dict, get_args, get_origin, Optional, types as typing_types, Union + +import pydantic + +from . import _common +from . import errors +from . import types + + +_DEFAULT_MAX_REMOTE_CALLS_AFC = 10 + + +def format_destination( + src: str, + config: Optional[types.CreateBatchJobConfigOrDict] = None, +) -> types.CreateBatchJobConfig: + """Formats the destination uri based on the source uri.""" + config = ( + types._CreateBatchJobParameters(config=config).config + or types.CreateBatchJobConfig() + ) + + unique_name = None + if not config.display_name: + unique_name = _common.timestamped_unique_name() + config.display_name = f'genai_batch_job_{unique_name}' + + if not config.dest: + if src.startswith('gs://') and src.endswith('.jsonl'): + # If source uri is "gs://bucket/path/to/src.jsonl", then the destination + # uri prefix will be "gs://bucket/path/to/src/dest". + config.dest = f'{src[:-6]}/dest' + elif src.startswith('bq://'): + # If source uri is "bq://project.dataset.src", then the destination + # uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID". + unique_name = unique_name or _common.timestamped_unique_name() + config.dest = f'{src}_dest_{unique_name}' + else: + raise ValueError(f'Unsupported source: {src}') + return config + + +def get_function_map( + config: Optional[types.GenerateContentConfigOrDict] = None, +) -> dict[str, object]: + """Returns a function map from the config.""" + config_model = ( + types.GenerateContentConfig(**config) + if config and isinstance(config, dict) + else config + ) + function_map = {} + if not config_model: + return function_map + if config_model.tools: + for tool in config_model.tools: + if callable(tool): + if inspect.iscoroutinefunction(tool): + raise errors.UnsupportedFunctionError( + f'Function {tool.__name__} is a coroutine function, which is not' + ' supported for automatic function calling. Please manually invoke' + f' {tool.__name__} to get the function response.' + ) + function_map[tool.__name__] = tool + return function_map + + +def convert_number_values_for_function_call_args( + args: Union[dict[str, object], list[object], object], +) -> Union[dict[str, object], list[object], object]: + """Converts float values with no decimal to integers.""" + if isinstance(args, float) and args.is_integer(): + return int(args) + if isinstance(args, dict): + return { + key: convert_number_values_for_function_call_args(value) + for key, value in args.items() + } + if isinstance(args, list): + return [ + convert_number_values_for_function_call_args(value) for value in args + ] + return args + + +def _is_annotation_pydantic_model(annotation: Any) -> bool: + return inspect.isclass(annotation) and issubclass( + annotation, pydantic.BaseModel + ) + + +def convert_if_exist_pydantic_model( + value: Any, annotation: Any, param_name: str, func_name: str +) -> Any: + if isinstance(value, dict) and _is_annotation_pydantic_model(annotation): + try: + return annotation(**value) + except pydantic.ValidationError as e: + raise errors.UnknownFunctionCallArgumentError( + f'Failed to parse parameter {param_name} for function' + f' {func_name} from function call part because function call argument' + f' value {value} is not compatible with parameter annotation' + f' {annotation}, due to error {e}' + ) + if isinstance(value, list) and get_origin(annotation) == list: + item_type = get_args(annotation)[0] + return [ + convert_if_exist_pydantic_model(item, item_type, param_name, func_name) + for item in value + ] + if isinstance(value, dict) and get_origin(annotation) == dict: + _, value_type = get_args(annotation) + return { + k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name) + for k, v in value.items() + } + # example 1: typing.Union[int, float] + # example 2: int | float equivalent to typing.types.UnionType[int, float] + if get_origin(annotation) in (Union, typing_types.UnionType): + for arg in get_args(annotation): + if isinstance(value, arg) or ( + isinstance(value, dict) and _is_annotation_pydantic_model(arg) + ): + try: + return convert_if_exist_pydantic_model( + value, arg, param_name, func_name + ) + # do not raise here because there could be multiple pydantic model types + # in the union type. + except pydantic.ValidationError: + continue + # if none of the union type is matched, raise error + raise errors.UnknownFunctionCallArgumentError( + f'Failed to parse parameter {param_name} for function' + f' {func_name} from function call part because function call argument' + f' value {value} cannot be converted to parameter annotation' + f' {annotation}.' + ) + # the only exception for value and annotation type to be different is int and + # float. see convert_number_values_for_function_call_args function for context + if isinstance(value, int) and annotation is float: + return value + if not isinstance(value, annotation): + raise errors.UnknownFunctionCallArgumentError( + f'Failed to parse parameter {param_name} for function {func_name} from' + f' function call part because function call argument value {value} is' + f' not compatible with parameter annotation {annotation}.' + ) + return value + + +def invoke_function_from_dict_args( + args: Dict[str, Any], function_to_invoke: Callable +) -> Any: + signature = inspect.signature(function_to_invoke) + func_name = function_to_invoke.__name__ + converted_args = {} + for param_name, param in signature.parameters.items(): + if param_name in args: + converted_args[param_name] = convert_if_exist_pydantic_model( + args[param_name], + param.annotation, + param_name, + func_name, + ) + try: + return function_to_invoke(**converted_args) + except Exception as e: + raise errors.FunctionInvocationError( + f'Failed to invoke function {func_name} with converted arguments' + f' {converted_args} from model returned function call argument' + f' {args} because of error {e}' + ) + + +def get_function_response_parts( + response: types.GenerateContentResponse, + function_map: dict[str, object], +) -> list[types.Part]: + """Returns the function response parts from the response.""" + func_response_parts = [] + for part in response.candidates[0].content.parts: + if not part.function_call: + continue + func_name = part.function_call.name + func = function_map[func_name] + args = convert_number_values_for_function_call_args(part.function_call.args) + try: + response = {'result': invoke_function_from_dict_args(args, func)} + except Exception as e: # pylint: disable=broad-except + response = {'error': str(e)} + func_response = types.Part.from_function_response(func_name, response) + + func_response_parts.append(func_response) + return func_response_parts + + +def should_disable_afc( + config: Optional[types.GenerateContentConfigOrDict] = None, +) -> bool: + """Returns whether automatic function calling is enabled.""" + config_model = ( + types.GenerateContentConfig(**config) + if config and isinstance(config, dict) + else config + ) + + # If max_remote_calls is less or equal to 0, warn and disable AFC. + if ( + config_model + and config_model.automatic_function_calling + and config_model.automatic_function_calling.maximum_remote_calls + is not None + and int(config_model.automatic_function_calling.maximum_remote_calls) + <= 0 + ): + logging.warning( + 'max_remote_calls in automatic_function_calling_config' + f' {config_model.automatic_function_calling.maximum_remote_calls} is' + ' less than or equal to 0. Disabling automatic function calling.' + ' Please set max_remote_calls to a positive integer.' + ) + return True + + # Default to enable AFC if not specified. + if ( + not config_model + or not config_model.automatic_function_calling + or config_model.automatic_function_calling.disable is None + ): + return False + + if ( + config_model.automatic_function_calling.disable + and config_model.automatic_function_calling.maximum_remote_calls + is not None + and int(config_model.automatic_function_calling.maximum_remote_calls) > 0 + ): + logging.warning( + '`automatic_function_calling.disable` is set to `True`. But' + ' `automatic_function_calling.maximum_remote_calls` is set to be a' + ' positive number' + f' {config_model.automatic_function_calling.maximum_remote_calls}.' + ' Disabling automatic function calling. If you want to enable' + ' automatic function calling, please set' + ' `automatic_function_calling.disable` to `False` or leave it unset,' + ' and set `automatic_function_calling.maximum_remote_calls` to a' + ' positive integer or leave' + ' `automatic_function_calling.maximum_remote_calls` unset.' + ) + + return config_model.automatic_function_calling.disable + + +def get_max_remote_calls_afc( + config: Optional[types.GenerateContentConfigOrDict] = None, +) -> int: + """Returns the remaining remote calls for automatic function calling.""" + if should_disable_afc(config): + raise ValueError( + 'automatic function calling is not enabled, but SDK is trying to get' + ' max remote calls.' + ) + config_model = ( + types.GenerateContentConfig(**config) + if config and isinstance(config, dict) + else config + ) + if ( + not config_model + or not config_model.automatic_function_calling + or config_model.automatic_function_calling.maximum_remote_calls is None + ): + return _DEFAULT_MAX_REMOTE_CALLS_AFC + return int(config_model.automatic_function_calling.maximum_remote_calls) + +def should_append_afc_history( + config: Optional[types.GenerateContentConfigOrDict] = None, +) -> bool: + config_model = ( + types.GenerateContentConfig(**config) + if config and isinstance(config, dict) + else config + ) + if ( + not config_model + or not config_model.automatic_function_calling + ): + return True + return not config_model.automatic_function_calling.ignore_call_history diff --git a/.venv/lib/python3.12/site-packages/google/genai/_replay_api_client.py b/.venv/lib/python3.12/site-packages/google/genai/_replay_api_client.py new file mode 100644 index 00000000..3af2336d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/_replay_api_client.py @@ -0,0 +1,449 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Replay API client.""" + +import base64 +import copy +import datetime +import inspect +import io +import json +import os +import re +from typing import Any, Literal, Optional, Union + +import google.auth +from requests.exceptions import HTTPError + +from . import errors +from ._api_client import ApiClient +from ._api_client import HttpOptions +from ._api_client import HttpRequest +from ._api_client import HttpResponse +from ._common import BaseModel + + +def _redact_version_numbers(version_string: str) -> str: + """Redacts version numbers in the form x.y.z from a string.""" + return re.sub(r'\d+\.\d+\.\d+', '{VERSION_NUMBER}', version_string) + + +def _redact_language_label(language_label: str) -> str: + """Removed because replay requests are used for all languages.""" + return re.sub(r'gl-python/', '{LANGUAGE_LABEL}/', language_label) + + +def _redact_request_headers(headers): + """Redacts headers that should not be recorded.""" + redacted_headers = {} + for header_name, header_value in headers.items(): + if header_name.lower() == 'x-goog-api-key': + redacted_headers[header_name] = '{REDACTED}' + elif header_name.lower() == 'user-agent': + redacted_headers[header_name] = _redact_language_label( + _redact_version_numbers(header_value) + ) + elif header_name.lower() == 'x-goog-api-client': + redacted_headers[header_name] = _redact_language_label( + _redact_version_numbers(header_value) + ) + else: + redacted_headers[header_name] = header_value + return redacted_headers + + +def _redact_request_url(url: str) -> str: + # Redact all the url parts before the resource name, so the test can work + # against any project, location, version, or whether it's EasyGCP. + result = re.sub( + r'.*/projects/[^/]+/locations/[^/]+/', + '{VERTEX_URL_PREFIX}/', + url, + ) + result = re.sub( + r'.*-aiplatform.googleapis.com/[^/]+/', + '{VERTEX_URL_PREFIX}/', + result, + ) + result = re.sub( + r'https://generativelanguage.googleapis.com/[^/]+', + '{MLDEV_URL_PREFIX}', + result, + ) + return result + + +def _redact_project_location_path(path: str) -> str: + # Redact a field in the request that is known to vary based on project and + # location. + if 'projects/' in path and 'locations/' in path: + result = re.sub( + r'projects/[^/]+/locations/[^/]+/', + '{PROJECT_AND_LOCATION_PATH}/', + path, + ) + return result + else: + return path + + +def _redact_request_body(body: dict[str, object]) -> dict[str, object]: + for key, value in body.items(): + if isinstance(value, str): + body[key] = _redact_project_location_path(value) + + +def redact_http_request(http_request: HttpRequest): + http_request.headers = _redact_request_headers(http_request.headers) + http_request.url = _redact_request_url(http_request.url) + _redact_request_body(http_request.data) + + +def _current_file_path_and_line(): + """Prints the current file path and line number.""" + frame = inspect.currentframe().f_back.f_back + filepath = inspect.getfile(frame) + lineno = frame.f_lineno + return f'File: {filepath}, Line: {lineno}' + + +def _debug_print(message: str): + print( + 'DEBUG (test', + os.environ.get('PYTEST_CURRENT_TEST'), + ')', + _current_file_path_and_line(), + ':\n ', + message, + ) + + +class ReplayRequest(BaseModel): + """Represents a single request in a replay.""" + + method: str + url: str + headers: dict[str, str] + body_segments: list[dict[str, object]] + + +class ReplayResponse(BaseModel): + """Represents a single response in a replay.""" + + status_code: int = 200 + headers: dict[str, str] + body_segments: list[dict[str, object]] + byte_segments: Optional[list[bytes]] = None + sdk_response_segments: list[dict[str, object]] + + def model_post_init(self, __context: Any) -> None: + # Remove headers that are not deterministic so the replay files don't change + # every time they are recorded. + self.headers.pop('Date', None) + self.headers.pop('Server-Timing', None) + + +class ReplayInteraction(BaseModel): + """Represents a single interaction, request and response in a replay.""" + + request: ReplayRequest + response: ReplayResponse + + +class ReplayFile(BaseModel): + """Represents a recorded session.""" + + replay_id: str + interactions: list[ReplayInteraction] + + +class ReplayApiClient(ApiClient): + """For integration testing, send recorded response or records a response.""" + + def __init__( + self, + mode: Literal['record', 'replay', 'auto', 'api'], + replay_id: str, + replays_directory: Optional[str] = None, + vertexai: bool = False, + api_key: Optional[str] = None, + credentials: Optional[google.auth.credentials.Credentials] = None, + project: Optional[str] = None, + location: Optional[str] = None, + http_options: Optional[HttpOptions] = None, + ): + super().__init__( + vertexai=vertexai, + api_key=api_key, + credentials=credentials, + project=project, + location=location, + http_options=http_options, + ) + self.replays_directory = replays_directory + if not self.replays_directory: + self.replays_directory = os.environ.get( + 'GOOGLE_GENAI_REPLAYS_DIRECTORY', None + ) + # Valid replay modes are replay-only or record-and-replay. + self.replay_session = None + self._mode = mode + self._replay_id = replay_id + + def initialize_replay_session(self, replay_id: str): + self._replay_id = replay_id + self._initialize_replay_session() + + def _get_replay_file_path(self): + return self._generate_file_path_from_replay_id( + self.replays_directory, self._replay_id + ) + + def _should_call_api(self): + return self._mode in ['record', 'api'] or ( + self._mode == 'auto' + and not os.path.isfile(self._get_replay_file_path()) + ) + + def _should_update_replay(self): + return self._should_call_api() and self._mode != 'api' + + def _initialize_replay_session_if_not_loaded(self): + if not self.replay_session: + self._initialize_replay_session() + + def _initialize_replay_session(self): + _debug_print('Test is using replay id: ' + self._replay_id) + self._replay_index = 0 + self._sdk_response_index = 0 + replay_file_path = self._get_replay_file_path() + # This should not be triggered from the constructor. + replay_file_exists = os.path.isfile(replay_file_path) + if self._mode == 'replay' and not replay_file_exists: + raise ValueError( + 'Replay files do not exist for replay id: ' + self._replay_id + ) + + if self._mode in ['replay', 'auto'] and replay_file_exists: + with open(replay_file_path, 'r') as f: + self.replay_session = ReplayFile.model_validate(json.loads(f.read())) + + if self._should_update_replay(): + self.replay_session = ReplayFile( + replay_id=self._replay_id, interactions=[] + ) + + def _generate_file_path_from_replay_id(self, replay_directory, replay_id): + session_parts = replay_id.split('/') + if len(session_parts) < 3: + raise ValueError( + f'{replay_id}: Session ID must be in the format of' + ' module/function/[vertex|mldev]' + ) + if replay_directory is None: + path_parts = [] + else: + path_parts = [replay_directory] + path_parts.extend(session_parts) + return os.path.join(*path_parts) + '.json' + + def close(self): + if not self._should_update_replay() or not self.replay_session: + return + replay_file_path = self._get_replay_file_path() + os.makedirs(os.path.dirname(replay_file_path), exist_ok=True) + with open(replay_file_path, 'w') as f: + f.write(self.replay_session.model_dump_json(exclude_unset=True, indent=2)) + self.replay_session = None + + def _record_interaction( + self, + http_request: HttpRequest, + http_response: Union[HttpResponse, errors.APIError, bytes], + ): + if not self._should_update_replay(): + return + redact_http_request(http_request) + request = ReplayRequest( + method=http_request.method, + url=http_request.url, + headers=http_request.headers, + body_segments=[http_request.data], + ) + if isinstance(http_response, HttpResponse): + response = ReplayResponse( + headers=dict(http_response.headers), + body_segments=list(http_response.segments()), + byte_segments=[ + seg[:100] + b'...' for seg in http_response.byte_segments() + ], + status_code=http_response.status_code, + sdk_response_segments=[], + ) + else: + response = ReplayResponse( + headers=dict(http_response.response.headers), + body_segments=[http_response._to_replay_record()], + status_code=http_response.code, + sdk_response_segments=[], + ) + self.replay_session.interactions.append( + ReplayInteraction(request=request, response=response) + ) + + def _match_request( + self, + http_request: HttpRequest, + interaction: ReplayInteraction, + ): + assert http_request.url == interaction.request.url + assert http_request.headers == interaction.request.headers, ( + 'Request headers mismatch:\n' + f'Actual: {http_request.headers}\n' + f'Expected: {interaction.request.headers}' + ) + assert http_request.method == interaction.request.method + + # Sanitize the request body, rewrite any fields that vary. + request_data_copy = copy.deepcopy(http_request.data) + # Both the request and recorded request must be redacted before comparing + # so that the comparison is fair. + _redact_request_body(request_data_copy) + + actual_request_body = [request_data_copy] + expected_request_body = interaction.request.body_segments + assert actual_request_body == expected_request_body, ( + 'Request body mismatch:\n' + f'Actual: {actual_request_body}\n' + f'Expected: {expected_request_body}' + ) + + def _build_response_from_replay(self, http_request: HttpRequest): + redact_http_request(http_request) + + interaction = self.replay_session.interactions[self._replay_index] + # Replay is on the right side of the assert so the diff makes more sense. + self._match_request(http_request, interaction) + self._replay_index += 1 + self._sdk_response_index = 0 + errors.APIError.raise_for_response(interaction.response) + return HttpResponse( + headers=interaction.response.headers, + response_stream=[ + json.dumps(segment) + for segment in interaction.response.body_segments + ], + byte_stream=interaction.response.byte_segments, + ) + + def _verify_response(self, response_model: BaseModel): + if self._mode == 'api': + return + # replay_index is advanced in _build_response_from_replay, so we need to -1. + interaction = self.replay_session.interactions[self._replay_index - 1] + if self._should_update_replay(): + if isinstance(response_model, list): + response_model = response_model[0] + interaction.response.sdk_response_segments.append( + response_model.model_dump(exclude_none=True) + ) + return + + if isinstance(response_model, list): + response_model = response_model[0] + print('response_model: ', response_model.model_dump(exclude_none=True)) + actual = response_model.model_dump(exclude_none=True, mode='json') + expected = interaction.response.sdk_response_segments[ + self._sdk_response_index + ] + assert ( + actual == expected + ), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}' + self._sdk_response_index += 1 + + def _request( + self, + http_request: HttpRequest, + stream: bool = False, + ) -> HttpResponse: + self._initialize_replay_session_if_not_loaded() + if self._should_call_api(): + _debug_print('api mode request: %s' % http_request) + try: + result = super()._request(http_request, stream) + except errors.APIError as e: + self._record_interaction(http_request, e) + raise e + if stream: + result_segments = [] + for segment in result.segments(): + result_segments.append(json.dumps(segment)) + result = HttpResponse(result.headers, result_segments) + self._record_interaction(http_request, result) + # Need to return a RecordedResponse that rebuilds the response + # segments since the stream has been consumed. + else: + self._record_interaction(http_request, result) + _debug_print('api mode result: %s' % result.text) + return result + else: + return self._build_response_from_replay(http_request) + + def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int): + if isinstance(file_path, io.IOBase): + offset = file_path.tell() + content = file_path.read() + file_path.seek(offset, os.SEEK_SET) + request = HttpRequest( + method='POST', + url='', + data={'bytes': base64.b64encode(content).decode('utf-8')}, + headers={} + ) + else: + request = HttpRequest( + method='POST', url='', data={'file_path': file_path}, headers={} + ) + if self._should_call_api(): + try: + result = super().upload_file(file_path, upload_url, upload_size) + except HTTPError as e: + result = HttpResponse( + e.response.headers, [json.dumps({'reason': e.response.reason})] + ) + result.status_code = e.response.status_code + raise e + self._record_interaction(request, HttpResponse({}, [json.dumps(result)])) + return result + else: + return self._build_response_from_replay(request).text + + def _download_file_request(self, request): + self._initialize_replay_session_if_not_loaded() + if self._should_call_api(): + try: + result = super()._download_file_request(request) + except HTTPError as e: + result = HttpResponse( + e.response.headers, [json.dumps({'reason': e.response.reason})] + ) + result.status_code = e.response.status_code + raise e + self._record_interaction(request, result) + return result + else: + return self._build_response_from_replay(request) + diff --git a/.venv/lib/python3.12/site-packages/google/genai/_test_api_client.py b/.venv/lib/python3.12/site-packages/google/genai/_test_api_client.py new file mode 100644 index 00000000..3d3bf3e2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/_test_api_client.py @@ -0,0 +1,149 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import time +from unittest.mock import MagicMock, patch +import pytest +from .api_client import ApiClient + + +@patch('genai.api_client.ApiClient._build_request') +@patch('genai.api_client.ApiClient._request') +def test_request_streamed_non_blocking(mock_request, mock_build_request): + api_client = ApiClient(api_key='test_api_key') + http_method = 'GET' + path = 'test/path' + request_dict = {'key': 'value'} + + mock_http_request = MagicMock() + mock_build_request.return_value = mock_http_request + + def delayed_segments(): + chunks = ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}'] + for chunk in chunks: + time.sleep(0.1) # 100ms delay + yield chunk + + mock_response = MagicMock() + mock_response.segments.side_effect = delayed_segments + mock_request.return_value = mock_response + + chunks = [] + start_time = time.time() + for chunk in api_client.request_streamed(http_method, path, request_dict): + chunks.append(chunk) + assert len(chunks) <= 3 + end_time = time.time() + + mock_build_request.assert_called_once_with( + http_method, path, request_dict, None + ) + mock_request.assert_called_once_with(mock_http_request, stream=True) + assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}'] + assert end_time - start_time > 0.3 + + +@patch('genai.api_client.ApiClient._build_request') +@patch('genai.api_client.ApiClient._async_request') +@pytest.mark.asyncio +async def test_async_request(mock_async_request, mock_build_request): + api_client = ApiClient(api_key='test_api_key') + http_method = 'GET' + path = 'test/path' + request_dict = {'key': 'value'} + + mock_http_request = MagicMock() + mock_build_request.return_value = mock_http_request + + class MockResponse: + + def __init__(self, text): + self.text = text + + async def delayed_response(http_request, stream): + await asyncio.sleep(0.1) # 100ms delay + return MockResponse('value') + + mock_async_request.side_effect = delayed_response + + async_coroutine1 = api_client.async_request(http_method, path, request_dict) + async_coroutine2 = api_client.async_request(http_method, path, request_dict) + async_coroutine3 = api_client.async_request(http_method, path, request_dict) + + start_time = time.time() + results = await asyncio.gather( + async_coroutine1, async_coroutine2, async_coroutine3 + ) + end_time = time.time() + + mock_build_request.assert_called_with(http_method, path, request_dict, None) + assert mock_build_request.call_count == 3 + mock_async_request.assert_called_with( + http_request=mock_http_request, stream=False + ) + assert mock_async_request.call_count == 3 + assert results == ['value', 'value', 'value'] + assert 0.1 <= end_time - start_time < 0.15 + + +@patch('genai.api_client.ApiClient._build_request') +@patch('genai.api_client.ApiClient._async_request') +@pytest.mark.asyncio +async def test_async_request_streamed_non_blocking( + mock_async_request, mock_build_request +): + api_client = ApiClient(api_key='test_api_key') + http_method = 'GET' + path = 'test/path' + request_dict = {'key': 'value'} + + mock_http_request = MagicMock() + mock_build_request.return_value = mock_http_request + + class MockResponse: + + def __init__(self, segments): + self._segments = segments + + # should mock async generator here but source code combines sync and async streaming in one segment method. + # TODO: fix the above + def segments(self): + for segment in self._segments: + time.sleep(0.1) # 100ms delay + yield segment + + async def delayed_response(http_request, stream): + return MockResponse(['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']) + + mock_async_request.side_effect = delayed_response + + chunks = [] + start_time = time.time() + async for chunk in api_client.async_request_streamed( + http_method, path, request_dict + ): + chunks.append(chunk) + assert len(chunks) <= 3 + end_time = time.time() + + mock_build_request.assert_called_once_with( + http_method, path, request_dict, None + ) + mock_async_request.assert_called_once_with( + http_request=mock_http_request, stream=True + ) + assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}'] + assert end_time - start_time > 0.3 diff --git a/.venv/lib/python3.12/site-packages/google/genai/_transformers.py b/.venv/lib/python3.12/site-packages/google/genai/_transformers.py new file mode 100644 index 00000000..f1b392e9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/_transformers.py @@ -0,0 +1,621 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Transformers for Google GenAI SDK.""" + +import base64 +from collections.abc import Iterable, Mapping +import inspect +import io +import re +import time +import typing +from typing import Any, GenericAlias, Optional, Union + +import PIL.Image +import PIL.PngImagePlugin +import pydantic + +from . import _api_client +from . import types + + +def _resource_name( + client: _api_client.ApiClient, + resource_name: str, + *, + collection_identifier: str, + collection_hierarchy_depth: int = 2, +): + # pylint: disable=line-too-long + """Prepends resource name with project, location, collection_identifier if needed. + + The collection_identifier will only be prepended if it's not present + and the prepending won't violate the collection hierarchy depth. + When the prepending condition doesn't meet, returns the input + resource_name. + + Args: + client: The API client. + resource_name: The user input resource name to be completed. + collection_identifier: The collection identifier to be prepended. See + collection identifiers in https://google.aip.dev/122. + collection_hierarchy_depth: The collection hierarchy depth. Only set this + field when the resource has nested collections. For example, + `users/vhugo1802/events/birthday-dinner-226`, the collection_identifier is + `users` and collection_hierarchy_depth is 4. See nested collections in + https://google.aip.dev/122. + + Example: + + resource_name = 'cachedContents/123' + client.vertexai = True + client.project = 'bar' + client.location = 'us-west1' + _resource_name(client, 'cachedContents/123', + collection_identifier='cachedContents') + returns: 'projects/bar/locations/us-west1/cachedContents/123' + + Example: + + resource_name = 'projects/foo/locations/us-central1/cachedContents/123' + # resource_name = 'locations/us-central1/cachedContents/123' + client.vertexai = True + client.project = 'bar' + client.location = 'us-west1' + _resource_name(client, resource_name, + collection_identifier='cachedContents') + returns: 'projects/foo/locations/us-central1/cachedContents/123' + + Example: + + resource_name = '123' + # resource_name = 'cachedContents/123' + client.vertexai = False + _resource_name(client, resource_name, + collection_identifier='cachedContents') + returns 'cachedContents/123' + + Example: + resource_name = 'some/wrong/cachedContents/resource/name/123' + resource_prefix = 'cachedContents' + client.vertexai = False + # client.vertexai = True + _resource_name(client, resource_name, + collection_identifier='cachedContents') + returns: 'some/wrong/cachedContents/resource/name/123' + + Returns: + The completed resource name. + """ + should_prepend_collection_identifier = ( + not resource_name.startswith(f'{collection_identifier}/') + # Check if prepending the collection identifier won't violate the + # collection hierarchy depth. + and f'{collection_identifier}/{resource_name}'.count('/') + 1 + == collection_hierarchy_depth + ) + if client.vertexai: + if resource_name.startswith('projects/'): + return resource_name + elif resource_name.startswith('locations/'): + return f'projects/{client.project}/{resource_name}' + elif resource_name.startswith(f'{collection_identifier}/'): + return f'projects/{client.project}/locations/{client.location}/{resource_name}' + elif should_prepend_collection_identifier: + return f'projects/{client.project}/locations/{client.location}/{collection_identifier}/{resource_name}' + else: + return resource_name + else: + if should_prepend_collection_identifier: + return f'{collection_identifier}/{resource_name}' + else: + return resource_name + + +def t_model(client: _api_client.ApiClient, model: str): + if not model: + raise ValueError('model is required.') + if client.vertexai: + if ( + model.startswith('projects/') + or model.startswith('models/') + or model.startswith('publishers/') + ): + return model + elif '/' in model: + publisher, model_id = model.split('/', 1) + return f'publishers/{publisher}/models/{model_id}' + else: + return f'publishers/google/models/{model}' + else: + if model.startswith('models/'): + return model + elif model.startswith('tunedModels/'): + return model + else: + return f'models/{model}' + + +def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str: + if api_client.vertexai: + if base_models: + return 'publishers/google/models' + else: + return 'models' + else: + if base_models: + return 'models' + else: + return 'tunedModels' + + +def t_extract_models( + api_client: _api_client.ApiClient, response: dict +) -> list[types.Model]: + if not response: + return [] + elif response.get('models') is not None: + return response.get('models') + elif response.get('tunedModels') is not None: + return response.get('tunedModels') + elif response.get('publisherModels') is not None: + return response.get('publisherModels') + else: + raise ValueError('Cannot determine the models type.') + + +def t_caches_model(api_client: _api_client.ApiClient, model: str): + model = t_model(api_client, model) + if not model: + return None + if model.startswith('publishers/') and api_client.vertexai: + # vertex caches only support model name start with projects. + return ( + f'projects/{api_client.project}/locations/{api_client.location}/{model}' + ) + elif model.startswith('models/') and api_client.vertexai: + return f'projects/{api_client.project}/locations/{api_client.location}/publishers/google/{model}' + else: + return model + + +def pil_to_blob(img): + bytesio = io.BytesIO() + if isinstance(img, PIL.PngImagePlugin.PngImageFile) or img.mode == 'RGBA': + img.save(bytesio, format='PNG') + mime_type = 'image/png' + else: + img.save(bytesio, format='JPEG') + mime_type = 'image/jpeg' + bytesio.seek(0) + data = bytesio.read() + return types.Blob(mime_type=mime_type, data=data) + + +PartType = Union[types.Part, types.PartDict, str, PIL.Image.Image] + + +def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part: + if not part: + raise ValueError('content part is required.') + if isinstance(part, str): + return types.Part(text=part) + if isinstance(part, PIL.Image.Image): + return types.Part(inline_data=pil_to_blob(part)) + if isinstance(part, types.File): + if not part.uri or not part.mime_type: + raise ValueError('file uri and mime_type are required.') + return types.Part.from_uri(part.uri, part.mime_type) + else: + return part + + +def t_parts( + client: _api_client.ApiClient, parts: Union[list, PartType] +) -> list[types.Part]: + if parts is None: + raise ValueError('content parts are required.') + if isinstance(parts, list): + return [t_part(client, part) for part in parts] + else: + return [t_part(client, parts)] + + +def t_image_predictions( + client: _api_client.ApiClient, + predictions: Optional[Iterable[Mapping[str, Any]]], +) -> list[types.GeneratedImage]: + if not predictions: + return None + images = [] + for prediction in predictions: + if prediction.get('image'): + images.append( + types.GeneratedImage( + image=types.Image( + gcs_uri=prediction['image']['gcsUri'], + image_bytes=prediction['image']['imageBytes'], + ) + ) + ) + return images + + +ContentType = Union[types.Content, types.ContentDict, PartType] + + +def t_content( + client: _api_client.ApiClient, + content: ContentType, +): + if not content: + raise ValueError('content is required.') + if isinstance(content, types.Content): + return content + if isinstance(content, dict): + return types.Content.model_validate(content) + return types.Content(role='user', parts=t_parts(client, content)) + + +def t_contents_for_embed( + client: _api_client.ApiClient, + contents: Union[list[types.Content], list[types.ContentDict], ContentType], +): + if client.vertexai and isinstance(contents, list): + # TODO: Assert that only text is supported. + return [t_content(client, content).parts[0].text for content in contents] + elif client.vertexai: + return [t_content(client, contents).parts[0].text] + elif isinstance(contents, list): + return [t_content(client, content) for content in contents] + else: + return [t_content(client, contents)] + + +def t_contents( + client: _api_client.ApiClient, + contents: Union[list[types.Content], list[types.ContentDict], ContentType], +): + if not contents: + raise ValueError('contents are required.') + if isinstance(contents, list): + return [t_content(client, content) for content in contents] + else: + return [t_content(client, contents)] + + +def process_schema( + data: dict[str, Any], client: Optional[_api_client.ApiClient] = None +): + if isinstance(data, dict): + # Iterate over a copy of keys to allow deletion + for key in list(data.keys()): + # Only delete 'title'for the Gemini API + if client and not client.vertexai and key == 'title': + del data[key] + else: + process_schema(data[key], client) + elif isinstance(data, list): + for item in data: + process_schema(item, client) + + return data + + +def _build_schema(fname: str, fields_dict: dict[str, Any]) -> dict[str, Any]: + parameters = pydantic.create_model(fname, **fields_dict).model_json_schema() + defs = parameters.pop('$defs', {}) + + for _, value in defs.items(): + unpack_defs(value, defs) + + unpack_defs(parameters, defs) + return parameters['properties']['dummy'] + + +def unpack_defs(schema: dict[str, Any], defs: dict[str, Any]): + """Unpacks the $defs values in the schema generated by pydantic so they can be understood by the API. + + Example of a schema before and after unpacking: + Before: + + `schema` + + {'properties': { + 'dummy': { + 'items': { + '$ref': '#/$defs/CountryInfo' + }, + 'title': 'Dummy', + 'type': 'array' + } + }, + 'required': ['dummy'], + 'title': 'dummy', + 'type': 'object'} + + `defs` + + {'CountryInfo': {'properties': {'continent': {'title': 'Continent', 'type': + 'string'}, 'gdp': {'title': 'Gdp', 'type': 'integer'}}, 'required': + ['continent', 'gdp'], 'title': 'CountryInfo', 'type': 'object'}} + + After: + + `schema` + {'properties': { + 'continent': {'title': 'Continent', 'type': 'string'}, + 'gdp': {'title': 'Gdp', 'type': 'integer'} + }, + 'required': ['continent', 'gdp'], + 'title': 'CountryInfo', + 'type': 'object' + } + """ + properties = schema.get('properties', None) + if properties is None: + return + + for name, value in properties.items(): + ref_key = value.get('$ref', None) + if ref_key is not None: + ref = defs[ref_key.split('defs/')[-1]] + unpack_defs(ref, defs) + properties[name] = ref + continue + + anyof = value.get('anyOf', None) + if anyof is not None: + for i, atype in enumerate(anyof): + ref_key = atype.get('$ref', None) + if ref_key is not None: + ref = defs[ref_key.split('defs/')[-1]] + unpack_defs(ref, defs) + anyof[i] = ref + continue + + items = value.get('items', None) + if items is not None: + ref_key = items.get('$ref', None) + if ref_key is not None: + ref = defs[ref_key.split('defs/')[-1]] + unpack_defs(ref, defs) + value['items'] = ref + continue + + +def t_schema( + client: _api_client.ApiClient, origin: Union[types.SchemaUnionDict, Any] +) -> Optional[types.Schema]: + if not origin: + return None + if isinstance(origin, dict): + return process_schema(origin, client) + if isinstance(origin, types.Schema): + if dict(origin) == dict(types.Schema()): + # response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation + raise ValueError(f'Unsupported schema type.') + schema = process_schema(origin.model_dump(exclude_unset=True), client) + return types.Schema.model_validate(schema) + if isinstance(origin, GenericAlias): + if origin.__origin__ is list: + if isinstance(origin.__args__[0], typing.types.UnionType): + raise ValueError(f'Unsupported schema type: GenericAlias {origin}') + if issubclass(origin.__args__[0], pydantic.BaseModel): + # Handle cases where response schema is `list[pydantic.BaseModel]` + list_schema = _build_schema( + 'dummy', {'dummy': (origin, pydantic.Field())} + ) + list_schema = process_schema(list_schema, client) + return types.Schema.model_validate(list_schema) + raise ValueError(f'Unsupported schema type: GenericAlias {origin}') + if issubclass(origin, pydantic.BaseModel): + schema = process_schema(origin.model_json_schema(), client) + return types.Schema.model_validate(schema) + raise ValueError(f'Unsupported schema type: {origin}') + + +def t_speech_config( + _: _api_client.ApiClient, origin: Union[types.SpeechConfigUnionDict, Any] +) -> Optional[types.SpeechConfig]: + if not origin: + return None + if isinstance(origin, types.SpeechConfig): + return origin + if isinstance(origin, str): + return types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=origin) + ) + ) + if ( + isinstance(origin, dict) + and 'voice_config' in origin + and 'prebuilt_voice_config' in origin['voice_config'] + ): + return types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name=origin['voice_config']['prebuilt_voice_config'].get( + 'voice_name' + ) + ) + ) + ) + raise ValueError(f'Unsupported speechConfig type: {type(origin)}') + + +def t_tool(client: _api_client.ApiClient, origin) -> types.Tool: + if not origin: + return None + if inspect.isfunction(origin) or inspect.ismethod(origin): + return types.Tool( + function_declarations=[ + types.FunctionDeclaration.from_callable(client, origin) + ] + ) + else: + return origin + + +# Only support functions now. +def t_tools( + client: _api_client.ApiClient, origin: list[Any] +) -> list[types.Tool]: + if not origin: + return [] + function_tool = types.Tool(function_declarations=[]) + tools = [] + for tool in origin: + transformed_tool = t_tool(client, tool) + # All functions should be merged into one tool. + if transformed_tool.function_declarations: + function_tool.function_declarations += ( + transformed_tool.function_declarations + ) + else: + tools.append(transformed_tool) + if function_tool.function_declarations: + tools.append(function_tool) + return tools + + +def t_cached_content_name(client: _api_client.ApiClient, name: str): + return _resource_name(client, name, collection_identifier='cachedContents') + + +def t_batch_job_source(client: _api_client.ApiClient, src: str): + if src.startswith('gs://'): + return types.BatchJobSource( + format='jsonl', + gcs_uri=[src], + ) + elif src.startswith('bq://'): + return types.BatchJobSource( + format='bigquery', + bigquery_uri=src, + ) + else: + raise ValueError(f'Unsupported source: {src}') + + +def t_batch_job_destination(client: _api_client.ApiClient, dest: str): + if dest.startswith('gs://'): + return types.BatchJobDestination( + format='jsonl', + gcs_uri=dest, + ) + elif dest.startswith('bq://'): + return types.BatchJobDestination( + format='bigquery', + bigquery_uri=dest, + ) + else: + raise ValueError(f'Unsupported destination: {dest}') + + +def t_batch_job_name(client: _api_client.ApiClient, name: str): + if not client.vertexai: + return name + + pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$' + if re.match(pattern, name): + return name.split('/')[-1] + elif name.isdigit(): + return name + else: + raise ValueError(f'Invalid batch job name: {name}.') + + +LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0 +LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0 +LRO_POLLING_TIMEOUT_SECONDS = 900.0 +LRO_POLLING_MULTIPLIER = 1.5 + + +def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict): + if (name := struct.get('name')) and '/operations/' in name: + operation: dict[str, Any] = struct + total_seconds = 0.0 + delay_seconds = LRO_POLLING_INITIAL_DELAY_SECONDS + while operation.get('done') != True: + if total_seconds > LRO_POLLING_TIMEOUT_SECONDS: + raise RuntimeError(f'Operation {name} timed out.\n{operation}') + # TODO(b/374433890): Replace with LRO module once it's available. + operation: dict[str, Any] = api_client.request( + http_method='GET', path=name, request_dict={} + ) + time.sleep(delay_seconds) + total_seconds += total_seconds + # Exponential backoff + delay_seconds = min( + delay_seconds * LRO_POLLING_MULTIPLIER, + LRO_POLLING_MAXIMUM_DELAY_SECONDS, + ) + if error := operation.get('error'): + raise RuntimeError( + f'Operation {name} failed with error: {error}.\n{operation}' + ) + return operation.get('response') + else: + return struct + + +def t_file_name( + api_client: _api_client.ApiClient, name: Union[str, types.File] +): + # Remove the files/ prefix since it's added to the url path. + if isinstance(name, types.File): + name = name.name + + if name is None: + raise ValueError('File name is required.') + + if name.startswith('https://'): + suffix = name.split('files/')[1] + match = re.match('[a-z0-9]+', suffix) + if match is None: + raise ValueError(f'Could not extract file name from URI: {name}') + name = match.group(0) + elif name.startswith('files/'): + name = name.split('files/')[1] + + return name + + +def t_tuning_job_status( + api_client: _api_client.ApiClient, status: str +) -> types.JobState: + if status == 'STATE_UNSPECIFIED': + return 'JOB_STATE_UNSPECIFIED' + elif status == 'CREATING': + return 'JOB_STATE_RUNNING' + elif status == 'ACTIVE': + return 'JOB_STATE_SUCCEEDED' + elif status == 'FAILED': + return 'JOB_STATE_FAILED' + else: + return status + + +# Some fields don't accept url safe base64 encoding. +# We shouldn't use this transformer if the backend adhere to Cloud Type +# format https://cloud.google.com/docs/discovery/type-format. +# TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue. +def t_bytes(api_client: _api_client.ApiClient, data: bytes) -> str: + if not isinstance(data, bytes): + return data + return base64.b64encode(data).decode('ascii') diff --git a/.venv/lib/python3.12/site-packages/google/genai/batches.py b/.venv/lib/python3.12/site-packages/google/genai/batches.py new file mode 100644 index 00000000..6dd5b4c5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/batches.py @@ -0,0 +1,1293 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +from typing import Optional, Union +from urllib.parse import urlencode +from . import _common +from . import _extra_utils +from . import _transformers as t +from . import types +from ._api_client import ApiClient +from ._common import get_value_by_path as getv +from ._common import set_value_by_path as setv +from .pagers import AsyncPager, Pager + + +def _BatchJobSource_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['format']) is not None: + raise ValueError('format parameter is not supported in Google AI.') + + if getv(from_object, ['gcs_uri']) is not None: + raise ValueError('gcs_uri parameter is not supported in Google AI.') + + if getv(from_object, ['bigquery_uri']) is not None: + raise ValueError('bigquery_uri parameter is not supported in Google AI.') + + return to_object + + +def _BatchJobSource_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['format']) is not None: + setv(to_object, ['instancesFormat'], getv(from_object, ['format'])) + + if getv(from_object, ['gcs_uri']) is not None: + setv(to_object, ['gcsSource', 'uris'], getv(from_object, ['gcs_uri'])) + + if getv(from_object, ['bigquery_uri']) is not None: + setv( + to_object, + ['bigquerySource', 'inputUri'], + getv(from_object, ['bigquery_uri']), + ) + + return to_object + + +def _BatchJobDestination_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['format']) is not None: + raise ValueError('format parameter is not supported in Google AI.') + + if getv(from_object, ['gcs_uri']) is not None: + raise ValueError('gcs_uri parameter is not supported in Google AI.') + + if getv(from_object, ['bigquery_uri']) is not None: + raise ValueError('bigquery_uri parameter is not supported in Google AI.') + + return to_object + + +def _BatchJobDestination_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['format']) is not None: + setv(to_object, ['predictionsFormat'], getv(from_object, ['format'])) + + if getv(from_object, ['gcs_uri']) is not None: + setv( + to_object, + ['gcsDestination', 'outputUriPrefix'], + getv(from_object, ['gcs_uri']), + ) + + if getv(from_object, ['bigquery_uri']) is not None: + setv( + to_object, + ['bigqueryDestination', 'outputUri'], + getv(from_object, ['bigquery_uri']), + ) + + return to_object + + +def _CreateBatchJobConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['display_name']) is not None: + setv(parent_object, ['displayName'], getv(from_object, ['display_name'])) + + if getv(from_object, ['dest']) is not None: + raise ValueError('dest parameter is not supported in Google AI.') + + return to_object + + +def _CreateBatchJobConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['display_name']) is not None: + setv(parent_object, ['displayName'], getv(from_object, ['display_name'])) + + if getv(from_object, ['dest']) is not None: + setv( + parent_object, + ['outputConfig'], + _BatchJobDestination_to_vertex( + api_client, + t.t_batch_job_destination(api_client, getv(from_object, ['dest'])), + to_object, + ), + ) + + return to_object + + +def _CreateBatchJobParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + raise ValueError('model parameter is not supported in Google AI.') + + if getv(from_object, ['src']) is not None: + raise ValueError('src parameter is not supported in Google AI.') + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CreateBatchJobConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _CreateBatchJobParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['src']) is not None: + setv( + to_object, + ['inputConfig'], + _BatchJobSource_to_vertex( + api_client, + t.t_batch_job_source(api_client, getv(from_object, ['src'])), + to_object, + ), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CreateBatchJobConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _GetBatchJobConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _GetBatchJobConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _GetBatchJobParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + raise ValueError('name parameter is not supported in Google AI.') + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _GetBatchJobConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _GetBatchJobParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_batch_job_name(api_client, getv(from_object, ['name'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _GetBatchJobConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _CancelBatchJobConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _CancelBatchJobConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _CancelBatchJobParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + raise ValueError('name parameter is not supported in Google AI.') + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CancelBatchJobConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _CancelBatchJobParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_batch_job_name(api_client, getv(from_object, ['name'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CancelBatchJobConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _ListBatchJobConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['page_size']) is not None: + setv( + parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size']) + ) + + if getv(from_object, ['page_token']) is not None: + setv( + parent_object, + ['_query', 'pageToken'], + getv(from_object, ['page_token']), + ) + + if getv(from_object, ['filter']) is not None: + raise ValueError('filter parameter is not supported in Google AI.') + + return to_object + + +def _ListBatchJobConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['page_size']) is not None: + setv( + parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size']) + ) + + if getv(from_object, ['page_token']) is not None: + setv( + parent_object, + ['_query', 'pageToken'], + getv(from_object, ['page_token']), + ) + + if getv(from_object, ['filter']) is not None: + setv(parent_object, ['_query', 'filter'], getv(from_object, ['filter'])) + + return to_object + + +def _ListBatchJobParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['config']) is not None: + raise ValueError('config parameter is not supported in Google AI.') + + return to_object + + +def _ListBatchJobParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _ListBatchJobConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _DeleteBatchJobParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + raise ValueError('name parameter is not supported in Google AI.') + + return to_object + + +def _DeleteBatchJobParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_batch_job_name(api_client, getv(from_object, ['name'])), + ) + + return to_object + + +def _JobError_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _JobError_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['details']) is not None: + setv(to_object, ['details'], getv(from_object, ['details'])) + + if getv(from_object, ['code']) is not None: + setv(to_object, ['code'], getv(from_object, ['code'])) + + if getv(from_object, ['message']) is not None: + setv(to_object, ['message'], getv(from_object, ['message'])) + + return to_object + + +def _BatchJobSource_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _BatchJobSource_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['instancesFormat']) is not None: + setv(to_object, ['format'], getv(from_object, ['instancesFormat'])) + + if getv(from_object, ['gcsSource', 'uris']) is not None: + setv(to_object, ['gcs_uri'], getv(from_object, ['gcsSource', 'uris'])) + + if getv(from_object, ['bigquerySource', 'inputUri']) is not None: + setv( + to_object, + ['bigquery_uri'], + getv(from_object, ['bigquerySource', 'inputUri']), + ) + + return to_object + + +def _BatchJobDestination_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _BatchJobDestination_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['predictionsFormat']) is not None: + setv(to_object, ['format'], getv(from_object, ['predictionsFormat'])) + + if getv(from_object, ['gcsDestination', 'outputUriPrefix']) is not None: + setv( + to_object, + ['gcs_uri'], + getv(from_object, ['gcsDestination', 'outputUriPrefix']), + ) + + if getv(from_object, ['bigqueryDestination', 'outputUri']) is not None: + setv( + to_object, + ['bigquery_uri'], + getv(from_object, ['bigqueryDestination', 'outputUri']), + ) + + return to_object + + +def _BatchJob_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _BatchJob_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['displayName']) is not None: + setv(to_object, ['display_name'], getv(from_object, ['displayName'])) + + if getv(from_object, ['state']) is not None: + setv(to_object, ['state'], getv(from_object, ['state'])) + + if getv(from_object, ['error']) is not None: + setv( + to_object, + ['error'], + _JobError_from_vertex( + api_client, getv(from_object, ['error']), to_object + ), + ) + + if getv(from_object, ['createTime']) is not None: + setv(to_object, ['create_time'], getv(from_object, ['createTime'])) + + if getv(from_object, ['startTime']) is not None: + setv(to_object, ['start_time'], getv(from_object, ['startTime'])) + + if getv(from_object, ['endTime']) is not None: + setv(to_object, ['end_time'], getv(from_object, ['endTime'])) + + if getv(from_object, ['updateTime']) is not None: + setv(to_object, ['update_time'], getv(from_object, ['updateTime'])) + + if getv(from_object, ['model']) is not None: + setv(to_object, ['model'], getv(from_object, ['model'])) + + if getv(from_object, ['inputConfig']) is not None: + setv( + to_object, + ['src'], + _BatchJobSource_from_vertex( + api_client, getv(from_object, ['inputConfig']), to_object + ), + ) + + if getv(from_object, ['outputConfig']) is not None: + setv( + to_object, + ['dest'], + _BatchJobDestination_from_vertex( + api_client, getv(from_object, ['outputConfig']), to_object + ), + ) + + return to_object + + +def _ListBatchJobResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['nextPageToken']) is not None: + setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken'])) + + return to_object + + +def _ListBatchJobResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['nextPageToken']) is not None: + setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken'])) + + if getv(from_object, ['batchPredictionJobs']) is not None: + setv( + to_object, + ['batch_jobs'], + [ + _BatchJob_from_vertex(api_client, item, to_object) + for item in getv(from_object, ['batchPredictionJobs']) + ], + ) + + return to_object + + +def _DeleteResourceJob_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _DeleteResourceJob_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['done']) is not None: + setv(to_object, ['done'], getv(from_object, ['done'])) + + if getv(from_object, ['error']) is not None: + setv( + to_object, + ['error'], + _JobError_from_vertex( + api_client, getv(from_object, ['error']), to_object + ), + ) + + return to_object + + +class Batches(_common.BaseModule): + + def _create( + self, + *, + model: str, + src: str, + config: Optional[types.CreateBatchJobConfigOrDict] = None, + ) -> types.BatchJob: + parameter_model = types._CreateBatchJobParameters( + model=model, + src=src, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _CreateBatchJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'batchPredictionJobs'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _BatchJob_from_vertex(self._api_client, response_dict) + else: + response_dict = _BatchJob_from_mldev(self._api_client, response_dict) + + return_value = types.BatchJob._from_response(response_dict, parameter_model) + self._api_client._verify_response(return_value) + return return_value + + def get( + self, *, name: str, config: Optional[types.GetBatchJobConfigOrDict] = None + ) -> types.BatchJob: + """Gets a batch job. + + Args: + name (str): A fully-qualified BatchJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" or "456" + when project and location are initialized in the client. + + Returns: + A BatchJob object that contains details about the batch job. + + Usage: + + .. code-block:: python + + batch_job = client.batches.get(name='123456789') + print(f"Batch job: {batch_job.name}, state {batch_job.state}") + """ + + parameter_model = types._GetBatchJobParameters( + name=name, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _GetBatchJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'batchPredictionJobs/{name}'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _BatchJob_from_vertex(self._api_client, response_dict) + else: + response_dict = _BatchJob_from_mldev(self._api_client, response_dict) + + return_value = types.BatchJob._from_response(response_dict, parameter_model) + self._api_client._verify_response(return_value) + return return_value + + def cancel( + self, + *, + name: str, + config: Optional[types.CancelBatchJobConfigOrDict] = None, + ) -> None: + parameter_model = types._CancelBatchJobParameters( + name=name, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _CancelBatchJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'batchPredictionJobs/{name}:cancel'.format_map( + request_dict.get('_url') + ) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + def _list( + self, *, config: types.ListBatchJobConfigOrDict + ) -> types.ListBatchJobResponse: + parameter_model = types._ListBatchJobParameters( + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _ListBatchJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'batchPredictionJobs'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ListBatchJobResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ListBatchJobResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ListBatchJobResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def delete(self, *, name: str) -> types.DeleteResourceJob: + """Deletes a batch job. + + Args: + name (str): A fully-qualified BatchJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" or "456" + when project and location are initialized in the client. + + Returns: + A DeleteResourceJob object that shows the status of the deletion. + + Usage: + + .. code-block:: python + + client.batches.delete(name='123456789') + """ + + parameter_model = types._DeleteBatchJobParameters( + name=name, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _DeleteBatchJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'batchPredictionJobs/{name}'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'delete', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _DeleteResourceJob_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _DeleteResourceJob_from_mldev( + self._api_client, response_dict + ) + + return_value = types.DeleteResourceJob._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def create( + self, + *, + model: str, + src: str, + config: Optional[types.CreateBatchJobConfigOrDict] = None, + ) -> types.BatchJob: + """Creates a batch job. + + Args: + model (str): The model to use for the batch job. + src (str): The source of the batch job. Currently supports GCS URI(-s) or + BigQuery URI. Example: "gs://path/to/input/data" or + "bq://projectId.bqDatasetId.bqTableId". + config (CreateBatchJobConfig): Optional configuration for the batch job. + + Returns: + A BatchJob object that contains details about the batch job. + + Usage: + + .. code-block:: python + + batch_job = client.batches.create( + model="gemini-1.5-flash", + src="gs://path/to/input/data", + ) + print(batch_job.state) + """ + config = _extra_utils.format_destination(src, config) + return self._create(model=model, src=src, config=config) + + def list( + self, *, config: Optional[types.ListBatchJobConfigOrDict] = None + ) -> Pager[types.BatchJob]: + """Lists batch jobs. + + Args: + config (ListBatchJobConfig): Optional configuration for the list request. + + Returns: + A Pager object that contains one page of batch jobs. When iterating over + the pager, it automatically fetches the next page if there are more. + + Usage: + + .. code-block:: python + + batch_jobs = client.batches.list(config={"page_size": 10}) + for batch_job in batch_jobs: + print(f"Batch job: {batch_job.name}, state {batch_job.state}") + """ + return Pager( + 'batch_jobs', + self._list, + self._list(config=config), + config, + ) + + +class AsyncBatches(_common.BaseModule): + + async def _create( + self, + *, + model: str, + src: str, + config: Optional[types.CreateBatchJobConfigOrDict] = None, + ) -> types.BatchJob: + parameter_model = types._CreateBatchJobParameters( + model=model, + src=src, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _CreateBatchJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'batchPredictionJobs'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _BatchJob_from_vertex(self._api_client, response_dict) + else: + response_dict = _BatchJob_from_mldev(self._api_client, response_dict) + + return_value = types.BatchJob._from_response(response_dict, parameter_model) + self._api_client._verify_response(return_value) + return return_value + + async def get( + self, *, name: str, config: Optional[types.GetBatchJobConfigOrDict] = None + ) -> types.BatchJob: + """Gets a batch job. + + Args: + name (str): A fully-qualified BatchJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" or "456" + when project and location are initialized in the client. + + Returns: + A BatchJob object that contains details about the batch job. + + Usage: + + .. code-block:: python + + batch_job = client.batches.get(name='123456789') + print(f"Batch job: {batch_job.name}, state {batch_job.state}") + """ + + parameter_model = types._GetBatchJobParameters( + name=name, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _GetBatchJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'batchPredictionJobs/{name}'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _BatchJob_from_vertex(self._api_client, response_dict) + else: + response_dict = _BatchJob_from_mldev(self._api_client, response_dict) + + return_value = types.BatchJob._from_response(response_dict, parameter_model) + self._api_client._verify_response(return_value) + return return_value + + async def cancel( + self, + *, + name: str, + config: Optional[types.CancelBatchJobConfigOrDict] = None, + ) -> None: + parameter_model = types._CancelBatchJobParameters( + name=name, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _CancelBatchJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'batchPredictionJobs/{name}:cancel'.format_map( + request_dict.get('_url') + ) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + async def _list( + self, *, config: types.ListBatchJobConfigOrDict + ) -> types.ListBatchJobResponse: + parameter_model = types._ListBatchJobParameters( + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _ListBatchJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'batchPredictionJobs'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ListBatchJobResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ListBatchJobResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ListBatchJobResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def delete(self, *, name: str) -> types.DeleteResourceJob: + """Deletes a batch job. + + Args: + name (str): A fully-qualified BatchJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" or "456" + when project and location are initialized in the client. + + Returns: + A DeleteResourceJob object that shows the status of the deletion. + + Usage: + + .. code-block:: python + + client.batches.delete(name='123456789') + """ + + parameter_model = types._DeleteBatchJobParameters( + name=name, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _DeleteBatchJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'batchPredictionJobs/{name}'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'delete', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _DeleteResourceJob_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _DeleteResourceJob_from_mldev( + self._api_client, response_dict + ) + + return_value = types.DeleteResourceJob._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def create( + self, + *, + model: str, + src: str, + config: Optional[types.CreateBatchJobConfigOrDict] = None, + ) -> types.BatchJob: + """Creates a batch job asynchronously. + + Args: + model (str): The model to use for the batch job. + src (str): The source of the batch job. Currently supports GCS URI(-s) or + BigQuery URI. Example: "gs://path/to/input/data" or + "bq://projectId.bqDatasetId.bqTableId". + config (CreateBatchJobConfig): Optional configuration for the batch job. + + Returns: + A BatchJob object that contains details about the batch job. + + Usage: + + .. code-block:: python + + batch_job = await client.aio.batches.create( + model="gemini-1.5-flash", + src="gs://path/to/input/data", + ) + """ + config = _extra_utils.format_destination(src, config) + return await self._create(model=model, src=src, config=config) + + async def list( + self, *, config: Optional[types.ListBatchJobConfigOrDict] = None + ) -> AsyncPager[types.BatchJob]: + """Lists batch jobs asynchronously. + + Args: + config (ListBatchJobConfig): Optional configuration for the list request. + + Returns: + A Pager object that contains one page of batch jobs. When iterating over + the pager, it automatically fetches the next page if there are more. + + Usage: + + .. code-block:: python + + batch_jobs = await client.aio.batches.list(config={'page_size': 5}) + print(f"current page: {batch_jobs.page}") + await batch_jobs_pager.next_page() + print(f"next page: {batch_jobs_pager.page}") + """ + return AsyncPager( + 'batch_jobs', + self._list, + await self._list(config=config), + config, + ) diff --git a/.venv/lib/python3.12/site-packages/google/genai/caches.py b/.venv/lib/python3.12/site-packages/google/genai/caches.py new file mode 100644 index 00000000..6d0e576b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/caches.py @@ -0,0 +1,1856 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +from typing import Optional, Union +from urllib.parse import urlencode +from . import _common +from . import _transformers as t +from . import types +from ._api_client import ApiClient +from ._common import get_value_by_path as getv +from ._common import set_value_by_path as setv +from .pagers import AsyncPager, Pager + + +def _Part_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['video_metadata']) is not None: + raise ValueError('video_metadata parameter is not supported in Google AI.') + + if getv(from_object, ['thought']) is not None: + setv(to_object, ['thought'], getv(from_object, ['thought'])) + + if getv(from_object, ['code_execution_result']) is not None: + setv( + to_object, + ['codeExecutionResult'], + getv(from_object, ['code_execution_result']), + ) + + if getv(from_object, ['executable_code']) is not None: + setv(to_object, ['executableCode'], getv(from_object, ['executable_code'])) + + if getv(from_object, ['file_data']) is not None: + setv(to_object, ['fileData'], getv(from_object, ['file_data'])) + + if getv(from_object, ['function_call']) is not None: + setv(to_object, ['functionCall'], getv(from_object, ['function_call'])) + + if getv(from_object, ['function_response']) is not None: + setv( + to_object, + ['functionResponse'], + getv(from_object, ['function_response']), + ) + + if getv(from_object, ['inline_data']) is not None: + setv(to_object, ['inlineData'], getv(from_object, ['inline_data'])) + + if getv(from_object, ['text']) is not None: + setv(to_object, ['text'], getv(from_object, ['text'])) + + return to_object + + +def _Part_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['video_metadata']) is not None: + setv(to_object, ['videoMetadata'], getv(from_object, ['video_metadata'])) + + if getv(from_object, ['thought']) is not None: + setv(to_object, ['thought'], getv(from_object, ['thought'])) + + if getv(from_object, ['code_execution_result']) is not None: + setv( + to_object, + ['codeExecutionResult'], + getv(from_object, ['code_execution_result']), + ) + + if getv(from_object, ['executable_code']) is not None: + setv(to_object, ['executableCode'], getv(from_object, ['executable_code'])) + + if getv(from_object, ['file_data']) is not None: + setv(to_object, ['fileData'], getv(from_object, ['file_data'])) + + if getv(from_object, ['function_call']) is not None: + setv(to_object, ['functionCall'], getv(from_object, ['function_call'])) + + if getv(from_object, ['function_response']) is not None: + setv( + to_object, + ['functionResponse'], + getv(from_object, ['function_response']), + ) + + if getv(from_object, ['inline_data']) is not None: + setv(to_object, ['inlineData'], getv(from_object, ['inline_data'])) + + if getv(from_object, ['text']) is not None: + setv(to_object, ['text'], getv(from_object, ['text'])) + + return to_object + + +def _Content_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['parts']) is not None: + setv( + to_object, + ['parts'], + [ + _Part_to_mldev(api_client, item, to_object) + for item in getv(from_object, ['parts']) + ], + ) + + if getv(from_object, ['role']) is not None: + setv(to_object, ['role'], getv(from_object, ['role'])) + + return to_object + + +def _Content_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['parts']) is not None: + setv( + to_object, + ['parts'], + [ + _Part_to_vertex(api_client, item, to_object) + for item in getv(from_object, ['parts']) + ], + ) + + if getv(from_object, ['role']) is not None: + setv(to_object, ['role'], getv(from_object, ['role'])) + + return to_object + + +def _Schema_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['min_items']) is not None: + raise ValueError('min_items parameter is not supported in Google AI.') + + if getv(from_object, ['example']) is not None: + raise ValueError('example parameter is not supported in Google AI.') + + if getv(from_object, ['property_ordering']) is not None: + raise ValueError( + 'property_ordering parameter is not supported in Google AI.' + ) + + if getv(from_object, ['pattern']) is not None: + raise ValueError('pattern parameter is not supported in Google AI.') + + if getv(from_object, ['minimum']) is not None: + raise ValueError('minimum parameter is not supported in Google AI.') + + if getv(from_object, ['default']) is not None: + raise ValueError('default parameter is not supported in Google AI.') + + if getv(from_object, ['any_of']) is not None: + raise ValueError('any_of parameter is not supported in Google AI.') + + if getv(from_object, ['max_length']) is not None: + raise ValueError('max_length parameter is not supported in Google AI.') + + if getv(from_object, ['title']) is not None: + raise ValueError('title parameter is not supported in Google AI.') + + if getv(from_object, ['min_length']) is not None: + raise ValueError('min_length parameter is not supported in Google AI.') + + if getv(from_object, ['min_properties']) is not None: + raise ValueError('min_properties parameter is not supported in Google AI.') + + if getv(from_object, ['max_items']) is not None: + raise ValueError('max_items parameter is not supported in Google AI.') + + if getv(from_object, ['maximum']) is not None: + raise ValueError('maximum parameter is not supported in Google AI.') + + if getv(from_object, ['nullable']) is not None: + raise ValueError('nullable parameter is not supported in Google AI.') + + if getv(from_object, ['max_properties']) is not None: + raise ValueError('max_properties parameter is not supported in Google AI.') + + if getv(from_object, ['type']) is not None: + setv(to_object, ['type'], getv(from_object, ['type'])) + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['enum']) is not None: + setv(to_object, ['enum'], getv(from_object, ['enum'])) + + if getv(from_object, ['format']) is not None: + setv(to_object, ['format'], getv(from_object, ['format'])) + + if getv(from_object, ['items']) is not None: + setv(to_object, ['items'], getv(from_object, ['items'])) + + if getv(from_object, ['properties']) is not None: + setv(to_object, ['properties'], getv(from_object, ['properties'])) + + if getv(from_object, ['required']) is not None: + setv(to_object, ['required'], getv(from_object, ['required'])) + + return to_object + + +def _Schema_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['min_items']) is not None: + setv(to_object, ['minItems'], getv(from_object, ['min_items'])) + + if getv(from_object, ['example']) is not None: + setv(to_object, ['example'], getv(from_object, ['example'])) + + if getv(from_object, ['property_ordering']) is not None: + setv( + to_object, + ['propertyOrdering'], + getv(from_object, ['property_ordering']), + ) + + if getv(from_object, ['pattern']) is not None: + setv(to_object, ['pattern'], getv(from_object, ['pattern'])) + + if getv(from_object, ['minimum']) is not None: + setv(to_object, ['minimum'], getv(from_object, ['minimum'])) + + if getv(from_object, ['default']) is not None: + setv(to_object, ['default'], getv(from_object, ['default'])) + + if getv(from_object, ['any_of']) is not None: + setv(to_object, ['anyOf'], getv(from_object, ['any_of'])) + + if getv(from_object, ['max_length']) is not None: + setv(to_object, ['maxLength'], getv(from_object, ['max_length'])) + + if getv(from_object, ['title']) is not None: + setv(to_object, ['title'], getv(from_object, ['title'])) + + if getv(from_object, ['min_length']) is not None: + setv(to_object, ['minLength'], getv(from_object, ['min_length'])) + + if getv(from_object, ['min_properties']) is not None: + setv(to_object, ['minProperties'], getv(from_object, ['min_properties'])) + + if getv(from_object, ['max_items']) is not None: + setv(to_object, ['maxItems'], getv(from_object, ['max_items'])) + + if getv(from_object, ['maximum']) is not None: + setv(to_object, ['maximum'], getv(from_object, ['maximum'])) + + if getv(from_object, ['nullable']) is not None: + setv(to_object, ['nullable'], getv(from_object, ['nullable'])) + + if getv(from_object, ['max_properties']) is not None: + setv(to_object, ['maxProperties'], getv(from_object, ['max_properties'])) + + if getv(from_object, ['type']) is not None: + setv(to_object, ['type'], getv(from_object, ['type'])) + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['enum']) is not None: + setv(to_object, ['enum'], getv(from_object, ['enum'])) + + if getv(from_object, ['format']) is not None: + setv(to_object, ['format'], getv(from_object, ['format'])) + + if getv(from_object, ['items']) is not None: + setv(to_object, ['items'], getv(from_object, ['items'])) + + if getv(from_object, ['properties']) is not None: + setv(to_object, ['properties'], getv(from_object, ['properties'])) + + if getv(from_object, ['required']) is not None: + setv(to_object, ['required'], getv(from_object, ['required'])) + + return to_object + + +def _FunctionDeclaration_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['response']) is not None: + raise ValueError('response parameter is not supported in Google AI.') + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['parameters']) is not None: + setv(to_object, ['parameters'], getv(from_object, ['parameters'])) + + return to_object + + +def _FunctionDeclaration_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['response']) is not None: + setv( + to_object, + ['response'], + _Schema_to_vertex( + api_client, getv(from_object, ['response']), to_object + ), + ) + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['parameters']) is not None: + setv(to_object, ['parameters'], getv(from_object, ['parameters'])) + + return to_object + + +def _GoogleSearch_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _GoogleSearch_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _DynamicRetrievalConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['mode']) is not None: + setv(to_object, ['mode'], getv(from_object, ['mode'])) + + if getv(from_object, ['dynamic_threshold']) is not None: + setv( + to_object, + ['dynamicThreshold'], + getv(from_object, ['dynamic_threshold']), + ) + + return to_object + + +def _DynamicRetrievalConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['mode']) is not None: + setv(to_object, ['mode'], getv(from_object, ['mode'])) + + if getv(from_object, ['dynamic_threshold']) is not None: + setv( + to_object, + ['dynamicThreshold'], + getv(from_object, ['dynamic_threshold']), + ) + + return to_object + + +def _GoogleSearchRetrieval_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['dynamic_retrieval_config']) is not None: + setv( + to_object, + ['dynamicRetrievalConfig'], + _DynamicRetrievalConfig_to_mldev( + api_client, + getv(from_object, ['dynamic_retrieval_config']), + to_object, + ), + ) + + return to_object + + +def _GoogleSearchRetrieval_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['dynamic_retrieval_config']) is not None: + setv( + to_object, + ['dynamicRetrievalConfig'], + _DynamicRetrievalConfig_to_vertex( + api_client, + getv(from_object, ['dynamic_retrieval_config']), + to_object, + ), + ) + + return to_object + + +def _Tool_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['function_declarations']) is not None: + setv( + to_object, + ['functionDeclarations'], + [ + _FunctionDeclaration_to_mldev(api_client, item, to_object) + for item in getv(from_object, ['function_declarations']) + ], + ) + + if getv(from_object, ['retrieval']) is not None: + raise ValueError('retrieval parameter is not supported in Google AI.') + + if getv(from_object, ['google_search']) is not None: + setv( + to_object, + ['googleSearch'], + _GoogleSearch_to_mldev( + api_client, getv(from_object, ['google_search']), to_object + ), + ) + + if getv(from_object, ['google_search_retrieval']) is not None: + setv( + to_object, + ['googleSearchRetrieval'], + _GoogleSearchRetrieval_to_mldev( + api_client, + getv(from_object, ['google_search_retrieval']), + to_object, + ), + ) + + if getv(from_object, ['code_execution']) is not None: + setv(to_object, ['codeExecution'], getv(from_object, ['code_execution'])) + + return to_object + + +def _Tool_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['function_declarations']) is not None: + setv( + to_object, + ['functionDeclarations'], + [ + _FunctionDeclaration_to_vertex(api_client, item, to_object) + for item in getv(from_object, ['function_declarations']) + ], + ) + + if getv(from_object, ['retrieval']) is not None: + setv(to_object, ['retrieval'], getv(from_object, ['retrieval'])) + + if getv(from_object, ['google_search']) is not None: + setv( + to_object, + ['googleSearch'], + _GoogleSearch_to_vertex( + api_client, getv(from_object, ['google_search']), to_object + ), + ) + + if getv(from_object, ['google_search_retrieval']) is not None: + setv( + to_object, + ['googleSearchRetrieval'], + _GoogleSearchRetrieval_to_vertex( + api_client, + getv(from_object, ['google_search_retrieval']), + to_object, + ), + ) + + if getv(from_object, ['code_execution']) is not None: + setv(to_object, ['codeExecution'], getv(from_object, ['code_execution'])) + + return to_object + + +def _FunctionCallingConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['mode']) is not None: + setv(to_object, ['mode'], getv(from_object, ['mode'])) + + if getv(from_object, ['allowed_function_names']) is not None: + setv( + to_object, + ['allowedFunctionNames'], + getv(from_object, ['allowed_function_names']), + ) + + return to_object + + +def _FunctionCallingConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['mode']) is not None: + setv(to_object, ['mode'], getv(from_object, ['mode'])) + + if getv(from_object, ['allowed_function_names']) is not None: + setv( + to_object, + ['allowedFunctionNames'], + getv(from_object, ['allowed_function_names']), + ) + + return to_object + + +def _ToolConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['function_calling_config']) is not None: + setv( + to_object, + ['functionCallingConfig'], + _FunctionCallingConfig_to_mldev( + api_client, + getv(from_object, ['function_calling_config']), + to_object, + ), + ) + + return to_object + + +def _ToolConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['function_calling_config']) is not None: + setv( + to_object, + ['functionCallingConfig'], + _FunctionCallingConfig_to_vertex( + api_client, + getv(from_object, ['function_calling_config']), + to_object, + ), + ) + + return to_object + + +def _CreateCachedContentConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['ttl']) is not None: + setv(parent_object, ['ttl'], getv(from_object, ['ttl'])) + + if getv(from_object, ['expire_time']) is not None: + setv(parent_object, ['expireTime'], getv(from_object, ['expire_time'])) + + if getv(from_object, ['display_name']) is not None: + setv(parent_object, ['displayName'], getv(from_object, ['display_name'])) + + if getv(from_object, ['contents']) is not None: + setv( + parent_object, + ['contents'], + [ + _Content_to_mldev(api_client, item, to_object) + for item in t.t_contents( + api_client, getv(from_object, ['contents']) + ) + ], + ) + + if getv(from_object, ['system_instruction']) is not None: + setv( + parent_object, + ['systemInstruction'], + _Content_to_mldev( + api_client, + t.t_content(api_client, getv(from_object, ['system_instruction'])), + to_object, + ), + ) + + if getv(from_object, ['tools']) is not None: + setv( + parent_object, + ['tools'], + [ + _Tool_to_mldev(api_client, item, to_object) + for item in getv(from_object, ['tools']) + ], + ) + + if getv(from_object, ['tool_config']) is not None: + setv( + parent_object, + ['toolConfig'], + _ToolConfig_to_mldev( + api_client, getv(from_object, ['tool_config']), to_object + ), + ) + + return to_object + + +def _CreateCachedContentConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['ttl']) is not None: + setv(parent_object, ['ttl'], getv(from_object, ['ttl'])) + + if getv(from_object, ['expire_time']) is not None: + setv(parent_object, ['expireTime'], getv(from_object, ['expire_time'])) + + if getv(from_object, ['display_name']) is not None: + setv(parent_object, ['displayName'], getv(from_object, ['display_name'])) + + if getv(from_object, ['contents']) is not None: + setv( + parent_object, + ['contents'], + [ + _Content_to_vertex(api_client, item, to_object) + for item in t.t_contents( + api_client, getv(from_object, ['contents']) + ) + ], + ) + + if getv(from_object, ['system_instruction']) is not None: + setv( + parent_object, + ['systemInstruction'], + _Content_to_vertex( + api_client, + t.t_content(api_client, getv(from_object, ['system_instruction'])), + to_object, + ), + ) + + if getv(from_object, ['tools']) is not None: + setv( + parent_object, + ['tools'], + [ + _Tool_to_vertex(api_client, item, to_object) + for item in getv(from_object, ['tools']) + ], + ) + + if getv(from_object, ['tool_config']) is not None: + setv( + parent_object, + ['toolConfig'], + _ToolConfig_to_vertex( + api_client, getv(from_object, ['tool_config']), to_object + ), + ) + + return to_object + + +def _CreateCachedContentParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['model'], + t.t_caches_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CreateCachedContentConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _CreateCachedContentParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['model'], + t.t_caches_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CreateCachedContentConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _GetCachedContentConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _GetCachedContentConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _GetCachedContentParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_cached_content_name(api_client, getv(from_object, ['name'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _GetCachedContentConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _GetCachedContentParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_cached_content_name(api_client, getv(from_object, ['name'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _GetCachedContentConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _DeleteCachedContentConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _DeleteCachedContentConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _DeleteCachedContentParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_cached_content_name(api_client, getv(from_object, ['name'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _DeleteCachedContentConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _DeleteCachedContentParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_cached_content_name(api_client, getv(from_object, ['name'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _DeleteCachedContentConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _UpdateCachedContentConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['ttl']) is not None: + setv(parent_object, ['ttl'], getv(from_object, ['ttl'])) + + if getv(from_object, ['expire_time']) is not None: + setv(parent_object, ['expireTime'], getv(from_object, ['expire_time'])) + + return to_object + + +def _UpdateCachedContentConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['ttl']) is not None: + setv(parent_object, ['ttl'], getv(from_object, ['ttl'])) + + if getv(from_object, ['expire_time']) is not None: + setv(parent_object, ['expireTime'], getv(from_object, ['expire_time'])) + + return to_object + + +def _UpdateCachedContentParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_cached_content_name(api_client, getv(from_object, ['name'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _UpdateCachedContentConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _UpdateCachedContentParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_cached_content_name(api_client, getv(from_object, ['name'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _UpdateCachedContentConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _ListCachedContentsConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['page_size']) is not None: + setv( + parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size']) + ) + + if getv(from_object, ['page_token']) is not None: + setv( + parent_object, + ['_query', 'pageToken'], + getv(from_object, ['page_token']), + ) + + return to_object + + +def _ListCachedContentsConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['page_size']) is not None: + setv( + parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size']) + ) + + if getv(from_object, ['page_token']) is not None: + setv( + parent_object, + ['_query', 'pageToken'], + getv(from_object, ['page_token']), + ) + + return to_object + + +def _ListCachedContentsParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _ListCachedContentsConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _ListCachedContentsParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _ListCachedContentsConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _CachedContent_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['displayName']) is not None: + setv(to_object, ['display_name'], getv(from_object, ['displayName'])) + + if getv(from_object, ['model']) is not None: + setv(to_object, ['model'], getv(from_object, ['model'])) + + if getv(from_object, ['createTime']) is not None: + setv(to_object, ['create_time'], getv(from_object, ['createTime'])) + + if getv(from_object, ['updateTime']) is not None: + setv(to_object, ['update_time'], getv(from_object, ['updateTime'])) + + if getv(from_object, ['expireTime']) is not None: + setv(to_object, ['expire_time'], getv(from_object, ['expireTime'])) + + if getv(from_object, ['usageMetadata']) is not None: + setv(to_object, ['usage_metadata'], getv(from_object, ['usageMetadata'])) + + return to_object + + +def _CachedContent_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['displayName']) is not None: + setv(to_object, ['display_name'], getv(from_object, ['displayName'])) + + if getv(from_object, ['model']) is not None: + setv(to_object, ['model'], getv(from_object, ['model'])) + + if getv(from_object, ['createTime']) is not None: + setv(to_object, ['create_time'], getv(from_object, ['createTime'])) + + if getv(from_object, ['updateTime']) is not None: + setv(to_object, ['update_time'], getv(from_object, ['updateTime'])) + + if getv(from_object, ['expireTime']) is not None: + setv(to_object, ['expire_time'], getv(from_object, ['expireTime'])) + + if getv(from_object, ['usageMetadata']) is not None: + setv(to_object, ['usage_metadata'], getv(from_object, ['usageMetadata'])) + + return to_object + + +def _DeleteCachedContentResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _DeleteCachedContentResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _ListCachedContentsResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['nextPageToken']) is not None: + setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken'])) + + if getv(from_object, ['cachedContents']) is not None: + setv( + to_object, + ['cached_contents'], + [ + _CachedContent_from_mldev(api_client, item, to_object) + for item in getv(from_object, ['cachedContents']) + ], + ) + + return to_object + + +def _ListCachedContentsResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['nextPageToken']) is not None: + setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken'])) + + if getv(from_object, ['cachedContents']) is not None: + setv( + to_object, + ['cached_contents'], + [ + _CachedContent_from_vertex(api_client, item, to_object) + for item in getv(from_object, ['cachedContents']) + ], + ) + + return to_object + + +class Caches(_common.BaseModule): + + def create( + self, + *, + model: str, + config: Optional[types.CreateCachedContentConfigOrDict] = None, + ) -> types.CachedContent: + """Creates cached content, this call will initialize the cached + + content in the data storage, and users need to pay for the cache data + storage. + + Usage: + + .. code-block:: python + + contents = ... // Initialize the content to cache. + response = await client.aio.caches.create( + model= ... // The publisher model id + contents=contents, + config={ + 'display_name': 'test cache', + 'system_instruction': 'What is the sum of the two pdfs?', + 'ttl': '86400s', + }, + ) + """ + + parameter_model = types._CreateCachedContentParameters( + model=model, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _CreateCachedContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'cachedContents'.format_map(request_dict.get('_url')) + else: + request_dict = _CreateCachedContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'cachedContents'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _CachedContent_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _CachedContent_from_mldev(self._api_client, response_dict) + + return_value = types.CachedContent._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def get( + self, + *, + name: str, + config: Optional[types.GetCachedContentConfigOrDict] = None, + ) -> types.CachedContent: + """Gets cached content configurations. + + .. code-block:: python + + await client.aio.caches.get(name= ... ) // The server-generated resource + name. + """ + + parameter_model = types._GetCachedContentParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _GetCachedContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _GetCachedContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _CachedContent_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _CachedContent_from_mldev(self._api_client, response_dict) + + return_value = types.CachedContent._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def delete( + self, + *, + name: str, + config: Optional[types.DeleteCachedContentConfigOrDict] = None, + ) -> types.DeleteCachedContentResponse: + """Deletes cached content. + + Usage: + + .. code-block:: python + + await client.aio.caches.delete(name= ... ) // The server-generated + resource name. + """ + + parameter_model = types._DeleteCachedContentParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _DeleteCachedContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _DeleteCachedContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'delete', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _DeleteCachedContentResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _DeleteCachedContentResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.DeleteCachedContentResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def update( + self, + *, + name: str, + config: Optional[types.UpdateCachedContentConfigOrDict] = None, + ) -> types.CachedContent: + """Updates cached content configurations. + + .. code-block:: python + + response = await client.aio.caches.update( + name= ... // The server-generated resource name. + config={ + 'ttl': '7600s', + }, + ) + """ + + parameter_model = types._UpdateCachedContentParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _UpdateCachedContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _UpdateCachedContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'patch', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _CachedContent_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _CachedContent_from_mldev(self._api_client, response_dict) + + return_value = types.CachedContent._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, *, config: Optional[types.ListCachedContentsConfigOrDict] = None + ) -> types.ListCachedContentsResponse: + """Lists cached content configurations. + + .. code-block:: python + + cached_contents = await client.aio.caches.list(config={'page_size': 2}) + async for cached_content in cached_contents: + print(cached_content) + """ + + parameter_model = types._ListCachedContentsParameters( + config=config, + ) + + if self._api_client.vertexai: + request_dict = _ListCachedContentsParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'cachedContents'.format_map(request_dict.get('_url')) + else: + request_dict = _ListCachedContentsParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'cachedContents'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ListCachedContentsResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ListCachedContentsResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ListCachedContentsResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def list( + self, *, config: Optional[types.ListCachedContentsConfigOrDict] = None + ) -> Pager[types.CachedContent]: + return Pager( + 'cached_contents', + self._list, + self._list(config=config), + config, + ) + + +class AsyncCaches(_common.BaseModule): + + async def create( + self, + *, + model: str, + config: Optional[types.CreateCachedContentConfigOrDict] = None, + ) -> types.CachedContent: + """Creates cached content, this call will initialize the cached + + content in the data storage, and users need to pay for the cache data + storage. + + Usage: + + .. code-block:: python + + contents = ... // Initialize the content to cache. + response = await client.aio.caches.create( + model= ... // The publisher model id + contents=contents, + config={ + 'display_name': 'test cache', + 'system_instruction': 'What is the sum of the two pdfs?', + 'ttl': '86400s', + }, + ) + """ + + parameter_model = types._CreateCachedContentParameters( + model=model, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _CreateCachedContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'cachedContents'.format_map(request_dict.get('_url')) + else: + request_dict = _CreateCachedContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'cachedContents'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _CachedContent_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _CachedContent_from_mldev(self._api_client, response_dict) + + return_value = types.CachedContent._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def get( + self, + *, + name: str, + config: Optional[types.GetCachedContentConfigOrDict] = None, + ) -> types.CachedContent: + """Gets cached content configurations. + + .. code-block:: python + + await client.aio.caches.get(name= ... ) // The server-generated resource + name. + """ + + parameter_model = types._GetCachedContentParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _GetCachedContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _GetCachedContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _CachedContent_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _CachedContent_from_mldev(self._api_client, response_dict) + + return_value = types.CachedContent._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def delete( + self, + *, + name: str, + config: Optional[types.DeleteCachedContentConfigOrDict] = None, + ) -> types.DeleteCachedContentResponse: + """Deletes cached content. + + Usage: + + .. code-block:: python + + await client.aio.caches.delete(name= ... ) // The server-generated + resource name. + """ + + parameter_model = types._DeleteCachedContentParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _DeleteCachedContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _DeleteCachedContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'delete', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _DeleteCachedContentResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _DeleteCachedContentResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.DeleteCachedContentResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def update( + self, + *, + name: str, + config: Optional[types.UpdateCachedContentConfigOrDict] = None, + ) -> types.CachedContent: + """Updates cached content configurations. + + .. code-block:: python + + response = await client.aio.caches.update( + name= ... // The server-generated resource name. + config={ + 'ttl': '7600s', + }, + ) + """ + + parameter_model = types._UpdateCachedContentParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _UpdateCachedContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _UpdateCachedContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'patch', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _CachedContent_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _CachedContent_from_mldev(self._api_client, response_dict) + + return_value = types.CachedContent._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, *, config: Optional[types.ListCachedContentsConfigOrDict] = None + ) -> types.ListCachedContentsResponse: + """Lists cached content configurations. + + .. code-block:: python + + cached_contents = await client.aio.caches.list(config={'page_size': 2}) + async for cached_content in cached_contents: + print(cached_content) + """ + + parameter_model = types._ListCachedContentsParameters( + config=config, + ) + + if self._api_client.vertexai: + request_dict = _ListCachedContentsParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'cachedContents'.format_map(request_dict.get('_url')) + else: + request_dict = _ListCachedContentsParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'cachedContents'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ListCachedContentsResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ListCachedContentsResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ListCachedContentsResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def list( + self, *, config: Optional[types.ListCachedContentsConfigOrDict] = None + ) -> AsyncPager[types.CachedContent]: + return AsyncPager( + 'cached_contents', + self._list, + await self._list(config=config), + config, + ) diff --git a/.venv/lib/python3.12/site-packages/google/genai/chats.py b/.venv/lib/python3.12/site-packages/google/genai/chats.py new file mode 100644 index 00000000..707b7c8d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/chats.py @@ -0,0 +1,266 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional +from typing import Union + +from . import _transformers as t +from .models import AsyncModels, Models +from .types import Content, ContentDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict + + +def _validate_response(response: GenerateContentResponse) -> bool: + if not response.candidates: + return False + if not response.candidates[0].content: + return False + if not response.candidates[0].content.parts: + return False + for part in response.candidates[0].content.parts: + if part == Part(): + return False + if part.text is not None and part.text == "": + return False + return True + + +class _BaseChat: + """Base chat session.""" + + def __init__( + self, + *, + modules: Union[Models, AsyncModels], + model: str, + config: GenerateContentConfigOrDict = None, + history: list[Content], + ): + self._modules = modules + self._model = model + self._config = config + self._curated_history = history + + +class Chat(_BaseChat): + """Chat session.""" + + def send_message( + self, message: Union[list[PartUnionDict], PartUnionDict] + ) -> GenerateContentResponse: + """Sends the conversation history with the additional message and returns the model's response. + + Args: + message: The message to send to the model. + + Returns: + The model's response. + + Usage: + + .. code-block:: python + + chat = client.chats.create(model='gemini-1.5-flash') + response = chat.send_message('tell me a story') + """ + + input_content = t.t_content(self._modules._api_client, message) + response = self._modules.generate_content( + model=self._model, + contents=self._curated_history + [input_content], + config=self._config, + ) + if _validate_response(response): + if response.automatic_function_calling_history: + self._curated_history.extend( + response.automatic_function_calling_history + ) + else: + self._curated_history.append(input_content) + self._curated_history.append(response.candidates[0].content) + return response + + def send_message_stream( + self, message: Union[list[PartUnionDict], PartUnionDict] + ): + """Sends the conversation history with the additional message and yields the model's response in chunks. + + Args: + message: The message to send to the model. + + Yields: + The model's response in chunks. + + Usage: + + .. code-block:: python + + chat = client.chats.create(model='gemini-1.5-flash') + for chunk in chat.send_message_stream('tell me a story'): + print(chunk.text) + """ + + input_content = t.t_content(self._modules._api_client, message) + output_contents = [] + finish_reason = None + for chunk in self._modules.generate_content_stream( + model=self._model, + contents=self._curated_history + [input_content], + config=self._config, + ): + if _validate_response(chunk): + output_contents.append(chunk.candidates[0].content) + if chunk.candidates and chunk.candidates[0].finish_reason: + finish_reason = chunk.candidates[0].finish_reason + yield chunk + if output_contents and finish_reason: + self._curated_history.append(input_content) + self._curated_history.extend(output_contents) + + +class Chats: + """A util class to create chat sessions.""" + + def __init__(self, modules: Models): + self._modules = modules + + def create( + self, + *, + model: str, + config: GenerateContentConfigOrDict = None, + history: Optional[list[Content]] = None, + ) -> Chat: + """Creates a new chat session. + + Args: + model: The model to use for the chat. + config: The configuration to use for the generate content request. + history: The history to use for the chat. + + Returns: + A new chat session. + """ + return Chat( + modules=self._modules, + model=model, + config=config, + history=history if history else [], + ) + + +class AsyncChat(_BaseChat): + """Async chat session.""" + + async def send_message( + self, message: Union[list[PartUnionDict], PartUnionDict] + ) -> GenerateContentResponse: + """Sends the conversation history with the additional message and returns model's response. + + Args: + message: The message to send to the model. + + Returns: + The model's response. + + Usage: + + .. code-block:: python + + chat = client.aio.chats.create(model='gemini-1.5-flash') + response = await chat.send_message('tell me a story') + """ + + input_content = t.t_content(self._modules._api_client, message) + response = await self._modules.generate_content( + model=self._model, + contents=self._curated_history + [input_content], + config=self._config, + ) + if _validate_response(response): + if response.automatic_function_calling_history: + self._curated_history.extend( + response.automatic_function_calling_history + ) + else: + self._curated_history.append(input_content) + self._curated_history.append(response.candidates[0].content) + return response + + async def send_message_stream( + self, message: Union[list[PartUnionDict], PartUnionDict] + ): + """Sends the conversation history with the additional message and yields the model's response in chunks. + + Args: + message: The message to send to the model. + + Yields: + The model's response in chunks. + + Usage: + + .. code-block:: python + chat = client.aio.chats.create(model='gemini-1.5-flash') + async for chunk in chat.send_message_stream('tell me a story'): + print(chunk.text) + """ + + input_content = t.t_content(self._modules._api_client, message) + output_contents = [] + finish_reason = None + async for chunk in self._modules.generate_content_stream( + model=self._model, + contents=self._curated_history + [input_content], + config=self._config, + ): + if _validate_response(chunk): + output_contents.append(chunk.candidates[0].content) + if chunk.candidates and chunk.candidates[0].finish_reason: + finish_reason = chunk.candidates[0].finish_reason + yield chunk + if output_contents and finish_reason: + self._curated_history.append(input_content) + self._curated_history.extend(output_contents) + + +class AsyncChats: + """A util class to create async chat sessions.""" + + def __init__(self, modules: AsyncModels): + self._modules = modules + + def create( + self, + *, + model: str, + config: GenerateContentConfigOrDict = None, + history: Optional[list[Content]] = None, + ) -> AsyncChat: + """Creates a new chat session. + + Args: + model: The model to use for the chat. + config: The configuration to use for the generate content request. + history: The history to use for the chat. + + Returns: + A new chat session. + """ + return AsyncChat( + modules=self._modules, + model=model, + config=config, + history=history if history else [], + ) diff --git a/.venv/lib/python3.12/site-packages/google/genai/client.py b/.venv/lib/python3.12/site-packages/google/genai/client.py new file mode 100644 index 00000000..f29bfe72 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/client.py @@ -0,0 +1,281 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from typing import Optional, Union + +import google.auth +import pydantic + +from ._api_client import ApiClient, HttpOptions, HttpOptionsDict +from ._replay_api_client import ReplayApiClient +from .batches import AsyncBatches, Batches +from .caches import AsyncCaches, Caches +from .chats import AsyncChats, Chats +from .files import AsyncFiles, Files +from .live import AsyncLive +from .models import AsyncModels, Models +from .tunings import AsyncTunings, Tunings + + +class AsyncClient: + """Client for making asynchronous (non-blocking) requests.""" + + def __init__(self, api_client: ApiClient): + + self._api_client = api_client + self._models = AsyncModels(self._api_client) + self._tunings = AsyncTunings(self._api_client) + self._caches = AsyncCaches(self._api_client) + self._batches = AsyncBatches(self._api_client) + self._files = AsyncFiles(self._api_client) + self._live = AsyncLive(self._api_client) + + @property + def models(self) -> AsyncModels: + return self._models + + @property + def tunings(self) -> AsyncTunings: + return self._tunings + + @property + def caches(self) -> AsyncCaches: + return self._caches + + @property + def batches(self) -> AsyncBatches: + return self._batches + + @property + def chats(self) -> AsyncChats: + return AsyncChats(modules=self.models) + + @property + def files(self) -> AsyncFiles: + return self._files + + @property + def live(self) -> AsyncLive: + return self._live + + +class DebugConfig(pydantic.BaseModel): + """Configuration options that change client network behavior when testing.""" + + client_mode: Optional[str] = pydantic.Field( + default_factory=lambda: os.getenv('GOOGLE_GENAI_CLIENT_MODE', None) + ) + + replays_directory: Optional[str] = pydantic.Field( + default_factory=lambda: os.getenv('GOOGLE_GENAI_REPLAYS_DIRECTORY', None) + ) + + replay_id: Optional[str] = pydantic.Field( + default_factory=lambda: os.getenv('GOOGLE_GENAI_REPLAY_ID', None) + ) + + +class Client: + """Client for making synchronous requests. + + Use this client to make a request to the Gemini Developer API or Vertex AI + API and then wait for the response. + + Attributes: + api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to + use for authentication. Applies to the Gemini Developer API only. + vertexai: Indicates whether the client should use the Vertex AI + API endpoints. Defaults to False (uses Gemini Developer API endpoints). + Applies to the Vertex AI API only. + credentials: The credentials to use for authentication when calling the + Vertex AI APIs. Credentials can be obtained from environment variables and + default credentials. For more information, see + `Set up Application Default Credentials + <https://cloud.google.com/docs/authentication/provide-credentials-adc>`_. + Applies to the Vertex AI API only. + project: The `Google Cloud project ID <https://cloud.google.com/vertex-ai/docs/start/cloud-environment>`_ to + use for quota. Can be obtained from environment variables (for example, + ``GOOGLE_CLOUD_PROJECT``). Applies to the Vertex AI API only. + location: The `location <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations>`_ + to send API requests to (for example, ``us-central1``). Can be obtained + from environment variables. Applies to the Vertex AI API only. + debug_config: Config settings that control network behavior of the client. + This is typically used when running test code. + http_options: Http options to use for the client. Response_payload can't be + set when passing to the client constructor. + + Usage for the Gemini Developer API: + + .. code-block:: python + + from google import genai + + client = genai.Client(api_key='my-api-key') + + Usage for the Vertex AI API: + + .. code-block:: python + + from google import genai + + client = genai.Client( + vertexai=True, project='my-project-id', location='us-central1' + ) + """ + + def __init__( + self, + *, + vertexai: Optional[bool] = None, + api_key: Optional[str] = None, + credentials: Optional[google.auth.credentials.Credentials] = None, + project: Optional[str] = None, + location: Optional[str] = None, + debug_config: Optional[DebugConfig] = None, + http_options: Optional[Union[HttpOptions, HttpOptionsDict]] = None, + ): + """Initializes the client. + + Args: + vertexai (bool): + Indicates whether the client should use the Vertex AI + API endpoints. Defaults to False (uses Gemini Developer API + endpoints). Applies to the Vertex AI API only. + api_key (str): + The `API key + <https://ai.google.dev/gemini-api/docs/api-key>`_ to use for + authentication. Applies to the Gemini Developer API only. + credentials (google.auth.credentials.Credentials): + The credentials to + use for authentication when calling the Vertex AI APIs. Credentials + can be obtained from environment variables and default credentials. + For more information, see `Set up Application Default Credentials + <https://cloud.google.com/docs/authentication/provide-credentials-adc>`_. + Applies to the Vertex AI API only. + project (str): + The `Google Cloud project ID + <https://cloud.google.com/vertex-ai/docs/start/cloud-environment>`_ to + use for quota. Can be obtained from environment variables (for + example, ``GOOGLE_CLOUD_PROJECT``). Applies to the Vertex AI API only. + location (str): + The `location + <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations>`_ + to send API requests to (for example, ``us-central1``). Can be + obtained from environment variables. Applies to the Vertex AI API + only. + debug_config (DebugConfig): + Config settings that control network + behavior of the client. This is typically used when running test code. + http_options (Union[HttpOptions, HttpOptionsDict]): + Http options to use for the client. Response_payload can't be + set when passing to the client constructor. + """ + + self._debug_config = debug_config or DebugConfig() + + # Throw ValueError if response_payload is set in http_options due to + # unpredictable behavior when running multiple coroutines through + # client.aio. + if http_options and 'response_payload' in http_options: + raise ValueError( + 'Setting response_payload in http_options is not supported.' + ) + + self._api_client = self._get_api_client( + vertexai=vertexai, + api_key=api_key, + credentials=credentials, + project=project, + location=location, + debug_config=self._debug_config, + http_options=http_options, + ) + + self._aio = AsyncClient(self._api_client) + self._models = Models(self._api_client) + self._tunings = Tunings(self._api_client) + self._caches = Caches(self._api_client) + self._batches = Batches(self._api_client) + self._files = Files(self._api_client) + + @staticmethod + def _get_api_client( + vertexai: Optional[bool] = None, + api_key: Optional[str] = None, + credentials: Optional[google.auth.credentials.Credentials] = None, + project: Optional[str] = None, + location: Optional[str] = None, + debug_config: Optional[DebugConfig] = None, + http_options: Optional[HttpOptions] = None, + ): + if debug_config and debug_config.client_mode in [ + 'record', + 'replay', + 'auto', + ]: + return ReplayApiClient( + mode=debug_config.client_mode, + replay_id=debug_config.replay_id, + replays_directory=debug_config.replays_directory, + vertexai=vertexai, + api_key=api_key, + credentials=credentials, + project=project, + location=location, + http_options=http_options, + ) + + return ApiClient( + vertexai=vertexai, + api_key=api_key, + credentials=credentials, + project=project, + location=location, + http_options=http_options, + ) + + @property + def chats(self) -> Chats: + return Chats(modules=self.models) + + @property + def aio(self) -> AsyncClient: + return self._aio + + @property + def models(self) -> Models: + return self._models + + @property + def tunings(self) -> Tunings: + return self._tunings + + @property + def caches(self) -> Caches: + return self._caches + + @property + def batches(self) -> Batches: + return self._batches + + @property + def files(self) -> Files: + return self._files + + @property + def vertexai(self) -> bool: + """Returns whether the client is using the Vertex AI API.""" + return self._api_client.vertexai or False diff --git a/.venv/lib/python3.12/site-packages/google/genai/errors.py b/.venv/lib/python3.12/site-packages/google/genai/errors.py new file mode 100644 index 00000000..12e03dfc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/errors.py @@ -0,0 +1,130 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Error classes for the GenAI SDK.""" + +from typing import Any, Optional, TYPE_CHECKING, Union + +import requests + + +if TYPE_CHECKING: + from .replay_api_client import ReplayResponse + + +class APIError(Exception): + """General errors raised by the GenAI API.""" + code: int + response: requests.Response + + status: Optional[str] = None + message: Optional[str] = None + response: Optional[Any] = None + + def __init__( + self, code: int, response: Union[requests.Response, 'ReplayResponse'] + ): + self.response = response + + if isinstance(response, requests.Response): + try: + # do not do any extra muanipulation on the response. + # return the raw response json as is. + response_json = response.json() + except requests.exceptions.JSONDecodeError: + response_json = { + 'message': response.text, + 'status': response.reason, + } + else: + response_json = response.body_segments[0].get('error', {}) + + self.details = response_json + self.message = self._get_message(response_json) + self.status = self._get_status(response_json) + self.code = code if code else self._get_code(response_json) + + super().__init__(f'{self.code} {self.status}. {self.details}') + + def _get_status(self, response_json): + return response_json.get( + 'status', response_json.get('error', {}).get('status', None) + ) + + def _get_message(self, response_json): + return response_json.get( + 'message', response_json.get('error', {}).get('message', None) + ) + + def _get_code(self, response_json): + return response_json.get( + 'code', response_json.get('error', {}).get('code', None) + ) + + def _to_replay_record(self): + """Returns a dictionary representation of the error for replay recording. + + details is not included since it may expose internal information in the + replay file. + """ + return { + 'error': { + 'code': self.code, + 'message': self.message, + 'status': self.status, + } + } + + @classmethod + def raise_for_response( + cls, response: Union[requests.Response, 'ReplayResponse'] + ): + """Raises an error with detailed error message if the response has an error status.""" + if response.status_code == 200: + return + + status_code = response.status_code + if 400 <= status_code < 500: + raise ClientError(status_code, response) + elif 500 <= status_code < 600: + raise ServerError(status_code, response) + else: + raise cls(status_code, response) + + +class ClientError(APIError): + """Client error raised by the GenAI API.""" + pass + + +class ServerError(APIError): + """Server error raised by the GenAI API.""" + pass + + +class UnknownFunctionCallArgumentError(ValueError): + """Raised when the function call argument cannot be converted to the parameter annotation.""" + + pass + + +class UnsupportedFunctionError(ValueError): + """Raised when the function is not supported.""" + + +class FunctionInvocationError(ValueError): + """Raised when the function cannot be invoked with the given arguments.""" + + pass diff --git a/.venv/lib/python3.12/site-packages/google/genai/files.py b/.venv/lib/python3.12/site-packages/google/genai/files.py new file mode 100644 index 00000000..20cf30af --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/files.py @@ -0,0 +1,1417 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import io +import mimetypes +import os +import pathlib +from typing import Optional, Union +from urllib.parse import urlencode +from . import _common +from . import _transformers as t +from . import types +from ._api_client import ApiClient +from ._common import get_value_by_path as getv +from ._common import set_value_by_path as setv +from .pagers import AsyncPager, Pager + + +def _ListFilesConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['page_size']) is not None: + setv( + parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size']) + ) + + if getv(from_object, ['page_token']) is not None: + setv( + parent_object, + ['_query', 'pageToken'], + getv(from_object, ['page_token']), + ) + + return to_object + + +def _ListFilesConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['page_size']) is not None: + setv( + parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size']) + ) + + if getv(from_object, ['page_token']) is not None: + setv( + parent_object, + ['_query', 'pageToken'], + getv(from_object, ['page_token']), + ) + + return to_object + + +def _ListFilesParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _ListFilesConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _ListFilesParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['config']) is not None: + raise ValueError('config parameter is not supported in Vertex AI.') + + return to_object + + +def _FileStatus_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['details']) is not None: + setv(to_object, ['details'], getv(from_object, ['details'])) + + if getv(from_object, ['message']) is not None: + setv(to_object, ['message'], getv(from_object, ['message'])) + + if getv(from_object, ['code']) is not None: + setv(to_object, ['code'], getv(from_object, ['code'])) + + return to_object + + +def _FileStatus_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['details']) is not None: + raise ValueError('details parameter is not supported in Vertex AI.') + + if getv(from_object, ['message']) is not None: + raise ValueError('message parameter is not supported in Vertex AI.') + + if getv(from_object, ['code']) is not None: + raise ValueError('code parameter is not supported in Vertex AI.') + + return to_object + + +def _File_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['display_name']) is not None: + setv(to_object, ['displayName'], getv(from_object, ['display_name'])) + + if getv(from_object, ['mime_type']) is not None: + setv(to_object, ['mimeType'], getv(from_object, ['mime_type'])) + + if getv(from_object, ['size_bytes']) is not None: + setv(to_object, ['sizeBytes'], getv(from_object, ['size_bytes'])) + + if getv(from_object, ['create_time']) is not None: + setv(to_object, ['createTime'], getv(from_object, ['create_time'])) + + if getv(from_object, ['expiration_time']) is not None: + setv(to_object, ['expirationTime'], getv(from_object, ['expiration_time'])) + + if getv(from_object, ['update_time']) is not None: + setv(to_object, ['updateTime'], getv(from_object, ['update_time'])) + + if getv(from_object, ['sha256_hash']) is not None: + setv(to_object, ['sha256Hash'], getv(from_object, ['sha256_hash'])) + + if getv(from_object, ['uri']) is not None: + setv(to_object, ['uri'], getv(from_object, ['uri'])) + + if getv(from_object, ['download_uri']) is not None: + setv(to_object, ['downloadUri'], getv(from_object, ['download_uri'])) + + if getv(from_object, ['state']) is not None: + setv(to_object, ['state'], getv(from_object, ['state'])) + + if getv(from_object, ['source']) is not None: + setv(to_object, ['source'], getv(from_object, ['source'])) + + if getv(from_object, ['video_metadata']) is not None: + setv(to_object, ['videoMetadata'], getv(from_object, ['video_metadata'])) + + if getv(from_object, ['error']) is not None: + setv( + to_object, + ['error'], + _FileStatus_to_mldev( + api_client, getv(from_object, ['error']), to_object + ), + ) + + return to_object + + +def _File_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + raise ValueError('name parameter is not supported in Vertex AI.') + + if getv(from_object, ['display_name']) is not None: + raise ValueError('display_name parameter is not supported in Vertex AI.') + + if getv(from_object, ['mime_type']) is not None: + raise ValueError('mime_type parameter is not supported in Vertex AI.') + + if getv(from_object, ['size_bytes']) is not None: + raise ValueError('size_bytes parameter is not supported in Vertex AI.') + + if getv(from_object, ['create_time']) is not None: + raise ValueError('create_time parameter is not supported in Vertex AI.') + + if getv(from_object, ['expiration_time']) is not None: + raise ValueError('expiration_time parameter is not supported in Vertex AI.') + + if getv(from_object, ['update_time']) is not None: + raise ValueError('update_time parameter is not supported in Vertex AI.') + + if getv(from_object, ['sha256_hash']) is not None: + raise ValueError('sha256_hash parameter is not supported in Vertex AI.') + + if getv(from_object, ['uri']) is not None: + raise ValueError('uri parameter is not supported in Vertex AI.') + + if getv(from_object, ['download_uri']) is not None: + raise ValueError('download_uri parameter is not supported in Vertex AI.') + + if getv(from_object, ['state']) is not None: + raise ValueError('state parameter is not supported in Vertex AI.') + + if getv(from_object, ['source']) is not None: + raise ValueError('source parameter is not supported in Vertex AI.') + + if getv(from_object, ['video_metadata']) is not None: + raise ValueError('video_metadata parameter is not supported in Vertex AI.') + + if getv(from_object, ['error']) is not None: + raise ValueError('error parameter is not supported in Vertex AI.') + + return to_object + + +def _CreateFileConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _CreateFileConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _CreateFileParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['file']) is not None: + setv( + to_object, + ['file'], + _File_to_mldev(api_client, getv(from_object, ['file']), to_object), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CreateFileConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _CreateFileParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['file']) is not None: + raise ValueError('file parameter is not supported in Vertex AI.') + + if getv(from_object, ['config']) is not None: + raise ValueError('config parameter is not supported in Vertex AI.') + + return to_object + + +def _GetFileConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _GetFileConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _GetFileParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv( + to_object, + ['_url', 'file'], + t.t_file_name(api_client, getv(from_object, ['name'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _GetFileConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _GetFileParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + raise ValueError('name parameter is not supported in Vertex AI.') + + if getv(from_object, ['config']) is not None: + raise ValueError('config parameter is not supported in Vertex AI.') + + return to_object + + +def _DeleteFileConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _DeleteFileConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _DeleteFileParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv( + to_object, + ['_url', 'file'], + t.t_file_name(api_client, getv(from_object, ['name'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _DeleteFileConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _DeleteFileParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + raise ValueError('name parameter is not supported in Vertex AI.') + + if getv(from_object, ['config']) is not None: + raise ValueError('config parameter is not supported in Vertex AI.') + + return to_object + + +def _FileStatus_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['details']) is not None: + setv(to_object, ['details'], getv(from_object, ['details'])) + + if getv(from_object, ['message']) is not None: + setv(to_object, ['message'], getv(from_object, ['message'])) + + if getv(from_object, ['code']) is not None: + setv(to_object, ['code'], getv(from_object, ['code'])) + + return to_object + + +def _FileStatus_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _File_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['displayName']) is not None: + setv(to_object, ['display_name'], getv(from_object, ['displayName'])) + + if getv(from_object, ['mimeType']) is not None: + setv(to_object, ['mime_type'], getv(from_object, ['mimeType'])) + + if getv(from_object, ['sizeBytes']) is not None: + setv(to_object, ['size_bytes'], getv(from_object, ['sizeBytes'])) + + if getv(from_object, ['createTime']) is not None: + setv(to_object, ['create_time'], getv(from_object, ['createTime'])) + + if getv(from_object, ['expirationTime']) is not None: + setv(to_object, ['expiration_time'], getv(from_object, ['expirationTime'])) + + if getv(from_object, ['updateTime']) is not None: + setv(to_object, ['update_time'], getv(from_object, ['updateTime'])) + + if getv(from_object, ['sha256Hash']) is not None: + setv(to_object, ['sha256_hash'], getv(from_object, ['sha256Hash'])) + + if getv(from_object, ['uri']) is not None: + setv(to_object, ['uri'], getv(from_object, ['uri'])) + + if getv(from_object, ['downloadUri']) is not None: + setv(to_object, ['download_uri'], getv(from_object, ['downloadUri'])) + + if getv(from_object, ['state']) is not None: + setv(to_object, ['state'], getv(from_object, ['state'])) + + if getv(from_object, ['source']) is not None: + setv(to_object, ['source'], getv(from_object, ['source'])) + + if getv(from_object, ['videoMetadata']) is not None: + setv(to_object, ['video_metadata'], getv(from_object, ['videoMetadata'])) + + if getv(from_object, ['error']) is not None: + setv( + to_object, + ['error'], + _FileStatus_from_mldev( + api_client, getv(from_object, ['error']), to_object + ), + ) + + return to_object + + +def _File_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _ListFilesResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['nextPageToken']) is not None: + setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken'])) + + if getv(from_object, ['files']) is not None: + setv( + to_object, + ['files'], + [ + _File_from_mldev(api_client, item, to_object) + for item in getv(from_object, ['files']) + ], + ) + + return to_object + + +def _ListFilesResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _CreateFileResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _CreateFileResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _DeleteFileResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _DeleteFileResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +class Files(_common.BaseModule): + + def _list( + self, *, config: Optional[types.ListFilesConfigOrDict] = None + ) -> types.ListFilesResponse: + """Lists all files from the service. + + Args: + config (ListFilesConfig): Optional, configuration for the list method. + + Returns: + ListFilesResponse: The response for the list method. + + Usage: + + .. code-block:: python + + pager = client.files.list(config={'page_size': 10}) + for file in pager.page: + print(file.name) + """ + + parameter_model = types._ListFilesParameters( + config=config, + ) + + if self._api_client.vertexai: + raise ValueError('This method is only supported in the default client.') + else: + request_dict = _ListFilesParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'files'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ListFilesResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ListFilesResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ListFilesResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def _create( + self, + *, + file: types.FileOrDict, + config: Optional[types.CreateFileConfigOrDict] = None, + ) -> types.CreateFileResponse: + parameter_model = types._CreateFileParameters( + file=file, + config=config, + ) + + if self._api_client.vertexai: + raise ValueError('This method is only supported in the default client.') + else: + request_dict = _CreateFileParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'upload/v1beta/files'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _CreateFileResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _CreateFileResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.CreateFileResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def get( + self, *, name: str, config: Optional[types.GetFileConfigOrDict] = None + ) -> types.File: + """Retrieves the file information from the service. + + Args: + name (str): The name identifier for the file to retrieve. + config (GetFileConfig): Optional, configuration for the get method. + + Returns: + File: The file information. + + Usage: + + .. code-block:: python + + file = client.files.get(name='files/...') + print(file.uri) + """ + + parameter_model = types._GetFileParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + raise ValueError('This method is only supported in the default client.') + else: + request_dict = _GetFileParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'files/{file}'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _File_from_vertex(self._api_client, response_dict) + else: + response_dict = _File_from_mldev(self._api_client, response_dict) + + return_value = types.File._from_response(response_dict, parameter_model) + self._api_client._verify_response(return_value) + return return_value + + def delete( + self, *, name: str, config: Optional[types.DeleteFileConfigOrDict] = None + ) -> types.DeleteFileResponse: + """Deletes an existing file from the service. + + Args: + name (str): The name identifier for the file to delete. + config (DeleteFileConfig): Optional, configuration for the delete method. + + Returns: + DeleteFileResponse: The response for the delete method + + Usage: + + .. code-block:: python + + client.files.delete(name='files/...') + """ + + parameter_model = types._DeleteFileParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + raise ValueError('This method is only supported in the default client.') + else: + request_dict = _DeleteFileParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'files/{file}'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'delete', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _DeleteFileResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _DeleteFileResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.DeleteFileResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def upload( + self, + *, + path: Union[str, pathlib.Path, os.PathLike, io.IOBase], + config: Optional[types.UploadFileConfigOrDict] = None, + ) -> types.File: + """Calls the API to upload a file using a supported file service. + + Args: + path: The path to the file or an `IOBase` object to be uploaded. If it's + an IOBase object, it must be opened in blocking mode and binary mode. In + other words, do not use non-blocking mode or text mode. The given stream + must be seekable, that is, it must be able to call seek() on 'path'. + config: Optional parameters to set `diplay_name`, `mime_type`, and `name`. + """ + if self._api_client.vertexai: + raise ValueError( + 'Vertex AI does not support creating files. You can upload files to' + ' GCS files instead.' + ) + config_model = None + if config: + if isinstance(config, dict): + config_model = types.UploadFileConfig(**config) + else: + config_model = config + file = types.File( + mime_type=config_model.mime_type, + name=config_model.name, + display_name=config_model.display_name, + ) + else: # if not config + file = types.File() + if file.name is not None and not file.name.startswith('files/'): + file.name = f'files/{file.name}' + + if isinstance(path, io.IOBase): + if file.mime_type is None: + raise ValueError( + 'Unknown mime type: Could not determine the mimetype for your' + ' file\n please set the `mime_type` argument' + ) + if hasattr(path, 'mode'): + if 'b' not in path.mode: + raise ValueError('The file must be opened in binary mode.') + offset = path.tell() + path.seek(0, os.SEEK_END) + file.size_bytes = path.tell() - offset + path.seek(offset, os.SEEK_SET) + else: + fs_path = os.fspath(path) + if not fs_path or not os.path.isfile(fs_path): + raise FileNotFoundError(f'{path} is not a valid file path.') + file.size_bytes = os.path.getsize(fs_path) + if file.mime_type is None: + file.mime_type, _ = mimetypes.guess_type(fs_path) + if file.mime_type is None: + raise ValueError( + 'Unknown mime type: Could not determine the mimetype for your' + ' file\n please set the `mime_type` argument' + ) + response = {} + if config_model and config_model.http_options: + http_options = config_model.http_options + else: + http_options = { + 'api_version': '', # api-version is set in the path. + 'headers': { + 'Content-Type': 'application/json', + 'X-Goog-Upload-Protocol': 'resumable', + 'X-Goog-Upload-Command': 'start', + 'X-Goog-Upload-Header-Content-Length': f'{file.size_bytes}', + 'X-Goog-Upload-Header-Content-Type': f'{file.mime_type}', + }, + 'response_payload': response, + } + self._create(file=file, config={'http_options': http_options}) + + if ( + 'headers' not in response + or 'X-Goog-Upload-URL' not in response['headers'] + ): + raise KeyError( + 'Failed to create file. Upload URL did not returned from the create' + ' file request.' + ) + upload_url = response['headers']['X-Goog-Upload-URL'] + + if isinstance(path, io.IOBase): + return_file = self._api_client.upload_file( + path, upload_url, file.size_bytes + ) + else: + return_file = self._api_client.upload_file( + fs_path, upload_url, file.size_bytes + ) + + return types.File._from_response( + _File_from_mldev(self._api_client, return_file['file']), None + ) + + def list( + self, *, config: Optional[types.ListFilesConfigOrDict] = None + ) -> Pager[types.File]: + return Pager( + 'files', + self._list, + self._list(config=config), + config, + ) + + def download( + self, + *, + file: Union[str, types.File], + config: Optional[types.DownloadFileConfigOrDict] = None, + ) -> bytes: + """Downloads a file's data from storage. + + Files created by `upload` can't be downloaded. You can tell which files are + downloadable by checking the `source` or `download_uri` property. + + Args: + file (str): A file name, uri, or file object. Identifying which file to + download. + config (DownloadFileConfigOrDict): Optional, configuration for the get + method. + + Returns: + File: The file data as bytes. + + Usage: + + .. code-block:: python + + for file client.files.list(): + if file.download_uri is not None: + break + else: + raise ValueError('No files found with a `download_uri`.') + data = client.files.download(file=file) + # data = client.files.download(file=file.name) + # data = client.files.download(file=file.download_uri) + """ + if self._api_client.vertexai: + raise ValueError( + 'Vertex AI does not support the Files API. Use GCS files instead.' + ) + + config_model = None + if config: + if isinstance(config, dict): + config_model = types.DownloadFileConfig(**config) + else: + config_model = config + + if isinstance(file, types.File) and file.download_uri is None: + raise ValueError( + "Only generated files can be downloaded, uploaded files can't be " + 'downloaded. You can tell which files are downloadable by checking ' + 'the `source` or `download_uri` property.' + ) + name = t.t_file_name(self, file) + + path = f'files/{name}:download' + + query_params = {'alt': 'media'} + path = f'{path}?{urlencode(query_params)}' + http_options = None + if getv(config_model, ['http_options']) is not None: + http_options = getv(config_model, ['http_options']) + + data = self._api_client.download_file( + path, + http_options, + ) + + return data + + +class AsyncFiles(_common.BaseModule): + + async def _list( + self, *, config: Optional[types.ListFilesConfigOrDict] = None + ) -> types.ListFilesResponse: + """Lists all files from the service. + + Args: + config (ListFilesConfig): Optional, configuration for the list method. + + Returns: + ListFilesResponse: The response for the list method. + + Usage: + + .. code-block:: python + + pager = client.files.list(config={'page_size': 10}) + for file in pager.page: + print(file.name) + """ + + parameter_model = types._ListFilesParameters( + config=config, + ) + + if self._api_client.vertexai: + raise ValueError('This method is only supported in the default client.') + else: + request_dict = _ListFilesParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'files'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ListFilesResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ListFilesResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ListFilesResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def _create( + self, + *, + file: types.FileOrDict, + config: Optional[types.CreateFileConfigOrDict] = None, + ) -> types.CreateFileResponse: + parameter_model = types._CreateFileParameters( + file=file, + config=config, + ) + + if self._api_client.vertexai: + raise ValueError('This method is only supported in the default client.') + else: + request_dict = _CreateFileParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'upload/v1beta/files'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _CreateFileResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _CreateFileResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.CreateFileResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def get( + self, *, name: str, config: Optional[types.GetFileConfigOrDict] = None + ) -> types.File: + """Retrieves the file information from the service. + + Args: + name (str): The name identifier for the file to retrieve. + config (GetFileConfig): Optional, configuration for the get method. + + Returns: + File: The file information. + + Usage: + + .. code-block:: python + + file = client.files.get(name='files/...') + print(file.uri) + """ + + parameter_model = types._GetFileParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + raise ValueError('This method is only supported in the default client.') + else: + request_dict = _GetFileParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'files/{file}'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _File_from_vertex(self._api_client, response_dict) + else: + response_dict = _File_from_mldev(self._api_client, response_dict) + + return_value = types.File._from_response(response_dict, parameter_model) + self._api_client._verify_response(return_value) + return return_value + + async def delete( + self, *, name: str, config: Optional[types.DeleteFileConfigOrDict] = None + ) -> types.DeleteFileResponse: + """Deletes an existing file from the service. + + Args: + name (str): The name identifier for the file to delete. + config (DeleteFileConfig): Optional, configuration for the delete method. + + Returns: + DeleteFileResponse: The response for the delete method + + Usage: + + .. code-block:: python + + client.files.delete(name='files/...') + """ + + parameter_model = types._DeleteFileParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + raise ValueError('This method is only supported in the default client.') + else: + request_dict = _DeleteFileParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'files/{file}'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'delete', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _DeleteFileResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _DeleteFileResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.DeleteFileResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def upload( + self, + *, + path: Union[str, pathlib.Path, os.PathLike, io.IOBase], + config: Optional[types.UploadFileConfigOrDict] = None, + ) -> types.File: + """Calls the API to upload a file asynchronously using a supported file service. + + Args: + path: The path to the file or an `IOBase` object to be uploaded. If it's + an IOBase object, it must be opened in blocking mode and binary mode. In + other words, do not use non-blocking mode or text mode. The given stream + must be seekable, that is, it must be able to call seek() on 'path'. + config: Optional parameters to set `diplay_name`, `mime_type`, and `name`. + """ + if self._api_client.vertexai: + raise ValueError( + 'Vertex AI does not support creating files. You can upload files to' + ' GCS files instead.' + ) + config_model = None + if config: + if isinstance(config, dict): + config_model = types.UploadFileConfig(**config) + else: + config_model = config + file = types.File( + mime_type=config_model.mime_type, + name=config_model.name, + display_name=config_model.display_name, + ) + else: # if not config + file = types.File() + if file.name is not None and not file.name.startswith('files/'): + file.name = f'files/{file.name}' + + if isinstance(path, io.IOBase): + if file.mime_type is None: + raise ValueError( + 'Unknown mime type: Could not determine the mimetype for your' + ' file\n please set the `mime_type` argument' + ) + if hasattr(path, 'mode'): + if 'b' not in path.mode: + raise ValueError('The file must be opened in binary mode.') + offset = path.tell() + path.seek(0, os.SEEK_END) + file.size_bytes = path.tell() - offset + path.seek(offset, os.SEEK_SET) + else: + fs_path = os.fspath(path) + if not fs_path or not os.path.isfile(fs_path): + raise FileNotFoundError(f'{path} is not a valid file path.') + file.size_bytes = os.path.getsize(fs_path) + if file.mime_type is None: + file.mime_type, _ = mimetypes.guess_type(fs_path) + if file.mime_type is None: + raise ValueError( + 'Unknown mime type: Could not determine the mimetype for your' + ' file\n please set the `mime_type` argument' + ) + + response = {} + if config_model and config_model.http_options: + http_options = config_model.http_options + else: + http_options = { + 'api_version': '', # api-version is set in the path. + 'headers': { + 'Content-Type': 'application/json', + 'X-Goog-Upload-Protocol': 'resumable', + 'X-Goog-Upload-Command': 'start', + 'X-Goog-Upload-Header-Content-Length': f'{file.size_bytes}', + 'X-Goog-Upload-Header-Content-Type': f'{file.mime_type}', + }, + 'response_payload': response, + } + await self._create(file=file, config={'http_options': http_options}) + if ( + 'headers' not in response + or 'X-Goog-Upload-URL' not in response['headers'] + ): + raise KeyError( + 'Failed to create file. Upload URL did not returned from the create' + ' file request.' + ) + upload_url = response['headers']['X-Goog-Upload-URL'] + + if isinstance(path, io.IOBase): + return_file = await self._api_client.async_upload_file( + path, upload_url, file.size_bytes + ) + else: + return_file = await self._api_client.async_upload_file( + fs_path, upload_url, file.size_bytes + ) + + return types.File._from_response( + _File_from_mldev(self._api_client, return_file['file']), None + ) + + async def list( + self, *, config: Optional[types.ListFilesConfigOrDict] = None + ) -> AsyncPager[types.File]: + return AsyncPager( + 'files', + self._list, + await self._list(config=config), + config, + ) + + async def download( + self, + *, + file: Union[str, types.File], + config: Optional[types.DownloadFileConfigOrDict] = None, + ) -> bytes: + """Downloads a file's data from the file service. + + The Vertex-AI implementation of the API foes not include the file service. + + Files created by `upload` can't be downloaded. You can tell which files are + downloadable by checking the `download_uri` property. + + Args: + File (str): A file name, uri, or file object. Identifying which file to + download. + config (DownloadFileConfigOrDict): Optional, configuration for the get + method. + + Returns: + File: The file data as bytes. + + Usage: + + .. code-block:: python + + for file client.files.list(): + if file.download_uri is not None: + break + else: + raise ValueError('No files found with a `download_uri`.') + data = client.files.download(file=file) + # data = client.files.download(file=file.name) + # data = client.files.download(file=file.uri) + """ + if self._api_client.vertexai: + raise ValueError( + 'Vertex AI does not support the Files API. Use GCS files instead.' + ) + + config_model = None + if config: + if isinstance(config, dict): + config_model = types.DownloadFileConfig(**config) + else: + config_model = config + + name = t.t_file_name(self, file) + + path = f'files/{name}:download' + + http_options = None + if getv(config_model, ['http_options']) is not None: + http_options = getv(config_model, ['http_options']) + + query_params = {'alt': 'media'} + if query_params: + path = f'{path}?{urlencode(query_params)}' + + data = await self._api_client.async_download_file( + path, + http_options, + ) + + return data diff --git a/.venv/lib/python3.12/site-packages/google/genai/live.py b/.venv/lib/python3.12/site-packages/google/genai/live.py new file mode 100644 index 00000000..586ba0ce --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/live.py @@ -0,0 +1,696 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Live client.""" + +import asyncio +import base64 +import contextlib +import json +import logging +from typing import AsyncIterator, Optional, Sequence, Union + +import google.auth +from websockets import ConnectionClosed + +from . import _common +from . import _transformers as t +from . import client +from . import types +from ._api_client import ApiClient +from ._common import get_value_by_path as getv +from ._common import set_value_by_path as setv +from .models import _Content_from_mldev +from .models import _Content_from_vertex +from .models import _Content_to_mldev +from .models import _Content_to_vertex +from .models import _GenerateContentConfig_to_mldev +from .models import _GenerateContentConfig_to_vertex +from .models import _SafetySetting_to_mldev +from .models import _SafetySetting_to_vertex +from .models import _SpeechConfig_to_mldev +from .models import _SpeechConfig_to_vertex +from .models import _Tool_to_mldev +from .models import _Tool_to_vertex + +try: + from websockets.asyncio.client import ClientConnection + from websockets.asyncio.client import connect +except ModuleNotFoundError: + from websockets.client import ClientConnection + from websockets.client import connect + + +_FUNCTION_RESPONSE_REQUIRES_ID = ( + 'FunctionResponse request must have an `id` field from the' + ' response of a ToolCall.FunctionalCalls in Google AI.' +) + + +class AsyncSession: + """AsyncSession.""" + + def __init__(self, api_client: client.ApiClient, websocket: ClientConnection): + self._api_client = api_client + self._ws = websocket + + async def send( + self, + *, + input: Union[ + types.ContentListUnion, + types.ContentListUnionDict, + types.LiveClientContentOrDict, + types.LiveClientRealtimeInputOrDict, + types.LiveClientRealtimeInputOrDict, + types.LiveClientToolResponseOrDict, + types.FunctionResponseOrDict, + Sequence[types.FunctionResponseOrDict], + ], + end_of_turn: Optional[bool] = False, + ): + """Send input to the model. + + The method will send the input request to the server. + + Args: + input: The input request to the model. + end_of_turn: Whether the input is the last message in a turn. + + Example usage: + + .. code-block:: python + + client = genai.Client(api_key=API_KEY) + + async with client.aio.live.connect(model='...') as session: + await session.send(input='Hello world!', end_of_turn=True) + async for message in session.receive(): + print(message) + """ + client_message = self._parse_client_message(input, end_of_turn) + await self._ws.send(json.dumps(client_message)) + + async def receive(self) -> AsyncIterator[types.LiveServerMessage]: + """Receive model responses from the server. + + The method will yield the model responses from the server. The returned + responses will represent a complete model turn. When the returned message + is function call, user must call `send` with the function response to + continue the turn. + + Yields: + The model responses from the server. + + Example usage: + + .. code-block:: python + + client = genai.Client(api_key=API_KEY) + + async with client.aio.live.connect(model='...') as session: + await session.send(input='Hello world!', end_of_turn=True) + async for message in session.receive(): + print(message) + """ + # TODO(b/365983264) Handle intermittent issues for the user. + while result := await self._receive(): + if result.server_content and result.server_content.turn_complete: + yield result + break + yield result + + async def start_stream( + self, *, stream: AsyncIterator[bytes], mime_type: str + ) -> AsyncIterator[types.LiveServerMessage]: + """start a live session from a data stream. + + The interaction terminates when the input stream is complete. + This method will start two async tasks. One task will be used to send the + input stream to the model and the other task will be used to receive the + responses from the model. + + Args: + stream: An iterator that yields the model response. + mime_type: The MIME type of the data in the stream. + + Yields: + The audio bytes received from the model and server response messages. + + Example usage: + + .. code-block:: python + + client = genai.Client(api_key=API_KEY) + config = {'response_modalities': ['AUDIO']} + async def audio_stream(): + stream = read_audio() + for data in stream: + yield data + async with client.aio.live.connect(model='...') as session: + for audio in session.start_stream(stream = audio_stream(), + mime_type = 'audio/pcm'): + play_audio_chunk(audio.data) + """ + stop_event = asyncio.Event() + # Start the send loop. When stream is complete stop_event is set. + asyncio.create_task(self._send_loop(stream, mime_type, stop_event)) + recv_task = None + while not stop_event.is_set(): + try: + recv_task = asyncio.create_task(self._receive()) + await asyncio.wait( + [ + recv_task, + asyncio.create_task(stop_event.wait()), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + if recv_task.done(): + yield recv_task.result() + # Give a chance for the send loop to process requests. + await asyncio.sleep(10**-12) + except ConnectionClosed: + break + if recv_task is not None and not recv_task.done(): + recv_task.cancel() + # Wait for the task to finish (cancelled or not) + try: + await recv_task + except asyncio.CancelledError: + pass + + async def _receive(self) -> types.LiveServerMessage: + parameter_model = types.LiveServerMessage() + raw_response = await self._ws.recv(decode=False) + if raw_response: + try: + response = json.loads(raw_response) + except json.decoder.JSONDecodeError: + raise ValueError(f'Failed to parse response: {raw_response}') + else: + response = {} + if self._api_client.vertexai: + response_dict = self._LiveServerMessage_from_vertex(response) + else: + response_dict = self._LiveServerMessage_from_mldev(response) + + return types.LiveServerMessage._from_response( + response_dict, parameter_model + ) + + async def _send_loop( + self, + data_stream: AsyncIterator[bytes], + mime_type: str, + stop_event: asyncio.Event, + ): + async for data in data_stream: + input = {'data': data, 'mimeType': mime_type} + await self.send(input=input) + # Give a chance for the receive loop to process responses. + await asyncio.sleep(10**-12) + # Give a chance for the receiver to process the last response. + stop_event.set() + + def _LiveServerContent_from_mldev( + self, + from_object: Union[dict, object], + ) -> dict: + to_object = {} + if getv(from_object, ['modelTurn']) is not None: + setv( + to_object, + ['model_turn'], + _Content_from_mldev( + self._api_client, + getv(from_object, ['modelTurn']), + ), + ) + if getv(from_object, ['turnComplete']) is not None: + setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete'])) + if getv(from_object, ['interrupted']) is not None: + setv(to_object, ['interrupted'], getv(from_object, ['interrupted'])) + return to_object + + def _LiveToolCall_from_mldev( + self, + from_object: Union[dict, object], + ) -> dict: + to_object = {} + if getv(from_object, ['functionCalls']) is not None: + setv( + to_object, + ['function_calls'], + getv(from_object, ['functionCalls']), + ) + return to_object + + def _LiveToolCall_from_vertex( + self, + from_object: Union[dict, object], + ) -> dict: + to_object = {} + if getv(from_object, ['functionCalls']) is not None: + setv( + to_object, + ['function_calls'], + getv(from_object, ['functionCalls']), + ) + return to_object + + def _LiveServerMessage_from_mldev( + self, + from_object: Union[dict, object], + ) -> dict: + to_object = {} + if getv(from_object, ['serverContent']) is not None: + setv( + to_object, + ['server_content'], + self._LiveServerContent_from_mldev( + getv(from_object, ['serverContent']) + ), + ) + if getv(from_object, ['toolCall']) is not None: + setv( + to_object, + ['tool_call'], + self._LiveToolCall_from_mldev(getv(from_object, ['toolCall'])), + ) + if getv(from_object, ['toolCallCancellation']) is not None: + setv( + to_object, + ['tool_call_cancellation'], + getv(from_object, ['toolCallCancellation']), + ) + return to_object + + def _LiveServerContent_from_vertex( + self, + from_object: Union[dict, object], + ) -> dict: + to_object = {} + if getv(from_object, ['modelTurn']) is not None: + setv( + to_object, + ['model_turn'], + _Content_from_vertex( + self._api_client, + getv(from_object, ['modelTurn']), + ), + ) + if getv(from_object, ['turnComplete']) is not None: + setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete'])) + if getv(from_object, ['interrupted']) is not None: + setv(to_object, ['interrupted'], getv(from_object, ['interrupted'])) + return to_object + + def _LiveServerMessage_from_vertex( + self, + from_object: Union[dict, object], + ) -> dict: + to_object = {} + if getv(from_object, ['serverContent']) is not None: + setv( + to_object, + ['server_content'], + self._LiveServerContent_from_vertex( + getv(from_object, ['serverContent']) + ), + ) + + if getv(from_object, ['toolCall']) is not None: + setv( + to_object, + ['tool_call'], + self._LiveToolCall_from_vertex(getv(from_object, ['toolCall'])), + ) + if getv(from_object, ['toolCallCancellation']) is not None: + setv( + to_object, + ['tool_call_cancellation'], + getv(from_object, ['toolCallCancellation']), + ) + return to_object + + def _parse_client_message( + self, + input: Union[ + types.ContentListUnion, + types.ContentListUnionDict, + types.LiveClientContentOrDict, + types.LiveClientRealtimeInputOrDict, + types.LiveClientRealtimeInputOrDict, + types.LiveClientToolResponseOrDict, + types.FunctionResponseOrDict, + Sequence[types.FunctionResponseOrDict], + ], + end_of_turn: Optional[bool] = False, + ) -> dict: + if isinstance(input, str): + input = [input] + elif isinstance(input, dict) and 'data' in input: + if isinstance(input['data'], bytes): + decoded_data = base64.b64encode(input['data']).decode('utf-8') + input['data'] = decoded_data + input = [input] + elif isinstance(input, types.Blob): + input.data = base64.b64encode(input.data).decode('utf-8') + input = [input] + elif isinstance(input, dict) and 'name' in input and 'response' in input: + # ToolResponse.FunctionResponse + if not (self._api_client.vertexai) and 'id' not in input: + raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID) + input = [input] + + if isinstance(input, Sequence) and any( + isinstance(c, dict) and 'name' in c and 'response' in c for c in input + ): + # ToolResponse.FunctionResponse + if not (self._api_client.vertexai): + for item in input: + if 'id' not in item: + raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID) + client_message = {'tool_response': {'function_responses': input}} + elif isinstance(input, Sequence) and any(isinstance(c, str) for c in input): + to_object = {} + if self._api_client.vertexai: + contents = [ + _Content_to_vertex(self._api_client, item, to_object) + for item in t.t_contents(self._api_client, input) + ] + else: + contents = [ + _Content_to_mldev(self._api_client, item, to_object) + for item in t.t_contents(self._api_client, input) + ] + + client_message = { + 'client_content': {'turns': contents, 'turn_complete': end_of_turn} + } + elif isinstance(input, Sequence): + if any((isinstance(b, dict) and 'data' in b) for b in input): + pass + elif any(isinstance(b, types.Blob) for b in input): + input = [b.model_dump(exclude_none=True) for b in input] + else: + raise ValueError( + f'Unsupported input type "{type(input)}" or input content "{input}"' + ) + + client_message = {'realtime_input': {'media_chunks': input}} + + elif isinstance(input, dict) and 'content' in input: + # TODO(b/365983264) Add validation checks for content_update input_dict. + client_message = {'client_content': input} + elif isinstance(input, types.LiveClientRealtimeInput): + client_message = {'realtime_input': input.model_dump(exclude_none=True)} + if isinstance( + client_message['realtime_input']['media_chunks'][0]['data'], bytes + ): + client_message['realtime_input']['media_chunks'] = [ + { + 'data': base64.b64encode(item['data']).decode('utf-8'), + 'mime_type': item['mime_type'], + } + for item in client_message['realtime_input']['media_chunks'] + ] + + elif isinstance(input, types.LiveClientContent): + client_message = {'client_content': input.model_dump(exclude_none=True)} + elif isinstance(input, types.LiveClientToolResponse): + # ToolResponse.FunctionResponse + if not (self._api_client.vertexai) and not ( + input.function_responses[0].id + ): + raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID) + client_message = {'tool_response': input.model_dump(exclude_none=True)} + elif isinstance(input, types.FunctionResponse): + if not (self._api_client.vertexai) and not (input.id): + raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID) + client_message = { + 'tool_response': { + 'function_responses': [input.model_dump(exclude_none=True)] + } + } + elif isinstance(input, Sequence) and isinstance( + input[0], types.FunctionResponse + ): + if not (self._api_client.vertexai) and not (input[0].id): + raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID) + client_message = { + 'tool_response': { + 'function_responses': [ + c.model_dump(exclude_none=True) for c in input + ] + } + } + else: + raise ValueError( + f'Unsupported input type "{type(input)}" or input content "{input}"' + ) + + return client_message + + async def close(self): + # Close the websocket connection. + await self._ws.close() + + +class AsyncLive(_common.BaseModule): + """AsyncLive.""" + + def _LiveSetup_to_mldev( + self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None + ): + if isinstance(config, types.LiveConnectConfig): + from_object = config.model_dump(exclude_none=True) + else: + from_object = config + + to_object = {} + if getv(from_object, ['generation_config']) is not None: + setv( + to_object, + ['generationConfig'], + _GenerateContentConfig_to_mldev( + self._api_client, + getv(from_object, ['generation_config']), + to_object, + ), + ) + if getv(from_object, ['response_modalities']) is not None: + if getv(to_object, ['generationConfig']) is not None: + to_object['generationConfig']['responseModalities'] = from_object[ + 'response_modalities' + ] + else: + to_object['generationConfig'] = { + 'responseModalities': from_object['response_modalities'] + } + if getv(from_object, ['speech_config']) is not None: + if getv(to_object, ['generationConfig']) is not None: + to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev( + self._api_client, + t.t_speech_config( + self._api_client, getv(from_object, ['speech_config']) + ), + to_object, + ) + else: + to_object['generationConfig'] = { + 'speechConfig': _SpeechConfig_to_mldev( + self._api_client, + t.t_speech_config( + self._api_client, getv(from_object, ['speech_config']) + ), + to_object, + ) + } + + if getv(from_object, ['system_instruction']) is not None: + setv( + to_object, + ['systemInstruction'], + _Content_to_mldev( + self._api_client, + t.t_content( + self._api_client, getv(from_object, ['system_instruction']) + ), + to_object, + ), + ) + if getv(from_object, ['tools']) is not None: + setv( + to_object, + ['tools'], + [ + _Tool_to_mldev(self._api_client, item, to_object) + for item in getv(from_object, ['tools']) + ], + ) + + return_value = {'setup': {'model': model}} + return_value['setup'].update(to_object) + return return_value + + def _LiveSetup_to_vertex( + self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None + ): + if isinstance(config, types.LiveConnectConfig): + from_object = config.model_dump(exclude_none=True) + else: + from_object = config + + to_object = {} + + if getv(from_object, ['generation_config']) is not None: + setv( + to_object, + ['generationConfig'], + _GenerateContentConfig_to_vertex( + self._api_client, + getv(from_object, ['generation_config']), + to_object, + ), + ) + if getv(from_object, ['response_modalities']) is not None: + if getv(to_object, ['generationConfig']) is not None: + to_object['generationConfig']['responseModalities'] = from_object[ + 'response_modalities' + ] + else: + to_object['generationConfig'] = { + 'responseModalities': from_object['response_modalities'] + } + else: + # Set default to AUDIO to align with MLDev API. + if getv(to_object, ['generationConfig']) is not None: + to_object['generationConfig'].update({'responseModalities': ['AUDIO']}) + else: + to_object.update( + {'generationConfig': {'responseModalities': ['AUDIO']}} + ) + if getv(from_object, ['speech_config']) is not None: + if getv(to_object, ['generationConfig']) is not None: + to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex( + self._api_client, + t.t_speech_config( + self._api_client, getv(from_object, ['speech_config']) + ), + to_object, + ) + else: + to_object['generationConfig'] = { + 'speechConfig': _SpeechConfig_to_vertex( + self._api_client, + t.t_speech_config( + self._api_client, getv(from_object, ['speech_config']) + ), + to_object, + ) + } + if getv(from_object, ['system_instruction']) is not None: + setv( + to_object, + ['systemInstruction'], + _Content_to_vertex( + self._api_client, + t.t_content( + self._api_client, getv(from_object, ['system_instruction']) + ), + to_object, + ), + ) + if getv(from_object, ['tools']) is not None: + setv( + to_object, + ['tools'], + [ + _Tool_to_vertex(self._api_client, item, to_object) + for item in getv(from_object, ['tools']) + ], + ) + + return_value = {'setup': {'model': model}} + return_value['setup'].update(to_object) + return return_value + + @contextlib.asynccontextmanager + async def connect( + self, + *, + model: str, + config: Optional[types.LiveConnectConfigOrDict] = None, + ) -> AsyncSession: + """Connect to the live server. + + Usage: + + .. code-block:: python + + client = genai.Client(api_key=API_KEY) + config = {} + async with client.aio.live.connect(model='...', config=config) as session: + await session.send(input='Hello world!', end_of_turn=True) + async for message in session.receive(): + print(message) + """ + base_url = self._api_client._websocket_base_url() + if self._api_client.api_key: + api_key = self._api_client.api_key + version = self._api_client._http_options['api_version'] + uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}' + headers = self._api_client._http_options['headers'] + + transformed_model = t.t_model(self._api_client, model) + request = json.dumps( + self._LiveSetup_to_mldev(model=transformed_model, config=config) + ) + else: + # Get bearer token through Application Default Credentials. + creds, _ = google.auth.default( + scopes=['https://www.googleapis.com/auth/cloud-platform'] + ) + + # creds.valid is False, and creds.token is None + # Need to refresh credentials to populate those + auth_req = google.auth.transport.requests.Request() + creds.refresh(auth_req) + bearer_token = creds.token + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer {}'.format(bearer_token), + } + version = self._api_client._http_options['api_version'] + uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent' + location = self._api_client.location + project = self._api_client.project + transformed_model = t.t_model(self._api_client, model) + if transformed_model.startswith('publishers/'): + transformed_model = ( + f'projects/{project}/locations/{location}/' + transformed_model + ) + + request = json.dumps( + self._LiveSetup_to_vertex(model=transformed_model, config=config) + ) + + async with connect(uri, additional_headers=headers) as ws: + await ws.send(request) + logging.info(await ws.recv(decode=False)) + + yield AsyncSession(api_client=self._api_client, websocket=ws) diff --git a/.venv/lib/python3.12/site-packages/google/genai/models.py b/.venv/lib/python3.12/site-packages/google/genai/models.py new file mode 100644 index 00000000..b23428b6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/models.py @@ -0,0 +1,5567 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import logging +from typing import AsyncIterator, Iterator, Optional, Union +from urllib.parse import urlencode +from . import _common +from . import _extra_utils +from . import _transformers as t +from . import types +from ._api_client import ApiClient, HttpOptionsDict +from ._common import get_value_by_path as getv +from ._common import set_value_by_path as setv +from .pagers import AsyncPager, Pager + + +def _Part_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['video_metadata']) is not None: + raise ValueError('video_metadata parameter is not supported in Google AI.') + + if getv(from_object, ['thought']) is not None: + setv(to_object, ['thought'], getv(from_object, ['thought'])) + + if getv(from_object, ['code_execution_result']) is not None: + setv( + to_object, + ['codeExecutionResult'], + getv(from_object, ['code_execution_result']), + ) + + if getv(from_object, ['executable_code']) is not None: + setv(to_object, ['executableCode'], getv(from_object, ['executable_code'])) + + if getv(from_object, ['file_data']) is not None: + setv(to_object, ['fileData'], getv(from_object, ['file_data'])) + + if getv(from_object, ['function_call']) is not None: + setv(to_object, ['functionCall'], getv(from_object, ['function_call'])) + + if getv(from_object, ['function_response']) is not None: + setv( + to_object, + ['functionResponse'], + getv(from_object, ['function_response']), + ) + + if getv(from_object, ['inline_data']) is not None: + setv(to_object, ['inlineData'], getv(from_object, ['inline_data'])) + + if getv(from_object, ['text']) is not None: + setv(to_object, ['text'], getv(from_object, ['text'])) + + return to_object + + +def _Part_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['video_metadata']) is not None: + setv(to_object, ['videoMetadata'], getv(from_object, ['video_metadata'])) + + if getv(from_object, ['thought']) is not None: + setv(to_object, ['thought'], getv(from_object, ['thought'])) + + if getv(from_object, ['code_execution_result']) is not None: + setv( + to_object, + ['codeExecutionResult'], + getv(from_object, ['code_execution_result']), + ) + + if getv(from_object, ['executable_code']) is not None: + setv(to_object, ['executableCode'], getv(from_object, ['executable_code'])) + + if getv(from_object, ['file_data']) is not None: + setv(to_object, ['fileData'], getv(from_object, ['file_data'])) + + if getv(from_object, ['function_call']) is not None: + setv(to_object, ['functionCall'], getv(from_object, ['function_call'])) + + if getv(from_object, ['function_response']) is not None: + setv( + to_object, + ['functionResponse'], + getv(from_object, ['function_response']), + ) + + if getv(from_object, ['inline_data']) is not None: + setv(to_object, ['inlineData'], getv(from_object, ['inline_data'])) + + if getv(from_object, ['text']) is not None: + setv(to_object, ['text'], getv(from_object, ['text'])) + + return to_object + + +def _Content_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['parts']) is not None: + setv( + to_object, + ['parts'], + [ + _Part_to_mldev(api_client, item, to_object) + for item in getv(from_object, ['parts']) + ], + ) + + if getv(from_object, ['role']) is not None: + setv(to_object, ['role'], getv(from_object, ['role'])) + + return to_object + + +def _Content_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['parts']) is not None: + setv( + to_object, + ['parts'], + [ + _Part_to_vertex(api_client, item, to_object) + for item in getv(from_object, ['parts']) + ], + ) + + if getv(from_object, ['role']) is not None: + setv(to_object, ['role'], getv(from_object, ['role'])) + + return to_object + + +def _Schema_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['min_items']) is not None: + raise ValueError('min_items parameter is not supported in Google AI.') + + if getv(from_object, ['example']) is not None: + raise ValueError('example parameter is not supported in Google AI.') + + if getv(from_object, ['property_ordering']) is not None: + raise ValueError( + 'property_ordering parameter is not supported in Google AI.' + ) + + if getv(from_object, ['pattern']) is not None: + raise ValueError('pattern parameter is not supported in Google AI.') + + if getv(from_object, ['minimum']) is not None: + raise ValueError('minimum parameter is not supported in Google AI.') + + if getv(from_object, ['default']) is not None: + raise ValueError('default parameter is not supported in Google AI.') + + if getv(from_object, ['any_of']) is not None: + raise ValueError('any_of parameter is not supported in Google AI.') + + if getv(from_object, ['max_length']) is not None: + raise ValueError('max_length parameter is not supported in Google AI.') + + if getv(from_object, ['title']) is not None: + raise ValueError('title parameter is not supported in Google AI.') + + if getv(from_object, ['min_length']) is not None: + raise ValueError('min_length parameter is not supported in Google AI.') + + if getv(from_object, ['min_properties']) is not None: + raise ValueError('min_properties parameter is not supported in Google AI.') + + if getv(from_object, ['max_items']) is not None: + raise ValueError('max_items parameter is not supported in Google AI.') + + if getv(from_object, ['maximum']) is not None: + raise ValueError('maximum parameter is not supported in Google AI.') + + if getv(from_object, ['nullable']) is not None: + raise ValueError('nullable parameter is not supported in Google AI.') + + if getv(from_object, ['max_properties']) is not None: + raise ValueError('max_properties parameter is not supported in Google AI.') + + if getv(from_object, ['type']) is not None: + setv(to_object, ['type'], getv(from_object, ['type'])) + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['enum']) is not None: + setv(to_object, ['enum'], getv(from_object, ['enum'])) + + if getv(from_object, ['format']) is not None: + setv(to_object, ['format'], getv(from_object, ['format'])) + + if getv(from_object, ['items']) is not None: + setv(to_object, ['items'], getv(from_object, ['items'])) + + if getv(from_object, ['properties']) is not None: + setv(to_object, ['properties'], getv(from_object, ['properties'])) + + if getv(from_object, ['required']) is not None: + setv(to_object, ['required'], getv(from_object, ['required'])) + + return to_object + + +def _Schema_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['min_items']) is not None: + setv(to_object, ['minItems'], getv(from_object, ['min_items'])) + + if getv(from_object, ['example']) is not None: + setv(to_object, ['example'], getv(from_object, ['example'])) + + if getv(from_object, ['property_ordering']) is not None: + setv( + to_object, + ['propertyOrdering'], + getv(from_object, ['property_ordering']), + ) + + if getv(from_object, ['pattern']) is not None: + setv(to_object, ['pattern'], getv(from_object, ['pattern'])) + + if getv(from_object, ['minimum']) is not None: + setv(to_object, ['minimum'], getv(from_object, ['minimum'])) + + if getv(from_object, ['default']) is not None: + setv(to_object, ['default'], getv(from_object, ['default'])) + + if getv(from_object, ['any_of']) is not None: + setv(to_object, ['anyOf'], getv(from_object, ['any_of'])) + + if getv(from_object, ['max_length']) is not None: + setv(to_object, ['maxLength'], getv(from_object, ['max_length'])) + + if getv(from_object, ['title']) is not None: + setv(to_object, ['title'], getv(from_object, ['title'])) + + if getv(from_object, ['min_length']) is not None: + setv(to_object, ['minLength'], getv(from_object, ['min_length'])) + + if getv(from_object, ['min_properties']) is not None: + setv(to_object, ['minProperties'], getv(from_object, ['min_properties'])) + + if getv(from_object, ['max_items']) is not None: + setv(to_object, ['maxItems'], getv(from_object, ['max_items'])) + + if getv(from_object, ['maximum']) is not None: + setv(to_object, ['maximum'], getv(from_object, ['maximum'])) + + if getv(from_object, ['nullable']) is not None: + setv(to_object, ['nullable'], getv(from_object, ['nullable'])) + + if getv(from_object, ['max_properties']) is not None: + setv(to_object, ['maxProperties'], getv(from_object, ['max_properties'])) + + if getv(from_object, ['type']) is not None: + setv(to_object, ['type'], getv(from_object, ['type'])) + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['enum']) is not None: + setv(to_object, ['enum'], getv(from_object, ['enum'])) + + if getv(from_object, ['format']) is not None: + setv(to_object, ['format'], getv(from_object, ['format'])) + + if getv(from_object, ['items']) is not None: + setv(to_object, ['items'], getv(from_object, ['items'])) + + if getv(from_object, ['properties']) is not None: + setv(to_object, ['properties'], getv(from_object, ['properties'])) + + if getv(from_object, ['required']) is not None: + setv(to_object, ['required'], getv(from_object, ['required'])) + + return to_object + + +def _SafetySetting_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['method']) is not None: + raise ValueError('method parameter is not supported in Google AI.') + + if getv(from_object, ['category']) is not None: + setv(to_object, ['category'], getv(from_object, ['category'])) + + if getv(from_object, ['threshold']) is not None: + setv(to_object, ['threshold'], getv(from_object, ['threshold'])) + + return to_object + + +def _SafetySetting_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['method']) is not None: + setv(to_object, ['method'], getv(from_object, ['method'])) + + if getv(from_object, ['category']) is not None: + setv(to_object, ['category'], getv(from_object, ['category'])) + + if getv(from_object, ['threshold']) is not None: + setv(to_object, ['threshold'], getv(from_object, ['threshold'])) + + return to_object + + +def _FunctionDeclaration_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['response']) is not None: + raise ValueError('response parameter is not supported in Google AI.') + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['parameters']) is not None: + setv(to_object, ['parameters'], getv(from_object, ['parameters'])) + + return to_object + + +def _FunctionDeclaration_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['response']) is not None: + setv( + to_object, + ['response'], + _Schema_to_vertex( + api_client, getv(from_object, ['response']), to_object + ), + ) + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['parameters']) is not None: + setv(to_object, ['parameters'], getv(from_object, ['parameters'])) + + return to_object + + +def _GoogleSearch_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _GoogleSearch_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _DynamicRetrievalConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['mode']) is not None: + setv(to_object, ['mode'], getv(from_object, ['mode'])) + + if getv(from_object, ['dynamic_threshold']) is not None: + setv( + to_object, + ['dynamicThreshold'], + getv(from_object, ['dynamic_threshold']), + ) + + return to_object + + +def _DynamicRetrievalConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['mode']) is not None: + setv(to_object, ['mode'], getv(from_object, ['mode'])) + + if getv(from_object, ['dynamic_threshold']) is not None: + setv( + to_object, + ['dynamicThreshold'], + getv(from_object, ['dynamic_threshold']), + ) + + return to_object + + +def _GoogleSearchRetrieval_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['dynamic_retrieval_config']) is not None: + setv( + to_object, + ['dynamicRetrievalConfig'], + _DynamicRetrievalConfig_to_mldev( + api_client, + getv(from_object, ['dynamic_retrieval_config']), + to_object, + ), + ) + + return to_object + + +def _GoogleSearchRetrieval_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['dynamic_retrieval_config']) is not None: + setv( + to_object, + ['dynamicRetrievalConfig'], + _DynamicRetrievalConfig_to_vertex( + api_client, + getv(from_object, ['dynamic_retrieval_config']), + to_object, + ), + ) + + return to_object + + +def _Tool_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['function_declarations']) is not None: + setv( + to_object, + ['functionDeclarations'], + [ + _FunctionDeclaration_to_mldev(api_client, item, to_object) + for item in getv(from_object, ['function_declarations']) + ], + ) + + if getv(from_object, ['retrieval']) is not None: + raise ValueError('retrieval parameter is not supported in Google AI.') + + if getv(from_object, ['google_search']) is not None: + setv( + to_object, + ['googleSearch'], + _GoogleSearch_to_mldev( + api_client, getv(from_object, ['google_search']), to_object + ), + ) + + if getv(from_object, ['google_search_retrieval']) is not None: + setv( + to_object, + ['googleSearchRetrieval'], + _GoogleSearchRetrieval_to_mldev( + api_client, + getv(from_object, ['google_search_retrieval']), + to_object, + ), + ) + + if getv(from_object, ['code_execution']) is not None: + setv(to_object, ['codeExecution'], getv(from_object, ['code_execution'])) + + return to_object + + +def _Tool_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['function_declarations']) is not None: + setv( + to_object, + ['functionDeclarations'], + [ + _FunctionDeclaration_to_vertex(api_client, item, to_object) + for item in getv(from_object, ['function_declarations']) + ], + ) + + if getv(from_object, ['retrieval']) is not None: + setv(to_object, ['retrieval'], getv(from_object, ['retrieval'])) + + if getv(from_object, ['google_search']) is not None: + setv( + to_object, + ['googleSearch'], + _GoogleSearch_to_vertex( + api_client, getv(from_object, ['google_search']), to_object + ), + ) + + if getv(from_object, ['google_search_retrieval']) is not None: + setv( + to_object, + ['googleSearchRetrieval'], + _GoogleSearchRetrieval_to_vertex( + api_client, + getv(from_object, ['google_search_retrieval']), + to_object, + ), + ) + + if getv(from_object, ['code_execution']) is not None: + setv(to_object, ['codeExecution'], getv(from_object, ['code_execution'])) + + return to_object + + +def _FunctionCallingConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['mode']) is not None: + setv(to_object, ['mode'], getv(from_object, ['mode'])) + + if getv(from_object, ['allowed_function_names']) is not None: + setv( + to_object, + ['allowedFunctionNames'], + getv(from_object, ['allowed_function_names']), + ) + + return to_object + + +def _FunctionCallingConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['mode']) is not None: + setv(to_object, ['mode'], getv(from_object, ['mode'])) + + if getv(from_object, ['allowed_function_names']) is not None: + setv( + to_object, + ['allowedFunctionNames'], + getv(from_object, ['allowed_function_names']), + ) + + return to_object + + +def _ToolConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['function_calling_config']) is not None: + setv( + to_object, + ['functionCallingConfig'], + _FunctionCallingConfig_to_mldev( + api_client, + getv(from_object, ['function_calling_config']), + to_object, + ), + ) + + return to_object + + +def _ToolConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['function_calling_config']) is not None: + setv( + to_object, + ['functionCallingConfig'], + _FunctionCallingConfig_to_vertex( + api_client, + getv(from_object, ['function_calling_config']), + to_object, + ), + ) + + return to_object + + +def _PrebuiltVoiceConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['voice_name']) is not None: + setv(to_object, ['voiceName'], getv(from_object, ['voice_name'])) + + return to_object + + +def _PrebuiltVoiceConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['voice_name']) is not None: + setv(to_object, ['voiceName'], getv(from_object, ['voice_name'])) + + return to_object + + +def _VoiceConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['prebuilt_voice_config']) is not None: + setv( + to_object, + ['prebuiltVoiceConfig'], + _PrebuiltVoiceConfig_to_mldev( + api_client, getv(from_object, ['prebuilt_voice_config']), to_object + ), + ) + + return to_object + + +def _VoiceConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['prebuilt_voice_config']) is not None: + setv( + to_object, + ['prebuiltVoiceConfig'], + _PrebuiltVoiceConfig_to_vertex( + api_client, getv(from_object, ['prebuilt_voice_config']), to_object + ), + ) + + return to_object + + +def _SpeechConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['voice_config']) is not None: + setv( + to_object, + ['voiceConfig'], + _VoiceConfig_to_mldev( + api_client, getv(from_object, ['voice_config']), to_object + ), + ) + + return to_object + + +def _SpeechConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['voice_config']) is not None: + setv( + to_object, + ['voiceConfig'], + _VoiceConfig_to_vertex( + api_client, getv(from_object, ['voice_config']), to_object + ), + ) + + return to_object + + +def _ThinkingConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['include_thoughts']) is not None: + setv( + to_object, ['includeThoughts'], getv(from_object, ['include_thoughts']) + ) + + return to_object + + +def _ThinkingConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['include_thoughts']) is not None: + setv( + to_object, ['includeThoughts'], getv(from_object, ['include_thoughts']) + ) + + return to_object + + +def _GenerateContentConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['system_instruction']) is not None: + setv( + parent_object, + ['systemInstruction'], + _Content_to_mldev( + api_client, + t.t_content(api_client, getv(from_object, ['system_instruction'])), + to_object, + ), + ) + + if getv(from_object, ['temperature']) is not None: + setv(to_object, ['temperature'], getv(from_object, ['temperature'])) + + if getv(from_object, ['top_p']) is not None: + setv(to_object, ['topP'], getv(from_object, ['top_p'])) + + if getv(from_object, ['top_k']) is not None: + setv(to_object, ['topK'], getv(from_object, ['top_k'])) + + if getv(from_object, ['candidate_count']) is not None: + setv(to_object, ['candidateCount'], getv(from_object, ['candidate_count'])) + + if getv(from_object, ['max_output_tokens']) is not None: + setv( + to_object, ['maxOutputTokens'], getv(from_object, ['max_output_tokens']) + ) + + if getv(from_object, ['stop_sequences']) is not None: + setv(to_object, ['stopSequences'], getv(from_object, ['stop_sequences'])) + + if getv(from_object, ['response_logprobs']) is not None: + setv( + to_object, + ['responseLogprobs'], + getv(from_object, ['response_logprobs']), + ) + + if getv(from_object, ['logprobs']) is not None: + setv(to_object, ['logprobs'], getv(from_object, ['logprobs'])) + + if getv(from_object, ['presence_penalty']) is not None: + setv( + to_object, ['presencePenalty'], getv(from_object, ['presence_penalty']) + ) + + if getv(from_object, ['frequency_penalty']) is not None: + setv( + to_object, + ['frequencyPenalty'], + getv(from_object, ['frequency_penalty']), + ) + + if getv(from_object, ['seed']) is not None: + setv(to_object, ['seed'], getv(from_object, ['seed'])) + + if getv(from_object, ['response_mime_type']) is not None: + setv( + to_object, + ['responseMimeType'], + getv(from_object, ['response_mime_type']), + ) + + if getv(from_object, ['response_schema']) is not None: + setv( + to_object, + ['responseSchema'], + _Schema_to_mldev( + api_client, + t.t_schema(api_client, getv(from_object, ['response_schema'])), + to_object, + ), + ) + + if getv(from_object, ['routing_config']) is not None: + raise ValueError('routing_config parameter is not supported in Google AI.') + + if getv(from_object, ['safety_settings']) is not None: + setv( + parent_object, + ['safetySettings'], + [ + _SafetySetting_to_mldev(api_client, item, to_object) + for item in getv(from_object, ['safety_settings']) + ], + ) + + if getv(from_object, ['tools']) is not None: + setv( + parent_object, + ['tools'], + [ + _Tool_to_mldev(api_client, t.t_tool(api_client, item), to_object) + for item in t.t_tools(api_client, getv(from_object, ['tools'])) + ], + ) + + if getv(from_object, ['tool_config']) is not None: + setv( + parent_object, + ['toolConfig'], + _ToolConfig_to_mldev( + api_client, getv(from_object, ['tool_config']), to_object + ), + ) + + if getv(from_object, ['cached_content']) is not None: + setv( + parent_object, + ['cachedContent'], + t.t_cached_content_name( + api_client, getv(from_object, ['cached_content']) + ), + ) + + if getv(from_object, ['response_modalities']) is not None: + setv( + to_object, + ['responseModalities'], + getv(from_object, ['response_modalities']), + ) + + if getv(from_object, ['media_resolution']) is not None: + raise ValueError( + 'media_resolution parameter is not supported in Google AI.' + ) + + if getv(from_object, ['speech_config']) is not None: + setv( + to_object, + ['speechConfig'], + _SpeechConfig_to_mldev( + api_client, + t.t_speech_config(api_client, getv(from_object, ['speech_config'])), + to_object, + ), + ) + + if getv(from_object, ['audio_timestamp']) is not None: + raise ValueError('audio_timestamp parameter is not supported in Google AI.') + + if getv(from_object, ['thinking_config']) is not None: + setv( + to_object, + ['thinkingConfig'], + _ThinkingConfig_to_mldev( + api_client, getv(from_object, ['thinking_config']), to_object + ), + ) + + return to_object + + +def _GenerateContentConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['system_instruction']) is not None: + setv( + parent_object, + ['systemInstruction'], + _Content_to_vertex( + api_client, + t.t_content(api_client, getv(from_object, ['system_instruction'])), + to_object, + ), + ) + + if getv(from_object, ['temperature']) is not None: + setv(to_object, ['temperature'], getv(from_object, ['temperature'])) + + if getv(from_object, ['top_p']) is not None: + setv(to_object, ['topP'], getv(from_object, ['top_p'])) + + if getv(from_object, ['top_k']) is not None: + setv(to_object, ['topK'], getv(from_object, ['top_k'])) + + if getv(from_object, ['candidate_count']) is not None: + setv(to_object, ['candidateCount'], getv(from_object, ['candidate_count'])) + + if getv(from_object, ['max_output_tokens']) is not None: + setv( + to_object, ['maxOutputTokens'], getv(from_object, ['max_output_tokens']) + ) + + if getv(from_object, ['stop_sequences']) is not None: + setv(to_object, ['stopSequences'], getv(from_object, ['stop_sequences'])) + + if getv(from_object, ['response_logprobs']) is not None: + setv( + to_object, + ['responseLogprobs'], + getv(from_object, ['response_logprobs']), + ) + + if getv(from_object, ['logprobs']) is not None: + setv(to_object, ['logprobs'], getv(from_object, ['logprobs'])) + + if getv(from_object, ['presence_penalty']) is not None: + setv( + to_object, ['presencePenalty'], getv(from_object, ['presence_penalty']) + ) + + if getv(from_object, ['frequency_penalty']) is not None: + setv( + to_object, + ['frequencyPenalty'], + getv(from_object, ['frequency_penalty']), + ) + + if getv(from_object, ['seed']) is not None: + setv(to_object, ['seed'], getv(from_object, ['seed'])) + + if getv(from_object, ['response_mime_type']) is not None: + setv( + to_object, + ['responseMimeType'], + getv(from_object, ['response_mime_type']), + ) + + if getv(from_object, ['response_schema']) is not None: + setv( + to_object, + ['responseSchema'], + _Schema_to_vertex( + api_client, + t.t_schema(api_client, getv(from_object, ['response_schema'])), + to_object, + ), + ) + + if getv(from_object, ['routing_config']) is not None: + setv(to_object, ['routingConfig'], getv(from_object, ['routing_config'])) + + if getv(from_object, ['safety_settings']) is not None: + setv( + parent_object, + ['safetySettings'], + [ + _SafetySetting_to_vertex(api_client, item, to_object) + for item in getv(from_object, ['safety_settings']) + ], + ) + + if getv(from_object, ['tools']) is not None: + setv( + parent_object, + ['tools'], + [ + _Tool_to_vertex(api_client, t.t_tool(api_client, item), to_object) + for item in t.t_tools(api_client, getv(from_object, ['tools'])) + ], + ) + + if getv(from_object, ['tool_config']) is not None: + setv( + parent_object, + ['toolConfig'], + _ToolConfig_to_vertex( + api_client, getv(from_object, ['tool_config']), to_object + ), + ) + + if getv(from_object, ['cached_content']) is not None: + setv( + parent_object, + ['cachedContent'], + t.t_cached_content_name( + api_client, getv(from_object, ['cached_content']) + ), + ) + + if getv(from_object, ['response_modalities']) is not None: + setv( + to_object, + ['responseModalities'], + getv(from_object, ['response_modalities']), + ) + + if getv(from_object, ['media_resolution']) is not None: + setv( + to_object, ['mediaResolution'], getv(from_object, ['media_resolution']) + ) + + if getv(from_object, ['speech_config']) is not None: + setv( + to_object, + ['speechConfig'], + _SpeechConfig_to_vertex( + api_client, + t.t_speech_config(api_client, getv(from_object, ['speech_config'])), + to_object, + ), + ) + + if getv(from_object, ['audio_timestamp']) is not None: + setv(to_object, ['audioTimestamp'], getv(from_object, ['audio_timestamp'])) + + if getv(from_object, ['thinking_config']) is not None: + setv( + to_object, + ['thinkingConfig'], + _ThinkingConfig_to_vertex( + api_client, getv(from_object, ['thinking_config']), to_object + ), + ) + + return to_object + + +def _GenerateContentParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['contents']) is not None: + setv( + to_object, + ['contents'], + [ + _Content_to_mldev(api_client, item, to_object) + for item in t.t_contents( + api_client, getv(from_object, ['contents']) + ) + ], + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['generationConfig'], + _GenerateContentConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _GenerateContentParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['contents']) is not None: + setv( + to_object, + ['contents'], + [ + _Content_to_vertex(api_client, item, to_object) + for item in t.t_contents( + api_client, getv(from_object, ['contents']) + ) + ], + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['generationConfig'], + _GenerateContentConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _EmbedContentConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['task_type']) is not None: + setv( + parent_object, + ['requests[]', 'taskType'], + getv(from_object, ['task_type']), + ) + + if getv(from_object, ['title']) is not None: + setv(parent_object, ['requests[]', 'title'], getv(from_object, ['title'])) + + if getv(from_object, ['output_dimensionality']) is not None: + setv( + parent_object, + ['requests[]', 'outputDimensionality'], + getv(from_object, ['output_dimensionality']), + ) + + if getv(from_object, ['mime_type']) is not None: + raise ValueError('mime_type parameter is not supported in Google AI.') + + if getv(from_object, ['auto_truncate']) is not None: + raise ValueError('auto_truncate parameter is not supported in Google AI.') + + return to_object + + +def _EmbedContentConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['task_type']) is not None: + setv( + parent_object, + ['instances[]', 'task_type'], + getv(from_object, ['task_type']), + ) + + if getv(from_object, ['title']) is not None: + setv(parent_object, ['instances[]', 'title'], getv(from_object, ['title'])) + + if getv(from_object, ['output_dimensionality']) is not None: + setv( + parent_object, + ['parameters', 'outputDimensionality'], + getv(from_object, ['output_dimensionality']), + ) + + if getv(from_object, ['mime_type']) is not None: + setv( + parent_object, + ['instances[]', 'mimeType'], + getv(from_object, ['mime_type']), + ) + + if getv(from_object, ['auto_truncate']) is not None: + setv( + parent_object, + ['parameters', 'autoTruncate'], + getv(from_object, ['auto_truncate']), + ) + + return to_object + + +def _EmbedContentParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['contents']) is not None: + setv( + to_object, + ['requests[]', 'content'], + t.t_contents_for_embed(api_client, getv(from_object, ['contents'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _EmbedContentConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + setv( + to_object, + ['requests[]', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + return to_object + + +def _EmbedContentParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['contents']) is not None: + setv( + to_object, + ['instances[]', 'content'], + t.t_contents_for_embed(api_client, getv(from_object, ['contents'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _EmbedContentConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _GenerateImageConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['output_gcs_uri']) is not None: + raise ValueError('output_gcs_uri parameter is not supported in Google AI.') + + if getv(from_object, ['negative_prompt']) is not None: + setv( + parent_object, + ['parameters', 'negativePrompt'], + getv(from_object, ['negative_prompt']), + ) + + if getv(from_object, ['number_of_images']) is not None: + setv( + parent_object, + ['parameters', 'sampleCount'], + getv(from_object, ['number_of_images']), + ) + + if getv(from_object, ['guidance_scale']) is not None: + setv( + parent_object, + ['parameters', 'guidanceScale'], + getv(from_object, ['guidance_scale']), + ) + + if getv(from_object, ['seed']) is not None: + raise ValueError('seed parameter is not supported in Google AI.') + + if getv(from_object, ['safety_filter_level']) is not None: + setv( + parent_object, + ['parameters', 'safetySetting'], + getv(from_object, ['safety_filter_level']), + ) + + if getv(from_object, ['person_generation']) is not None: + setv( + parent_object, + ['parameters', 'personGeneration'], + getv(from_object, ['person_generation']), + ) + + if getv(from_object, ['include_safety_attributes']) is not None: + setv( + parent_object, + ['parameters', 'includeSafetyAttributes'], + getv(from_object, ['include_safety_attributes']), + ) + + if getv(from_object, ['include_rai_reason']) is not None: + setv( + parent_object, + ['parameters', 'includeRaiReason'], + getv(from_object, ['include_rai_reason']), + ) + + if getv(from_object, ['language']) is not None: + setv( + parent_object, + ['parameters', 'language'], + getv(from_object, ['language']), + ) + + if getv(from_object, ['output_mime_type']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'mimeType'], + getv(from_object, ['output_mime_type']), + ) + + if getv(from_object, ['output_compression_quality']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'compressionQuality'], + getv(from_object, ['output_compression_quality']), + ) + + if getv(from_object, ['add_watermark']) is not None: + raise ValueError('add_watermark parameter is not supported in Google AI.') + + if getv(from_object, ['aspect_ratio']) is not None: + setv( + parent_object, + ['parameters', 'aspectRatio'], + getv(from_object, ['aspect_ratio']), + ) + + return to_object + + +def _GenerateImageConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['output_gcs_uri']) is not None: + setv( + parent_object, + ['parameters', 'storageUri'], + getv(from_object, ['output_gcs_uri']), + ) + + if getv(from_object, ['negative_prompt']) is not None: + setv( + parent_object, + ['parameters', 'negativePrompt'], + getv(from_object, ['negative_prompt']), + ) + + if getv(from_object, ['number_of_images']) is not None: + setv( + parent_object, + ['parameters', 'sampleCount'], + getv(from_object, ['number_of_images']), + ) + + if getv(from_object, ['guidance_scale']) is not None: + setv( + parent_object, + ['parameters', 'guidanceScale'], + getv(from_object, ['guidance_scale']), + ) + + if getv(from_object, ['seed']) is not None: + setv(parent_object, ['parameters', 'seed'], getv(from_object, ['seed'])) + + if getv(from_object, ['safety_filter_level']) is not None: + setv( + parent_object, + ['parameters', 'safetySetting'], + getv(from_object, ['safety_filter_level']), + ) + + if getv(from_object, ['person_generation']) is not None: + setv( + parent_object, + ['parameters', 'personGeneration'], + getv(from_object, ['person_generation']), + ) + + if getv(from_object, ['include_safety_attributes']) is not None: + setv( + parent_object, + ['parameters', 'includeSafetyAttributes'], + getv(from_object, ['include_safety_attributes']), + ) + + if getv(from_object, ['include_rai_reason']) is not None: + setv( + parent_object, + ['parameters', 'includeRaiReason'], + getv(from_object, ['include_rai_reason']), + ) + + if getv(from_object, ['language']) is not None: + setv( + parent_object, + ['parameters', 'language'], + getv(from_object, ['language']), + ) + + if getv(from_object, ['output_mime_type']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'mimeType'], + getv(from_object, ['output_mime_type']), + ) + + if getv(from_object, ['output_compression_quality']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'compressionQuality'], + getv(from_object, ['output_compression_quality']), + ) + + if getv(from_object, ['add_watermark']) is not None: + setv( + parent_object, + ['parameters', 'addWatermark'], + getv(from_object, ['add_watermark']), + ) + + if getv(from_object, ['aspect_ratio']) is not None: + setv( + parent_object, + ['parameters', 'aspectRatio'], + getv(from_object, ['aspect_ratio']), + ) + + return to_object + + +def _GenerateImageParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['prompt']) is not None: + setv(to_object, ['instances', 'prompt'], getv(from_object, ['prompt'])) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _GenerateImageConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _GenerateImageParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['prompt']) is not None: + setv(to_object, ['instances', 'prompt'], getv(from_object, ['prompt'])) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _GenerateImageConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _Image_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['gcs_uri']) is not None: + raise ValueError('gcs_uri parameter is not supported in Google AI.') + + if getv(from_object, ['image_bytes']) is not None: + setv( + to_object, + ['bytesBase64Encoded'], + t.t_bytes(api_client, getv(from_object, ['image_bytes'])), + ) + + return to_object + + +def _Image_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['gcs_uri']) is not None: + setv(to_object, ['gcsUri'], getv(from_object, ['gcs_uri'])) + + if getv(from_object, ['image_bytes']) is not None: + setv( + to_object, + ['bytesBase64Encoded'], + t.t_bytes(api_client, getv(from_object, ['image_bytes'])), + ) + + return to_object + + +def _MaskReferenceConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['mask_mode']) is not None: + raise ValueError('mask_mode parameter is not supported in Google AI.') + + if getv(from_object, ['segmentation_classes']) is not None: + raise ValueError( + 'segmentation_classes parameter is not supported in Google AI.' + ) + + if getv(from_object, ['mask_dilation']) is not None: + raise ValueError('mask_dilation parameter is not supported in Google AI.') + + return to_object + + +def _MaskReferenceConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['mask_mode']) is not None: + setv(to_object, ['maskMode'], getv(from_object, ['mask_mode'])) + + if getv(from_object, ['segmentation_classes']) is not None: + setv( + to_object, ['maskClasses'], getv(from_object, ['segmentation_classes']) + ) + + if getv(from_object, ['mask_dilation']) is not None: + setv(to_object, ['dilation'], getv(from_object, ['mask_dilation'])) + + return to_object + + +def _ControlReferenceConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['control_type']) is not None: + raise ValueError('control_type parameter is not supported in Google AI.') + + if getv(from_object, ['enable_control_image_computation']) is not None: + raise ValueError( + 'enable_control_image_computation parameter is not supported in' + ' Google AI.' + ) + + return to_object + + +def _ControlReferenceConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['control_type']) is not None: + setv(to_object, ['controlType'], getv(from_object, ['control_type'])) + + if getv(from_object, ['enable_control_image_computation']) is not None: + setv( + to_object, + ['computeControl'], + getv(from_object, ['enable_control_image_computation']), + ) + + return to_object + + +def _StyleReferenceConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['style_description']) is not None: + raise ValueError( + 'style_description parameter is not supported in Google AI.' + ) + + return to_object + + +def _StyleReferenceConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['style_description']) is not None: + setv( + to_object, + ['styleDescription'], + getv(from_object, ['style_description']), + ) + + return to_object + + +def _SubjectReferenceConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['subject_type']) is not None: + raise ValueError('subject_type parameter is not supported in Google AI.') + + if getv(from_object, ['subject_description']) is not None: + raise ValueError( + 'subject_description parameter is not supported in Google AI.' + ) + + return to_object + + +def _SubjectReferenceConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['subject_type']) is not None: + setv(to_object, ['subjectType'], getv(from_object, ['subject_type'])) + + if getv(from_object, ['subject_description']) is not None: + setv( + to_object, + ['subjectDescription'], + getv(from_object, ['subject_description']), + ) + + return to_object + + +def _ReferenceImageAPI_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['reference_image']) is not None: + raise ValueError('reference_image parameter is not supported in Google AI.') + + if getv(from_object, ['reference_id']) is not None: + raise ValueError('reference_id parameter is not supported in Google AI.') + + if getv(from_object, ['reference_type']) is not None: + raise ValueError('reference_type parameter is not supported in Google AI.') + + if getv(from_object, ['mask_image_config']) is not None: + raise ValueError( + 'mask_image_config parameter is not supported in Google AI.' + ) + + if getv(from_object, ['control_image_config']) is not None: + raise ValueError( + 'control_image_config parameter is not supported in Google AI.' + ) + + if getv(from_object, ['style_image_config']) is not None: + raise ValueError( + 'style_image_config parameter is not supported in Google AI.' + ) + + if getv(from_object, ['subject_image_config']) is not None: + raise ValueError( + 'subject_image_config parameter is not supported in Google AI.' + ) + + return to_object + + +def _ReferenceImageAPI_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['reference_image']) is not None: + setv( + to_object, + ['referenceImage'], + _Image_to_vertex( + api_client, getv(from_object, ['reference_image']), to_object + ), + ) + + if getv(from_object, ['reference_id']) is not None: + setv(to_object, ['referenceId'], getv(from_object, ['reference_id'])) + + if getv(from_object, ['reference_type']) is not None: + setv(to_object, ['referenceType'], getv(from_object, ['reference_type'])) + + if getv(from_object, ['mask_image_config']) is not None: + setv( + to_object, + ['maskImageConfig'], + _MaskReferenceConfig_to_vertex( + api_client, getv(from_object, ['mask_image_config']), to_object + ), + ) + + if getv(from_object, ['control_image_config']) is not None: + setv( + to_object, + ['controlImageConfig'], + _ControlReferenceConfig_to_vertex( + api_client, getv(from_object, ['control_image_config']), to_object + ), + ) + + if getv(from_object, ['style_image_config']) is not None: + setv( + to_object, + ['styleImageConfig'], + _StyleReferenceConfig_to_vertex( + api_client, getv(from_object, ['style_image_config']), to_object + ), + ) + + if getv(from_object, ['subject_image_config']) is not None: + setv( + to_object, + ['subjectImageConfig'], + _SubjectReferenceConfig_to_vertex( + api_client, getv(from_object, ['subject_image_config']), to_object + ), + ) + + return to_object + + +def _EditImageConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['output_gcs_uri']) is not None: + raise ValueError('output_gcs_uri parameter is not supported in Google AI.') + + if getv(from_object, ['negative_prompt']) is not None: + setv( + parent_object, + ['parameters', 'negativePrompt'], + getv(from_object, ['negative_prompt']), + ) + + if getv(from_object, ['number_of_images']) is not None: + setv( + parent_object, + ['parameters', 'sampleCount'], + getv(from_object, ['number_of_images']), + ) + + if getv(from_object, ['guidance_scale']) is not None: + setv( + parent_object, + ['parameters', 'guidanceScale'], + getv(from_object, ['guidance_scale']), + ) + + if getv(from_object, ['seed']) is not None: + raise ValueError('seed parameter is not supported in Google AI.') + + if getv(from_object, ['safety_filter_level']) is not None: + setv( + parent_object, + ['parameters', 'safetySetting'], + getv(from_object, ['safety_filter_level']), + ) + + if getv(from_object, ['person_generation']) is not None: + setv( + parent_object, + ['parameters', 'personGeneration'], + getv(from_object, ['person_generation']), + ) + + if getv(from_object, ['include_safety_attributes']) is not None: + setv( + parent_object, + ['parameters', 'includeSafetyAttributes'], + getv(from_object, ['include_safety_attributes']), + ) + + if getv(from_object, ['include_rai_reason']) is not None: + setv( + parent_object, + ['parameters', 'includeRaiReason'], + getv(from_object, ['include_rai_reason']), + ) + + if getv(from_object, ['language']) is not None: + setv( + parent_object, + ['parameters', 'language'], + getv(from_object, ['language']), + ) + + if getv(from_object, ['output_mime_type']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'mimeType'], + getv(from_object, ['output_mime_type']), + ) + + if getv(from_object, ['output_compression_quality']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'compressionQuality'], + getv(from_object, ['output_compression_quality']), + ) + + if getv(from_object, ['edit_mode']) is not None: + setv( + parent_object, + ['parameters', 'editMode'], + getv(from_object, ['edit_mode']), + ) + + return to_object + + +def _EditImageConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['output_gcs_uri']) is not None: + setv( + parent_object, + ['parameters', 'storageUri'], + getv(from_object, ['output_gcs_uri']), + ) + + if getv(from_object, ['negative_prompt']) is not None: + setv( + parent_object, + ['parameters', 'negativePrompt'], + getv(from_object, ['negative_prompt']), + ) + + if getv(from_object, ['number_of_images']) is not None: + setv( + parent_object, + ['parameters', 'sampleCount'], + getv(from_object, ['number_of_images']), + ) + + if getv(from_object, ['guidance_scale']) is not None: + setv( + parent_object, + ['parameters', 'guidanceScale'], + getv(from_object, ['guidance_scale']), + ) + + if getv(from_object, ['seed']) is not None: + setv(parent_object, ['parameters', 'seed'], getv(from_object, ['seed'])) + + if getv(from_object, ['safety_filter_level']) is not None: + setv( + parent_object, + ['parameters', 'safetySetting'], + getv(from_object, ['safety_filter_level']), + ) + + if getv(from_object, ['person_generation']) is not None: + setv( + parent_object, + ['parameters', 'personGeneration'], + getv(from_object, ['person_generation']), + ) + + if getv(from_object, ['include_safety_attributes']) is not None: + setv( + parent_object, + ['parameters', 'includeSafetyAttributes'], + getv(from_object, ['include_safety_attributes']), + ) + + if getv(from_object, ['include_rai_reason']) is not None: + setv( + parent_object, + ['parameters', 'includeRaiReason'], + getv(from_object, ['include_rai_reason']), + ) + + if getv(from_object, ['language']) is not None: + setv( + parent_object, + ['parameters', 'language'], + getv(from_object, ['language']), + ) + + if getv(from_object, ['output_mime_type']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'mimeType'], + getv(from_object, ['output_mime_type']), + ) + + if getv(from_object, ['output_compression_quality']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'compressionQuality'], + getv(from_object, ['output_compression_quality']), + ) + + if getv(from_object, ['edit_mode']) is not None: + setv( + parent_object, + ['parameters', 'editMode'], + getv(from_object, ['edit_mode']), + ) + + return to_object + + +def _EditImageParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['prompt']) is not None: + setv(to_object, ['instances', 'prompt'], getv(from_object, ['prompt'])) + + if getv(from_object, ['reference_images']) is not None: + setv( + to_object, + ['instances', 'referenceImages'], + [ + _ReferenceImageAPI_to_mldev(api_client, item, to_object) + for item in getv(from_object, ['reference_images']) + ], + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _EditImageConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _EditImageParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['prompt']) is not None: + setv(to_object, ['instances', 'prompt'], getv(from_object, ['prompt'])) + + if getv(from_object, ['reference_images']) is not None: + setv( + to_object, + ['instances', 'referenceImages'], + [ + _ReferenceImageAPI_to_vertex(api_client, item, to_object) + for item in getv(from_object, ['reference_images']) + ], + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _EditImageConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _UpscaleImageAPIConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['include_rai_reason']) is not None: + setv( + parent_object, + ['parameters', 'includeRaiReason'], + getv(from_object, ['include_rai_reason']), + ) + + if getv(from_object, ['output_mime_type']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'mimeType'], + getv(from_object, ['output_mime_type']), + ) + + if getv(from_object, ['output_compression_quality']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'compressionQuality'], + getv(from_object, ['output_compression_quality']), + ) + + if getv(from_object, ['number_of_images']) is not None: + setv( + parent_object, + ['parameters', 'sampleCount'], + getv(from_object, ['number_of_images']), + ) + + if getv(from_object, ['mode']) is not None: + setv(parent_object, ['parameters', 'mode'], getv(from_object, ['mode'])) + + return to_object + + +def _UpscaleImageAPIConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['include_rai_reason']) is not None: + setv( + parent_object, + ['parameters', 'includeRaiReason'], + getv(from_object, ['include_rai_reason']), + ) + + if getv(from_object, ['output_mime_type']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'mimeType'], + getv(from_object, ['output_mime_type']), + ) + + if getv(from_object, ['output_compression_quality']) is not None: + setv( + parent_object, + ['parameters', 'outputOptions', 'compressionQuality'], + getv(from_object, ['output_compression_quality']), + ) + + if getv(from_object, ['number_of_images']) is not None: + setv( + parent_object, + ['parameters', 'sampleCount'], + getv(from_object, ['number_of_images']), + ) + + if getv(from_object, ['mode']) is not None: + setv(parent_object, ['parameters', 'mode'], getv(from_object, ['mode'])) + + return to_object + + +def _UpscaleImageAPIParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['image']) is not None: + setv( + to_object, + ['instances', 'image'], + _Image_to_mldev(api_client, getv(from_object, ['image']), to_object), + ) + + if getv(from_object, ['upscale_factor']) is not None: + setv( + to_object, + ['parameters', 'upscaleConfig', 'upscaleFactor'], + getv(from_object, ['upscale_factor']), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _UpscaleImageAPIConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _UpscaleImageAPIParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['image']) is not None: + setv( + to_object, + ['instances', 'image'], + _Image_to_vertex(api_client, getv(from_object, ['image']), to_object), + ) + + if getv(from_object, ['upscale_factor']) is not None: + setv( + to_object, + ['parameters', 'upscaleConfig', 'upscaleFactor'], + getv(from_object, ['upscale_factor']), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _UpscaleImageAPIConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _GetModelParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + return to_object + + +def _GetModelParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + return to_object + + +def _ListModelsConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['page_size']) is not None: + setv( + parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size']) + ) + + if getv(from_object, ['page_token']) is not None: + setv( + parent_object, + ['_query', 'pageToken'], + getv(from_object, ['page_token']), + ) + + if getv(from_object, ['filter']) is not None: + setv(parent_object, ['_query', 'filter'], getv(from_object, ['filter'])) + + if getv(from_object, ['query_base']) is not None: + setv( + parent_object, + ['_url', 'models_url'], + t.t_models_url(api_client, getv(from_object, ['query_base'])), + ) + + return to_object + + +def _ListModelsConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['page_size']) is not None: + setv( + parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size']) + ) + + if getv(from_object, ['page_token']) is not None: + setv( + parent_object, + ['_query', 'pageToken'], + getv(from_object, ['page_token']), + ) + + if getv(from_object, ['filter']) is not None: + setv(parent_object, ['_query', 'filter'], getv(from_object, ['filter'])) + + if getv(from_object, ['query_base']) is not None: + setv( + parent_object, + ['_url', 'models_url'], + t.t_models_url(api_client, getv(from_object, ['query_base'])), + ) + + return to_object + + +def _ListModelsParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _ListModelsConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _ListModelsParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _ListModelsConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _UpdateModelConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['display_name']) is not None: + setv(parent_object, ['displayName'], getv(from_object, ['display_name'])) + + if getv(from_object, ['description']) is not None: + setv(parent_object, ['description'], getv(from_object, ['description'])) + + return to_object + + +def _UpdateModelConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['display_name']) is not None: + setv(parent_object, ['displayName'], getv(from_object, ['display_name'])) + + if getv(from_object, ['description']) is not None: + setv(parent_object, ['description'], getv(from_object, ['description'])) + + return to_object + + +def _UpdateModelParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _UpdateModelConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _UpdateModelParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _UpdateModelConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _DeleteModelParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + return to_object + + +def _DeleteModelParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'name'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + return to_object + + +def _CountTokensConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['system_instruction']) is not None: + setv( + parent_object, + ['generateContentRequest', 'systemInstruction'], + _Content_to_mldev( + api_client, + t.t_content(api_client, getv(from_object, ['system_instruction'])), + to_object, + ), + ) + + if getv(from_object, ['tools']) is not None: + setv( + parent_object, + ['generateContentRequest', 'tools'], + [ + _Tool_to_mldev(api_client, item, to_object) + for item in getv(from_object, ['tools']) + ], + ) + + if getv(from_object, ['generation_config']) is not None: + raise ValueError( + 'generation_config parameter is not supported in Google AI.' + ) + + return to_object + + +def _CountTokensConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['system_instruction']) is not None: + setv( + parent_object, + ['systemInstruction'], + _Content_to_vertex( + api_client, + t.t_content(api_client, getv(from_object, ['system_instruction'])), + to_object, + ), + ) + + if getv(from_object, ['tools']) is not None: + setv( + parent_object, + ['tools'], + [ + _Tool_to_vertex(api_client, item, to_object) + for item in getv(from_object, ['tools']) + ], + ) + + if getv(from_object, ['generation_config']) is not None: + setv( + parent_object, + ['generationConfig'], + getv(from_object, ['generation_config']), + ) + + return to_object + + +def _CountTokensParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['contents']) is not None: + setv( + to_object, + ['contents'], + [ + _Content_to_mldev(api_client, item, to_object) + for item in t.t_contents( + api_client, getv(from_object, ['contents']) + ) + ], + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CountTokensConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _CountTokensParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['contents']) is not None: + setv( + to_object, + ['contents'], + [ + _Content_to_vertex(api_client, item, to_object) + for item in t.t_contents( + api_client, getv(from_object, ['contents']) + ) + ], + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CountTokensConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _ComputeTokensConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _ComputeTokensConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _ComputeTokensParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['contents']) is not None: + raise ValueError('contents parameter is not supported in Google AI.') + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _ComputeTokensConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _ComputeTokensParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['contents']) is not None: + setv( + to_object, + ['contents'], + [ + _Content_to_vertex(api_client, item, to_object) + for item in t.t_contents( + api_client, getv(from_object, ['contents']) + ) + ], + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _ComputeTokensConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _Part_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + if getv(from_object, ['thought']) is not None: + setv(to_object, ['thought'], getv(from_object, ['thought'])) + + if getv(from_object, ['codeExecutionResult']) is not None: + setv( + to_object, + ['code_execution_result'], + getv(from_object, ['codeExecutionResult']), + ) + + if getv(from_object, ['executableCode']) is not None: + setv(to_object, ['executable_code'], getv(from_object, ['executableCode'])) + + if getv(from_object, ['fileData']) is not None: + setv(to_object, ['file_data'], getv(from_object, ['fileData'])) + + if getv(from_object, ['functionCall']) is not None: + setv(to_object, ['function_call'], getv(from_object, ['functionCall'])) + + if getv(from_object, ['functionResponse']) is not None: + setv( + to_object, + ['function_response'], + getv(from_object, ['functionResponse']), + ) + + if getv(from_object, ['inlineData']) is not None: + setv(to_object, ['inline_data'], getv(from_object, ['inlineData'])) + + if getv(from_object, ['text']) is not None: + setv(to_object, ['text'], getv(from_object, ['text'])) + + return to_object + + +def _Part_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['videoMetadata']) is not None: + setv(to_object, ['video_metadata'], getv(from_object, ['videoMetadata'])) + + if getv(from_object, ['thought']) is not None: + setv(to_object, ['thought'], getv(from_object, ['thought'])) + + if getv(from_object, ['codeExecutionResult']) is not None: + setv( + to_object, + ['code_execution_result'], + getv(from_object, ['codeExecutionResult']), + ) + + if getv(from_object, ['executableCode']) is not None: + setv(to_object, ['executable_code'], getv(from_object, ['executableCode'])) + + if getv(from_object, ['fileData']) is not None: + setv(to_object, ['file_data'], getv(from_object, ['fileData'])) + + if getv(from_object, ['functionCall']) is not None: + setv(to_object, ['function_call'], getv(from_object, ['functionCall'])) + + if getv(from_object, ['functionResponse']) is not None: + setv( + to_object, + ['function_response'], + getv(from_object, ['functionResponse']), + ) + + if getv(from_object, ['inlineData']) is not None: + setv(to_object, ['inline_data'], getv(from_object, ['inlineData'])) + + if getv(from_object, ['text']) is not None: + setv(to_object, ['text'], getv(from_object, ['text'])) + + return to_object + + +def _Content_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['parts']) is not None: + setv( + to_object, + ['parts'], + [ + _Part_from_mldev(api_client, item, to_object) + for item in getv(from_object, ['parts']) + ], + ) + + if getv(from_object, ['role']) is not None: + setv(to_object, ['role'], getv(from_object, ['role'])) + + return to_object + + +def _Content_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['parts']) is not None: + setv( + to_object, + ['parts'], + [ + _Part_from_vertex(api_client, item, to_object) + for item in getv(from_object, ['parts']) + ], + ) + + if getv(from_object, ['role']) is not None: + setv(to_object, ['role'], getv(from_object, ['role'])) + + return to_object + + +def _CitationMetadata_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['citationSources']) is not None: + setv(to_object, ['citations'], getv(from_object, ['citationSources'])) + + return to_object + + +def _CitationMetadata_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['citations']) is not None: + setv(to_object, ['citations'], getv(from_object, ['citations'])) + + return to_object + + +def _Candidate_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['content']) is not None: + setv( + to_object, + ['content'], + _Content_from_mldev( + api_client, getv(from_object, ['content']), to_object + ), + ) + + if getv(from_object, ['citationMetadata']) is not None: + setv( + to_object, + ['citation_metadata'], + _CitationMetadata_from_mldev( + api_client, getv(from_object, ['citationMetadata']), to_object + ), + ) + + if getv(from_object, ['tokenCount']) is not None: + setv(to_object, ['token_count'], getv(from_object, ['tokenCount'])) + + if getv(from_object, ['avgLogprobs']) is not None: + setv(to_object, ['avg_logprobs'], getv(from_object, ['avgLogprobs'])) + + if getv(from_object, ['finishReason']) is not None: + setv(to_object, ['finish_reason'], getv(from_object, ['finishReason'])) + + if getv(from_object, ['groundingMetadata']) is not None: + setv( + to_object, + ['grounding_metadata'], + getv(from_object, ['groundingMetadata']), + ) + + if getv(from_object, ['index']) is not None: + setv(to_object, ['index'], getv(from_object, ['index'])) + + if getv(from_object, ['logprobsResult']) is not None: + setv(to_object, ['logprobs_result'], getv(from_object, ['logprobsResult'])) + + if getv(from_object, ['safetyRatings']) is not None: + setv(to_object, ['safety_ratings'], getv(from_object, ['safetyRatings'])) + + return to_object + + +def _Candidate_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['content']) is not None: + setv( + to_object, + ['content'], + _Content_from_vertex( + api_client, getv(from_object, ['content']), to_object + ), + ) + + if getv(from_object, ['citationMetadata']) is not None: + setv( + to_object, + ['citation_metadata'], + _CitationMetadata_from_vertex( + api_client, getv(from_object, ['citationMetadata']), to_object + ), + ) + + if getv(from_object, ['finishMessage']) is not None: + setv(to_object, ['finish_message'], getv(from_object, ['finishMessage'])) + + if getv(from_object, ['avgLogprobs']) is not None: + setv(to_object, ['avg_logprobs'], getv(from_object, ['avgLogprobs'])) + + if getv(from_object, ['finishReason']) is not None: + setv(to_object, ['finish_reason'], getv(from_object, ['finishReason'])) + + if getv(from_object, ['groundingMetadata']) is not None: + setv( + to_object, + ['grounding_metadata'], + getv(from_object, ['groundingMetadata']), + ) + + if getv(from_object, ['index']) is not None: + setv(to_object, ['index'], getv(from_object, ['index'])) + + if getv(from_object, ['logprobsResult']) is not None: + setv(to_object, ['logprobs_result'], getv(from_object, ['logprobsResult'])) + + if getv(from_object, ['safetyRatings']) is not None: + setv(to_object, ['safety_ratings'], getv(from_object, ['safetyRatings'])) + + return to_object + + +def _GenerateContentResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['candidates']) is not None: + setv( + to_object, + ['candidates'], + [ + _Candidate_from_mldev(api_client, item, to_object) + for item in getv(from_object, ['candidates']) + ], + ) + + if getv(from_object, ['modelVersion']) is not None: + setv(to_object, ['model_version'], getv(from_object, ['modelVersion'])) + + if getv(from_object, ['promptFeedback']) is not None: + setv(to_object, ['prompt_feedback'], getv(from_object, ['promptFeedback'])) + + if getv(from_object, ['usageMetadata']) is not None: + setv(to_object, ['usage_metadata'], getv(from_object, ['usageMetadata'])) + + return to_object + + +def _GenerateContentResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['candidates']) is not None: + setv( + to_object, + ['candidates'], + [ + _Candidate_from_vertex(api_client, item, to_object) + for item in getv(from_object, ['candidates']) + ], + ) + + if getv(from_object, ['modelVersion']) is not None: + setv(to_object, ['model_version'], getv(from_object, ['modelVersion'])) + + if getv(from_object, ['promptFeedback']) is not None: + setv(to_object, ['prompt_feedback'], getv(from_object, ['promptFeedback'])) + + if getv(from_object, ['usageMetadata']) is not None: + setv(to_object, ['usage_metadata'], getv(from_object, ['usageMetadata'])) + + return to_object + + +def _ContentEmbeddingStatistics_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _ContentEmbeddingStatistics_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['truncated']) is not None: + setv(to_object, ['truncated'], getv(from_object, ['truncated'])) + + if getv(from_object, ['token_count']) is not None: + setv(to_object, ['token_count'], getv(from_object, ['token_count'])) + + return to_object + + +def _ContentEmbedding_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['values']) is not None: + setv(to_object, ['values'], getv(from_object, ['values'])) + + return to_object + + +def _ContentEmbedding_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['values']) is not None: + setv(to_object, ['values'], getv(from_object, ['values'])) + + if getv(from_object, ['statistics']) is not None: + setv( + to_object, + ['statistics'], + _ContentEmbeddingStatistics_from_vertex( + api_client, getv(from_object, ['statistics']), to_object + ), + ) + + return to_object + + +def _EmbedContentMetadata_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _EmbedContentMetadata_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['billableCharacterCount']) is not None: + setv( + to_object, + ['billable_character_count'], + getv(from_object, ['billableCharacterCount']), + ) + + return to_object + + +def _EmbedContentResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['embeddings']) is not None: + setv( + to_object, + ['embeddings'], + [ + _ContentEmbedding_from_mldev(api_client, item, to_object) + for item in getv(from_object, ['embeddings']) + ], + ) + + if getv(from_object, ['metadata']) is not None: + setv( + to_object, + ['metadata'], + _EmbedContentMetadata_from_mldev( + api_client, getv(from_object, ['metadata']), to_object + ), + ) + + return to_object + + +def _EmbedContentResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['predictions[]', 'embeddings']) is not None: + setv( + to_object, + ['embeddings'], + [ + _ContentEmbedding_from_vertex(api_client, item, to_object) + for item in getv(from_object, ['predictions[]', 'embeddings']) + ], + ) + + if getv(from_object, ['metadata']) is not None: + setv( + to_object, + ['metadata'], + _EmbedContentMetadata_from_vertex( + api_client, getv(from_object, ['metadata']), to_object + ), + ) + + return to_object + + +def _Image_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + if getv(from_object, ['bytesBase64Encoded']) is not None: + setv( + to_object, + ['image_bytes'], + t.t_bytes(api_client, getv(from_object, ['bytesBase64Encoded'])), + ) + + return to_object + + +def _Image_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['gcsUri']) is not None: + setv(to_object, ['gcs_uri'], getv(from_object, ['gcsUri'])) + + if getv(from_object, ['bytesBase64Encoded']) is not None: + setv( + to_object, + ['image_bytes'], + t.t_bytes(api_client, getv(from_object, ['bytesBase64Encoded'])), + ) + + return to_object + + +def _GeneratedImage_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['_self']) is not None: + setv( + to_object, + ['image'], + _Image_from_mldev(api_client, getv(from_object, ['_self']), to_object), + ) + + if getv(from_object, ['raiFilteredReason']) is not None: + setv( + to_object, + ['rai_filtered_reason'], + getv(from_object, ['raiFilteredReason']), + ) + + return to_object + + +def _GeneratedImage_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['_self']) is not None: + setv( + to_object, + ['image'], + _Image_from_vertex(api_client, getv(from_object, ['_self']), to_object), + ) + + if getv(from_object, ['raiFilteredReason']) is not None: + setv( + to_object, + ['rai_filtered_reason'], + getv(from_object, ['raiFilteredReason']), + ) + + return to_object + + +def _GenerateImageResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['predictions']) is not None: + setv( + to_object, + ['generated_images'], + [ + _GeneratedImage_from_mldev(api_client, item, to_object) + for item in getv(from_object, ['predictions']) + ], + ) + + return to_object + + +def _GenerateImageResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['predictions']) is not None: + setv( + to_object, + ['generated_images'], + [ + _GeneratedImage_from_vertex(api_client, item, to_object) + for item in getv(from_object, ['predictions']) + ], + ) + + return to_object + + +def _EditImageResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['predictions']) is not None: + setv( + to_object, + ['generated_images'], + [ + _GeneratedImage_from_mldev(api_client, item, to_object) + for item in getv(from_object, ['predictions']) + ], + ) + + return to_object + + +def _EditImageResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['predictions']) is not None: + setv( + to_object, + ['generated_images'], + [ + _GeneratedImage_from_vertex(api_client, item, to_object) + for item in getv(from_object, ['predictions']) + ], + ) + + return to_object + + +def _UpscaleImageResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['predictions']) is not None: + setv( + to_object, + ['generated_images'], + [ + _GeneratedImage_from_mldev(api_client, item, to_object) + for item in getv(from_object, ['predictions']) + ], + ) + + return to_object + + +def _UpscaleImageResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['predictions']) is not None: + setv( + to_object, + ['generated_images'], + [ + _GeneratedImage_from_vertex(api_client, item, to_object) + for item in getv(from_object, ['predictions']) + ], + ) + + return to_object + + +def _Endpoint_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _Endpoint_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['endpoint']) is not None: + setv(to_object, ['name'], getv(from_object, ['endpoint'])) + + if getv(from_object, ['deployedModelId']) is not None: + setv( + to_object, ['deployed_model_id'], getv(from_object, ['deployedModelId']) + ) + + return to_object + + +def _TunedModelInfo_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['baseModel']) is not None: + setv(to_object, ['base_model'], getv(from_object, ['baseModel'])) + + if getv(from_object, ['createTime']) is not None: + setv(to_object, ['create_time'], getv(from_object, ['createTime'])) + + if getv(from_object, ['updateTime']) is not None: + setv(to_object, ['update_time'], getv(from_object, ['updateTime'])) + + return to_object + + +def _TunedModelInfo_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if ( + getv(from_object, ['labels', 'google-vertex-llm-tuning-base-model-id']) + is not None + ): + setv( + to_object, + ['base_model'], + getv(from_object, ['labels', 'google-vertex-llm-tuning-base-model-id']), + ) + + if getv(from_object, ['createTime']) is not None: + setv(to_object, ['create_time'], getv(from_object, ['createTime'])) + + if getv(from_object, ['updateTime']) is not None: + setv(to_object, ['update_time'], getv(from_object, ['updateTime'])) + + return to_object + + +def _Model_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['displayName']) is not None: + setv(to_object, ['display_name'], getv(from_object, ['displayName'])) + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['version']) is not None: + setv(to_object, ['version'], getv(from_object, ['version'])) + + if getv(from_object, ['_self']) is not None: + setv( + to_object, + ['tuned_model_info'], + _TunedModelInfo_from_mldev( + api_client, getv(from_object, ['_self']), to_object + ), + ) + + if getv(from_object, ['inputTokenLimit']) is not None: + setv( + to_object, ['input_token_limit'], getv(from_object, ['inputTokenLimit']) + ) + + if getv(from_object, ['outputTokenLimit']) is not None: + setv( + to_object, + ['output_token_limit'], + getv(from_object, ['outputTokenLimit']), + ) + + if getv(from_object, ['supportedGenerationMethods']) is not None: + setv( + to_object, + ['supported_actions'], + getv(from_object, ['supportedGenerationMethods']), + ) + + return to_object + + +def _Model_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['displayName']) is not None: + setv(to_object, ['display_name'], getv(from_object, ['displayName'])) + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['versionId']) is not None: + setv(to_object, ['version'], getv(from_object, ['versionId'])) + + if getv(from_object, ['deployedModels']) is not None: + setv( + to_object, + ['endpoints'], + [ + _Endpoint_from_vertex(api_client, item, to_object) + for item in getv(from_object, ['deployedModels']) + ], + ) + + if getv(from_object, ['labels']) is not None: + setv(to_object, ['labels'], getv(from_object, ['labels'])) + + if getv(from_object, ['_self']) is not None: + setv( + to_object, + ['tuned_model_info'], + _TunedModelInfo_from_vertex( + api_client, getv(from_object, ['_self']), to_object + ), + ) + + return to_object + + +def _ListModelsResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['nextPageToken']) is not None: + setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken'])) + + if getv(from_object, ['_self']) is not None: + setv( + to_object, + ['models'], + [ + _Model_from_mldev(api_client, item, to_object) + for item in t.t_extract_models( + api_client, getv(from_object, ['_self']) + ) + ], + ) + + return to_object + + +def _ListModelsResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['nextPageToken']) is not None: + setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken'])) + + if getv(from_object, ['_self']) is not None: + setv( + to_object, + ['models'], + [ + _Model_from_vertex(api_client, item, to_object) + for item in t.t_extract_models( + api_client, getv(from_object, ['_self']) + ) + ], + ) + + return to_object + + +def _DeleteModelResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _DeleteModelResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + + return to_object + + +def _CountTokensResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['totalTokens']) is not None: + setv(to_object, ['total_tokens'], getv(from_object, ['totalTokens'])) + + if getv(from_object, ['cachedContentTokenCount']) is not None: + setv( + to_object, + ['cached_content_token_count'], + getv(from_object, ['cachedContentTokenCount']), + ) + + return to_object + + +def _CountTokensResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['totalTokens']) is not None: + setv(to_object, ['total_tokens'], getv(from_object, ['totalTokens'])) + + return to_object + + +def _ComputeTokensResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['tokensInfo']) is not None: + setv(to_object, ['tokens_info'], getv(from_object, ['tokensInfo'])) + + return to_object + + +def _ComputeTokensResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['tokensInfo']) is not None: + setv(to_object, ['tokens_info'], getv(from_object, ['tokensInfo'])) + + return to_object + + +class Models(_common.BaseModule): + + def _generate_content( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.GenerateContentConfigOrDict] = None, + ) -> types.GenerateContentResponse: + parameter_model = types._GenerateContentParameters( + model=model, + contents=contents, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _GenerateContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:generateContent'.format_map(request_dict.get('_url')) + else: + request_dict = _GenerateContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{model}:generateContent'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _GenerateContentResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _GenerateContentResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.GenerateContentResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def generate_content_stream( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.GenerateContentConfigOrDict] = None, + ) -> Iterator[types.GenerateContentResponse]: + parameter_model = types._GenerateContentParameters( + model=model, + contents=contents, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _GenerateContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:streamGenerateContent?alt=sse'.format_map( + request_dict.get('_url') + ) + else: + request_dict = _GenerateContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{model}:streamGenerateContent?alt=sse'.format_map( + request_dict.get('_url') + ) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + for response_dict in self._api_client.request_streamed( + 'post', path, request_dict, http_options + ): + + if self._api_client.vertexai: + response_dict = _GenerateContentResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _GenerateContentResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.GenerateContentResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + yield return_value + + def embed_content( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.EmbedContentConfigOrDict] = None, + ) -> types.EmbedContentResponse: + """Calculates embeddings for the given contents(only text is supported). + + Args: + model (str): The model to use. + contents (list[Content]): The contents to embed. + config (EmbedContentConfig): Optional configuration for embeddings. + + Usage: + + .. code-block:: python + + embeddings = client.models.embed_content( + model= 'text-embedding-004', + contents=[ + 'What is your name?', + 'What is your favorite color?', + ], + config={ + 'output_dimensionality': 64 + }, + ) + """ + + parameter_model = types._EmbedContentParameters( + model=model, + contents=contents, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _EmbedContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:predict'.format_map(request_dict.get('_url')) + else: + request_dict = _EmbedContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{model}:batchEmbedContents'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _EmbedContentResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _EmbedContentResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.EmbedContentResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def generate_image( + self, + *, + model: str, + prompt: str, + config: Optional[types.GenerateImageConfigOrDict] = None, + ) -> types.GenerateImageResponse: + """Generates an image based on a text description and configuration. + + Args: + model (str): The model to use. + prompt (str): A text description of the image to generate. + config (GenerateImageConfig): Configuration for generation. + + Usage: + + .. code-block:: python + + response = client.models.generate_image( + model='imagen-3.0-generate-001', + prompt='Man with a dog', + config=types.GenerateImageConfig( + number_of_images= 1, + include_rai_reason= True, + ) + ) + response.generated_images[0].image.show() + # Shows a man with a dog. + """ + + parameter_model = types._GenerateImageParameters( + model=model, + prompt=prompt, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _GenerateImageParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:predict'.format_map(request_dict.get('_url')) + else: + request_dict = _GenerateImageParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{model}:predict'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _GenerateImageResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _GenerateImageResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.GenerateImageResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def edit_image( + self, + *, + model: str, + prompt: str, + reference_images: list[types._ReferenceImageAPIOrDict], + config: Optional[types.EditImageConfigOrDict] = None, + ) -> types.EditImageResponse: + """Edits an image based on a text description and configuration. + + Args: + model (str): The model to use. + prompt (str): A text description of the edit to apply to the image. + reference_images (list[Union[RawReferenceImage, MaskReferenceImage, + ControlReferenceImage, StyleReferenceImage, SubjectReferenceImage]): The + reference images for editing. + config (EditImageConfig): Configuration for editing. + + Usage: + + .. code-block:: python + + from google.genai.types import RawReferenceImage, MaskReferenceImage + + raw_ref_image = RawReferenceImage( + reference_id=1, + reference_image=types.Image.from_file(IMAGE_FILE_PATH), + ) + + mask_ref_image = MaskReferenceImage( + reference_id=2, + config=types.MaskReferenceConfig( + mask_mode='MASK_MODE_FOREGROUND', + mask_dilation=0.06, + ), + ) + response = client.models.edit_image( + model='imagen-3.0-capability-preview-0930', + prompt='man with dog', + reference_images=[raw_ref_image, mask_ref_image], + config=types.EditImageConfig( + edit_mode= "EDIT_MODE_INPAINT_INSERTION", + number_of_images= 1, + include_rai_reason= True, + ) + ) + response.generated_images[0].image.show() + # Shows a man with a dog instead of a cat. + """ + + parameter_model = types._EditImageParameters( + model=model, + prompt=prompt, + reference_images=reference_images, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _EditImageParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:predict'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _EditImageResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _EditImageResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.EditImageResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def _upscale_image( + self, + *, + model: str, + image: types.ImageOrDict, + upscale_factor: str, + config: Optional[types._UpscaleImageAPIConfigOrDict] = None, + ) -> types.UpscaleImageResponse: + """Upscales an image. + + Args: + model (str): The model to use. + image (Image): The input image for upscaling. + upscale_factor (str): The factor to upscale the image (x2 or x4). + config (_UpscaleImageAPIConfig): Configuration for upscaling. + """ + + parameter_model = types._UpscaleImageAPIParameters( + model=model, + image=image, + upscale_factor=upscale_factor, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _UpscaleImageAPIParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:predict'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _UpscaleImageResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _UpscaleImageResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.UpscaleImageResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def get(self, *, model: str) -> types.Model: + parameter_model = types._GetModelParameters( + model=model, + ) + + if self._api_client.vertexai: + request_dict = _GetModelParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _GetModelParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _Model_from_vertex(self._api_client, response_dict) + else: + response_dict = _Model_from_mldev(self._api_client, response_dict) + + return_value = types.Model._from_response(response_dict, parameter_model) + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, *, config: Optional[types.ListModelsConfigOrDict] = None + ) -> types.ListModelsResponse: + parameter_model = types._ListModelsParameters( + config=config, + ) + + if self._api_client.vertexai: + request_dict = _ListModelsParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{models_url}'.format_map(request_dict.get('_url')) + else: + request_dict = _ListModelsParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{models_url}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ListModelsResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ListModelsResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ListModelsResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def update( + self, + *, + model: str, + config: Optional[types.UpdateModelConfigOrDict] = None, + ) -> types.Model: + parameter_model = types._UpdateModelParameters( + model=model, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _UpdateModelParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}'.format_map(request_dict.get('_url')) + else: + request_dict = _UpdateModelParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'patch', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _Model_from_vertex(self._api_client, response_dict) + else: + response_dict = _Model_from_mldev(self._api_client, response_dict) + + return_value = types.Model._from_response(response_dict, parameter_model) + self._api_client._verify_response(return_value) + return return_value + + def delete(self, *, model: str) -> types.DeleteModelResponse: + parameter_model = types._DeleteModelParameters( + model=model, + ) + + if self._api_client.vertexai: + request_dict = _DeleteModelParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _DeleteModelParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'delete', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _DeleteModelResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _DeleteModelResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.DeleteModelResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def count_tokens( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.CountTokensConfigOrDict] = None, + ) -> types.CountTokensResponse: + """Counts the number of tokens in the given content. + + Args: + model (str): The model to use for counting tokens. + contents (list[types.Content]): The content to count tokens for. + Multimodal input is supported for Gemini models. + config (CountTokensConfig): The configuration for counting tokens. + + Usage: + + .. code-block:: python + + response = client.models.count_tokens( + model='gemini-1.5-flash', + contents='What is your name?', + ) + print(response) + # total_tokens=5 cached_content_token_count=None + """ + + parameter_model = types._CountTokensParameters( + model=model, + contents=contents, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _CountTokensParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:countTokens'.format_map(request_dict.get('_url')) + else: + request_dict = _CountTokensParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{model}:countTokens'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _CountTokensResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _CountTokensResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.CountTokensResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def compute_tokens( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.ComputeTokensConfigOrDict] = None, + ) -> types.ComputeTokensResponse: + """Return a list of tokens based on the input text. + + This method is not supported by the Gemini Developer API. + + Args: + model (str): The model to use. + contents (list[shared.Content]): The content to compute tokens for. Only + text is supported. + + Usage: + + .. code-block:: python + + response = client.models.compute_tokens( + model='gemini-1.5-flash', + contents='What is your name?', + ) + print(response) + # tokens_info=[TokensInfo(role='user', token_ids=['1841', ...], + # tokens=[b'What', b' is', b' your', b' name', b'?'])] + """ + + parameter_model = types._ComputeTokensParameters( + model=model, + contents=contents, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _ComputeTokensParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:computeTokens'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ComputeTokensResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ComputeTokensResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ComputeTokensResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def generate_content( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.GenerateContentConfigOrDict] = None, + ) -> types.GenerateContentResponse: + """Makes an API request to generate content using a model. + + Some models support multimodal input and output. + + Usage: + + .. code-block:: python + + from google.genai import types + from google import genai + + client = genai.Client( + vertexai=True, project='my-project-id', location='us-central1' + ) + + response = client.models.generate_content( + model='gemini-1.5-flash-002', + contents='''What is a good name for a flower shop that specializes in + selling bouquets of dried flowers?''' + ) + print(response.text) + # **Elegant & Classic:** + # * The Dried Bloom + # * Everlasting Florals + # * Timeless Petals + + response = client.models.generate_content( + model='gemini-1.5-flash-002', + contents=[ + types.Part.from_text('What is shown in this image?'), + types.Part.from_uri('gs://generativeai-downloads/images/scones.jpg', + 'image/jpeg') + ] + ) + print(response.text) + # The image shows a flat lay arrangement of freshly baked blueberry + # scones. + """ + + if _extra_utils.should_disable_afc(config): + return self._generate_content( + model=model, contents=contents, config=config + ) + remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config) + logging.info( + f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.' + ) + automatic_function_calling_history = [] + while remaining_remote_calls_afc > 0: + response = self._generate_content( + model=model, contents=contents, config=config + ) + remaining_remote_calls_afc -= 1 + if remaining_remote_calls_afc == 0: + logging.info('Reached max remote calls for automatic function calling.') + + function_map = _extra_utils.get_function_map(config) + if not function_map: + break + if ( + not response.candidates + or not response.candidates[0].content + or not response.candidates[0].content.parts + ): + break + func_response_parts = _extra_utils.get_function_response_parts( + response, function_map + ) + if not func_response_parts: + break + contents = t.t_contents(self._api_client, contents) + contents.append(response.candidates[0].content) + contents.append( + types.Content( + role='user', + parts=func_response_parts, + ) + ) + automatic_function_calling_history.extend(contents) + if _extra_utils.should_append_afc_history(config): + response.automatic_function_calling_history = ( + automatic_function_calling_history + ) + return response + + def upscale_image( + self, + *, + model: str, + image: types.ImageOrDict, + upscale_factor: str, + config: Optional[types.UpscaleImageConfigOrDict] = None, + ) -> types.UpscaleImageResponse: + """Makes an API request to upscale a provided image. + + Args: + model (str): The model to use. + image (Image): The input image for upscaling. + upscale_factor (str): The factor to upscale the image (x2 or x4). + config (UpscaleImageConfig): Configuration for upscaling. + + Usage: + + .. code-block:: python + + from google.genai.types import Image + + IMAGE_FILE_PATH="my-image.png" + response=client.models.upscale_image( + model='imagen-3.0-generate-001', + image=types.Image.from_file(IMAGE_FILE_PATH), + upscale_factor='x2', + ) + response.generated_images[0].image.show() + # Opens my-image.png which is upscaled by a factor of 2. + """ + + # Validate config. + types.UpscaleImageParameters( + model=model, + image=image, + upscale_factor=upscale_factor, + config=config, + ) + + # Convert to API config. + config = config or {} + config_dct = config if isinstance(config, dict) else config.dict() + api_config = types._UpscaleImageAPIConfigDict(**config_dct) # pylint: disable=protected-access + + # Provide default values through API config. + api_config['mode'] = 'upscale' + api_config['number_of_images'] = 1 + + return self._upscale_image( + model=model, + image=image, + upscale_factor=upscale_factor, + config=api_config, + ) + + def list( + self, + *, + config: Optional[types.ListModelsConfigOrDict] = None, + ) -> Pager[types.Model]: + """Makes an API request to list the available models. + + If `query_base` is set to True in the config, the API will return all + available base models. If set to False or not set (default), it will return + all tuned models. + + Args: + config (ListModelsConfigOrDict): Configuration for retrieving models. + + Usage: + + .. code-block:: python + + response=client.models.list(config={'page_size': 5}) + print(response.page) + # [Model(name='projects/./locations/./models/123', display_name='my_model' + + response=client.models.list(config={'page_size': 5, 'query_base': True}) + print(response.page) + # [Model(name='publishers/google/models/gemini-2.0-flash-exp' ... + """ + + config = ( + types._ListModelsParameters(config=config).config + or types.ListModelsConfig() + ) + if self._api_client.vertexai: + config = config.copy() + if config.query_base: + http_options = ( + config.http_options if config.http_options else HttpOptionsDict() + ) + http_options['skip_project_and_location_in_path'] = True + config.http_options = http_options + else: + # Filter for tuning jobs artifacts by labels. + filter_value = config.filter + config.filter = ( + filter_value + '&filter=labels.tune-type:*' + if filter_value + else 'labels.tune-type:*' + ) + if not config.query_base: + config.query_base = False + return Pager( + 'models', + self._list, + self._list(config=config), + config, + ) + + +class AsyncModels(_common.BaseModule): + + async def _generate_content( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.GenerateContentConfigOrDict] = None, + ) -> types.GenerateContentResponse: + parameter_model = types._GenerateContentParameters( + model=model, + contents=contents, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _GenerateContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:generateContent'.format_map(request_dict.get('_url')) + else: + request_dict = _GenerateContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{model}:generateContent'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _GenerateContentResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _GenerateContentResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.GenerateContentResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def generate_content_stream( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.GenerateContentConfigOrDict] = None, + ) -> AsyncIterator[types.GenerateContentResponse]: + parameter_model = types._GenerateContentParameters( + model=model, + contents=contents, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _GenerateContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:streamGenerateContent?alt=sse'.format_map( + request_dict.get('_url') + ) + else: + request_dict = _GenerateContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{model}:streamGenerateContent?alt=sse'.format_map( + request_dict.get('_url') + ) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + async for response_dict in self._api_client.async_request_streamed( + 'post', path, request_dict, http_options + ): + + if self._api_client.vertexai: + response_dict = _GenerateContentResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _GenerateContentResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.GenerateContentResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + yield return_value + + async def embed_content( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.EmbedContentConfigOrDict] = None, + ) -> types.EmbedContentResponse: + """Calculates embeddings for the given contents(only text is supported). + + Args: + model (str): The model to use. + contents (list[Content]): The contents to embed. + config (EmbedContentConfig): Optional configuration for embeddings. + + Usage: + + .. code-block:: python + + embeddings = client.models.embed_content( + model= 'text-embedding-004', + contents=[ + 'What is your name?', + 'What is your favorite color?', + ], + config={ + 'output_dimensionality': 64 + }, + ) + """ + + parameter_model = types._EmbedContentParameters( + model=model, + contents=contents, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _EmbedContentParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:predict'.format_map(request_dict.get('_url')) + else: + request_dict = _EmbedContentParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{model}:batchEmbedContents'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _EmbedContentResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _EmbedContentResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.EmbedContentResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def generate_image( + self, + *, + model: str, + prompt: str, + config: Optional[types.GenerateImageConfigOrDict] = None, + ) -> types.GenerateImageResponse: + """Generates an image based on a text description and configuration. + + Args: + model (str): The model to use. + prompt (str): A text description of the image to generate. + config (GenerateImageConfig): Configuration for generation. + + Usage: + + .. code-block:: python + + response = client.models.generate_image( + model='imagen-3.0-generate-001', + prompt='Man with a dog', + config=types.GenerateImageConfig( + number_of_images= 1, + include_rai_reason= True, + ) + ) + response.generated_images[0].image.show() + # Shows a man with a dog. + """ + + parameter_model = types._GenerateImageParameters( + model=model, + prompt=prompt, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _GenerateImageParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:predict'.format_map(request_dict.get('_url')) + else: + request_dict = _GenerateImageParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{model}:predict'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _GenerateImageResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _GenerateImageResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.GenerateImageResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def edit_image( + self, + *, + model: str, + prompt: str, + reference_images: list[types._ReferenceImageAPIOrDict], + config: Optional[types.EditImageConfigOrDict] = None, + ) -> types.EditImageResponse: + """Edits an image based on a text description and configuration. + + Args: + model (str): The model to use. + prompt (str): A text description of the edit to apply to the image. + reference_images (list[Union[RawReferenceImage, MaskReferenceImage, + ControlReferenceImage, StyleReferenceImage, SubjectReferenceImage]): The + reference images for editing. + config (EditImageConfig): Configuration for editing. + + Usage: + + .. code-block:: python + + from google.genai.types import RawReferenceImage, MaskReferenceImage + + raw_ref_image = RawReferenceImage( + reference_id=1, + reference_image=types.Image.from_file(IMAGE_FILE_PATH), + ) + + mask_ref_image = MaskReferenceImage( + reference_id=2, + config=types.MaskReferenceConfig( + mask_mode='MASK_MODE_FOREGROUND', + mask_dilation=0.06, + ), + ) + response = client.models.edit_image( + model='imagen-3.0-capability-preview-0930', + prompt='man with dog', + reference_images=[raw_ref_image, mask_ref_image], + config=types.EditImageConfig( + edit_mode= "EDIT_MODE_INPAINT_INSERTION", + number_of_images= 1, + include_rai_reason= True, + ) + ) + response.generated_images[0].image.show() + # Shows a man with a dog instead of a cat. + """ + + parameter_model = types._EditImageParameters( + model=model, + prompt=prompt, + reference_images=reference_images, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _EditImageParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:predict'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _EditImageResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _EditImageResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.EditImageResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def _upscale_image( + self, + *, + model: str, + image: types.ImageOrDict, + upscale_factor: str, + config: Optional[types._UpscaleImageAPIConfigOrDict] = None, + ) -> types.UpscaleImageResponse: + """Upscales an image. + + Args: + model (str): The model to use. + image (Image): The input image for upscaling. + upscale_factor (str): The factor to upscale the image (x2 or x4). + config (_UpscaleImageAPIConfig): Configuration for upscaling. + """ + + parameter_model = types._UpscaleImageAPIParameters( + model=model, + image=image, + upscale_factor=upscale_factor, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _UpscaleImageAPIParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:predict'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _UpscaleImageResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _UpscaleImageResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.UpscaleImageResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def get(self, *, model: str) -> types.Model: + parameter_model = types._GetModelParameters( + model=model, + ) + + if self._api_client.vertexai: + request_dict = _GetModelParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _GetModelParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _Model_from_vertex(self._api_client, response_dict) + else: + response_dict = _Model_from_mldev(self._api_client, response_dict) + + return_value = types.Model._from_response(response_dict, parameter_model) + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, *, config: Optional[types.ListModelsConfigOrDict] = None + ) -> types.ListModelsResponse: + parameter_model = types._ListModelsParameters( + config=config, + ) + + if self._api_client.vertexai: + request_dict = _ListModelsParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{models_url}'.format_map(request_dict.get('_url')) + else: + request_dict = _ListModelsParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{models_url}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ListModelsResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ListModelsResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ListModelsResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def update( + self, + *, + model: str, + config: Optional[types.UpdateModelConfigOrDict] = None, + ) -> types.Model: + parameter_model = types._UpdateModelParameters( + model=model, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _UpdateModelParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}'.format_map(request_dict.get('_url')) + else: + request_dict = _UpdateModelParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'patch', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _Model_from_vertex(self._api_client, response_dict) + else: + response_dict = _Model_from_mldev(self._api_client, response_dict) + + return_value = types.Model._from_response(response_dict, parameter_model) + self._api_client._verify_response(return_value) + return return_value + + async def delete(self, *, model: str) -> types.DeleteModelResponse: + parameter_model = types._DeleteModelParameters( + model=model, + ) + + if self._api_client.vertexai: + request_dict = _DeleteModelParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _DeleteModelParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'delete', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _DeleteModelResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _DeleteModelResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.DeleteModelResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def count_tokens( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.CountTokensConfigOrDict] = None, + ) -> types.CountTokensResponse: + """Counts the number of tokens in the given content. + + Args: + model (str): The model to use for counting tokens. + contents (list[types.Content]): The content to count tokens for. + Multimodal input is supported for Gemini models. + config (CountTokensConfig): The configuration for counting tokens. + + Usage: + + .. code-block:: python + + response = client.models.count_tokens( + model='gemini-1.5-flash', + contents='What is your name?', + ) + print(response) + # total_tokens=5 cached_content_token_count=None + """ + + parameter_model = types._CountTokensParameters( + model=model, + contents=contents, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _CountTokensParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:countTokens'.format_map(request_dict.get('_url')) + else: + request_dict = _CountTokensParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{model}:countTokens'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _CountTokensResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _CountTokensResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.CountTokensResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def compute_tokens( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.ComputeTokensConfigOrDict] = None, + ) -> types.ComputeTokensResponse: + """Return a list of tokens based on the input text. + + This method is not supported by the Gemini Developer API. + + Args: + model (str): The model to use. + contents (list[shared.Content]): The content to compute tokens for. Only + text is supported. + + Usage: + + .. code-block:: python + + response = client.models.compute_tokens( + model='gemini-1.5-flash', + contents='What is your name?', + ) + print(response) + # tokens_info=[TokensInfo(role='user', token_ids=['1841', ...], + # tokens=[b'What', b' is', b' your', b' name', b'?'])] + """ + + parameter_model = types._ComputeTokensParameters( + model=model, + contents=contents, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _ComputeTokensParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{model}:computeTokens'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ComputeTokensResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ComputeTokensResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ComputeTokensResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def generate_content( + self, + *, + model: str, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], + config: Optional[types.GenerateContentConfigOrDict] = None, + ) -> types.GenerateContentResponse: + """Makes an API request to generate content using a model. + + Some models support multimodal input and output. + + Usage: + + .. code-block:: python + + from google.genai import types + from google import genai + + client = genai.Client( + vertexai=True, project='my-project-id', location='us-central1' + ) + + response = await client.aio.models.generate_content( + model='gemini-1.5-flash-002', + contents='User input: I like bagels. Answer:', + config=types.GenerateContentConfig( + system_instruction= + [ + 'You are a helpful language translator.', + 'Your mission is to translate text in English to French.' + ] + ), + ) + print(response.text) + # J'aime les bagels. + """ + if _extra_utils.should_disable_afc(config): + return await self._generate_content( + model=model, contents=contents, config=config + ) + remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config) + logging.info( + f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.' + ) + automatic_function_calling_history = [] + while remaining_remote_calls_afc > 0: + response = await self._generate_content( + model=model, contents=contents, config=config + ) + remaining_remote_calls_afc -= 1 + if remaining_remote_calls_afc == 0: + logging.info('Reached max remote calls for automatic function calling.') + + function_map = _extra_utils.get_function_map(config) + if not function_map: + break + if ( + not response.candidates + or not response.candidates[0].content + or not response.candidates[0].content.parts + ): + break + func_response_parts = _extra_utils.get_function_response_parts( + response, function_map + ) + if not func_response_parts: + break + contents = t.t_contents(self._api_client, contents) + contents.append(response.candidates[0].content) + contents.append( + types.Content( + role='user', + parts=func_response_parts, + ) + ) + automatic_function_calling_history.extend(contents) + + if _extra_utils.should_append_afc_history(config): + response.automatic_function_calling_history = ( + automatic_function_calling_history + ) + return response + + async def list( + self, + *, + config: Optional[types.ListModelsConfigOrDict] = None, + ) -> AsyncPager[types.Model]: + """Makes an API request to list the available models. + + If `query_base` is set to True in the config, the API will return all + available base models. If set to False or not set (default), it will return + all tuned models. + + Args: + config (ListModelsConfigOrDict): Configuration for retrieving models. + + Usage: + + .. code-block:: python + + response = await client.aio.models.list(config={'page_size': 5}) + print(response.page) + # [Model(name='projects/./locations/./models/123', display_name='my_model' + + response = await client.aio.models.list( + config={'page_size': 5, 'query_base': True} + ) + print(response.page) + # [Model(name='publishers/google/models/gemini-2.0-flash-exp' ... + """ + + config = ( + types._ListModelsParameters(config=config).config + or types.ListModelsConfig() + ) + if self._api_client.vertexai: + config = config.copy() + if config.query_base: + http_options = ( + config.http_options if config.http_options else HttpOptionsDict() + ) + http_options['skip_project_and_location_in_path'] = True + config.http_options = http_options + else: + # Filter for tuning jobs artifacts by labels. + filter_value = config.filter + config.filter = ( + filter_value + '&filter=labels.tune-type:*' + if filter_value + else 'labels.tune-type:*' + ) + if not config.query_base: + config.query_base = False + return AsyncPager( + 'models', + self._list, + await self._list(config=config), + config, + ) + + async def upscale_image( + self, + *, + model: str, + image: types.ImageOrDict, + upscale_factor: str, + config: Optional[types.UpscaleImageConfigOrDict] = None, + ) -> types.UpscaleImageResponse: + """Makes an API request to upscale a provided image. + + Args: + model (str): The model to use. + image (Image): The input image for upscaling. + upscale_factor (str): The factor to upscale the image (x2 or x4). + config (UpscaleImageConfig): Configuration for upscaling. + + Usage: + + .. code-block:: python + + from google.genai.types import Image + + IMAGE_FILE_PATH="my-image.png" + response = await client.aio.models.upscale_image( + model='imagen-3.0-generate-001', + image=types.Image.from_file(IMAGE_FILE_PATH), + upscale_factor='x2', + ) + response.generated_images[0].image.show() + # Opens my-image.png which is upscaled by a factor of 2. + """ + + # Validate config. + types.UpscaleImageParameters( + model=model, + image=image, + upscale_factor=upscale_factor, + config=config, + ) + + # Convert to API config. + config = config or {} + config_dct = config if isinstance(config, dict) else config.dict() + api_config = types._UpscaleImageAPIConfigDict(**config_dct) # pylint: disable=protected-access + + # Provide default values through API config. + api_config['mode'] = 'upscale' + api_config['number_of_images'] = 1 + + return await self._upscale_image( + model=model, + image=image, + upscale_factor=upscale_factor, + config=api_config, + ) diff --git a/.venv/lib/python3.12/site-packages/google/genai/pagers.py b/.venv/lib/python3.12/site-packages/google/genai/pagers.py new file mode 100644 index 00000000..a90465c3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/pagers.py @@ -0,0 +1,245 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Pagers for the GenAI List APIs.""" + +# pylint: disable=protected-access + +import copy +from typing import Any, AsyncIterator, Awaitable, Callable, Generic, Iterator, Literal, TypeVar + +T = TypeVar('T') + +PagedItem = Literal[ + 'batch_jobs', 'models', 'tuning_jobs', 'files', 'cached_contents' +] + + +class _BasePager(Generic[T]): + """Base pager class for iterating through paginated results.""" + + def __init__( + self, + name: PagedItem, + request: Callable[Any, Any], + response: Any, + config: Any, + ): + self._name = name + self._request = request + + self._page = getattr(response, self._name) or [] + self._idx = 0 + + if not config: + request_config = {} + elif isinstance(config, dict): + request_config = copy.deepcopy(config) + else: + request_config = dict(config) + request_config['page_token'] = getattr(response, 'next_page_token') + self._config = request_config + + self._page_size = request_config.get('page_size', len(self._page)) + + @property + def page(self) -> list[T]: + """Returns the current page, which is a list of items. + + The returned list of items is a subset of the entire list. + + Usage: + + .. code-block:: python + + batch_jobs_pager = client.batches.list(config={'page_size': 5}) + print(f"first page: {batch_jobs_pager.page}") + # first page: [BatchJob(name='projects/./locations/./batchPredictionJobs/1 + """ + + return self._page + + @property + def name(self) -> str: + """Returns the type of paged item (for example, ``batch_jobs``). + + Usage: + + .. code-block:: python + + batch_jobs_pager = client.batches.list(config={'page_size': 5}) + print(f"name: {batch_jobs_pager.name}") + # name: batch_jobs + """ + + return self._name + + @property + def page_size(self) -> int: + """Returns the length of the page fetched each time by this pager. + + The number of items in the page is less than or equal to the page length. + + Usage: + + .. code-block:: python + + batch_jobs_pager = client.batches.list(config={'page_size': 5}) + print(f"page_size: {batch_jobs_pager.page_size}") + # page_size: 5 + """ + + return self._page_size + + @property + def config(self) -> dict[str, Any]: + """Returns the configuration when making the API request for the next page. + + A configuration is a set of optional parameters and arguments that can be + used to customize the API request. For example, the ``page_token`` parameter + contains the token to request the next page. + + Usage: + + .. code-block:: python + + batch_jobs_pager = client.batches.list(config={'page_size': 5}) + print(f"config: {batch_jobs_pager.config}") + # config: {'page_size': 5, 'page_token': 'AMEw9yO5jnsGnZJLHSKDFHJJu'} + """ + + return self._config + + def __len__(self) -> int: + """Returns the total number of items in the current page.""" + return len(self.page) + + def __getitem__(self, index: int) -> T: + """Returns the item at the given index.""" + return self.page[index] + + def _init_next_page(self, response: Any) -> None: + """Initializes the next page from the response. + + This is an internal method that should be called by subclasses after + fetching the next page. + + Args: + response: The response object from the API request. + """ + self.__init__(self.name, self._request, response, self.config) + + +class Pager(_BasePager[T]): + """Pager class for iterating through paginated results.""" + + def __next__(self) -> T: + """Returns the next item.""" + if self._idx >= len(self): + try: + self.next_page() + except IndexError: + raise StopIteration + + item = self.page[self._idx] + self._idx += 1 + return item + + def __iter__(self) -> Iterator[T]: + """Returns an iterator over the items.""" + self._idx = 0 + return self + + def next_page(self) -> list[T]: + """Fetches the next page of items. This makes a new API request. + + Usage: + + .. code-block:: python + + batch_jobs_pager = client.batches.list(config={'page_size': 5}) + print(f"current page: {batch_jobs_pager.page}") + batch_jobs_pager.next_page() + print(f"next page: {batch_jobs_pager.page}") + # current page: [BatchJob(name='projects/.../batchPredictionJobs/1 + # next page: [BatchJob(name='projects/.../batchPredictionJobs/6 + """ + + if not self.config.get('page_token'): + raise IndexError('No more pages to fetch.') + + response = self._request(config=self.config) + self._init_next_page(response) + return self.page + + +class AsyncPager(_BasePager[T]): + """AsyncPager class for iterating through paginated results.""" + + def __init__( + self, + name: PagedItem, + request: Callable[Any, Awaitable[Any]], + response: Any, + config: Any, + ): + super().__init__(name, request, response, config) + + def __aiter__(self) -> AsyncIterator[T]: + """Returns an async iterator over the items.""" + self._idx = 0 + return self + + async def __anext__(self) -> Awaitable[T]: + """Returns the next item asynchronously.""" + if self._idx >= len(self): + try: + await self.next_page() + except IndexError: + raise StopAsyncIteration + + item = self.page[self._idx] + self._idx += 1 + return item + + async def next_page(self) -> list[T]: + """Fetches the next page of items asynchronously. + + This makes a new API request. + + Returns: + The next page of items. + + Raises: + IndexError: No more pages to fetch. + + Usage: + + .. code-block:: python + + batch_jobs_pager = await client.aio.batches.list(config={'page_size': 5}) + print(f"current page: {batch_jobs_pager.page}") + await batch_jobs_pager.next_page() + print(f"next page: {batch_jobs_pager.page}") + # current page: [BatchJob(name='projects/.../batchPredictionJobs/1 + # next page: [BatchJob(name='projects/.../batchPredictionJobs/6 + """ + + if not self.config.get('page_token'): + raise IndexError('No more pages to fetch.') + + response = await self._request(config=self.config) + self._init_next_page(response) + return self.page diff --git a/.venv/lib/python3.12/site-packages/google/genai/tunings.py b/.venv/lib/python3.12/site-packages/google/genai/tunings.py new file mode 100644 index 00000000..a215a195 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/tunings.py @@ -0,0 +1,1681 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +from typing import Optional, Union +from urllib.parse import urlencode +from . import _common +from . import _transformers as t +from . import types +from ._api_client import ApiClient +from ._common import get_value_by_path as getv +from ._common import set_value_by_path as setv +from .pagers import AsyncPager, Pager + + +def _GetTuningJobConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _GetTuningJobConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + return to_object + + +def _GetTuningJobParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['_url', 'name'], getv(from_object, ['name'])) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _GetTuningJobConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _GetTuningJobParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['_url', 'name'], getv(from_object, ['name'])) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _GetTuningJobConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _ListTuningJobsConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['page_size']) is not None: + setv( + parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size']) + ) + + if getv(from_object, ['page_token']) is not None: + setv( + parent_object, + ['_query', 'pageToken'], + getv(from_object, ['page_token']), + ) + + if getv(from_object, ['filter']) is not None: + setv(parent_object, ['_query', 'filter'], getv(from_object, ['filter'])) + + return to_object + + +def _ListTuningJobsConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['page_size']) is not None: + setv( + parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size']) + ) + + if getv(from_object, ['page_token']) is not None: + setv( + parent_object, + ['_query', 'pageToken'], + getv(from_object, ['page_token']), + ) + + if getv(from_object, ['filter']) is not None: + setv(parent_object, ['_query', 'filter'], getv(from_object, ['filter'])) + + return to_object + + +def _ListTuningJobsParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _ListTuningJobsConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _ListTuningJobsParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _ListTuningJobsConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _TuningExample_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['text_input']) is not None: + setv(to_object, ['textInput'], getv(from_object, ['text_input'])) + + if getv(from_object, ['output']) is not None: + setv(to_object, ['output'], getv(from_object, ['output'])) + + return to_object + + +def _TuningExample_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['text_input']) is not None: + raise ValueError('text_input parameter is not supported in Vertex AI.') + + if getv(from_object, ['output']) is not None: + raise ValueError('output parameter is not supported in Vertex AI.') + + return to_object + + +def _TuningDataset_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['gcs_uri']) is not None: + raise ValueError('gcs_uri parameter is not supported in Google AI.') + + if getv(from_object, ['examples']) is not None: + setv( + to_object, + ['examples', 'examples'], + [ + _TuningExample_to_mldev(api_client, item, to_object) + for item in getv(from_object, ['examples']) + ], + ) + + return to_object + + +def _TuningDataset_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['gcs_uri']) is not None: + setv( + parent_object, + ['supervisedTuningSpec', 'trainingDatasetUri'], + getv(from_object, ['gcs_uri']), + ) + + if getv(from_object, ['examples']) is not None: + raise ValueError('examples parameter is not supported in Vertex AI.') + + return to_object + + +def _TuningValidationDataset_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['gcs_uri']) is not None: + raise ValueError('gcs_uri parameter is not supported in Google AI.') + + return to_object + + +def _TuningValidationDataset_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['gcs_uri']) is not None: + setv(to_object, ['validationDatasetUri'], getv(from_object, ['gcs_uri'])) + + return to_object + + +def _CreateTuningJobConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['validation_dataset']) is not None: + raise ValueError( + 'validation_dataset parameter is not supported in Google AI.' + ) + + if getv(from_object, ['tuned_model_display_name']) is not None: + setv( + parent_object, + ['displayName'], + getv(from_object, ['tuned_model_display_name']), + ) + + if getv(from_object, ['description']) is not None: + raise ValueError('description parameter is not supported in Google AI.') + + if getv(from_object, ['epoch_count']) is not None: + setv( + parent_object, + ['tuningTask', 'hyperparameters', 'epochCount'], + getv(from_object, ['epoch_count']), + ) + + if getv(from_object, ['learning_rate_multiplier']) is not None: + setv( + to_object, + ['tuningTask', 'hyperparameters', 'learningRateMultiplier'], + getv(from_object, ['learning_rate_multiplier']), + ) + + if getv(from_object, ['adapter_size']) is not None: + raise ValueError('adapter_size parameter is not supported in Google AI.') + + if getv(from_object, ['batch_size']) is not None: + setv( + parent_object, + ['tuningTask', 'hyperparameters', 'batchSize'], + getv(from_object, ['batch_size']), + ) + + if getv(from_object, ['learning_rate']) is not None: + setv( + parent_object, + ['tuningTask', 'hyperparameters', 'learningRate'], + getv(from_object, ['learning_rate']), + ) + + return to_object + + +def _CreateTuningJobConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['validation_dataset']) is not None: + setv( + parent_object, + ['supervisedTuningSpec'], + _TuningValidationDataset_to_vertex( + api_client, getv(from_object, ['validation_dataset']), to_object + ), + ) + + if getv(from_object, ['tuned_model_display_name']) is not None: + setv( + parent_object, + ['tunedModelDisplayName'], + getv(from_object, ['tuned_model_display_name']), + ) + + if getv(from_object, ['description']) is not None: + setv(parent_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['epoch_count']) is not None: + setv( + parent_object, + ['supervisedTuningSpec', 'hyperParameters', 'epochCount'], + getv(from_object, ['epoch_count']), + ) + + if getv(from_object, ['learning_rate_multiplier']) is not None: + setv( + to_object, + ['supervisedTuningSpec', 'hyperParameters', 'learningRateMultiplier'], + getv(from_object, ['learning_rate_multiplier']), + ) + + if getv(from_object, ['adapter_size']) is not None: + setv( + parent_object, + ['supervisedTuningSpec', 'hyperParameters', 'adapterSize'], + getv(from_object, ['adapter_size']), + ) + + if getv(from_object, ['batch_size']) is not None: + raise ValueError('batch_size parameter is not supported in Vertex AI.') + + if getv(from_object, ['learning_rate']) is not None: + raise ValueError('learning_rate parameter is not supported in Vertex AI.') + + return to_object + + +def _CreateTuningJobParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['base_model']) is not None: + setv(to_object, ['baseModel'], getv(from_object, ['base_model'])) + + if getv(from_object, ['training_dataset']) is not None: + setv( + to_object, + ['tuningTask', 'trainingData'], + _TuningDataset_to_mldev( + api_client, getv(from_object, ['training_dataset']), to_object + ), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CreateTuningJobConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _CreateTuningJobParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['base_model']) is not None: + setv(to_object, ['baseModel'], getv(from_object, ['base_model'])) + + if getv(from_object, ['training_dataset']) is not None: + setv( + to_object, + ['supervisedTuningSpec', 'trainingDatasetUri'], + _TuningDataset_to_vertex( + api_client, getv(from_object, ['training_dataset']), to_object + ), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CreateTuningJobConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _DistillationDataset_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['gcs_uri']) is not None: + raise ValueError('gcs_uri parameter is not supported in Google AI.') + + return to_object + + +def _DistillationDataset_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['gcs_uri']) is not None: + setv( + parent_object, + ['distillationSpec', 'trainingDatasetUri'], + getv(from_object, ['gcs_uri']), + ) + + return to_object + + +def _DistillationValidationDataset_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['gcs_uri']) is not None: + raise ValueError('gcs_uri parameter is not supported in Google AI.') + + return to_object + + +def _DistillationValidationDataset_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['gcs_uri']) is not None: + setv(to_object, ['validationDatasetUri'], getv(from_object, ['gcs_uri'])) + + return to_object + + +def _CreateDistillationJobConfig_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['validation_dataset']) is not None: + raise ValueError( + 'validation_dataset parameter is not supported in Google AI.' + ) + + if getv(from_object, ['tuned_model_display_name']) is not None: + setv( + parent_object, + ['displayName'], + getv(from_object, ['tuned_model_display_name']), + ) + + if getv(from_object, ['epoch_count']) is not None: + setv( + parent_object, + ['tuningTask', 'hyperparameters', 'epochCount'], + getv(from_object, ['epoch_count']), + ) + + if getv(from_object, ['learning_rate_multiplier']) is not None: + setv( + parent_object, + ['tuningTask', 'hyperparameters', 'learningRateMultiplier'], + getv(from_object, ['learning_rate_multiplier']), + ) + + if getv(from_object, ['adapter_size']) is not None: + raise ValueError('adapter_size parameter is not supported in Google AI.') + + if getv(from_object, ['pipeline_root_directory']) is not None: + raise ValueError( + 'pipeline_root_directory parameter is not supported in Google AI.' + ) + + return to_object + + +def _CreateDistillationJobConfig_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['http_options']) is not None: + setv(to_object, ['httpOptions'], getv(from_object, ['http_options'])) + + if getv(from_object, ['validation_dataset']) is not None: + setv( + parent_object, + ['distillationSpec'], + _DistillationValidationDataset_to_vertex( + api_client, getv(from_object, ['validation_dataset']), to_object + ), + ) + + if getv(from_object, ['tuned_model_display_name']) is not None: + setv( + parent_object, + ['tunedModelDisplayName'], + getv(from_object, ['tuned_model_display_name']), + ) + + if getv(from_object, ['epoch_count']) is not None: + setv( + parent_object, + ['distillationSpec', 'hyperParameters', 'epochCount'], + getv(from_object, ['epoch_count']), + ) + + if getv(from_object, ['learning_rate_multiplier']) is not None: + setv( + parent_object, + ['distillationSpec', 'hyperParameters', 'learningRateMultiplier'], + getv(from_object, ['learning_rate_multiplier']), + ) + + if getv(from_object, ['adapter_size']) is not None: + setv( + parent_object, + ['distillationSpec', 'hyperParameters', 'adapterSize'], + getv(from_object, ['adapter_size']), + ) + + if getv(from_object, ['pipeline_root_directory']) is not None: + setv( + parent_object, + ['distillationSpec', 'pipelineRootDirectory'], + getv(from_object, ['pipeline_root_directory']), + ) + + return to_object + + +def _CreateDistillationJobParameters_to_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['student_model']) is not None: + raise ValueError('student_model parameter is not supported in Google AI.') + + if getv(from_object, ['teacher_model']) is not None: + raise ValueError('teacher_model parameter is not supported in Google AI.') + + if getv(from_object, ['training_dataset']) is not None: + setv( + to_object, + ['tuningTask', 'trainingData'], + _DistillationDataset_to_mldev( + api_client, getv(from_object, ['training_dataset']), to_object + ), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CreateDistillationJobConfig_to_mldev( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _CreateDistillationJobParameters_to_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['student_model']) is not None: + setv( + to_object, + ['distillationSpec', 'studentModel'], + getv(from_object, ['student_model']), + ) + + if getv(from_object, ['teacher_model']) is not None: + setv( + to_object, + ['distillationSpec', 'baseTeacherModel'], + getv(from_object, ['teacher_model']), + ) + + if getv(from_object, ['training_dataset']) is not None: + setv( + to_object, + ['distillationSpec', 'trainingDatasetUri'], + _DistillationDataset_to_vertex( + api_client, getv(from_object, ['training_dataset']), to_object + ), + ) + + if getv(from_object, ['config']) is not None: + setv( + to_object, + ['config'], + _CreateDistillationJobConfig_to_vertex( + api_client, getv(from_object, ['config']), to_object + ), + ) + + return to_object + + +def _TunedModel_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['model'], getv(from_object, ['name'])) + + if getv(from_object, ['name']) is not None: + setv(to_object, ['endpoint'], getv(from_object, ['name'])) + + return to_object + + +def _TunedModel_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['model']) is not None: + setv(to_object, ['model'], getv(from_object, ['model'])) + + if getv(from_object, ['endpoint']) is not None: + setv(to_object, ['endpoint'], getv(from_object, ['endpoint'])) + + return to_object + + +def _TuningJob_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['state']) is not None: + setv( + to_object, + ['state'], + t.t_tuning_job_status(api_client, getv(from_object, ['state'])), + ) + + if getv(from_object, ['createTime']) is not None: + setv(to_object, ['create_time'], getv(from_object, ['createTime'])) + + if getv(from_object, ['tuningTask', 'startTime']) is not None: + setv( + to_object, + ['start_time'], + getv(from_object, ['tuningTask', 'startTime']), + ) + + if getv(from_object, ['tuningTask', 'completeTime']) is not None: + setv( + to_object, + ['end_time'], + getv(from_object, ['tuningTask', 'completeTime']), + ) + + if getv(from_object, ['updateTime']) is not None: + setv(to_object, ['update_time'], getv(from_object, ['updateTime'])) + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['baseModel']) is not None: + setv(to_object, ['base_model'], getv(from_object, ['baseModel'])) + + if getv(from_object, ['_self']) is not None: + setv( + to_object, + ['tuned_model'], + _TunedModel_from_mldev( + api_client, getv(from_object, ['_self']), to_object + ), + ) + + if getv(from_object, ['experiment']) is not None: + setv(to_object, ['experiment'], getv(from_object, ['experiment'])) + + if getv(from_object, ['labels']) is not None: + setv(to_object, ['labels'], getv(from_object, ['labels'])) + + if getv(from_object, ['tunedModelDisplayName']) is not None: + setv( + to_object, + ['tuned_model_display_name'], + getv(from_object, ['tunedModelDisplayName']), + ) + + return to_object + + +def _TuningJob_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['name']) is not None: + setv(to_object, ['name'], getv(from_object, ['name'])) + + if getv(from_object, ['state']) is not None: + setv( + to_object, + ['state'], + t.t_tuning_job_status(api_client, getv(from_object, ['state'])), + ) + + if getv(from_object, ['createTime']) is not None: + setv(to_object, ['create_time'], getv(from_object, ['createTime'])) + + if getv(from_object, ['startTime']) is not None: + setv(to_object, ['start_time'], getv(from_object, ['startTime'])) + + if getv(from_object, ['endTime']) is not None: + setv(to_object, ['end_time'], getv(from_object, ['endTime'])) + + if getv(from_object, ['updateTime']) is not None: + setv(to_object, ['update_time'], getv(from_object, ['updateTime'])) + + if getv(from_object, ['error']) is not None: + setv(to_object, ['error'], getv(from_object, ['error'])) + + if getv(from_object, ['description']) is not None: + setv(to_object, ['description'], getv(from_object, ['description'])) + + if getv(from_object, ['baseModel']) is not None: + setv(to_object, ['base_model'], getv(from_object, ['baseModel'])) + + if getv(from_object, ['tunedModel']) is not None: + setv( + to_object, + ['tuned_model'], + _TunedModel_from_vertex( + api_client, getv(from_object, ['tunedModel']), to_object + ), + ) + + if getv(from_object, ['supervisedTuningSpec']) is not None: + setv( + to_object, + ['supervised_tuning_spec'], + getv(from_object, ['supervisedTuningSpec']), + ) + + if getv(from_object, ['tuningDataStats']) is not None: + setv( + to_object, ['tuning_data_stats'], getv(from_object, ['tuningDataStats']) + ) + + if getv(from_object, ['encryptionSpec']) is not None: + setv(to_object, ['encryption_spec'], getv(from_object, ['encryptionSpec'])) + + if getv(from_object, ['distillationSpec']) is not None: + setv( + to_object, + ['distillation_spec'], + getv(from_object, ['distillationSpec']), + ) + + if getv(from_object, ['partnerModelTuningSpec']) is not None: + setv( + to_object, + ['partner_model_tuning_spec'], + getv(from_object, ['partnerModelTuningSpec']), + ) + + if getv(from_object, ['pipelineJob']) is not None: + setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob'])) + + if getv(from_object, ['experiment']) is not None: + setv(to_object, ['experiment'], getv(from_object, ['experiment'])) + + if getv(from_object, ['labels']) is not None: + setv(to_object, ['labels'], getv(from_object, ['labels'])) + + if getv(from_object, ['tunedModelDisplayName']) is not None: + setv( + to_object, + ['tuned_model_display_name'], + getv(from_object, ['tunedModelDisplayName']), + ) + + return to_object + + +def _ListTuningJobsResponse_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['nextPageToken']) is not None: + setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken'])) + + if getv(from_object, ['tunedModels']) is not None: + setv( + to_object, + ['tuning_jobs'], + [ + _TuningJob_from_mldev(api_client, item, to_object) + for item in getv(from_object, ['tunedModels']) + ], + ) + + return to_object + + +def _ListTuningJobsResponse_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['nextPageToken']) is not None: + setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken'])) + + if getv(from_object, ['tuningJobs']) is not None: + setv( + to_object, + ['tuning_jobs'], + [ + _TuningJob_from_vertex(api_client, item, to_object) + for item in getv(from_object, ['tuningJobs']) + ], + ) + + return to_object + + +def _TuningJobOrOperation_from_mldev( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['_self']) is not None: + setv( + to_object, + ['tuning_job'], + _TuningJob_from_mldev( + api_client, + t.t_resolve_operation(api_client, getv(from_object, ['_self'])), + to_object, + ), + ) + + return to_object + + +def _TuningJobOrOperation_from_vertex( + api_client: ApiClient, + from_object: Union[dict, object], + parent_object: dict = None, +) -> dict: + to_object = {} + if getv(from_object, ['_self']) is not None: + setv( + to_object, + ['tuning_job'], + _TuningJob_from_vertex( + api_client, + t.t_resolve_operation(api_client, getv(from_object, ['_self'])), + to_object, + ), + ) + + return to_object + + +class Tunings(_common.BaseModule): + + def _get( + self, + *, + name: str, + config: Optional[types.GetTuningJobConfigOrDict] = None, + ) -> types.TuningJob: + """Gets a TuningJob. + + Args: + name: The resource name of the tuning job. + + Returns: + A TuningJob object. + """ + + parameter_model = types._GetTuningJobParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _GetTuningJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _GetTuningJobParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _TuningJob_from_vertex(self._api_client, response_dict) + else: + response_dict = _TuningJob_from_mldev(self._api_client, response_dict) + + return_value = types.TuningJob._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None + ) -> types.ListTuningJobsResponse: + """Lists tuning jobs. + + Args: + config: The configuration for the list request. + + Returns: + A list of tuning jobs. + """ + + parameter_model = types._ListTuningJobsParameters( + config=config, + ) + + if self._api_client.vertexai: + request_dict = _ListTuningJobsParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'tuningJobs'.format_map(request_dict.get('_url')) + else: + request_dict = _ListTuningJobsParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'tunedModels'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ListTuningJobsResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ListTuningJobsResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ListTuningJobsResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def _tune( + self, + *, + base_model: str, + training_dataset: types.TuningDatasetOrDict, + config: Optional[types.CreateTuningJobConfigOrDict] = None, + ) -> types.TuningJobOrOperation: + """Creates a supervised fine-tuning job. + + Args: + base_model: The name of the model to tune. + training_dataset: The training dataset to use. + config: The configuration to use for the tuning job. + + Returns: + A TuningJob object. + """ + + parameter_model = types._CreateTuningJobParameters( + base_model=base_model, + training_dataset=training_dataset, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _CreateTuningJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'tuningJobs'.format_map(request_dict.get('_url')) + else: + request_dict = _CreateTuningJobParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'tunedModels'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _TuningJobOrOperation_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _TuningJobOrOperation_from_mldev( + self._api_client, response_dict + ) + + return_value = types.TuningJobOrOperation._from_response( + response_dict, parameter_model + ).tuning_job + self._api_client._verify_response(return_value) + return return_value + + def distill( + self, + *, + student_model: str, + teacher_model: str, + training_dataset: types.DistillationDatasetOrDict, + config: Optional[types.CreateDistillationJobConfigOrDict] = None, + ) -> types.TuningJob: + """Creates a distillation job. + + Args: + student_model: The name of the model to tune. + teacher_model: The name of the model to distill from. + training_dataset: The training dataset to use. + config: The configuration to use for the distillation job. + + Returns: + A TuningJob object. + """ + + parameter_model = types._CreateDistillationJobParameters( + student_model=student_model, + teacher_model=teacher_model, + training_dataset=training_dataset, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _CreateDistillationJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'tuningJobs'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _TuningJob_from_vertex(self._api_client, response_dict) + else: + response_dict = _TuningJob_from_mldev(self._api_client, response_dict) + + return_value = types.TuningJob._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + def list( + self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None + ) -> Pager[types.TuningJob]: + return Pager( + 'tuning_jobs', + self._list, + self._list(config=config), + config, + ) + + def get( + self, + *, + name: str, + config: Optional[types.GetTuningJobConfigOrDict] = None, + ) -> types.TuningJob: + job = self._get(name=name, config=config) + if job.experiment and self._api_client.vertexai: + _IpythonUtils.display_experiment_button( + experiment=job.experiment, + project=self._api_client.project, + ) + return job + + def tune( + self, + *, + base_model: str, + training_dataset: types.TuningDatasetOrDict, + config: Optional[types.CreateTuningJobConfigOrDict] = None, + ) -> types.TuningJobOrOperation: + result = self._tune( + base_model=base_model, + training_dataset=training_dataset, + config=config, + ) + if result.name and self._api_client.vertexai: + _IpythonUtils.display_model_tuning_button(tuning_job_resource=result.name) + return result + + +class AsyncTunings(_common.BaseModule): + + async def _get( + self, + *, + name: str, + config: Optional[types.GetTuningJobConfigOrDict] = None, + ) -> types.TuningJob: + """Gets a TuningJob. + + Args: + name: The resource name of the tuning job. + + Returns: + A TuningJob object. + """ + + parameter_model = types._GetTuningJobParameters( + name=name, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _GetTuningJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + else: + request_dict = _GetTuningJobParameters_to_mldev( + self._api_client, parameter_model + ) + path = '{name}'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _TuningJob_from_vertex(self._api_client, response_dict) + else: + response_dict = _TuningJob_from_mldev(self._api_client, response_dict) + + return_value = types.TuningJob._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None + ) -> types.ListTuningJobsResponse: + """Lists tuning jobs. + + Args: + config: The configuration for the list request. + + Returns: + A list of tuning jobs. + """ + + parameter_model = types._ListTuningJobsParameters( + config=config, + ) + + if self._api_client.vertexai: + request_dict = _ListTuningJobsParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'tuningJobs'.format_map(request_dict.get('_url')) + else: + request_dict = _ListTuningJobsParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'tunedModels'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'get', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _ListTuningJobsResponse_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _ListTuningJobsResponse_from_mldev( + self._api_client, response_dict + ) + + return_value = types.ListTuningJobsResponse._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def _tune( + self, + *, + base_model: str, + training_dataset: types.TuningDatasetOrDict, + config: Optional[types.CreateTuningJobConfigOrDict] = None, + ) -> types.TuningJobOrOperation: + """Creates a supervised fine-tuning job. + + Args: + base_model: The name of the model to tune. + training_dataset: The training dataset to use. + config: The configuration to use for the tuning job. + + Returns: + A TuningJob object. + """ + + parameter_model = types._CreateTuningJobParameters( + base_model=base_model, + training_dataset=training_dataset, + config=config, + ) + + if self._api_client.vertexai: + request_dict = _CreateTuningJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'tuningJobs'.format_map(request_dict.get('_url')) + else: + request_dict = _CreateTuningJobParameters_to_mldev( + self._api_client, parameter_model + ) + path = 'tunedModels'.format_map(request_dict.get('_url')) + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _TuningJobOrOperation_from_vertex( + self._api_client, response_dict + ) + else: + response_dict = _TuningJobOrOperation_from_mldev( + self._api_client, response_dict + ) + + return_value = types.TuningJobOrOperation._from_response( + response_dict, parameter_model + ).tuning_job + self._api_client._verify_response(return_value) + return return_value + + async def distill( + self, + *, + student_model: str, + teacher_model: str, + training_dataset: types.DistillationDatasetOrDict, + config: Optional[types.CreateDistillationJobConfigOrDict] = None, + ) -> types.TuningJob: + """Creates a distillation job. + + Args: + student_model: The name of the model to tune. + teacher_model: The name of the model to distill from. + training_dataset: The training dataset to use. + config: The configuration to use for the distillation job. + + Returns: + A TuningJob object. + """ + + parameter_model = types._CreateDistillationJobParameters( + student_model=student_model, + teacher_model=teacher_model, + training_dataset=training_dataset, + config=config, + ) + + if not self._api_client.vertexai: + raise ValueError('This method is only supported in the Vertex AI client.') + else: + request_dict = _CreateDistillationJobParameters_to_vertex( + self._api_client, parameter_model + ) + path = 'tuningJobs'.format_map(request_dict.get('_url')) + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + config = request_dict.pop('config', None) + http_options = config.pop('httpOptions', None) if config else None + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _TuningJob_from_vertex(self._api_client, response_dict) + else: + response_dict = _TuningJob_from_mldev(self._api_client, response_dict) + + return_value = types.TuningJob._from_response( + response_dict, parameter_model + ) + self._api_client._verify_response(return_value) + return return_value + + async def list( + self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None + ) -> AsyncPager[types.TuningJob]: + return AsyncPager( + 'tuning_jobs', + self._list, + await self._list(config=config), + config, + ) + + async def get( + self, + *, + name: str, + config: Optional[types.GetTuningJobConfigOrDict] = None, + ) -> types.TuningJob: + job = await self._get(name=name, config=config) + if job.experiment and self._api_client.vertexai: + _IpythonUtils.display_experiment_button( + experiment=job.experiment, + project=self._api_client.project, + ) + return job + + async def tune( + self, + *, + base_model: str, + training_dataset: types.TuningDatasetOrDict, + config: Optional[types.CreateTuningJobConfigOrDict] = None, + ) -> types.TuningJobOrOperation: + result = await self._tune( + base_model=base_model, + training_dataset=training_dataset, + config=config, + ) + if result.name and self._api_client.vertexai: + _IpythonUtils.display_model_tuning_button(tuning_job_resource=result.name) + return result + + +class _IpythonUtils: + """Temporary class to hold the IPython related functions.""" + + displayed_experiments = set() + + @staticmethod + def _get_ipython_shell_name() -> str: + import sys + + if 'IPython' in sys.modules: + from IPython import get_ipython + + return get_ipython().__class__.__name__ + return '' + + @staticmethod + def is_ipython_available() -> bool: + return bool(_IpythonUtils._get_ipython_shell_name()) + + @staticmethod + def _get_styles() -> None: + """Returns the HTML style markup to support custom buttons.""" + return """ + <link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons"> + <style> + .view-vertex-resource, + .view-vertex-resource:hover, + .view-vertex-resource:visited { + position: relative; + display: inline-flex; + flex-direction: row; + height: 32px; + padding: 0 12px; + margin: 4px 18px; + gap: 4px; + border-radius: 4px; + + align-items: center; + justify-content: center; + background-color: rgb(255, 255, 255); + color: rgb(51, 103, 214); + + font-family: Roboto,"Helvetica Neue",sans-serif; + font-size: 13px; + font-weight: 500; + text-transform: uppercase; + text-decoration: none !important; + + transition: box-shadow 280ms cubic-bezier(0.4, 0, 0.2, 1) 0s; + box-shadow: 0px 3px 1px -2px rgba(0,0,0,0.2), 0px 2px 2px 0px rgba(0,0,0,0.14), 0px 1px 5px 0px rgba(0,0,0,0.12); + } + .view-vertex-resource:active { + box-shadow: 0px 5px 5px -3px rgba(0,0,0,0.2),0px 8px 10px 1px rgba(0,0,0,0.14),0px 3px 14px 2px rgba(0,0,0,0.12); + } + .view-vertex-resource:active .view-vertex-ripple::before { + position: absolute; + top: 0; + bottom: 0; + left: 0; + right: 0; + border-radius: 4px; + pointer-events: none; + + content: ''; + background-color: rgb(51, 103, 214); + opacity: 0.12; + } + .view-vertex-icon { + font-size: 18px; + } + </style> + """ + + @staticmethod + def _parse_resource_name(marker: str, resource_parts: list[str]) -> str: + """Returns the part after the marker text part.""" + for i in range(len(resource_parts)): + if resource_parts[i] == marker and i + 1 < len(resource_parts): + return resource_parts[i + 1] + return '' + + @staticmethod + def _display_link( + text: str, url: str, icon: Optional[str] = 'open_in_new' + ) -> None: + """Creates and displays the link to open the Vertex resource. + + Args: + text: The text displayed on the clickable button. + url: The url that the button will lead to. Only cloud console URIs are + allowed. + icon: The icon name on the button (from material-icons library) + """ + CLOUD_UI_URL = 'https://console.cloud.google.com' # pylint: disable=invalid-name + if not url.startswith(CLOUD_UI_URL): + raise ValueError(f'Only urls starting with {CLOUD_UI_URL} are allowed.') + + import uuid + + button_id = f'view-vertex-resource-{str(uuid.uuid4())}' + + # Add the markup for the CSS and link component + html = f""" + {_IpythonUtils._get_styles()} + <a class="view-vertex-resource" id="{button_id}" href="#view-{button_id}"> + <span class="material-icons view-vertex-icon">{icon}</span> + <span>{text}</span> + </a> + """ + + # Add the click handler for the link + html += f""" + <script> + (function () {{ + const link = document.getElementById('{button_id}'); + link.addEventListener('click', (e) => {{ + if (window.google?.colab?.openUrl) {{ + window.google.colab.openUrl('{url}'); + }} else {{ + window.open('{url}', '_blank'); + }} + e.stopPropagation(); + e.preventDefault(); + }}); + }})(); + </script> + """ + + from IPython.core.display import display + from IPython.display import HTML + + display(HTML(html)) + + @staticmethod + def display_experiment_button(experiment: str, project: str) -> None: + """Function to generate a link bound to the Vertex experiment. + + Args: + experiment: The Vertex experiment name. Example format: + projects/{project_id}/locations/{location}/metadataStores/default/contexts/{experiment_name} + project: The project (alphanumeric) name. + """ + if ( + not _IpythonUtils.is_ipython_available() + or experiment in _IpythonUtils.displayed_experiments + ): + return + # Experiment gives the numeric id, but we need the alphanumeric project + # name. So we get the project from the api client object as an argument. + resource_parts = experiment.split('/') + location = resource_parts[3] + experiment_name = resource_parts[-1] + + uri = ( + 'https://console.cloud.google.com/vertex-ai/experiments/locations/' + + f'{location}/experiments/{experiment_name}/' + + f'runs?project={project}' + ) + _IpythonUtils._display_link('View Experiment', uri, 'science') + + # Avoid repeatedly showing the button + _IpythonUtils.displayed_experiments.add(experiment) + + @staticmethod + def display_model_tuning_button(tuning_job_resource: str) -> None: + """Function to generate a link bound to the Vertex model tuning job. + + Args: + tuning_job_resource: The Vertex tuning job name. Example format: + projects/{project_id}/locations/{location}/tuningJobs/{tuning_job_id} + """ + if not _IpythonUtils.is_ipython_available(): + return + + resource_parts = tuning_job_resource.split('/') + project = resource_parts[1] + location = resource_parts[3] + tuning_job_id = resource_parts[-1] + + uri = ( + 'https://console.cloud.google.com/vertex-ai/generative/language/' + + f'locations/{location}/tuning/tuningJob/{tuning_job_id}' + + f'?project={project}' + ) + _IpythonUtils._display_link('View Tuning Job', uri, 'tune') diff --git a/.venv/lib/python3.12/site-packages/google/genai/types.py b/.venv/lib/python3.12/site-packages/google/genai/types.py new file mode 100644 index 00000000..e8cb4620 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/types.py @@ -0,0 +1,8332 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import datetime +import inspect +import json +import logging +from typing import Any, Callable, GenericAlias, Literal, Optional, TypedDict, Union +import PIL.Image +import pydantic +from pydantic import Field +from . import _common + + +class Outcome(_common.CaseInSensitiveEnum): + """Required. Outcome of the code execution.""" + + OUTCOME_UNSPECIFIED = 'OUTCOME_UNSPECIFIED' + OUTCOME_OK = 'OUTCOME_OK' + OUTCOME_FAILED = 'OUTCOME_FAILED' + OUTCOME_DEADLINE_EXCEEDED = 'OUTCOME_DEADLINE_EXCEEDED' + + +class Language(_common.CaseInSensitiveEnum): + """Required. Programming language of the `code`.""" + + LANGUAGE_UNSPECIFIED = 'LANGUAGE_UNSPECIFIED' + PYTHON = 'PYTHON' + + +class Type(_common.CaseInSensitiveEnum): + """A basic data type.""" + + TYPE_UNSPECIFIED = 'TYPE_UNSPECIFIED' + STRING = 'STRING' + NUMBER = 'NUMBER' + INTEGER = 'INTEGER' + BOOLEAN = 'BOOLEAN' + ARRAY = 'ARRAY' + OBJECT = 'OBJECT' + + +class HarmCategory(_common.CaseInSensitiveEnum): + """Required. Harm category.""" + + HARM_CATEGORY_UNSPECIFIED = 'HARM_CATEGORY_UNSPECIFIED' + HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH' + HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT' + HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT' + HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT' + HARM_CATEGORY_CIVIC_INTEGRITY = 'HARM_CATEGORY_CIVIC_INTEGRITY' + + +class HarmBlockMethod(_common.CaseInSensitiveEnum): + """Optional. + + Specify if the threshold is used for probability or severity score. If not + specified, the threshold is used for probability score. + """ + + HARM_BLOCK_METHOD_UNSPECIFIED = 'HARM_BLOCK_METHOD_UNSPECIFIED' + SEVERITY = 'SEVERITY' + PROBABILITY = 'PROBABILITY' + + +class HarmBlockThreshold(_common.CaseInSensitiveEnum): + """Required. The harm block threshold.""" + + HARM_BLOCK_THRESHOLD_UNSPECIFIED = 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' + BLOCK_LOW_AND_ABOVE = 'BLOCK_LOW_AND_ABOVE' + BLOCK_MEDIUM_AND_ABOVE = 'BLOCK_MEDIUM_AND_ABOVE' + BLOCK_ONLY_HIGH = 'BLOCK_ONLY_HIGH' + BLOCK_NONE = 'BLOCK_NONE' + OFF = 'OFF' + + +class Mode(_common.CaseInSensitiveEnum): + """The mode of the predictor to be used in dynamic retrieval.""" + + MODE_UNSPECIFIED = 'MODE_UNSPECIFIED' + MODE_DYNAMIC = 'MODE_DYNAMIC' + + +class State(_common.CaseInSensitiveEnum): + """Output only. RagFile state.""" + + STATE_UNSPECIFIED = 'STATE_UNSPECIFIED' + ACTIVE = 'ACTIVE' + ERROR = 'ERROR' + + +class FinishReason(_common.CaseInSensitiveEnum): + """Output only. + + The reason why the model stopped generating tokens. If empty, the model has + not stopped generating the tokens. + """ + + FINISH_REASON_UNSPECIFIED = 'FINISH_REASON_UNSPECIFIED' + STOP = 'STOP' + MAX_TOKENS = 'MAX_TOKENS' + SAFETY = 'SAFETY' + RECITATION = 'RECITATION' + OTHER = 'OTHER' + BLOCKLIST = 'BLOCKLIST' + PROHIBITED_CONTENT = 'PROHIBITED_CONTENT' + SPII = 'SPII' + MALFORMED_FUNCTION_CALL = 'MALFORMED_FUNCTION_CALL' + + +class HarmProbability(_common.CaseInSensitiveEnum): + """Output only. Harm probability levels in the content.""" + + HARM_PROBABILITY_UNSPECIFIED = 'HARM_PROBABILITY_UNSPECIFIED' + NEGLIGIBLE = 'NEGLIGIBLE' + LOW = 'LOW' + MEDIUM = 'MEDIUM' + HIGH = 'HIGH' + + +class HarmSeverity(_common.CaseInSensitiveEnum): + """Output only. Harm severity levels in the content.""" + + HARM_SEVERITY_UNSPECIFIED = 'HARM_SEVERITY_UNSPECIFIED' + HARM_SEVERITY_NEGLIGIBLE = 'HARM_SEVERITY_NEGLIGIBLE' + HARM_SEVERITY_LOW = 'HARM_SEVERITY_LOW' + HARM_SEVERITY_MEDIUM = 'HARM_SEVERITY_MEDIUM' + HARM_SEVERITY_HIGH = 'HARM_SEVERITY_HIGH' + + +class BlockedReason(_common.CaseInSensitiveEnum): + """Output only. Blocked reason.""" + + BLOCKED_REASON_UNSPECIFIED = 'BLOCKED_REASON_UNSPECIFIED' + SAFETY = 'SAFETY' + OTHER = 'OTHER' + BLOCKLIST = 'BLOCKLIST' + PROHIBITED_CONTENT = 'PROHIBITED_CONTENT' + + +class DeploymentResourcesType(_common.CaseInSensitiveEnum): + """""" + + DEPLOYMENT_RESOURCES_TYPE_UNSPECIFIED = ( + 'DEPLOYMENT_RESOURCES_TYPE_UNSPECIFIED' + ) + DEDICATED_RESOURCES = 'DEDICATED_RESOURCES' + AUTOMATIC_RESOURCES = 'AUTOMATIC_RESOURCES' + SHARED_RESOURCES = 'SHARED_RESOURCES' + + +class JobState(_common.CaseInSensitiveEnum): + """Config class for the job state.""" + + JOB_STATE_UNSPECIFIED = 'JOB_STATE_UNSPECIFIED' + JOB_STATE_QUEUED = 'JOB_STATE_QUEUED' + JOB_STATE_PENDING = 'JOB_STATE_PENDING' + JOB_STATE_RUNNING = 'JOB_STATE_RUNNING' + JOB_STATE_SUCCEEDED = 'JOB_STATE_SUCCEEDED' + JOB_STATE_FAILED = 'JOB_STATE_FAILED' + JOB_STATE_CANCELLING = 'JOB_STATE_CANCELLING' + JOB_STATE_CANCELLED = 'JOB_STATE_CANCELLED' + JOB_STATE_PAUSED = 'JOB_STATE_PAUSED' + JOB_STATE_EXPIRED = 'JOB_STATE_EXPIRED' + JOB_STATE_UPDATING = 'JOB_STATE_UPDATING' + JOB_STATE_PARTIALLY_SUCCEEDED = 'JOB_STATE_PARTIALLY_SUCCEEDED' + + +class AdapterSize(_common.CaseInSensitiveEnum): + """Optional. Adapter size for tuning.""" + + ADAPTER_SIZE_UNSPECIFIED = 'ADAPTER_SIZE_UNSPECIFIED' + ADAPTER_SIZE_ONE = 'ADAPTER_SIZE_ONE' + ADAPTER_SIZE_FOUR = 'ADAPTER_SIZE_FOUR' + ADAPTER_SIZE_EIGHT = 'ADAPTER_SIZE_EIGHT' + ADAPTER_SIZE_SIXTEEN = 'ADAPTER_SIZE_SIXTEEN' + ADAPTER_SIZE_THIRTY_TWO = 'ADAPTER_SIZE_THIRTY_TWO' + + +class DynamicRetrievalConfigMode(_common.CaseInSensitiveEnum): + """Config class for the dynamic retrieval config mode.""" + + MODE_UNSPECIFIED = 'MODE_UNSPECIFIED' + MODE_DYNAMIC = 'MODE_DYNAMIC' + + +class FunctionCallingConfigMode(_common.CaseInSensitiveEnum): + """Config class for the function calling config mode.""" + + MODE_UNSPECIFIED = 'MODE_UNSPECIFIED' + AUTO = 'AUTO' + ANY = 'ANY' + NONE = 'NONE' + + +class MediaResolution(_common.CaseInSensitiveEnum): + """The media resolution to use.""" + + MEDIA_RESOLUTION_UNSPECIFIED = 'MEDIA_RESOLUTION_UNSPECIFIED' + MEDIA_RESOLUTION_LOW = 'MEDIA_RESOLUTION_LOW' + MEDIA_RESOLUTION_MEDIUM = 'MEDIA_RESOLUTION_MEDIUM' + MEDIA_RESOLUTION_HIGH = 'MEDIA_RESOLUTION_HIGH' + + +class SafetyFilterLevel(_common.CaseInSensitiveEnum): + """Enum that controls the safety filter level for objectionable content.""" + + BLOCK_LOW_AND_ABOVE = 'BLOCK_LOW_AND_ABOVE' + BLOCK_MEDIUM_AND_ABOVE = 'BLOCK_MEDIUM_AND_ABOVE' + BLOCK_ONLY_HIGH = 'BLOCK_ONLY_HIGH' + BLOCK_NONE = 'BLOCK_NONE' + + +class PersonGeneration(_common.CaseInSensitiveEnum): + """Enum that controls the generation of people.""" + + DONT_ALLOW = 'DONT_ALLOW' + ALLOW_ADULT = 'ALLOW_ADULT' + ALLOW_ALL = 'ALLOW_ALL' + + +class ImagePromptLanguage(_common.CaseInSensitiveEnum): + """Enum that specifies the language of the text in the prompt.""" + + auto = 'auto' + en = 'en' + ja = 'ja' + ko = 'ko' + hi = 'hi' + + +class MaskReferenceMode(_common.CaseInSensitiveEnum): + """Enum representing the mask mode of a mask reference image.""" + + MASK_MODE_DEFAULT = 'MASK_MODE_DEFAULT' + MASK_MODE_USER_PROVIDED = 'MASK_MODE_USER_PROVIDED' + MASK_MODE_BACKGROUND = 'MASK_MODE_BACKGROUND' + MASK_MODE_FOREGROUND = 'MASK_MODE_FOREGROUND' + MASK_MODE_SEMANTIC = 'MASK_MODE_SEMANTIC' + + +class ControlReferenceType(_common.CaseInSensitiveEnum): + """Enum representing the control type of a control reference image.""" + + CONTROL_TYPE_DEFAULT = 'CONTROL_TYPE_DEFAULT' + CONTROL_TYPE_CANNY = 'CONTROL_TYPE_CANNY' + CONTROL_TYPE_SCRIBBLE = 'CONTROL_TYPE_SCRIBBLE' + CONTROL_TYPE_FACE_MESH = 'CONTROL_TYPE_FACE_MESH' + + +class SubjectReferenceType(_common.CaseInSensitiveEnum): + """Enum representing the subject type of a subject reference image.""" + + SUBJECT_TYPE_DEFAULT = 'SUBJECT_TYPE_DEFAULT' + SUBJECT_TYPE_PERSON = 'SUBJECT_TYPE_PERSON' + SUBJECT_TYPE_ANIMAL = 'SUBJECT_TYPE_ANIMAL' + SUBJECT_TYPE_PRODUCT = 'SUBJECT_TYPE_PRODUCT' + + +class EditMode(_common.CaseInSensitiveEnum): + """Enum representing the Imagen 3 Edit mode.""" + + EDIT_MODE_DEFAULT = 'EDIT_MODE_DEFAULT' + EDIT_MODE_INPAINT_REMOVAL = 'EDIT_MODE_INPAINT_REMOVAL' + EDIT_MODE_INPAINT_INSERTION = 'EDIT_MODE_INPAINT_INSERTION' + EDIT_MODE_OUTPAINT = 'EDIT_MODE_OUTPAINT' + EDIT_MODE_CONTROLLED_EDITING = 'EDIT_MODE_CONTROLLED_EDITING' + EDIT_MODE_STYLE = 'EDIT_MODE_STYLE' + EDIT_MODE_BGSWAP = 'EDIT_MODE_BGSWAP' + EDIT_MODE_PRODUCT_IMAGE = 'EDIT_MODE_PRODUCT_IMAGE' + + +class FileState(_common.CaseInSensitiveEnum): + """State for the lifecycle of a File.""" + + STATE_UNSPECIFIED = 'STATE_UNSPECIFIED' + PROCESSING = 'PROCESSING' + ACTIVE = 'ACTIVE' + FAILED = 'FAILED' + + +class FileSource(_common.CaseInSensitiveEnum): + """Source of the File.""" + + SOURCE_UNSPECIFIED = 'SOURCE_UNSPECIFIED' + UPLOADED = 'UPLOADED' + GENERATED = 'GENERATED' + + +class Modality(_common.CaseInSensitiveEnum): + """Config class for the server content modalities.""" + + MODALITY_UNSPECIFIED = 'MODALITY_UNSPECIFIED' + TEXT = 'TEXT' + IMAGE = 'IMAGE' + AUDIO = 'AUDIO' + + +class VideoMetadata(_common.BaseModel): + """Metadata describes the input video content.""" + + end_offset: Optional[str] = Field( + default=None, description="""Optional. The end offset of the video.""" + ) + start_offset: Optional[str] = Field( + default=None, description="""Optional. The start offset of the video.""" + ) + + +class VideoMetadataDict(TypedDict, total=False): + """Metadata describes the input video content.""" + + end_offset: Optional[str] + """Optional. The end offset of the video.""" + + start_offset: Optional[str] + """Optional. The start offset of the video.""" + + +VideoMetadataOrDict = Union[VideoMetadata, VideoMetadataDict] + + +class CodeExecutionResult(_common.BaseModel): + """Result of executing the [ExecutableCode]. + + Always follows a `part` containing the [ExecutableCode]. + """ + + outcome: Optional[Outcome] = Field( + default=None, description="""Required. Outcome of the code execution.""" + ) + output: Optional[str] = Field( + default=None, + description="""Optional. Contains stdout when code execution is successful, stderr or other description otherwise.""", + ) + + +class CodeExecutionResultDict(TypedDict, total=False): + """Result of executing the [ExecutableCode]. + + Always follows a `part` containing the [ExecutableCode]. + """ + + outcome: Optional[Outcome] + """Required. Outcome of the code execution.""" + + output: Optional[str] + """Optional. Contains stdout when code execution is successful, stderr or other description otherwise.""" + + +CodeExecutionResultOrDict = Union[CodeExecutionResult, CodeExecutionResultDict] + + +class ExecutableCode(_common.BaseModel): + """Code generated by the model that is meant to be executed, and the result returned to the model. + + Generated when using the [FunctionDeclaration] tool and + [FunctionCallingConfig] mode is set to [Mode.CODE]. + """ + + code: Optional[str] = Field( + default=None, description="""Required. The code to be executed.""" + ) + language: Optional[Language] = Field( + default=None, + description="""Required. Programming language of the `code`.""", + ) + + +class ExecutableCodeDict(TypedDict, total=False): + """Code generated by the model that is meant to be executed, and the result returned to the model. + + Generated when using the [FunctionDeclaration] tool and + [FunctionCallingConfig] mode is set to [Mode.CODE]. + """ + + code: Optional[str] + """Required. The code to be executed.""" + + language: Optional[Language] + """Required. Programming language of the `code`.""" + + +ExecutableCodeOrDict = Union[ExecutableCode, ExecutableCodeDict] + + +class FileData(_common.BaseModel): + """URI based data.""" + + file_uri: Optional[str] = Field( + default=None, description="""Required. URI.""" + ) + mime_type: Optional[str] = Field( + default=None, + description="""Required. The IANA standard MIME type of the source data.""", + ) + + +class FileDataDict(TypedDict, total=False): + """URI based data.""" + + file_uri: Optional[str] + """Required. URI.""" + + mime_type: Optional[str] + """Required. The IANA standard MIME type of the source data.""" + + +FileDataOrDict = Union[FileData, FileDataDict] + + +class FunctionCall(_common.BaseModel): + """A function call.""" + + id: Optional[str] = Field( + default=None, + description="""The unique id of the function call. If populated, the client to execute the + `function_call` and return the response with the matching `id`.""", + ) + args: Optional[dict[str, Any]] = Field( + default=None, + description="""Optional. Required. The function parameters and values in JSON object format. See [FunctionDeclaration.parameters] for parameter details.""", + ) + name: Optional[str] = Field( + default=None, + description="""Required. The name of the function to call. Matches [FunctionDeclaration.name].""", + ) + + +class FunctionCallDict(TypedDict, total=False): + """A function call.""" + + id: Optional[str] + """The unique id of the function call. If populated, the client to execute the + `function_call` and return the response with the matching `id`.""" + + args: Optional[dict[str, Any]] + """Optional. Required. The function parameters and values in JSON object format. See [FunctionDeclaration.parameters] for parameter details.""" + + name: Optional[str] + """Required. The name of the function to call. Matches [FunctionDeclaration.name].""" + + +FunctionCallOrDict = Union[FunctionCall, FunctionCallDict] + + +class FunctionResponse(_common.BaseModel): + """A function response.""" + + id: Optional[str] = Field( + default=None, + description="""The id of the function call this response is for. Populated by the client + to match the corresponding function call `id`.""", + ) + name: Optional[str] = Field( + default=None, + description="""Required. The name of the function to call. Matches [FunctionDeclaration.name] and [FunctionCall.name].""", + ) + response: Optional[dict[str, Any]] = Field( + default=None, + description="""Required. The function response in JSON object format. Use "output" key to specify function output and "error" key to specify error details (if any). If "output" and "error" keys are not specified, then whole "response" is treated as function output.""", + ) + + +class FunctionResponseDict(TypedDict, total=False): + """A function response.""" + + id: Optional[str] + """The id of the function call this response is for. Populated by the client + to match the corresponding function call `id`.""" + + name: Optional[str] + """Required. The name of the function to call. Matches [FunctionDeclaration.name] and [FunctionCall.name].""" + + response: Optional[dict[str, Any]] + """Required. The function response in JSON object format. Use "output" key to specify function output and "error" key to specify error details (if any). If "output" and "error" keys are not specified, then whole "response" is treated as function output.""" + + +FunctionResponseOrDict = Union[FunctionResponse, FunctionResponseDict] + + +class Blob(_common.BaseModel): + """Content blob. + + It's preferred to send as text directly rather than raw bytes. + """ + + data: Optional[bytes] = Field( + default=None, description="""Required. Raw bytes.""" + ) + mime_type: Optional[str] = Field( + default=None, + description="""Required. The IANA standard MIME type of the source data.""", + ) + + +class BlobDict(TypedDict, total=False): + """Content blob. + + It's preferred to send as text directly rather than raw bytes. + """ + + data: Optional[bytes] + """Required. Raw bytes.""" + + mime_type: Optional[str] + """Required. The IANA standard MIME type of the source data.""" + + +BlobOrDict = Union[Blob, BlobDict] + + +class Part(_common.BaseModel): + """A datatype containing media content. + + Exactly one field within a Part should be set, representing the specific type + of content being conveyed. Using multiple fields within the same `Part` + instance is considered invalid. + """ + + video_metadata: Optional[VideoMetadata] = Field( + default=None, description="""Metadata for a given video.""" + ) + thought: Optional[bool] = Field( + default=None, + description="""Indicates if the part is thought from the model.""", + ) + code_execution_result: Optional[CodeExecutionResult] = Field( + default=None, + description="""Optional. Result of executing the [ExecutableCode].""", + ) + executable_code: Optional[ExecutableCode] = Field( + default=None, + description="""Optional. Code generated by the model that is meant to be executed.""", + ) + file_data: Optional[FileData] = Field( + default=None, description="""Optional. URI based data.""" + ) + function_call: Optional[FunctionCall] = Field( + default=None, + description="""Optional. A predicted [FunctionCall] returned from the model that contains a string representing the [FunctionDeclaration.name] with the parameters and their values.""", + ) + function_response: Optional[FunctionResponse] = Field( + default=None, + description="""Optional. The result output of a [FunctionCall] that contains a string representing the [FunctionDeclaration.name] and a structured JSON object containing any output from the function call. It is used as context to the model.""", + ) + inline_data: Optional[Blob] = Field( + default=None, description="""Optional. Inlined bytes data.""" + ) + text: Optional[str] = Field( + default=None, description="""Optional. Text part (can be code).""" + ) + + @classmethod + def from_uri(cls, file_uri: str, mime_type: str) -> 'Part': + file_data = FileData(file_uri=file_uri, mime_type=mime_type) + return cls(file_data=file_data) + + @classmethod + def from_text(cls, text: str) -> 'Part': + return cls(text=text) + + @classmethod + def from_bytes(cls, data: bytes, mime_type: str) -> 'Part': + inline_data = Blob( + data=data, + mime_type=mime_type, + ) + return cls(inline_data=inline_data) + + @classmethod + def from_function_call(cls, name: str, args: dict[str, Any]) -> 'Part': + function_call = FunctionCall(name=name, args=args) + return cls(function_call=function_call) + + @classmethod + def from_function_response( + cls, name: str, response: dict[str, Any] + ) -> 'Part': + function_response = FunctionResponse(name=name, response=response) + return cls(function_response=function_response) + + @classmethod + def from_video_metadata(cls, end_offset: str, start_offset: str) -> 'Part': + video_metadata = VideoMetadata( + end_offset=end_offset, start_offset=start_offset + ) + return cls(video_metadata=video_metadata) + + @classmethod + def from_executable_code(cls, code: str, language: Language) -> 'Part': + executable_code = ExecutableCode(code=code, language=language) + return cls(executable_code=executable_code) + + @classmethod + def from_code_execution_result(cls, outcome: Outcome, output: str) -> 'Part': + code_execution_result = CodeExecutionResult(outcome=outcome, output=output) + return cls(code_execution_result=code_execution_result) + + +class PartDict(TypedDict, total=False): + """A datatype containing media content. + + Exactly one field within a Part should be set, representing the specific type + of content being conveyed. Using multiple fields within the same `Part` + instance is considered invalid. + """ + + video_metadata: Optional[VideoMetadataDict] + """Metadata for a given video.""" + + thought: Optional[bool] + """Indicates if the part is thought from the model.""" + + code_execution_result: Optional[CodeExecutionResultDict] + """Optional. Result of executing the [ExecutableCode].""" + + executable_code: Optional[ExecutableCodeDict] + """Optional. Code generated by the model that is meant to be executed.""" + + file_data: Optional[FileDataDict] + """Optional. URI based data.""" + + function_call: Optional[FunctionCallDict] + """Optional. A predicted [FunctionCall] returned from the model that contains a string representing the [FunctionDeclaration.name] with the parameters and their values.""" + + function_response: Optional[FunctionResponseDict] + """Optional. The result output of a [FunctionCall] that contains a string representing the [FunctionDeclaration.name] and a structured JSON object containing any output from the function call. It is used as context to the model.""" + + inline_data: Optional[BlobDict] + """Optional. Inlined bytes data.""" + + text: Optional[str] + """Optional. Text part (can be code).""" + + +PartOrDict = Union[Part, PartDict] + + +class Content(_common.BaseModel): + """Contains the multi-part content of a message.""" + + parts: Optional[list[Part]] = Field( + default=None, + description="""List of parts that constitute a single message. Each part may have + a different IANA MIME type.""", + ) + role: Optional[str] = Field( + default=None, + description="""Optional. The producer of the content. Must be either 'user' or + 'model'. Useful to set for multi-turn conversations, otherwise can be + left blank or unset. If role is not specified, SDK will determine the role.""", + ) + + +class ContentDict(TypedDict, total=False): + """Contains the multi-part content of a message.""" + + parts: Optional[list[PartDict]] + """List of parts that constitute a single message. Each part may have + a different IANA MIME type.""" + + role: Optional[str] + """Optional. The producer of the content. Must be either 'user' or + 'model'. Useful to set for multi-turn conversations, otherwise can be + left blank or unset. If role is not specified, SDK will determine the role.""" + + +ContentOrDict = Union[Content, ContentDict] + + +class Schema(_common.BaseModel): + """Schema that defines the format of input and output data. + + Represents a select subset of an OpenAPI 3.0 schema object. + """ + + min_items: Optional[int] = Field( + default=None, + description="""Optional. Minimum number of the elements for Type.ARRAY.""", + ) + example: Optional[Any] = Field( + default=None, + description="""Optional. Example of the object. Will only populated when the object is the root.""", + ) + property_ordering: Optional[list[str]] = Field( + default=None, + description="""Optional. The order of the properties. Not a standard field in open api spec. Only used to support the order of the properties.""", + ) + pattern: Optional[str] = Field( + default=None, + description="""Optional. Pattern of the Type.STRING to restrict a string to a regular expression.""", + ) + minimum: Optional[float] = Field( + default=None, + description="""Optional. SCHEMA FIELDS FOR TYPE INTEGER and NUMBER Minimum value of the Type.INTEGER and Type.NUMBER""", + ) + default: Optional[Any] = Field( + default=None, description="""Optional. Default value of the data.""" + ) + any_of: list['Schema'] = Field( + default=None, + description="""Optional. The value should be validated against any (one or more) of the subschemas in the list.""", + ) + max_length: Optional[int] = Field( + default=None, + description="""Optional. Maximum length of the Type.STRING""", + ) + title: Optional[str] = Field( + default=None, description="""Optional. The title of the Schema.""" + ) + min_length: Optional[int] = Field( + default=None, + description="""Optional. SCHEMA FIELDS FOR TYPE STRING Minimum length of the Type.STRING""", + ) + min_properties: Optional[int] = Field( + default=None, + description="""Optional. Minimum number of the properties for Type.OBJECT.""", + ) + max_items: Optional[int] = Field( + default=None, + description="""Optional. Maximum number of the elements for Type.ARRAY.""", + ) + maximum: Optional[float] = Field( + default=None, + description="""Optional. Maximum value of the Type.INTEGER and Type.NUMBER""", + ) + nullable: Optional[bool] = Field( + default=None, + description="""Optional. Indicates if the value may be null.""", + ) + max_properties: Optional[int] = Field( + default=None, + description="""Optional. Maximum number of the properties for Type.OBJECT.""", + ) + type: Optional[Type] = Field( + default=None, description="""Optional. The type of the data.""" + ) + description: Optional[str] = Field( + default=None, description="""Optional. The description of the data.""" + ) + enum: Optional[list[str]] = Field( + default=None, + description="""Optional. Possible values of the element of primitive type with enum format. Examples: 1. We can define direction as : {type:STRING, format:enum, enum:["EAST", NORTH", "SOUTH", "WEST"]} 2. We can define apartment number as : {type:INTEGER, format:enum, enum:["101", "201", "301"]}""", + ) + format: Optional[str] = Field( + default=None, + description="""Optional. The format of the data. Supported formats: for NUMBER type: "float", "double" for INTEGER type: "int32", "int64" for STRING type: "email", "byte", etc""", + ) + items: 'Schema' = Field( + default=None, + description="""Optional. SCHEMA FIELDS FOR TYPE ARRAY Schema of the elements of Type.ARRAY.""", + ) + properties: dict[str, 'Schema'] = Field( + default=None, + description="""Optional. SCHEMA FIELDS FOR TYPE OBJECT Properties of Type.OBJECT.""", + ) + required: Optional[list[str]] = Field( + default=None, + description="""Optional. Required properties of Type.OBJECT.""", + ) + + +class SchemaDict(TypedDict, total=False): + """Schema that defines the format of input and output data. + + Represents a select subset of an OpenAPI 3.0 schema object. + """ + + min_items: Optional[int] + """Optional. Minimum number of the elements for Type.ARRAY.""" + + example: Optional[Any] + """Optional. Example of the object. Will only populated when the object is the root.""" + + property_ordering: Optional[list[str]] + """Optional. The order of the properties. Not a standard field in open api spec. Only used to support the order of the properties.""" + + pattern: Optional[str] + """Optional. Pattern of the Type.STRING to restrict a string to a regular expression.""" + + minimum: Optional[float] + """Optional. SCHEMA FIELDS FOR TYPE INTEGER and NUMBER Minimum value of the Type.INTEGER and Type.NUMBER""" + + default: Optional[Any] + """Optional. Default value of the data.""" + + any_of: list['SchemaDict'] + """Optional. The value should be validated against any (one or more) of the subschemas in the list.""" + + max_length: Optional[int] + """Optional. Maximum length of the Type.STRING""" + + title: Optional[str] + """Optional. The title of the Schema.""" + + min_length: Optional[int] + """Optional. SCHEMA FIELDS FOR TYPE STRING Minimum length of the Type.STRING""" + + min_properties: Optional[int] + """Optional. Minimum number of the properties for Type.OBJECT.""" + + max_items: Optional[int] + """Optional. Maximum number of the elements for Type.ARRAY.""" + + maximum: Optional[float] + """Optional. Maximum value of the Type.INTEGER and Type.NUMBER""" + + nullable: Optional[bool] + """Optional. Indicates if the value may be null.""" + + max_properties: Optional[int] + """Optional. Maximum number of the properties for Type.OBJECT.""" + + type: Optional[Type] + """Optional. The type of the data.""" + + description: Optional[str] + """Optional. The description of the data.""" + + enum: Optional[list[str]] + """Optional. Possible values of the element of primitive type with enum format. Examples: 1. We can define direction as : {type:STRING, format:enum, enum:["EAST", NORTH", "SOUTH", "WEST"]} 2. We can define apartment number as : {type:INTEGER, format:enum, enum:["101", "201", "301"]}""" + + format: Optional[str] + """Optional. The format of the data. Supported formats: for NUMBER type: "float", "double" for INTEGER type: "int32", "int64" for STRING type: "email", "byte", etc""" + + items: 'SchemaDict' + """Optional. SCHEMA FIELDS FOR TYPE ARRAY Schema of the elements of Type.ARRAY.""" + + properties: dict[str, 'SchemaDict'] + """Optional. SCHEMA FIELDS FOR TYPE OBJECT Properties of Type.OBJECT.""" + + required: Optional[list[str]] + """Optional. Required properties of Type.OBJECT.""" + + +SchemaOrDict = Union[Schema, SchemaDict] + + +class SafetySetting(_common.BaseModel): + """Safety settings.""" + + method: Optional[HarmBlockMethod] = Field( + default=None, + description="""Determines if the harm block method uses probability or probability + and severity scores.""", + ) + category: Optional[HarmCategory] = Field( + default=None, description="""Required. Harm category.""" + ) + threshold: Optional[HarmBlockThreshold] = Field( + default=None, description="""Required. The harm block threshold.""" + ) + + +class SafetySettingDict(TypedDict, total=False): + """Safety settings.""" + + method: Optional[HarmBlockMethod] + """Determines if the harm block method uses probability or probability + and severity scores.""" + + category: Optional[HarmCategory] + """Required. Harm category.""" + + threshold: Optional[HarmBlockThreshold] + """Required. The harm block threshold.""" + + +SafetySettingOrDict = Union[SafetySetting, SafetySettingDict] + + +class FunctionDeclaration(_common.BaseModel): + """Defines a function that the model can generate JSON inputs for. + + The inputs are based on `OpenAPI 3.0 specifications + <https://spec.openapis.org/oas/v3.0.3>`_. + """ + + response: Optional[Schema] = Field( + default=None, + description="""Describes the output from the function in the OpenAPI JSON Schema + Object format.""", + ) + description: Optional[str] = Field( + default=None, + description="""Optional. Description and purpose of the function. Model uses it to decide how and whether to call the function.""", + ) + name: Optional[str] = Field( + default=None, + description="""Required. The name of the function to call. Must start with a letter or an underscore. Must be a-z, A-Z, 0-9, or contain underscores, dots and dashes, with a maximum length of 64.""", + ) + parameters: Optional[Schema] = Field( + default=None, + description="""Optional. Describes the parameters to this function in JSON Schema Object format. Reflects the Open API 3.03 Parameter Object. string Key: the name of the parameter. Parameter names are case sensitive. Schema Value: the Schema defining the type used for the parameter. For function with no parameters, this can be left unset. Parameter names must start with a letter or an underscore and must only contain chars a-z, A-Z, 0-9, or underscores with a maximum length of 64. Example with 1 required and 1 optional parameter: type: OBJECT properties: param1: type: STRING param2: type: INTEGER required: - param1""", + ) + + @classmethod + def _get_variant(cls, client) -> str: + """Returns the function variant based on the provided client object.""" + if client.vertexai: + return 'VERTEX_AI' + else: + return 'GOOGLE_AI' + + @classmethod + def from_function_with_options( + cls, + func: Callable, + variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI', + ) -> 'FunctionDeclaration': + """Converts a function to a FunctionDeclaration based on an API endpoint. + + Supported endpoints are: 'GOOGLE_AI', 'VERTEX_AI', or 'DEFAULT'. + """ + from . import _automatic_function_calling_util + + supported_variants = ['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] + if variant not in supported_variants: + raise ValueError( + f'Unsupported variant: {variant}. Supported variants are:' + f' {", ".join(supported_variants)}' + ) + + # TODO: b/382524014 - Add support for DEFAULT API endpoint. + + parameters_properties = {} + for name, param in inspect.signature(func).parameters.items(): + if param.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_ONLY, + ): + schema = _automatic_function_calling_util._parse_schema_from_parameter( + variant, param, func.__name__ + ) + parameters_properties[name] = schema + declaration = FunctionDeclaration( + name=func.__name__, + description=func.__doc__, + ) + if parameters_properties: + declaration.parameters = Schema( + type='OBJECT', + properties=parameters_properties, + ) + if variant == 'VERTEX_AI': + declaration.parameters.required = ( + _automatic_function_calling_util._get_required_fields( + declaration.parameters + ) + ) + if not variant == 'VERTEX_AI': + return declaration + + return_annotation = inspect.signature(func).return_annotation + if return_annotation is inspect._empty: + return declaration + + declaration.response = ( + _automatic_function_calling_util._parse_schema_from_parameter( + variant, + inspect.Parameter( + 'return_value', + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=return_annotation, + ), + func.__name__, + ) + ) + return declaration + + @classmethod + def from_callable(cls, client, func: Callable) -> 'FunctionDeclaration': + """Converts a function to a FunctionDeclaration.""" + return cls.from_function_with_options( + variant=cls._get_variant(client), + func=func, + ) + + +class FunctionDeclarationDict(TypedDict, total=False): + """Defines a function that the model can generate JSON inputs for. + + The inputs are based on `OpenAPI 3.0 specifications + <https://spec.openapis.org/oas/v3.0.3>`_. + """ + + response: Optional[SchemaDict] + """Describes the output from the function in the OpenAPI JSON Schema + Object format.""" + + description: Optional[str] + """Optional. Description and purpose of the function. Model uses it to decide how and whether to call the function.""" + + name: Optional[str] + """Required. The name of the function to call. Must start with a letter or an underscore. Must be a-z, A-Z, 0-9, or contain underscores, dots and dashes, with a maximum length of 64.""" + + parameters: Optional[SchemaDict] + """Optional. Describes the parameters to this function in JSON Schema Object format. Reflects the Open API 3.03 Parameter Object. string Key: the name of the parameter. Parameter names are case sensitive. Schema Value: the Schema defining the type used for the parameter. For function with no parameters, this can be left unset. Parameter names must start with a letter or an underscore and must only contain chars a-z, A-Z, 0-9, or underscores with a maximum length of 64. Example with 1 required and 1 optional parameter: type: OBJECT properties: param1: type: STRING param2: type: INTEGER required: - param1""" + + +FunctionDeclarationOrDict = Union[FunctionDeclaration, FunctionDeclarationDict] + + +class GoogleSearch(_common.BaseModel): + """Tool to support Google Search in Model. Powered by Google.""" + + pass + + +class GoogleSearchDict(TypedDict, total=False): + """Tool to support Google Search in Model. Powered by Google.""" + + pass + + +GoogleSearchOrDict = Union[GoogleSearch, GoogleSearchDict] + + +class DynamicRetrievalConfig(_common.BaseModel): + """Describes the options to customize dynamic retrieval.""" + + mode: Optional[DynamicRetrievalConfigMode] = Field( + default=None, + description="""The mode of the predictor to be used in dynamic retrieval.""", + ) + dynamic_threshold: Optional[float] = Field( + default=None, + description="""Optional. The threshold to be used in dynamic retrieval. If not set, a system default value is used.""", + ) + + +class DynamicRetrievalConfigDict(TypedDict, total=False): + """Describes the options to customize dynamic retrieval.""" + + mode: Optional[DynamicRetrievalConfigMode] + """The mode of the predictor to be used in dynamic retrieval.""" + + dynamic_threshold: Optional[float] + """Optional. The threshold to be used in dynamic retrieval. If not set, a system default value is used.""" + + +DynamicRetrievalConfigOrDict = Union[ + DynamicRetrievalConfig, DynamicRetrievalConfigDict +] + + +class GoogleSearchRetrieval(_common.BaseModel): + """Tool to retrieve public web data for grounding, powered by Google.""" + + dynamic_retrieval_config: Optional[DynamicRetrievalConfig] = Field( + default=None, + description="""Specifies the dynamic retrieval configuration for the given source.""", + ) + + +class GoogleSearchRetrievalDict(TypedDict, total=False): + """Tool to retrieve public web data for grounding, powered by Google.""" + + dynamic_retrieval_config: Optional[DynamicRetrievalConfigDict] + """Specifies the dynamic retrieval configuration for the given source.""" + + +GoogleSearchRetrievalOrDict = Union[ + GoogleSearchRetrieval, GoogleSearchRetrievalDict +] + + +class VertexAISearch(_common.BaseModel): + """Retrieve from Vertex AI Search datastore for grounding. + + See https://cloud.google.com/products/agent-builder + """ + + datastore: Optional[str] = Field( + default=None, + description="""Required. Fully-qualified Vertex AI Search data store resource ID. Format: `projects/{project}/locations/{location}/collections/{collection}/dataStores/{dataStore}`""", + ) + + +class VertexAISearchDict(TypedDict, total=False): + """Retrieve from Vertex AI Search datastore for grounding. + + See https://cloud.google.com/products/agent-builder + """ + + datastore: Optional[str] + """Required. Fully-qualified Vertex AI Search data store resource ID. Format: `projects/{project}/locations/{location}/collections/{collection}/dataStores/{dataStore}`""" + + +VertexAISearchOrDict = Union[VertexAISearch, VertexAISearchDict] + + +class VertexRagStoreRagResource(_common.BaseModel): + """The definition of the Rag resource.""" + + rag_corpus: Optional[str] = Field( + default=None, + description="""Optional. RagCorpora resource name. Format: `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`""", + ) + rag_file_ids: Optional[list[str]] = Field( + default=None, + description="""Optional. rag_file_id. The files should be in the same rag_corpus set in rag_corpus field.""", + ) + + +class VertexRagStoreRagResourceDict(TypedDict, total=False): + """The definition of the Rag resource.""" + + rag_corpus: Optional[str] + """Optional. RagCorpora resource name. Format: `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`""" + + rag_file_ids: Optional[list[str]] + """Optional. rag_file_id. The files should be in the same rag_corpus set in rag_corpus field.""" + + +VertexRagStoreRagResourceOrDict = Union[ + VertexRagStoreRagResource, VertexRagStoreRagResourceDict +] + + +class VertexRagStore(_common.BaseModel): + """Retrieve from Vertex RAG Store for grounding.""" + + rag_corpora: Optional[list[str]] = Field( + default=None, + description="""Optional. Deprecated. Please use rag_resources instead.""", + ) + rag_resources: Optional[list[VertexRagStoreRagResource]] = Field( + default=None, + description="""Optional. The representation of the rag source. It can be used to specify corpus only or ragfiles. Currently only support one corpus or multiple files from one corpus. In the future we may open up multiple corpora support.""", + ) + similarity_top_k: Optional[int] = Field( + default=None, + description="""Optional. Number of top k results to return from the selected corpora.""", + ) + vector_distance_threshold: Optional[float] = Field( + default=None, + description="""Optional. Only return results with vector distance smaller than the threshold.""", + ) + + +class VertexRagStoreDict(TypedDict, total=False): + """Retrieve from Vertex RAG Store for grounding.""" + + rag_corpora: Optional[list[str]] + """Optional. Deprecated. Please use rag_resources instead.""" + + rag_resources: Optional[list[VertexRagStoreRagResourceDict]] + """Optional. The representation of the rag source. It can be used to specify corpus only or ragfiles. Currently only support one corpus or multiple files from one corpus. In the future we may open up multiple corpora support.""" + + similarity_top_k: Optional[int] + """Optional. Number of top k results to return from the selected corpora.""" + + vector_distance_threshold: Optional[float] + """Optional. Only return results with vector distance smaller than the threshold.""" + + +VertexRagStoreOrDict = Union[VertexRagStore, VertexRagStoreDict] + + +class Retrieval(_common.BaseModel): + """Defines a retrieval tool that model can call to access external knowledge.""" + + disable_attribution: Optional[bool] = Field( + default=None, + description="""Optional. Deprecated. This option is no longer supported.""", + ) + vertex_ai_search: Optional[VertexAISearch] = Field( + default=None, + description="""Set to use data source powered by Vertex AI Search.""", + ) + vertex_rag_store: Optional[VertexRagStore] = Field( + default=None, + description="""Set to use data source powered by Vertex RAG store. User data is uploaded via the VertexRagDataService.""", + ) + + +class RetrievalDict(TypedDict, total=False): + """Defines a retrieval tool that model can call to access external knowledge.""" + + disable_attribution: Optional[bool] + """Optional. Deprecated. This option is no longer supported.""" + + vertex_ai_search: Optional[VertexAISearchDict] + """Set to use data source powered by Vertex AI Search.""" + + vertex_rag_store: Optional[VertexRagStoreDict] + """Set to use data source powered by Vertex RAG store. User data is uploaded via the VertexRagDataService.""" + + +RetrievalOrDict = Union[Retrieval, RetrievalDict] + + +class ToolCodeExecution(_common.BaseModel): + """Tool that executes code generated by the model, and automatically returns the result to the model. + + See also [ExecutableCode]and [CodeExecutionResult] which are input and output + to this tool. + """ + + pass + + +class ToolCodeExecutionDict(TypedDict, total=False): + """Tool that executes code generated by the model, and automatically returns the result to the model. + + See also [ExecutableCode]and [CodeExecutionResult] which are input and output + to this tool. + """ + + pass + + +ToolCodeExecutionOrDict = Union[ToolCodeExecution, ToolCodeExecutionDict] + + +class Tool(_common.BaseModel): + """Tool details of a tool that the model may use to generate a response.""" + + function_declarations: Optional[list[FunctionDeclaration]] = Field( + default=None, + description="""List of function declarations that the tool supports.""", + ) + retrieval: Optional[Retrieval] = Field( + default=None, + description="""Optional. Retrieval tool type. System will always execute the provided retrieval tool(s) to get external knowledge to answer the prompt. Retrieval results are presented to the model for generation.""", + ) + google_search: Optional[GoogleSearch] = Field( + default=None, + description="""Optional. Google Search tool type. Specialized retrieval tool + that is powered by Google Search.""", + ) + google_search_retrieval: Optional[GoogleSearchRetrieval] = Field( + default=None, + description="""Optional. GoogleSearchRetrieval tool type. Specialized retrieval tool that is powered by Google search.""", + ) + code_execution: Optional[ToolCodeExecution] = Field( + default=None, + description="""Optional. CodeExecution tool type. Enables the model to execute code as part of generation. This field is only used by the Gemini Developer API services.""", + ) + + +class ToolDict(TypedDict, total=False): + """Tool details of a tool that the model may use to generate a response.""" + + function_declarations: Optional[list[FunctionDeclarationDict]] + """List of function declarations that the tool supports.""" + + retrieval: Optional[RetrievalDict] + """Optional. Retrieval tool type. System will always execute the provided retrieval tool(s) to get external knowledge to answer the prompt. Retrieval results are presented to the model for generation.""" + + google_search: Optional[GoogleSearchDict] + """Optional. Google Search tool type. Specialized retrieval tool + that is powered by Google Search.""" + + google_search_retrieval: Optional[GoogleSearchRetrievalDict] + """Optional. GoogleSearchRetrieval tool type. Specialized retrieval tool that is powered by Google search.""" + + code_execution: Optional[ToolCodeExecutionDict] + """Optional. CodeExecution tool type. Enables the model to execute code as part of generation. This field is only used by the Gemini Developer API services.""" + + +ToolOrDict = Union[Tool, ToolDict] +ToolListUnion = list[Union[Tool, Callable]] +ToolListUnionDict = list[Union[ToolDict, Callable]] + + +class FunctionCallingConfig(_common.BaseModel): + """Function calling config.""" + + mode: Optional[FunctionCallingConfigMode] = Field( + default=None, description="""Optional. Function calling mode.""" + ) + allowed_function_names: Optional[list[str]] = Field( + default=None, + description="""Optional. Function names to call. Only set when the Mode is ANY. Function names should match [FunctionDeclaration.name]. With mode set to ANY, model will predict a function call from the set of function names provided.""", + ) + + +class FunctionCallingConfigDict(TypedDict, total=False): + """Function calling config.""" + + mode: Optional[FunctionCallingConfigMode] + """Optional. Function calling mode.""" + + allowed_function_names: Optional[list[str]] + """Optional. Function names to call. Only set when the Mode is ANY. Function names should match [FunctionDeclaration.name]. With mode set to ANY, model will predict a function call from the set of function names provided.""" + + +FunctionCallingConfigOrDict = Union[ + FunctionCallingConfig, FunctionCallingConfigDict +] + + +class ToolConfig(_common.BaseModel): + """Tool config. + + This config is shared for all tools provided in the request. + """ + + function_calling_config: Optional[FunctionCallingConfig] = Field( + default=None, description="""Optional. Function calling config.""" + ) + + +class ToolConfigDict(TypedDict, total=False): + """Tool config. + + This config is shared for all tools provided in the request. + """ + + function_calling_config: Optional[FunctionCallingConfigDict] + """Optional. Function calling config.""" + + +ToolConfigOrDict = Union[ToolConfig, ToolConfigDict] + + +class PrebuiltVoiceConfig(_common.BaseModel): + """The configuration for the prebuilt speaker to use.""" + + voice_name: Optional[str] = Field( + default=None, + description="""The name of the prebuilt voice to use. + """, + ) + + +class PrebuiltVoiceConfigDict(TypedDict, total=False): + """The configuration for the prebuilt speaker to use.""" + + voice_name: Optional[str] + """The name of the prebuilt voice to use. + """ + + +PrebuiltVoiceConfigOrDict = Union[PrebuiltVoiceConfig, PrebuiltVoiceConfigDict] + + +class VoiceConfig(_common.BaseModel): + """The configuration for the voice to use.""" + + prebuilt_voice_config: Optional[PrebuiltVoiceConfig] = Field( + default=None, + description="""The configuration for the speaker to use. + """, + ) + + +class VoiceConfigDict(TypedDict, total=False): + """The configuration for the voice to use.""" + + prebuilt_voice_config: Optional[PrebuiltVoiceConfigDict] + """The configuration for the speaker to use. + """ + + +VoiceConfigOrDict = Union[VoiceConfig, VoiceConfigDict] + + +class SpeechConfig(_common.BaseModel): + """The speech generation configuration.""" + + voice_config: Optional[VoiceConfig] = Field( + default=None, + description="""The configuration for the speaker to use. + """, + ) + + +class SpeechConfigDict(TypedDict, total=False): + """The speech generation configuration.""" + + voice_config: Optional[VoiceConfigDict] + """The configuration for the speaker to use. + """ + + +SpeechConfigOrDict = Union[SpeechConfig, SpeechConfigDict] + + +class AutomaticFunctionCallingConfig(_common.BaseModel): + """The configuration for automatic function calling.""" + + disable: Optional[bool] = Field( + default=None, + description="""Whether to disable automatic function calling. + If not set or set to False, will enable automatic function calling. + If set to True, will disable automatic function calling. + """, + ) + maximum_remote_calls: Optional[int] = Field( + default=10, + description="""If automatic function calling is enabled, + maximum number of remote calls for automatic function calling. + This number should be a positive integer. + If not set, SDK will set maximum number of remote calls to 10. + """, + ) + ignore_call_history: Optional[bool] = Field( + default=None, + description="""If automatic function calling is enabled, + whether to ignore call history to the response. + If not set, SDK will set ignore_call_history to false, + and will append the call history to + GenerateContentResponse.automatic_function_calling_history. + """, + ) + + +class AutomaticFunctionCallingConfigDict(TypedDict, total=False): + """The configuration for automatic function calling.""" + + disable: Optional[bool] + """Whether to disable automatic function calling. + If not set or set to False, will enable automatic function calling. + If set to True, will disable automatic function calling. + """ + + maximum_remote_calls: Optional[int] + """If automatic function calling is enabled, + maximum number of remote calls for automatic function calling. + This number should be a positive integer. + If not set, SDK will set maximum number of remote calls to 10. + """ + + ignore_call_history: Optional[bool] + """If automatic function calling is enabled, + whether to ignore call history to the response. + If not set, SDK will set ignore_call_history to false, + and will append the call history to + GenerateContentResponse.automatic_function_calling_history. + """ + + +AutomaticFunctionCallingConfigOrDict = Union[ + AutomaticFunctionCallingConfig, AutomaticFunctionCallingConfigDict +] + + +class ThinkingConfig(_common.BaseModel): + """The thinking features configuration.""" + + include_thoughts: Optional[bool] = Field( + default=None, + description="""Indicates whether to include thoughts in the response. If true, thoughts are returned only if the model supports thought and thoughts are available. + """, + ) + + +class ThinkingConfigDict(TypedDict, total=False): + """The thinking features configuration.""" + + include_thoughts: Optional[bool] + """Indicates whether to include thoughts in the response. If true, thoughts are returned only if the model supports thought and thoughts are available. + """ + + +ThinkingConfigOrDict = Union[ThinkingConfig, ThinkingConfigDict] + + +class FileStatus(_common.BaseModel): + """Status of a File that uses a common error model.""" + + details: Optional[list[dict[str, Any]]] = Field( + default=None, + description="""A list of messages that carry the error details. There is a common set of message types for APIs to use.""", + ) + message: Optional[str] = Field( + default=None, + description="""A list of messages that carry the error details. There is a common set of message types for APIs to use.""", + ) + code: Optional[int] = Field( + default=None, description="""The status code. 0 for OK, 1 for CANCELLED""" + ) + + +class FileStatusDict(TypedDict, total=False): + """Status of a File that uses a common error model.""" + + details: Optional[list[dict[str, Any]]] + """A list of messages that carry the error details. There is a common set of message types for APIs to use.""" + + message: Optional[str] + """A list of messages that carry the error details. There is a common set of message types for APIs to use.""" + + code: Optional[int] + """The status code. 0 for OK, 1 for CANCELLED""" + + +FileStatusOrDict = Union[FileStatus, FileStatusDict] + + +class File(_common.BaseModel): + """A file uploaded to the API.""" + + name: Optional[str] = Field( + default=None, + description="""The `File` resource name. The ID (name excluding the "files/" prefix) can contain up to 40 characters that are lowercase alphanumeric or dashes (-). The ID cannot start or end with a dash. If the name is empty on create, a unique name will be generated. Example: `files/123-456`""", + ) + display_name: Optional[str] = Field( + default=None, + description="""Optional. The human-readable display name for the `File`. The display name must be no more than 512 characters in length, including spaces. Example: 'Welcome Image'""", + ) + mime_type: Optional[str] = Field( + default=None, description="""Output only. MIME type of the file.""" + ) + size_bytes: Optional[int] = Field( + default=None, description="""Output only. Size of the file in bytes.""" + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. The timestamp of when the `File` was created.""", + ) + expiration_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. The human-readable display name for the `File`. The display name must be no more than 512 characters in length, including spaces. Example: 'Welcome Image'""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. The timestamp of when the `File` was last updated.""", + ) + sha256_hash: Optional[str] = Field( + default=None, + description="""Output only. SHA-256 hash of the uploaded bytes.""", + ) + uri: Optional[str] = Field( + default=None, description="""Output only. The URI of the `File`.""" + ) + download_uri: Optional[str] = Field( + default=None, + description="""Output only. The URI of the `File`, only set for downloadable (generated) files.""", + ) + state: Optional[FileState] = Field( + default=None, description="""Output only. Processing state of the File.""" + ) + source: Optional[FileSource] = Field( + default=None, description="""Output only. The source of the `File`.""" + ) + video_metadata: Optional[dict[str, Any]] = Field( + default=None, description="""Output only. Metadata for a video.""" + ) + error: Optional[FileStatus] = Field( + default=None, + description="""Output only. Error status if File processing failed.""", + ) + + +class FileDict(TypedDict, total=False): + """A file uploaded to the API.""" + + name: Optional[str] + """The `File` resource name. The ID (name excluding the "files/" prefix) can contain up to 40 characters that are lowercase alphanumeric or dashes (-). The ID cannot start or end with a dash. If the name is empty on create, a unique name will be generated. Example: `files/123-456`""" + + display_name: Optional[str] + """Optional. The human-readable display name for the `File`. The display name must be no more than 512 characters in length, including spaces. Example: 'Welcome Image'""" + + mime_type: Optional[str] + """Output only. MIME type of the file.""" + + size_bytes: Optional[int] + """Output only. Size of the file in bytes.""" + + create_time: Optional[datetime.datetime] + """Output only. The timestamp of when the `File` was created.""" + + expiration_time: Optional[datetime.datetime] + """Optional. The human-readable display name for the `File`. The display name must be no more than 512 characters in length, including spaces. Example: 'Welcome Image'""" + + update_time: Optional[datetime.datetime] + """Output only. The timestamp of when the `File` was last updated.""" + + sha256_hash: Optional[str] + """Output only. SHA-256 hash of the uploaded bytes.""" + + uri: Optional[str] + """Output only. The URI of the `File`.""" + + download_uri: Optional[str] + """Output only. The URI of the `File`, only set for downloadable (generated) files.""" + + state: Optional[FileState] + """Output only. Processing state of the File.""" + + source: Optional[FileSource] + """Output only. The source of the `File`.""" + + video_metadata: Optional[dict[str, Any]] + """Output only. Metadata for a video.""" + + error: Optional[FileStatusDict] + """Output only. Error status if File processing failed.""" + + +FileOrDict = Union[File, FileDict] + + +PartUnion = Union[File, Part, PIL.Image.Image, str] + + +PartUnionDict = Union[PartUnion, PartDict] + + +ContentUnion = Union[Content, list[PartUnion], PartUnion] + + +ContentUnionDict = Union[ContentUnion, ContentDict] + + +SchemaUnion = Union[dict, type, Schema, GenericAlias] + + +SchemaUnionDict = Union[SchemaUnion, SchemaDict] + + +class GenerationConfigRoutingConfigAutoRoutingMode(_common.BaseModel): + """When automated routing is specified, the routing will be determined by the pretrained routing model and customer provided model routing preference.""" + + model_routing_preference: Optional[ + Literal['UNKNOWN', 'PRIORITIZE_QUALITY', 'BALANCED', 'PRIORITIZE_COST'] + ] = Field(default=None, description="""The model routing preference.""") + + +class GenerationConfigRoutingConfigAutoRoutingModeDict(TypedDict, total=False): + """When automated routing is specified, the routing will be determined by the pretrained routing model and customer provided model routing preference.""" + + model_routing_preference: Optional[ + Literal['UNKNOWN', 'PRIORITIZE_QUALITY', 'BALANCED', 'PRIORITIZE_COST'] + ] + """The model routing preference.""" + + +GenerationConfigRoutingConfigAutoRoutingModeOrDict = Union[ + GenerationConfigRoutingConfigAutoRoutingMode, + GenerationConfigRoutingConfigAutoRoutingModeDict, +] + + +class GenerationConfigRoutingConfigManualRoutingMode(_common.BaseModel): + """When manual routing is set, the specified model will be used directly.""" + + model_name: Optional[str] = Field( + default=None, + description="""The model name to use. Only the public LLM models are accepted. e.g. 'gemini-1.5-pro-001'.""", + ) + + +class GenerationConfigRoutingConfigManualRoutingModeDict( + TypedDict, total=False +): + """When manual routing is set, the specified model will be used directly.""" + + model_name: Optional[str] + """The model name to use. Only the public LLM models are accepted. e.g. 'gemini-1.5-pro-001'.""" + + +GenerationConfigRoutingConfigManualRoutingModeOrDict = Union[ + GenerationConfigRoutingConfigManualRoutingMode, + GenerationConfigRoutingConfigManualRoutingModeDict, +] + + +class GenerationConfigRoutingConfig(_common.BaseModel): + """The configuration for routing the request to a specific model.""" + + auto_mode: Optional[GenerationConfigRoutingConfigAutoRoutingMode] = Field( + default=None, description="""Automated routing.""" + ) + manual_mode: Optional[GenerationConfigRoutingConfigManualRoutingMode] = Field( + default=None, description="""Manual routing.""" + ) + + +class GenerationConfigRoutingConfigDict(TypedDict, total=False): + """The configuration for routing the request to a specific model.""" + + auto_mode: Optional[GenerationConfigRoutingConfigAutoRoutingModeDict] + """Automated routing.""" + + manual_mode: Optional[GenerationConfigRoutingConfigManualRoutingModeDict] + """Manual routing.""" + + +GenerationConfigRoutingConfigOrDict = Union[ + GenerationConfigRoutingConfig, GenerationConfigRoutingConfigDict +] + + +SpeechConfigUnion = Union[SpeechConfig, str] + + +SpeechConfigUnionDict = Union[SpeechConfigUnion, SpeechConfigDict] + + +class GenerateContentConfig(_common.BaseModel): + """Class for configuring optional model parameters. + + For more information, see `Content generation parameters + <https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/content-generation-parameters>`_. + """ + + system_instruction: Optional[ContentUnion] = Field( + default=None, + description="""Instructions for the model to steer it toward better performance. + For example, "Answer as concisely as possible" or "Don't use technical + terms in your response". + """, + ) + temperature: Optional[float] = Field( + default=None, + description="""Value that controls the degree of randomness in token selection. + Lower temperatures are good for prompts that require a less open-ended or + creative response, while higher temperatures can lead to more diverse or + creative results. + """, + ) + top_p: Optional[float] = Field( + default=None, + description="""Tokens are selected from the most to least probable until the sum + of their probabilities equals this value. Use a lower value for less + random responses and a higher value for more random responses. + """, + ) + top_k: Optional[float] = Field( + default=None, + description="""For each token selection step, the ``top_k`` tokens with the + highest probabilities are sampled. Then tokens are further filtered based + on ``top_p`` with the final token selected using temperature sampling. Use + a lower number for less random responses and a higher number for more + random responses. + """, + ) + candidate_count: Optional[int] = Field( + default=None, + description="""Number of response variations to return. + """, + ) + max_output_tokens: Optional[int] = Field( + default=None, + description="""Maximum number of tokens that can be generated in the response. + """, + ) + stop_sequences: Optional[list[str]] = Field( + default=None, + description="""List of strings that tells the model to stop generating text if one + of the strings is encountered in the response. + """, + ) + response_logprobs: Optional[bool] = Field( + default=None, + description="""Whether to return the log probabilities of the tokens that were + chosen by the model at each step. + """, + ) + logprobs: Optional[int] = Field( + default=None, + description="""Number of top candidate tokens to return the log probabilities for + at each generation step. + """, + ) + presence_penalty: Optional[float] = Field( + default=None, + description="""Positive values penalize tokens that already appear in the + generated text, increasing the probability of generating more diverse + content. + """, + ) + frequency_penalty: Optional[float] = Field( + default=None, + description="""Positive values penalize tokens that repeatedly appear in the + generated text, increasing the probability of generating more diverse + content. + """, + ) + seed: Optional[int] = Field( + default=None, + description="""When ``seed`` is fixed to a specific number, the model makes a best + effort to provide the same response for repeated requests. By default, a + random number is used. + """, + ) + response_mime_type: Optional[str] = Field( + default=None, + description="""Output response media type of the generated candidate text. + """, + ) + response_schema: Optional[SchemaUnion] = Field( + default=None, + description="""Schema that the generated candidate text must adhere to. + """, + ) + routing_config: Optional[GenerationConfigRoutingConfig] = Field( + default=None, + description="""Configuration for model router requests. + """, + ) + safety_settings: Optional[list[SafetySetting]] = Field( + default=None, + description="""Safety settings in the request to block unsafe content in the + response. + """, + ) + tools: Optional[ToolListUnion] = Field( + default=None, + description="""Code that enables the system to interact with external systems to + perform an action outside of the knowledge and scope of the model. + """, + ) + tool_config: Optional[ToolConfig] = Field( + default=None, + description="""Associates model output to a specific function call. + """, + ) + cached_content: Optional[str] = Field( + default=None, + description="""Resource name of a context cache that can be used in subsequent + requests. + """, + ) + response_modalities: Optional[list[str]] = Field( + default=None, + description="""The requested modalities of the response. Represents the set of + modalities that the model can return. + """, + ) + media_resolution: Optional[MediaResolution] = Field( + default=None, + description="""If specified, the media resolution specified will be used. + """, + ) + speech_config: Optional[SpeechConfigUnion] = Field( + default=None, + description="""The speech generation configuration. + """, + ) + audio_timestamp: Optional[bool] = Field( + default=None, + description="""If enabled, audio timestamp will be included in the request to the + model. + """, + ) + automatic_function_calling: Optional[AutomaticFunctionCallingConfig] = Field( + default=None, + description="""The configuration for automatic function calling. + """, + ) + thinking_config: Optional[ThinkingConfig] = Field( + default=None, + description="""The thinking features configuration. + """, + ) + + +class GenerateContentConfigDict(TypedDict, total=False): + """Class for configuring optional model parameters. + + For more information, see `Content generation parameters + <https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/content-generation-parameters>`_. + """ + + system_instruction: Optional[ContentUnionDict] + """Instructions for the model to steer it toward better performance. + For example, "Answer as concisely as possible" or "Don't use technical + terms in your response". + """ + + temperature: Optional[float] + """Value that controls the degree of randomness in token selection. + Lower temperatures are good for prompts that require a less open-ended or + creative response, while higher temperatures can lead to more diverse or + creative results. + """ + + top_p: Optional[float] + """Tokens are selected from the most to least probable until the sum + of their probabilities equals this value. Use a lower value for less + random responses and a higher value for more random responses. + """ + + top_k: Optional[float] + """For each token selection step, the ``top_k`` tokens with the + highest probabilities are sampled. Then tokens are further filtered based + on ``top_p`` with the final token selected using temperature sampling. Use + a lower number for less random responses and a higher number for more + random responses. + """ + + candidate_count: Optional[int] + """Number of response variations to return. + """ + + max_output_tokens: Optional[int] + """Maximum number of tokens that can be generated in the response. + """ + + stop_sequences: Optional[list[str]] + """List of strings that tells the model to stop generating text if one + of the strings is encountered in the response. + """ + + response_logprobs: Optional[bool] + """Whether to return the log probabilities of the tokens that were + chosen by the model at each step. + """ + + logprobs: Optional[int] + """Number of top candidate tokens to return the log probabilities for + at each generation step. + """ + + presence_penalty: Optional[float] + """Positive values penalize tokens that already appear in the + generated text, increasing the probability of generating more diverse + content. + """ + + frequency_penalty: Optional[float] + """Positive values penalize tokens that repeatedly appear in the + generated text, increasing the probability of generating more diverse + content. + """ + + seed: Optional[int] + """When ``seed`` is fixed to a specific number, the model makes a best + effort to provide the same response for repeated requests. By default, a + random number is used. + """ + + response_mime_type: Optional[str] + """Output response media type of the generated candidate text. + """ + + response_schema: Optional[SchemaUnionDict] + """Schema that the generated candidate text must adhere to. + """ + + routing_config: Optional[GenerationConfigRoutingConfigDict] + """Configuration for model router requests. + """ + + safety_settings: Optional[list[SafetySettingDict]] + """Safety settings in the request to block unsafe content in the + response. + """ + + tools: Optional[ToolListUnionDict] + """Code that enables the system to interact with external systems to + perform an action outside of the knowledge and scope of the model. + """ + + tool_config: Optional[ToolConfigDict] + """Associates model output to a specific function call. + """ + + cached_content: Optional[str] + """Resource name of a context cache that can be used in subsequent + requests. + """ + + response_modalities: Optional[list[str]] + """The requested modalities of the response. Represents the set of + modalities that the model can return. + """ + + media_resolution: Optional[MediaResolution] + """If specified, the media resolution specified will be used. + """ + + speech_config: Optional[SpeechConfigUnionDict] + """The speech generation configuration. + """ + + audio_timestamp: Optional[bool] + """If enabled, audio timestamp will be included in the request to the + model. + """ + + automatic_function_calling: Optional[AutomaticFunctionCallingConfigDict] + """The configuration for automatic function calling. + """ + + thinking_config: Optional[ThinkingConfigDict] + """The thinking features configuration. + """ + + +GenerateContentConfigOrDict = Union[ + GenerateContentConfig, GenerateContentConfigDict +] + + +ContentListUnion = Union[list[ContentUnion], ContentUnion] + + +ContentListUnionDict = Union[list[ContentUnionDict], ContentUnionDict] + + +class _GenerateContentParameters(_common.BaseModel): + """Class for configuring the content of the request to the model.""" + + model: Optional[str] = Field( + default=None, + description="""ID of the model to use. For a list of models, see `Google models + <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models>`_.""", + ) + contents: Optional[ContentListUnion] = Field( + default=None, + description="""Content of the request. + """, + ) + config: Optional[GenerateContentConfig] = Field( + default=None, + description="""Configuration that contains optional model parameters. + """, + ) + + +class _GenerateContentParametersDict(TypedDict, total=False): + """Class for configuring the content of the request to the model.""" + + model: Optional[str] + """ID of the model to use. For a list of models, see `Google models + <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models>`_.""" + + contents: Optional[ContentListUnionDict] + """Content of the request. + """ + + config: Optional[GenerateContentConfigDict] + """Configuration that contains optional model parameters. + """ + + +_GenerateContentParametersOrDict = Union[ + _GenerateContentParameters, _GenerateContentParametersDict +] + + +class GoogleTypeDate(_common.BaseModel): + """Represents a whole or partial calendar date, such as a birthday. + + The time of day and time zone are either specified elsewhere or are + insignificant. The date is relative to the Gregorian Calendar. This can + represent one of the following: * A full date, with non-zero year, month, and + day values. * A month and day, with a zero year (for example, an anniversary). + * A year on its own, with a zero month and a zero day. * A year and month, + with a zero day (for example, a credit card expiration date). Related types: * + google.type.TimeOfDay * google.type.DateTime * google.protobuf.Timestamp + """ + + day: Optional[int] = Field( + default=None, + description="""Day of a month. Must be from 1 to 31 and valid for the year and month, or 0 to specify a year by itself or a year and month where the day isn't significant.""", + ) + month: Optional[int] = Field( + default=None, + description="""Month of a year. Must be from 1 to 12, or 0 to specify a year without a month and day.""", + ) + year: Optional[int] = Field( + default=None, + description="""Year of the date. Must be from 1 to 9999, or 0 to specify a date without a year.""", + ) + + +class GoogleTypeDateDict(TypedDict, total=False): + """Represents a whole or partial calendar date, such as a birthday. + + The time of day and time zone are either specified elsewhere or are + insignificant. The date is relative to the Gregorian Calendar. This can + represent one of the following: * A full date, with non-zero year, month, and + day values. * A month and day, with a zero year (for example, an anniversary). + * A year on its own, with a zero month and a zero day. * A year and month, + with a zero day (for example, a credit card expiration date). Related types: * + google.type.TimeOfDay * google.type.DateTime * google.protobuf.Timestamp + """ + + day: Optional[int] + """Day of a month. Must be from 1 to 31 and valid for the year and month, or 0 to specify a year by itself or a year and month where the day isn't significant.""" + + month: Optional[int] + """Month of a year. Must be from 1 to 12, or 0 to specify a year without a month and day.""" + + year: Optional[int] + """Year of the date. Must be from 1 to 9999, or 0 to specify a date without a year.""" + + +GoogleTypeDateOrDict = Union[GoogleTypeDate, GoogleTypeDateDict] + + +class Citation(_common.BaseModel): + """Source attributions for content.""" + + end_index: Optional[int] = Field( + default=None, description="""Output only. End index into the content.""" + ) + license: Optional[str] = Field( + default=None, description="""Output only. License of the attribution.""" + ) + publication_date: Optional[GoogleTypeDate] = Field( + default=None, + description="""Output only. Publication date of the attribution.""", + ) + start_index: Optional[int] = Field( + default=None, description="""Output only. Start index into the content.""" + ) + title: Optional[str] = Field( + default=None, description="""Output only. Title of the attribution.""" + ) + uri: Optional[str] = Field( + default=None, + description="""Output only. Url reference of the attribution.""", + ) + + +class CitationDict(TypedDict, total=False): + """Source attributions for content.""" + + end_index: Optional[int] + """Output only. End index into the content.""" + + license: Optional[str] + """Output only. License of the attribution.""" + + publication_date: Optional[GoogleTypeDateDict] + """Output only. Publication date of the attribution.""" + + start_index: Optional[int] + """Output only. Start index into the content.""" + + title: Optional[str] + """Output only. Title of the attribution.""" + + uri: Optional[str] + """Output only. Url reference of the attribution.""" + + +CitationOrDict = Union[Citation, CitationDict] + + +class CitationMetadata(_common.BaseModel): + """Class for citation information when the model quotes another source.""" + + citations: Optional[list[Citation]] = Field( + default=None, + description="""Contains citation information when the model directly quotes, at + length, from another source. Can include traditional websites and code + repositories. + """, + ) + + +class CitationMetadataDict(TypedDict, total=False): + """Class for citation information when the model quotes another source.""" + + citations: Optional[list[CitationDict]] + """Contains citation information when the model directly quotes, at + length, from another source. Can include traditional websites and code + repositories. + """ + + +CitationMetadataOrDict = Union[CitationMetadata, CitationMetadataDict] + + +class GroundingChunkRetrievedContext(_common.BaseModel): + """Chunk from context retrieved by the retrieval tools.""" + + text: Optional[str] = Field( + default=None, description="""Text of the attribution.""" + ) + title: Optional[str] = Field( + default=None, description="""Title of the attribution.""" + ) + uri: Optional[str] = Field( + default=None, description="""URI reference of the attribution.""" + ) + + +class GroundingChunkRetrievedContextDict(TypedDict, total=False): + """Chunk from context retrieved by the retrieval tools.""" + + text: Optional[str] + """Text of the attribution.""" + + title: Optional[str] + """Title of the attribution.""" + + uri: Optional[str] + """URI reference of the attribution.""" + + +GroundingChunkRetrievedContextOrDict = Union[ + GroundingChunkRetrievedContext, GroundingChunkRetrievedContextDict +] + + +class GroundingChunkWeb(_common.BaseModel): + """Chunk from the web.""" + + title: Optional[str] = Field( + default=None, description="""Title of the chunk.""" + ) + uri: Optional[str] = Field( + default=None, description="""URI reference of the chunk.""" + ) + + +class GroundingChunkWebDict(TypedDict, total=False): + """Chunk from the web.""" + + title: Optional[str] + """Title of the chunk.""" + + uri: Optional[str] + """URI reference of the chunk.""" + + +GroundingChunkWebOrDict = Union[GroundingChunkWeb, GroundingChunkWebDict] + + +class GroundingChunk(_common.BaseModel): + """Grounding chunk.""" + + retrieved_context: Optional[GroundingChunkRetrievedContext] = Field( + default=None, + description="""Grounding chunk from context retrieved by the retrieval tools.""", + ) + web: Optional[GroundingChunkWeb] = Field( + default=None, description="""Grounding chunk from the web.""" + ) + + +class GroundingChunkDict(TypedDict, total=False): + """Grounding chunk.""" + + retrieved_context: Optional[GroundingChunkRetrievedContextDict] + """Grounding chunk from context retrieved by the retrieval tools.""" + + web: Optional[GroundingChunkWebDict] + """Grounding chunk from the web.""" + + +GroundingChunkOrDict = Union[GroundingChunk, GroundingChunkDict] + + +class Segment(_common.BaseModel): + """Segment of the content.""" + + end_index: Optional[int] = Field( + default=None, + description="""Output only. End index in the given Part, measured in bytes. Offset from the start of the Part, exclusive, starting at zero.""", + ) + part_index: Optional[int] = Field( + default=None, + description="""Output only. The index of a Part object within its parent Content object.""", + ) + start_index: Optional[int] = Field( + default=None, + description="""Output only. Start index in the given Part, measured in bytes. Offset from the start of the Part, inclusive, starting at zero.""", + ) + text: Optional[str] = Field( + default=None, + description="""Output only. The text corresponding to the segment from the response.""", + ) + + +class SegmentDict(TypedDict, total=False): + """Segment of the content.""" + + end_index: Optional[int] + """Output only. End index in the given Part, measured in bytes. Offset from the start of the Part, exclusive, starting at zero.""" + + part_index: Optional[int] + """Output only. The index of a Part object within its parent Content object.""" + + start_index: Optional[int] + """Output only. Start index in the given Part, measured in bytes. Offset from the start of the Part, inclusive, starting at zero.""" + + text: Optional[str] + """Output only. The text corresponding to the segment from the response.""" + + +SegmentOrDict = Union[Segment, SegmentDict] + + +class GroundingSupport(_common.BaseModel): + """Grounding support.""" + + confidence_scores: Optional[list[float]] = Field( + default=None, + description="""Confidence score of the support references. Ranges from 0 to 1. 1 is the most confident. This list must have the same size as the grounding_chunk_indices.""", + ) + grounding_chunk_indices: Optional[list[int]] = Field( + default=None, + description="""A list of indices (into 'grounding_chunk') specifying the citations associated with the claim. For instance [1,3,4] means that grounding_chunk[1], grounding_chunk[3], grounding_chunk[4] are the retrieved content attributed to the claim.""", + ) + segment: Optional[Segment] = Field( + default=None, + description="""Segment of the content this support belongs to.""", + ) + + +class GroundingSupportDict(TypedDict, total=False): + """Grounding support.""" + + confidence_scores: Optional[list[float]] + """Confidence score of the support references. Ranges from 0 to 1. 1 is the most confident. This list must have the same size as the grounding_chunk_indices.""" + + grounding_chunk_indices: Optional[list[int]] + """A list of indices (into 'grounding_chunk') specifying the citations associated with the claim. For instance [1,3,4] means that grounding_chunk[1], grounding_chunk[3], grounding_chunk[4] are the retrieved content attributed to the claim.""" + + segment: Optional[SegmentDict] + """Segment of the content this support belongs to.""" + + +GroundingSupportOrDict = Union[GroundingSupport, GroundingSupportDict] + + +class RetrievalMetadata(_common.BaseModel): + """Metadata related to retrieval in the grounding flow.""" + + google_search_dynamic_retrieval_score: Optional[float] = Field( + default=None, + description="""Optional. Score indicating how likely information from Google Search could help answer the prompt. The score is in the range `[0, 1]`, where 0 is the least likely and 1 is the most likely. This score is only populated when Google Search grounding and dynamic retrieval is enabled. It will be compared to the threshold to determine whether to trigger Google Search.""", + ) + + +class RetrievalMetadataDict(TypedDict, total=False): + """Metadata related to retrieval in the grounding flow.""" + + google_search_dynamic_retrieval_score: Optional[float] + """Optional. Score indicating how likely information from Google Search could help answer the prompt. The score is in the range `[0, 1]`, where 0 is the least likely and 1 is the most likely. This score is only populated when Google Search grounding and dynamic retrieval is enabled. It will be compared to the threshold to determine whether to trigger Google Search.""" + + +RetrievalMetadataOrDict = Union[RetrievalMetadata, RetrievalMetadataDict] + + +class SearchEntryPoint(_common.BaseModel): + """Google search entry point.""" + + rendered_content: Optional[str] = Field( + default=None, + description="""Optional. Web content snippet that can be embedded in a web page or an app webview.""", + ) + sdk_blob: Optional[bytes] = Field( + default=None, + description="""Optional. Base64 encoded JSON representing array of tuple.""", + ) + + +class SearchEntryPointDict(TypedDict, total=False): + """Google search entry point.""" + + rendered_content: Optional[str] + """Optional. Web content snippet that can be embedded in a web page or an app webview.""" + + sdk_blob: Optional[bytes] + """Optional. Base64 encoded JSON representing array of tuple.""" + + +SearchEntryPointOrDict = Union[SearchEntryPoint, SearchEntryPointDict] + + +class GroundingMetadata(_common.BaseModel): + """Metadata returned to client when grounding is enabled.""" + + grounding_chunks: Optional[list[GroundingChunk]] = Field( + default=None, + description="""List of supporting references retrieved from specified grounding source.""", + ) + grounding_supports: Optional[list[GroundingSupport]] = Field( + default=None, description="""Optional. List of grounding support.""" + ) + retrieval_metadata: Optional[RetrievalMetadata] = Field( + default=None, description="""Optional. Output only. Retrieval metadata.""" + ) + retrieval_queries: Optional[list[str]] = Field( + default=None, + description="""Optional. Queries executed by the retrieval tools.""", + ) + search_entry_point: Optional[SearchEntryPoint] = Field( + default=None, + description="""Optional. Google search entry for the following-up web searches.""", + ) + web_search_queries: Optional[list[str]] = Field( + default=None, + description="""Optional. Web search queries for the following-up web search.""", + ) + + +class GroundingMetadataDict(TypedDict, total=False): + """Metadata returned to client when grounding is enabled.""" + + grounding_chunks: Optional[list[GroundingChunkDict]] + """List of supporting references retrieved from specified grounding source.""" + + grounding_supports: Optional[list[GroundingSupportDict]] + """Optional. List of grounding support.""" + + retrieval_metadata: Optional[RetrievalMetadataDict] + """Optional. Output only. Retrieval metadata.""" + + retrieval_queries: Optional[list[str]] + """Optional. Queries executed by the retrieval tools.""" + + search_entry_point: Optional[SearchEntryPointDict] + """Optional. Google search entry for the following-up web searches.""" + + web_search_queries: Optional[list[str]] + """Optional. Web search queries for the following-up web search.""" + + +GroundingMetadataOrDict = Union[GroundingMetadata, GroundingMetadataDict] + + +class LogprobsResultCandidate(_common.BaseModel): + """Candidate for the logprobs token and score.""" + + log_probability: Optional[float] = Field( + default=None, description="""The candidate's log probability.""" + ) + token: Optional[str] = Field( + default=None, description="""The candidate's token string value.""" + ) + token_id: Optional[int] = Field( + default=None, description="""The candidate's token id value.""" + ) + + +class LogprobsResultCandidateDict(TypedDict, total=False): + """Candidate for the logprobs token and score.""" + + log_probability: Optional[float] + """The candidate's log probability.""" + + token: Optional[str] + """The candidate's token string value.""" + + token_id: Optional[int] + """The candidate's token id value.""" + + +LogprobsResultCandidateOrDict = Union[ + LogprobsResultCandidate, LogprobsResultCandidateDict +] + + +class LogprobsResultTopCandidates(_common.BaseModel): + """Candidates with top log probabilities at each decoding step.""" + + candidates: Optional[list[LogprobsResultCandidate]] = Field( + default=None, + description="""Sorted by log probability in descending order.""", + ) + + +class LogprobsResultTopCandidatesDict(TypedDict, total=False): + """Candidates with top log probabilities at each decoding step.""" + + candidates: Optional[list[LogprobsResultCandidateDict]] + """Sorted by log probability in descending order.""" + + +LogprobsResultTopCandidatesOrDict = Union[ + LogprobsResultTopCandidates, LogprobsResultTopCandidatesDict +] + + +class LogprobsResult(_common.BaseModel): + """Logprobs Result""" + + chosen_candidates: Optional[list[LogprobsResultCandidate]] = Field( + default=None, + description="""Length = total number of decoding steps. The chosen candidates may or may not be in top_candidates.""", + ) + top_candidates: Optional[list[LogprobsResultTopCandidates]] = Field( + default=None, description="""Length = total number of decoding steps.""" + ) + + +class LogprobsResultDict(TypedDict, total=False): + """Logprobs Result""" + + chosen_candidates: Optional[list[LogprobsResultCandidateDict]] + """Length = total number of decoding steps. The chosen candidates may or may not be in top_candidates.""" + + top_candidates: Optional[list[LogprobsResultTopCandidatesDict]] + """Length = total number of decoding steps.""" + + +LogprobsResultOrDict = Union[LogprobsResult, LogprobsResultDict] + + +class SafetyRating(_common.BaseModel): + """Safety rating corresponding to the generated content.""" + + blocked: Optional[bool] = Field( + default=None, + description="""Output only. Indicates whether the content was filtered out because of this rating.""", + ) + category: Optional[HarmCategory] = Field( + default=None, description="""Output only. Harm category.""" + ) + probability: Optional[HarmProbability] = Field( + default=None, + description="""Output only. Harm probability levels in the content.""", + ) + probability_score: Optional[float] = Field( + default=None, description="""Output only. Harm probability score.""" + ) + severity: Optional[HarmSeverity] = Field( + default=None, + description="""Output only. Harm severity levels in the content.""", + ) + severity_score: Optional[float] = Field( + default=None, description="""Output only. Harm severity score.""" + ) + + +class SafetyRatingDict(TypedDict, total=False): + """Safety rating corresponding to the generated content.""" + + blocked: Optional[bool] + """Output only. Indicates whether the content was filtered out because of this rating.""" + + category: Optional[HarmCategory] + """Output only. Harm category.""" + + probability: Optional[HarmProbability] + """Output only. Harm probability levels in the content.""" + + probability_score: Optional[float] + """Output only. Harm probability score.""" + + severity: Optional[HarmSeverity] + """Output only. Harm severity levels in the content.""" + + severity_score: Optional[float] + """Output only. Harm severity score.""" + + +SafetyRatingOrDict = Union[SafetyRating, SafetyRatingDict] + + +class Candidate(_common.BaseModel): + """Class containing a response candidate generated from the model.""" + + content: Optional[Content] = Field( + default=None, + description="""Contains the multi-part content of the response. + """, + ) + citation_metadata: Optional[CitationMetadata] = Field( + default=None, + description="""Source attribution of the generated content. + """, + ) + finish_message: Optional[str] = Field( + default=None, + description="""Describes the reason the model stopped generating tokens. + """, + ) + token_count: Optional[int] = Field( + default=None, + description="""Number of tokens for this candidate. + """, + ) + avg_logprobs: Optional[float] = Field( + default=None, + description="""Output only. Average log probability score of the candidate.""", + ) + finish_reason: Optional[FinishReason] = Field( + default=None, + description="""Output only. The reason why the model stopped generating tokens. If empty, the model has not stopped generating the tokens.""", + ) + grounding_metadata: Optional[GroundingMetadata] = Field( + default=None, + description="""Output only. Metadata specifies sources used to ground generated content.""", + ) + index: Optional[int] = Field( + default=None, description="""Output only. Index of the candidate.""" + ) + logprobs_result: Optional[LogprobsResult] = Field( + default=None, + description="""Output only. Log-likelihood scores for the response tokens and top tokens""", + ) + safety_ratings: Optional[list[SafetyRating]] = Field( + default=None, + description="""Output only. List of ratings for the safety of a response candidate. There is at most one rating per category.""", + ) + + +class CandidateDict(TypedDict, total=False): + """Class containing a response candidate generated from the model.""" + + content: Optional[ContentDict] + """Contains the multi-part content of the response. + """ + + citation_metadata: Optional[CitationMetadataDict] + """Source attribution of the generated content. + """ + + finish_message: Optional[str] + """Describes the reason the model stopped generating tokens. + """ + + token_count: Optional[int] + """Number of tokens for this candidate. + """ + + avg_logprobs: Optional[float] + """Output only. Average log probability score of the candidate.""" + + finish_reason: Optional[FinishReason] + """Output only. The reason why the model stopped generating tokens. If empty, the model has not stopped generating the tokens.""" + + grounding_metadata: Optional[GroundingMetadataDict] + """Output only. Metadata specifies sources used to ground generated content.""" + + index: Optional[int] + """Output only. Index of the candidate.""" + + logprobs_result: Optional[LogprobsResultDict] + """Output only. Log-likelihood scores for the response tokens and top tokens""" + + safety_ratings: Optional[list[SafetyRatingDict]] + """Output only. List of ratings for the safety of a response candidate. There is at most one rating per category.""" + + +CandidateOrDict = Union[Candidate, CandidateDict] + + +class GenerateContentResponsePromptFeedback(_common.BaseModel): + """Content filter results for a prompt sent in the request.""" + + block_reason: Optional[BlockedReason] = Field( + default=None, description="""Output only. Blocked reason.""" + ) + block_reason_message: Optional[str] = Field( + default=None, + description="""Output only. A readable block reason message.""", + ) + safety_ratings: Optional[list[SafetyRating]] = Field( + default=None, description="""Output only. Safety ratings.""" + ) + + +class GenerateContentResponsePromptFeedbackDict(TypedDict, total=False): + """Content filter results for a prompt sent in the request.""" + + block_reason: Optional[BlockedReason] + """Output only. Blocked reason.""" + + block_reason_message: Optional[str] + """Output only. A readable block reason message.""" + + safety_ratings: Optional[list[SafetyRatingDict]] + """Output only. Safety ratings.""" + + +GenerateContentResponsePromptFeedbackOrDict = Union[ + GenerateContentResponsePromptFeedback, + GenerateContentResponsePromptFeedbackDict, +] + + +class GenerateContentResponseUsageMetadata(_common.BaseModel): + """Usage metadata about response(s).""" + + cached_content_token_count: Optional[int] = Field( + default=None, + description="""Output only. Number of tokens in the cached part in the input (the cached content).""", + ) + candidates_token_count: Optional[int] = Field( + default=None, description="""Number of tokens in the response(s).""" + ) + prompt_token_count: Optional[int] = Field( + default=None, + description="""Number of tokens in the request. When `cached_content` is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.""", + ) + total_token_count: Optional[int] = Field( + default=None, + description="""Total token count for prompt and response candidates.""", + ) + + +class GenerateContentResponseUsageMetadataDict(TypedDict, total=False): + """Usage metadata about response(s).""" + + cached_content_token_count: Optional[int] + """Output only. Number of tokens in the cached part in the input (the cached content).""" + + candidates_token_count: Optional[int] + """Number of tokens in the response(s).""" + + prompt_token_count: Optional[int] + """Number of tokens in the request. When `cached_content` is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.""" + + total_token_count: Optional[int] + """Total token count for prompt and response candidates.""" + + +GenerateContentResponseUsageMetadataOrDict = Union[ + GenerateContentResponseUsageMetadata, + GenerateContentResponseUsageMetadataDict, +] + + +class GenerateContentResponse(_common.BaseModel): + """Response message for PredictionService.GenerateContent.""" + + candidates: Optional[list[Candidate]] = Field( + default=None, + description="""Response variations returned by the model. + """, + ) + model_version: Optional[str] = Field( + default=None, + description="""Output only. The model version used to generate the response.""", + ) + prompt_feedback: Optional[GenerateContentResponsePromptFeedback] = Field( + default=None, + description="""Output only. Content filter results for a prompt sent in the request. Note: Sent only in the first stream chunk. Only happens when no candidates were generated due to content violations.""", + ) + usage_metadata: Optional[GenerateContentResponseUsageMetadata] = Field( + default=None, description="""Usage metadata about the response(s).""" + ) + automatic_function_calling_history: Optional[list[Content]] = None + parsed: Union[pydantic.BaseModel, dict] = Field( + default=None, + description="""Parsed response if response_schema is provided. Not available for streaming.""", + ) + + @property + def text(self) -> Optional[str]: + """Returns the concatenation of all text parts in the response.""" + if ( + not self.candidates + or not self.candidates[0].content + or not self.candidates[0].content.parts + ): + return None + if len(self.candidates) > 1: + logging.warning( + f'there are {len(self.candidates)} candidates, returning text from' + ' the first candidate.Access response.candidates directly to get' + ' text from other candidates.' + ) + text = '' + any_text_part_text = False + for part in self.candidates[0].content.parts: + for field_name, field_value in part.dict( + exclude={'text', 'thought'} + ).items(): + if field_value is not None: + raise ValueError( + 'GenerateContentResponse.text only supports text parts, but got' + f' {field_name} part{part}' + ) + if isinstance(part.text, str): + if isinstance(part.thought, bool) and part.thought: + continue + any_text_part_text = True + text += part.text + # part.text == '' is different from part.text is None + return text if any_text_part_text else None + + @property + def function_calls(self) -> Optional[list[FunctionCall]]: + """Returns the list of function calls in the response.""" + if ( + not self.candidates + or not self.candidates[0].content + or not self.candidates[0].content.parts + ): + return None + if len(self.candidates) > 1: + logging.warning( + 'Warning: there are multiple candidates in the response, returning' + ' function calls from the first one.' + ) + function_calls = [ + part.function_call + for part in self.candidates[0].content.parts + if part.function_call is not None + ] + + return function_calls if function_calls else None + + @classmethod + def _from_response( + cls, response: dict[str, object], kwargs: dict[str, object] + ): + result = super()._from_response(response, kwargs) + + # Handles response schema. + response_schema = _common.get_value_by_path( + kwargs, ['config', 'response_schema'] + ) + if inspect.isclass(response_schema) and issubclass( + response_schema, pydantic.BaseModel + ): + # Pydantic schema. + try: + result.parsed = response_schema.model_validate_json(result.text) + # may not be a valid json per stream response + except pydantic.ValidationError: + pass + + elif isinstance(response_schema, GenericAlias) and issubclass( + response_schema.__args__[0], pydantic.BaseModel + ): + # Handle cases where `list[pydantic.BaseModel]` was provided. + result.parsed = [] + pydantic_model_class = response_schema.__args__[0] + response_list_json = json.loads(result.text) + for json_instance in response_list_json: + try: + pydantic_model_instance = pydantic_model_class.model_validate_json( + json.dumps(json_instance) + ) + result.parsed.append(pydantic_model_instance) + # may not be a valid json per stream response + except pydantic.ValidationError: + pass + + elif isinstance(response_schema, dict) or isinstance( + response_schema, pydantic.BaseModel + ): + # JSON schema. + try: + result.parsed = json.loads(result.text) + # may not be a valid json per stream response + except json.decoder.JSONDecodeError: + pass + + return result + + +class GenerateContentResponseDict(TypedDict, total=False): + """Response message for PredictionService.GenerateContent.""" + + candidates: Optional[list[CandidateDict]] + """Response variations returned by the model. + """ + + model_version: Optional[str] + """Output only. The model version used to generate the response.""" + + prompt_feedback: Optional[GenerateContentResponsePromptFeedbackDict] + """Output only. Content filter results for a prompt sent in the request. Note: Sent only in the first stream chunk. Only happens when no candidates were generated due to content violations.""" + + usage_metadata: Optional[GenerateContentResponseUsageMetadataDict] + """Usage metadata about the response(s).""" + + +GenerateContentResponseOrDict = Union[ + GenerateContentResponse, GenerateContentResponseDict +] + + +class EmbedContentConfig(_common.BaseModel): + """Class for configuring optional parameters for the embed_content method.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + task_type: Optional[str] = Field( + default=None, + description="""Type of task for which the embedding will be used. + """, + ) + title: Optional[str] = Field( + default=None, + description="""Title for the text. Only applicable when TaskType is + `RETRIEVAL_DOCUMENT`. + """, + ) + output_dimensionality: Optional[int] = Field( + default=None, + description="""Reduced dimension for the output embedding. If set, + excessive values in the output embedding are truncated from the end. + Supported by newer models since 2024 only. You cannot set this value if + using the earlier model (`models/embedding-001`). + """, + ) + mime_type: Optional[str] = Field( + default=None, + description="""Vertex API only. The MIME type of the input. + """, + ) + auto_truncate: Optional[bool] = Field( + default=None, + description="""Vertex API only. Whether to silently truncate inputs longer than + the max sequence length. If this option is set to false, oversized inputs + will lead to an INVALID_ARGUMENT error, similar to other text APIs. + """, + ) + + +class EmbedContentConfigDict(TypedDict, total=False): + """Class for configuring optional parameters for the embed_content method.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + task_type: Optional[str] + """Type of task for which the embedding will be used. + """ + + title: Optional[str] + """Title for the text. Only applicable when TaskType is + `RETRIEVAL_DOCUMENT`. + """ + + output_dimensionality: Optional[int] + """Reduced dimension for the output embedding. If set, + excessive values in the output embedding are truncated from the end. + Supported by newer models since 2024 only. You cannot set this value if + using the earlier model (`models/embedding-001`). + """ + + mime_type: Optional[str] + """Vertex API only. The MIME type of the input. + """ + + auto_truncate: Optional[bool] + """Vertex API only. Whether to silently truncate inputs longer than + the max sequence length. If this option is set to false, oversized inputs + will lead to an INVALID_ARGUMENT error, similar to other text APIs. + """ + + +EmbedContentConfigOrDict = Union[EmbedContentConfig, EmbedContentConfigDict] + + +class _EmbedContentParameters(_common.BaseModel): + """Parameters for the embed_content method.""" + + model: Optional[str] = Field( + default=None, + description="""ID of the model to use. For a list of models, see `Google models + <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models>`_.""", + ) + contents: Optional[ContentListUnion] = Field( + default=None, + description="""The content to embed. Only the `parts.text` fields will be counted. + """, + ) + config: Optional[EmbedContentConfig] = Field( + default=None, + description="""Configuration that contains optional parameters. + """, + ) + + +class _EmbedContentParametersDict(TypedDict, total=False): + """Parameters for the embed_content method.""" + + model: Optional[str] + """ID of the model to use. For a list of models, see `Google models + <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models>`_.""" + + contents: Optional[ContentListUnionDict] + """The content to embed. Only the `parts.text` fields will be counted. + """ + + config: Optional[EmbedContentConfigDict] + """Configuration that contains optional parameters. + """ + + +_EmbedContentParametersOrDict = Union[ + _EmbedContentParameters, _EmbedContentParametersDict +] + + +class ContentEmbeddingStatistics(_common.BaseModel): + """Statistics of the input text associated with the result of content embedding.""" + + truncated: Optional[bool] = Field( + default=None, + description="""Vertex API only. If the input text was truncated due to having + a length longer than the allowed maximum input. + """, + ) + token_count: Optional[float] = Field( + default=None, + description="""Vertex API only. Number of tokens of the input text. + """, + ) + + +class ContentEmbeddingStatisticsDict(TypedDict, total=False): + """Statistics of the input text associated with the result of content embedding.""" + + truncated: Optional[bool] + """Vertex API only. If the input text was truncated due to having + a length longer than the allowed maximum input. + """ + + token_count: Optional[float] + """Vertex API only. Number of tokens of the input text. + """ + + +ContentEmbeddingStatisticsOrDict = Union[ + ContentEmbeddingStatistics, ContentEmbeddingStatisticsDict +] + + +class ContentEmbedding(_common.BaseModel): + """The embedding generated from an input content.""" + + values: Optional[list[float]] = Field( + default=None, + description="""A list of floats representing an embedding. + """, + ) + statistics: Optional[ContentEmbeddingStatistics] = Field( + default=None, + description="""Vertex API only. Statistics of the input text associated with this + embedding. + """, + ) + + +class ContentEmbeddingDict(TypedDict, total=False): + """The embedding generated from an input content.""" + + values: Optional[list[float]] + """A list of floats representing an embedding. + """ + + statistics: Optional[ContentEmbeddingStatisticsDict] + """Vertex API only. Statistics of the input text associated with this + embedding. + """ + + +ContentEmbeddingOrDict = Union[ContentEmbedding, ContentEmbeddingDict] + + +class EmbedContentMetadata(_common.BaseModel): + """Request-level metadata for the Vertex Embed Content API.""" + + billable_character_count: Optional[int] = Field( + default=None, + description="""Vertex API only. The total number of billable characters included + in the request. + """, + ) + + +class EmbedContentMetadataDict(TypedDict, total=False): + """Request-level metadata for the Vertex Embed Content API.""" + + billable_character_count: Optional[int] + """Vertex API only. The total number of billable characters included + in the request. + """ + + +EmbedContentMetadataOrDict = Union[ + EmbedContentMetadata, EmbedContentMetadataDict +] + + +class EmbedContentResponse(_common.BaseModel): + """Response for the embed_content method.""" + + embeddings: Optional[list[ContentEmbedding]] = Field( + default=None, + description="""The embeddings for each request, in the same order as provided in + the batch request. + """, + ) + metadata: Optional[EmbedContentMetadata] = Field( + default=None, + description="""Vertex API only. Metadata about the request. + """, + ) + + +class EmbedContentResponseDict(TypedDict, total=False): + """Response for the embed_content method.""" + + embeddings: Optional[list[ContentEmbeddingDict]] + """The embeddings for each request, in the same order as provided in + the batch request. + """ + + metadata: Optional[EmbedContentMetadataDict] + """Vertex API only. Metadata about the request. + """ + + +EmbedContentResponseOrDict = Union[ + EmbedContentResponse, EmbedContentResponseDict +] + + +class GenerateImageConfig(_common.BaseModel): + """Class that represents the config for generating an image.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + output_gcs_uri: Optional[str] = Field( + default=None, + description="""Cloud Storage URI used to store the generated images. + """, + ) + negative_prompt: Optional[str] = Field( + default=None, + description="""Description of what to discourage in the generated images. + """, + ) + number_of_images: Optional[int] = Field( + default=None, + description="""Number of images to generate. + """, + ) + guidance_scale: Optional[float] = Field( + default=None, + description="""Controls how much the model adheres to the text prompt. Large + values increase output and prompt alignment, but may compromise image + quality. + """, + ) + seed: Optional[int] = Field( + default=None, + description="""Random seed for image generation. This is not available when + ``add_watermark`` is set to true. + """, + ) + safety_filter_level: Optional[SafetyFilterLevel] = Field( + default=None, + description="""Filter level for safety filtering. + """, + ) + person_generation: Optional[PersonGeneration] = Field( + default=None, + description="""Allows generation of people by the model. + """, + ) + include_safety_attributes: Optional[bool] = Field( + default=None, + description="""Whether to report the safety scores of each image in the response. + """, + ) + include_rai_reason: Optional[bool] = Field( + default=None, + description="""Whether to include the Responsible AI filter reason if the image + is filtered out of the response. + """, + ) + language: Optional[ImagePromptLanguage] = Field( + default=None, + description="""Language of the text in the prompt. + """, + ) + output_mime_type: Optional[str] = Field( + default=None, + description="""MIME type of the generated image. + """, + ) + output_compression_quality: Optional[int] = Field( + default=None, + description="""Compression quality of the generated image (for ``image/jpeg`` + only). + """, + ) + add_watermark: Optional[bool] = Field( + default=None, + description="""Whether to add a watermark to the generated image. + """, + ) + aspect_ratio: Optional[str] = Field( + default=None, + description="""Aspect ratio of the generated image. + """, + ) + + +class GenerateImageConfigDict(TypedDict, total=False): + """Class that represents the config for generating an image.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + output_gcs_uri: Optional[str] + """Cloud Storage URI used to store the generated images. + """ + + negative_prompt: Optional[str] + """Description of what to discourage in the generated images. + """ + + number_of_images: Optional[int] + """Number of images to generate. + """ + + guidance_scale: Optional[float] + """Controls how much the model adheres to the text prompt. Large + values increase output and prompt alignment, but may compromise image + quality. + """ + + seed: Optional[int] + """Random seed for image generation. This is not available when + ``add_watermark`` is set to true. + """ + + safety_filter_level: Optional[SafetyFilterLevel] + """Filter level for safety filtering. + """ + + person_generation: Optional[PersonGeneration] + """Allows generation of people by the model. + """ + + include_safety_attributes: Optional[bool] + """Whether to report the safety scores of each image in the response. + """ + + include_rai_reason: Optional[bool] + """Whether to include the Responsible AI filter reason if the image + is filtered out of the response. + """ + + language: Optional[ImagePromptLanguage] + """Language of the text in the prompt. + """ + + output_mime_type: Optional[str] + """MIME type of the generated image. + """ + + output_compression_quality: Optional[int] + """Compression quality of the generated image (for ``image/jpeg`` + only). + """ + + add_watermark: Optional[bool] + """Whether to add a watermark to the generated image. + """ + + aspect_ratio: Optional[str] + """Aspect ratio of the generated image. + """ + + +GenerateImageConfigOrDict = Union[GenerateImageConfig, GenerateImageConfigDict] + + +class _GenerateImageParameters(_common.BaseModel): + """Class that represents the parameters for generating an image.""" + + model: Optional[str] = Field( + default=None, + description="""ID of the model to use. For a list of models, see `Google models + <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models>`_.""", + ) + prompt: Optional[str] = Field( + default=None, + description="""Text prompt that typically describes the image to output. + """, + ) + config: Optional[GenerateImageConfig] = Field( + default=None, + description="""Configuration for generating an image. + """, + ) + + +class _GenerateImageParametersDict(TypedDict, total=False): + """Class that represents the parameters for generating an image.""" + + model: Optional[str] + """ID of the model to use. For a list of models, see `Google models + <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models>`_.""" + + prompt: Optional[str] + """Text prompt that typically describes the image to output. + """ + + config: Optional[GenerateImageConfigDict] + """Configuration for generating an image. + """ + + +_GenerateImageParametersOrDict = Union[ + _GenerateImageParameters, _GenerateImageParametersDict +] + + +class Image(_common.BaseModel): + """Class that represents an image.""" + + gcs_uri: Optional[str] = Field( + default=None, + description="""The Cloud Storage URI of the image. ``Image`` can contain a value + for this field or the ``image_bytes`` field but not both. + """, + ) + image_bytes: Optional[bytes] = Field( + default=None, + description="""The image bytes data. ``Image`` can contain a value for this field + or the ``gcs_uri`` field but not both. + """, + ) + + _loaded_image = None + + """Image.""" + + @staticmethod + def from_file(location: str) -> 'Image': + """Lazy-loads an image from a local file or Google Cloud Storage. + + Args: + location: The local path or Google Cloud Storage URI from which to load + the image. + + Returns: + A loaded image as an `Image` object. + """ + import urllib + import pathlib + + parsed_url = urllib.parse.urlparse(location) + if ( + parsed_url.scheme == 'https' + and parsed_url.netloc == 'storage.googleapis.com' + ): + parsed_url = parsed_url._replace( + scheme='gs', + netloc='', + path=f'/{urllib.parse.unquote(parsed_url.path)}', + ) + location = urllib.parse.urlunparse(parsed_url) + + if parsed_url.scheme == 'gs': + return Image(gcs_uri=location) + + # Load image from local path + image_bytes = pathlib.Path(location).read_bytes() + image = Image(image_bytes=image_bytes) + return image + + def show(self): + """Shows the image. + + This method only works in a notebook environment. + """ + try: + from IPython import display as IPython_display + except ImportError: + IPython_display = None + + try: + from PIL import Image as PIL_Image + except ImportError: + PIL_Image = None + if PIL_Image and IPython_display: + IPython_display.display(self._pil_image) + + @property + def _pil_image(self) -> 'PIL_Image.Image': + try: + from PIL import Image as PIL_Image + except ImportError: + PIL_Image = None + import io + + if self._loaded_image is None: + if not PIL_Image: + raise RuntimeError( + 'The PIL module is not available. Please install the Pillow' + ' package.' + ) + self._loaded_image = PIL_Image.open(io.BytesIO(self.image_bytes)) + return self._loaded_image + + def save(self, location: str): + """Saves the image to a file. + + Args: + location: Local path where to save the image. + """ + import pathlib + + pathlib.Path(location).write_bytes(self.image_bytes) + + +JOB_STATES_SUCCEEDED_VERTEX = [ + 'JOB_STATE_SUCCEEDED', +] + +JOB_STATES_SUCCEEDED_MLDEV = [ + 'ACTIVE', +] + +JOB_STATES_SUCCEEDED = JOB_STATES_SUCCEEDED_VERTEX + JOB_STATES_SUCCEEDED_MLDEV + + +JOB_STATES_ENDED_VERTEX = [ + 'JOB_STATE_SUCCEEDED', + 'JOB_STATE_FAILED', + 'JOB_STATE_CANCELLED', + 'JOB_STATE_EXPIRED', +] + +JOB_STATES_ENDED_MLDEV = [ + 'ACTIVE', + 'FAILED', +] + +JOB_STATES_ENDED = JOB_STATES_ENDED_VERTEX + JOB_STATES_ENDED_MLDEV + + +class ImageDict(TypedDict, total=False): + """Class that represents an image.""" + + gcs_uri: Optional[str] + """The Cloud Storage URI of the image. ``Image`` can contain a value + for this field or the ``image_bytes`` field but not both. + """ + + image_bytes: Optional[bytes] + """The image bytes data. ``Image`` can contain a value for this field + or the ``gcs_uri`` field but not both. + """ + + +ImageOrDict = Union[Image, ImageDict] + + +class GeneratedImage(_common.BaseModel): + """Class that represents an output image.""" + + image: Optional[Image] = Field( + default=None, + description="""The output image data. + """, + ) + rai_filtered_reason: Optional[str] = Field( + default=None, + description="""Responsible AI filter reason if the image is filtered out of the + response. + """, + ) + + +class GeneratedImageDict(TypedDict, total=False): + """Class that represents an output image.""" + + image: Optional[ImageDict] + """The output image data. + """ + + rai_filtered_reason: Optional[str] + """Responsible AI filter reason if the image is filtered out of the + response. + """ + + +GeneratedImageOrDict = Union[GeneratedImage, GeneratedImageDict] + + +class GenerateImageResponse(_common.BaseModel): + """Class that represents the output image response.""" + + generated_images: Optional[list[GeneratedImage]] = Field( + default=None, + description="""List of generated images. + """, + ) + + +class GenerateImageResponseDict(TypedDict, total=False): + """Class that represents the output image response.""" + + generated_images: Optional[list[GeneratedImageDict]] + """List of generated images. + """ + + +GenerateImageResponseOrDict = Union[ + GenerateImageResponse, GenerateImageResponseDict +] + + +class MaskReferenceConfig(_common.BaseModel): + """Configuration for a Mask reference image.""" + + mask_mode: Optional[MaskReferenceMode] = Field( + default=None, + description="""Prompts the model to generate a mask instead of you needing to + provide one (unless MASK_MODE_USER_PROVIDED is used).""", + ) + segmentation_classes: Optional[list[int]] = Field( + default=None, + description="""A list of up to 5 class ids to use for semantic segmentation. + Automatically creates an image mask based on specific objects.""", + ) + mask_dilation: Optional[float] = Field( + default=None, + description="""Dilation percentage of the mask provided. + Float between 0 and 1.""", + ) + + +class MaskReferenceConfigDict(TypedDict, total=False): + """Configuration for a Mask reference image.""" + + mask_mode: Optional[MaskReferenceMode] + """Prompts the model to generate a mask instead of you needing to + provide one (unless MASK_MODE_USER_PROVIDED is used).""" + + segmentation_classes: Optional[list[int]] + """A list of up to 5 class ids to use for semantic segmentation. + Automatically creates an image mask based on specific objects.""" + + mask_dilation: Optional[float] + """Dilation percentage of the mask provided. + Float between 0 and 1.""" + + +MaskReferenceConfigOrDict = Union[MaskReferenceConfig, MaskReferenceConfigDict] + + +class ControlReferenceConfig(_common.BaseModel): + """Configuration for a Control reference image.""" + + control_type: Optional[ControlReferenceType] = Field( + default=None, + description="""The type of control reference image to use.""", + ) + enable_control_image_computation: Optional[bool] = Field( + default=None, + description="""Defaults to False. When set to True, the control image will be + computed by the model based on the control type. When set to False, + the control image must be provided by the user.""", + ) + + +class ControlReferenceConfigDict(TypedDict, total=False): + """Configuration for a Control reference image.""" + + control_type: Optional[ControlReferenceType] + """The type of control reference image to use.""" + + enable_control_image_computation: Optional[bool] + """Defaults to False. When set to True, the control image will be + computed by the model based on the control type. When set to False, + the control image must be provided by the user.""" + + +ControlReferenceConfigOrDict = Union[ + ControlReferenceConfig, ControlReferenceConfigDict +] + + +class StyleReferenceConfig(_common.BaseModel): + """Configuration for a Style reference image.""" + + style_description: Optional[str] = Field( + default=None, + description="""A text description of the style to use for the generated image.""", + ) + + +class StyleReferenceConfigDict(TypedDict, total=False): + """Configuration for a Style reference image.""" + + style_description: Optional[str] + """A text description of the style to use for the generated image.""" + + +StyleReferenceConfigOrDict = Union[ + StyleReferenceConfig, StyleReferenceConfigDict +] + + +class SubjectReferenceConfig(_common.BaseModel): + """Configuration for a Subject reference image.""" + + subject_type: Optional[SubjectReferenceType] = Field( + default=None, + description="""The subject type of a subject reference image.""", + ) + subject_description: Optional[str] = Field( + default=None, description="""Subject description for the image.""" + ) + + +class SubjectReferenceConfigDict(TypedDict, total=False): + """Configuration for a Subject reference image.""" + + subject_type: Optional[SubjectReferenceType] + """The subject type of a subject reference image.""" + + subject_description: Optional[str] + """Subject description for the image.""" + + +SubjectReferenceConfigOrDict = Union[ + SubjectReferenceConfig, SubjectReferenceConfigDict +] + + +class _ReferenceImageAPI(_common.BaseModel): + """Private class that represents a Reference image that is sent to API.""" + + reference_image: Optional[Image] = Field( + default=None, + description="""The reference image for the editing operation.""", + ) + reference_id: Optional[int] = Field( + default=None, description="""The id of the reference image.""" + ) + reference_type: Optional[str] = Field( + default=None, description="""The type of the reference image.""" + ) + mask_image_config: Optional[MaskReferenceConfig] = Field( + default=None, + description="""Configuration for the mask reference image.""", + ) + control_image_config: Optional[ControlReferenceConfig] = Field( + default=None, + description="""Configuration for the control reference image.""", + ) + style_image_config: Optional[StyleReferenceConfig] = Field( + default=None, + description="""Configuration for the style reference image.""", + ) + subject_image_config: Optional[SubjectReferenceConfig] = Field( + default=None, + description="""Configuration for the subject reference image.""", + ) + + +class _ReferenceImageAPIDict(TypedDict, total=False): + """Private class that represents a Reference image that is sent to API.""" + + reference_image: Optional[ImageDict] + """The reference image for the editing operation.""" + + reference_id: Optional[int] + """The id of the reference image.""" + + reference_type: Optional[str] + """The type of the reference image.""" + + mask_image_config: Optional[MaskReferenceConfigDict] + """Configuration for the mask reference image.""" + + control_image_config: Optional[ControlReferenceConfigDict] + """Configuration for the control reference image.""" + + style_image_config: Optional[StyleReferenceConfigDict] + """Configuration for the style reference image.""" + + subject_image_config: Optional[SubjectReferenceConfigDict] + """Configuration for the subject reference image.""" + + +_ReferenceImageAPIOrDict = Union[_ReferenceImageAPI, _ReferenceImageAPIDict] + + +class EditImageConfig(_common.BaseModel): + """Configuration for editing an image.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + output_gcs_uri: Optional[str] = Field( + default=None, + description="""Cloud Storage URI used to store the generated images. + """, + ) + negative_prompt: Optional[str] = Field( + default=None, + description="""Description of what to discourage in the generated images. + """, + ) + number_of_images: Optional[int] = Field( + default=None, + description="""Number of images to generate. + """, + ) + guidance_scale: Optional[float] = Field( + default=None, + description="""Controls how much the model adheres to the text prompt. Large + values increase output and prompt alignment, but may compromise image + quality. + """, + ) + seed: Optional[int] = Field( + default=None, + description="""Random seed for image generation. This is not available when + ``add_watermark`` is set to true. + """, + ) + safety_filter_level: Optional[SafetyFilterLevel] = Field( + default=None, + description="""Filter level for safety filtering. + """, + ) + person_generation: Optional[PersonGeneration] = Field( + default=None, + description="""Allows generation of people by the model. + """, + ) + include_safety_attributes: Optional[bool] = Field( + default=None, + description="""Whether to report the safety scores of each image in the response. + """, + ) + include_rai_reason: Optional[bool] = Field( + default=None, + description="""Whether to include the Responsible AI filter reason if the image + is filtered out of the response. + """, + ) + language: Optional[ImagePromptLanguage] = Field( + default=None, + description="""Language of the text in the prompt. + """, + ) + output_mime_type: Optional[str] = Field( + default=None, + description="""MIME type of the generated image. + """, + ) + output_compression_quality: Optional[int] = Field( + default=None, + description="""Compression quality of the generated image (for ``image/jpeg`` + only). + """, + ) + edit_mode: Optional[EditMode] = Field( + default=None, + description="""Describes the editing mode for the request.""", + ) + + +class EditImageConfigDict(TypedDict, total=False): + """Configuration for editing an image.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + output_gcs_uri: Optional[str] + """Cloud Storage URI used to store the generated images. + """ + + negative_prompt: Optional[str] + """Description of what to discourage in the generated images. + """ + + number_of_images: Optional[int] + """Number of images to generate. + """ + + guidance_scale: Optional[float] + """Controls how much the model adheres to the text prompt. Large + values increase output and prompt alignment, but may compromise image + quality. + """ + + seed: Optional[int] + """Random seed for image generation. This is not available when + ``add_watermark`` is set to true. + """ + + safety_filter_level: Optional[SafetyFilterLevel] + """Filter level for safety filtering. + """ + + person_generation: Optional[PersonGeneration] + """Allows generation of people by the model. + """ + + include_safety_attributes: Optional[bool] + """Whether to report the safety scores of each image in the response. + """ + + include_rai_reason: Optional[bool] + """Whether to include the Responsible AI filter reason if the image + is filtered out of the response. + """ + + language: Optional[ImagePromptLanguage] + """Language of the text in the prompt. + """ + + output_mime_type: Optional[str] + """MIME type of the generated image. + """ + + output_compression_quality: Optional[int] + """Compression quality of the generated image (for ``image/jpeg`` + only). + """ + + edit_mode: Optional[EditMode] + """Describes the editing mode for the request.""" + + +EditImageConfigOrDict = Union[EditImageConfig, EditImageConfigDict] + + +class _EditImageParameters(_common.BaseModel): + """Parameters for the request to edit an image.""" + + model: Optional[str] = Field( + default=None, description="""The model to use.""" + ) + prompt: Optional[str] = Field( + default=None, + description="""A text description of the edit to apply to the image.""", + ) + reference_images: Optional[list[_ReferenceImageAPI]] = Field( + default=None, description="""The reference images for Imagen 3 editing.""" + ) + config: Optional[EditImageConfig] = Field( + default=None, description="""Configuration for editing.""" + ) + + +class _EditImageParametersDict(TypedDict, total=False): + """Parameters for the request to edit an image.""" + + model: Optional[str] + """The model to use.""" + + prompt: Optional[str] + """A text description of the edit to apply to the image.""" + + reference_images: Optional[list[_ReferenceImageAPIDict]] + """The reference images for Imagen 3 editing.""" + + config: Optional[EditImageConfigDict] + """Configuration for editing.""" + + +_EditImageParametersOrDict = Union[ + _EditImageParameters, _EditImageParametersDict +] + + +class EditImageResponse(_common.BaseModel): + """Response for the request to edit an image.""" + + generated_images: Optional[list[GeneratedImage]] = Field( + default=None, description="""Generated images.""" + ) + + +class EditImageResponseDict(TypedDict, total=False): + """Response for the request to edit an image.""" + + generated_images: Optional[list[GeneratedImageDict]] + """Generated images.""" + + +EditImageResponseOrDict = Union[EditImageResponse, EditImageResponseDict] + + +class _UpscaleImageAPIConfig(_common.BaseModel): + """API config for UpscaleImage with fields not exposed to users. + + These fields require default values sent to the API which are not intended + to be modifiable or exposed to users in the SDK method. + """ + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + include_rai_reason: Optional[bool] = Field( + default=None, + description="""Whether to include a reason for filtered-out images in the + response.""", + ) + output_mime_type: Optional[str] = Field( + default=None, + description="""The image format that the output should be saved as.""", + ) + output_compression_quality: Optional[int] = Field( + default=None, + description="""The level of compression if the ``output_mime_type`` is + ``image/jpeg``.""", + ) + number_of_images: Optional[int] = Field(default=None, description="""""") + mode: Optional[str] = Field(default=None, description="""""") + + +class _UpscaleImageAPIConfigDict(TypedDict, total=False): + """API config for UpscaleImage with fields not exposed to users. + + These fields require default values sent to the API which are not intended + to be modifiable or exposed to users in the SDK method. + """ + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + include_rai_reason: Optional[bool] + """Whether to include a reason for filtered-out images in the + response.""" + + output_mime_type: Optional[str] + """The image format that the output should be saved as.""" + + output_compression_quality: Optional[int] + """The level of compression if the ``output_mime_type`` is + ``image/jpeg``.""" + + number_of_images: Optional[int] + """""" + + mode: Optional[str] + """""" + + +_UpscaleImageAPIConfigOrDict = Union[ + _UpscaleImageAPIConfig, _UpscaleImageAPIConfigDict +] + + +class _UpscaleImageAPIParameters(_common.BaseModel): + """API parameters for UpscaleImage.""" + + model: Optional[str] = Field( + default=None, description="""The model to use.""" + ) + image: Optional[Image] = Field( + default=None, description="""The input image to upscale.""" + ) + upscale_factor: Optional[str] = Field( + default=None, + description="""The factor to upscale the image (x2 or x4).""", + ) + config: Optional[_UpscaleImageAPIConfig] = Field( + default=None, description="""Configuration for upscaling.""" + ) + + +class _UpscaleImageAPIParametersDict(TypedDict, total=False): + """API parameters for UpscaleImage.""" + + model: Optional[str] + """The model to use.""" + + image: Optional[ImageDict] + """The input image to upscale.""" + + upscale_factor: Optional[str] + """The factor to upscale the image (x2 or x4).""" + + config: Optional[_UpscaleImageAPIConfigDict] + """Configuration for upscaling.""" + + +_UpscaleImageAPIParametersOrDict = Union[ + _UpscaleImageAPIParameters, _UpscaleImageAPIParametersDict +] + + +class UpscaleImageResponse(_common.BaseModel): + + generated_images: Optional[list[GeneratedImage]] = Field( + default=None, description="""Generated images.""" + ) + + +class UpscaleImageResponseDict(TypedDict, total=False): + + generated_images: Optional[list[GeneratedImageDict]] + """Generated images.""" + + +UpscaleImageResponseOrDict = Union[ + UpscaleImageResponse, UpscaleImageResponseDict +] + + +class _GetModelParameters(_common.BaseModel): + + model: Optional[str] = Field(default=None, description="""""") + + +class _GetModelParametersDict(TypedDict, total=False): + + model: Optional[str] + """""" + + +_GetModelParametersOrDict = Union[_GetModelParameters, _GetModelParametersDict] + + +class Endpoint(_common.BaseModel): + """An endpoint where you deploy models.""" + + name: Optional[str] = Field( + default=None, description="""Resource name of the endpoint.""" + ) + deployed_model_id: Optional[str] = Field( + default=None, + description="""ID of the model that's deployed to the endpoint.""", + ) + + +class EndpointDict(TypedDict, total=False): + """An endpoint where you deploy models.""" + + name: Optional[str] + """Resource name of the endpoint.""" + + deployed_model_id: Optional[str] + """ID of the model that's deployed to the endpoint.""" + + +EndpointOrDict = Union[Endpoint, EndpointDict] + + +class TunedModelInfo(_common.BaseModel): + """A tuned machine learning model.""" + + base_model: Optional[str] = Field( + default=None, + description="""ID of the base model that you want to tune.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Date and time when the base model was created.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Date and time when the base model was last updated.""", + ) + + +class TunedModelInfoDict(TypedDict, total=False): + """A tuned machine learning model.""" + + base_model: Optional[str] + """ID of the base model that you want to tune.""" + + create_time: Optional[datetime.datetime] + """Date and time when the base model was created.""" + + update_time: Optional[datetime.datetime] + """Date and time when the base model was last updated.""" + + +TunedModelInfoOrDict = Union[TunedModelInfo, TunedModelInfoDict] + + +class Model(_common.BaseModel): + """A trained machine learning model.""" + + name: Optional[str] = Field( + default=None, description="""Resource name of the model.""" + ) + display_name: Optional[str] = Field( + default=None, description="""Display name of the model.""" + ) + description: Optional[str] = Field( + default=None, description="""Description of the model.""" + ) + version: Optional[str] = Field( + default=None, + description="""Version ID of the model. A new version is committed when a new + model version is uploaded or trained under an existing model ID. The + version ID is an auto-incrementing decimal number in string + representation.""", + ) + endpoints: Optional[list[Endpoint]] = Field( + default=None, + description="""List of deployed models created from this base model. Note that a + model could have been deployed to endpoints in different locations.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, + description="""Labels with user-defined metadata to organize your models.""", + ) + tuned_model_info: Optional[TunedModelInfo] = Field( + default=None, + description="""Information about the tuned model from the base model.""", + ) + input_token_limit: Optional[int] = Field( + default=None, + description="""The maximum number of input tokens that the model can handle.""", + ) + output_token_limit: Optional[int] = Field( + default=None, + description="""The maximum number of output tokens that the model can generate.""", + ) + supported_actions: Optional[list[str]] = Field( + default=None, + description="""List of actions that are supported by the model.""", + ) + + +class ModelDict(TypedDict, total=False): + """A trained machine learning model.""" + + name: Optional[str] + """Resource name of the model.""" + + display_name: Optional[str] + """Display name of the model.""" + + description: Optional[str] + """Description of the model.""" + + version: Optional[str] + """Version ID of the model. A new version is committed when a new + model version is uploaded or trained under an existing model ID. The + version ID is an auto-incrementing decimal number in string + representation.""" + + endpoints: Optional[list[EndpointDict]] + """List of deployed models created from this base model. Note that a + model could have been deployed to endpoints in different locations.""" + + labels: Optional[dict[str, str]] + """Labels with user-defined metadata to organize your models.""" + + tuned_model_info: Optional[TunedModelInfoDict] + """Information about the tuned model from the base model.""" + + input_token_limit: Optional[int] + """The maximum number of input tokens that the model can handle.""" + + output_token_limit: Optional[int] + """The maximum number of output tokens that the model can generate.""" + + supported_actions: Optional[list[str]] + """List of actions that are supported by the model.""" + + +ModelOrDict = Union[Model, ModelDict] + + +class ListModelsConfig(_common.BaseModel): + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field(default=None, description="""""") + query_base: Optional[bool] = Field( + default=None, + description="""Set true to list base models, false to list tuned models.""", + ) + + +class ListModelsConfigDict(TypedDict, total=False): + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """""" + + query_base: Optional[bool] + """Set true to list base models, false to list tuned models.""" + + +ListModelsConfigOrDict = Union[ListModelsConfig, ListModelsConfigDict] + + +class _ListModelsParameters(_common.BaseModel): + + config: Optional[ListModelsConfig] = Field(default=None, description="""""") + + +class _ListModelsParametersDict(TypedDict, total=False): + + config: Optional[ListModelsConfigDict] + """""" + + +_ListModelsParametersOrDict = Union[ + _ListModelsParameters, _ListModelsParametersDict +] + + +class ListModelsResponse(_common.BaseModel): + + next_page_token: Optional[str] = Field(default=None, description="""""") + models: Optional[list[Model]] = Field(default=None, description="""""") + + +class ListModelsResponseDict(TypedDict, total=False): + + next_page_token: Optional[str] + """""" + + models: Optional[list[ModelDict]] + """""" + + +ListModelsResponseOrDict = Union[ListModelsResponse, ListModelsResponseDict] + + +class UpdateModelConfig(_common.BaseModel): + + display_name: Optional[str] = Field(default=None, description="""""") + description: Optional[str] = Field(default=None, description="""""") + + +class UpdateModelConfigDict(TypedDict, total=False): + + display_name: Optional[str] + """""" + + description: Optional[str] + """""" + + +UpdateModelConfigOrDict = Union[UpdateModelConfig, UpdateModelConfigDict] + + +class _UpdateModelParameters(_common.BaseModel): + + model: Optional[str] = Field(default=None, description="""""") + config: Optional[UpdateModelConfig] = Field(default=None, description="""""") + + +class _UpdateModelParametersDict(TypedDict, total=False): + + model: Optional[str] + """""" + + config: Optional[UpdateModelConfigDict] + """""" + + +_UpdateModelParametersOrDict = Union[ + _UpdateModelParameters, _UpdateModelParametersDict +] + + +class _DeleteModelParameters(_common.BaseModel): + + model: Optional[str] = Field(default=None, description="""""") + + +class _DeleteModelParametersDict(TypedDict, total=False): + + model: Optional[str] + """""" + + +_DeleteModelParametersOrDict = Union[ + _DeleteModelParameters, _DeleteModelParametersDict +] + + +class DeleteModelResponse(_common.BaseModel): + + pass + + +class DeleteModelResponseDict(TypedDict, total=False): + + pass + + +DeleteModelResponseOrDict = Union[DeleteModelResponse, DeleteModelResponseDict] + + +class GenerationConfig(_common.BaseModel): + """Generation config.""" + + audio_timestamp: Optional[bool] = Field( + default=None, + description="""Optional. If enabled, audio timestamp will be included in the request to the model.""", + ) + candidate_count: Optional[int] = Field( + default=None, + description="""Optional. Number of candidates to generate.""", + ) + frequency_penalty: Optional[float] = Field( + default=None, description="""Optional. Frequency penalties.""" + ) + logprobs: Optional[int] = Field( + default=None, description="""Optional. Logit probabilities.""" + ) + max_output_tokens: Optional[int] = Field( + default=None, + description="""Optional. The maximum number of output tokens to generate per message.""", + ) + presence_penalty: Optional[float] = Field( + default=None, description="""Optional. Positive penalties.""" + ) + response_logprobs: Optional[bool] = Field( + default=None, + description="""Optional. If true, export the logprobs results in response.""", + ) + response_mime_type: Optional[str] = Field( + default=None, + description="""Optional. Output response mimetype of the generated candidate text. Supported mimetype: - `text/plain`: (default) Text output. - `application/json`: JSON response in the candidates. The model needs to be prompted to output the appropriate response type, otherwise the behavior is undefined. This is a preview feature.""", + ) + response_schema: Optional[Schema] = Field( + default=None, + description="""Optional. The `Schema` object allows the definition of input and output data types. These types can be objects, but also primitives and arrays. Represents a select subset of an [OpenAPI 3.0 schema object](https://spec.openapis.org/oas/v3.0.3#schema). If set, a compatible response_mime_type must also be set. Compatible mimetypes: `application/json`: Schema for JSON response.""", + ) + routing_config: Optional[GenerationConfigRoutingConfig] = Field( + default=None, description="""Optional. Routing configuration.""" + ) + seed: Optional[int] = Field(default=None, description="""Optional. Seed.""") + stop_sequences: Optional[list[str]] = Field( + default=None, description="""Optional. Stop sequences.""" + ) + temperature: Optional[float] = Field( + default=None, + description="""Optional. Controls the randomness of predictions.""", + ) + top_k: Optional[float] = Field( + default=None, + description="""Optional. If specified, top-k sampling will be used.""", + ) + top_p: Optional[float] = Field( + default=None, + description="""Optional. If specified, nucleus sampling will be used.""", + ) + + +class GenerationConfigDict(TypedDict, total=False): + """Generation config.""" + + audio_timestamp: Optional[bool] + """Optional. If enabled, audio timestamp will be included in the request to the model.""" + + candidate_count: Optional[int] + """Optional. Number of candidates to generate.""" + + frequency_penalty: Optional[float] + """Optional. Frequency penalties.""" + + logprobs: Optional[int] + """Optional. Logit probabilities.""" + + max_output_tokens: Optional[int] + """Optional. The maximum number of output tokens to generate per message.""" + + presence_penalty: Optional[float] + """Optional. Positive penalties.""" + + response_logprobs: Optional[bool] + """Optional. If true, export the logprobs results in response.""" + + response_mime_type: Optional[str] + """Optional. Output response mimetype of the generated candidate text. Supported mimetype: - `text/plain`: (default) Text output. - `application/json`: JSON response in the candidates. The model needs to be prompted to output the appropriate response type, otherwise the behavior is undefined. This is a preview feature.""" + + response_schema: Optional[SchemaDict] + """Optional. The `Schema` object allows the definition of input and output data types. These types can be objects, but also primitives and arrays. Represents a select subset of an [OpenAPI 3.0 schema object](https://spec.openapis.org/oas/v3.0.3#schema). If set, a compatible response_mime_type must also be set. Compatible mimetypes: `application/json`: Schema for JSON response.""" + + routing_config: Optional[GenerationConfigRoutingConfigDict] + """Optional. Routing configuration.""" + + seed: Optional[int] + """Optional. Seed.""" + + stop_sequences: Optional[list[str]] + """Optional. Stop sequences.""" + + temperature: Optional[float] + """Optional. Controls the randomness of predictions.""" + + top_k: Optional[float] + """Optional. If specified, top-k sampling will be used.""" + + top_p: Optional[float] + """Optional. If specified, nucleus sampling will be used.""" + + +GenerationConfigOrDict = Union[GenerationConfig, GenerationConfigDict] + + +class CountTokensConfig(_common.BaseModel): + """Config for the count_tokens method.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + system_instruction: Optional[ContentUnion] = Field( + default=None, + description="""Instructions for the model to steer it toward better performance. + """, + ) + tools: Optional[list[Tool]] = Field( + default=None, + description="""Code that enables the system to interact with external systems to + perform an action outside of the knowledge and scope of the model. + """, + ) + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="""Configuration that the model uses to generate the response. Not + supported by the Gemini Developer API. + """, + ) + + +class CountTokensConfigDict(TypedDict, total=False): + """Config for the count_tokens method.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + system_instruction: Optional[ContentUnionDict] + """Instructions for the model to steer it toward better performance. + """ + + tools: Optional[list[ToolDict]] + """Code that enables the system to interact with external systems to + perform an action outside of the knowledge and scope of the model. + """ + + generation_config: Optional[GenerationConfigDict] + """Configuration that the model uses to generate the response. Not + supported by the Gemini Developer API. + """ + + +CountTokensConfigOrDict = Union[CountTokensConfig, CountTokensConfigDict] + + +class _CountTokensParameters(_common.BaseModel): + """Parameters for counting tokens.""" + + model: Optional[str] = Field( + default=None, + description="""ID of the model to use. For a list of models, see `Google models + <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models>`_.""", + ) + contents: Optional[ContentListUnion] = Field( + default=None, description="""Input content.""" + ) + config: Optional[CountTokensConfig] = Field( + default=None, description="""Configuration for counting tokens.""" + ) + + +class _CountTokensParametersDict(TypedDict, total=False): + """Parameters for counting tokens.""" + + model: Optional[str] + """ID of the model to use. For a list of models, see `Google models + <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models>`_.""" + + contents: Optional[ContentListUnionDict] + """Input content.""" + + config: Optional[CountTokensConfigDict] + """Configuration for counting tokens.""" + + +_CountTokensParametersOrDict = Union[ + _CountTokensParameters, _CountTokensParametersDict +] + + +class CountTokensResponse(_common.BaseModel): + """Response for counting tokens.""" + + total_tokens: Optional[int] = Field( + default=None, description="""Total number of tokens.""" + ) + cached_content_token_count: Optional[int] = Field( + default=None, + description="""Number of tokens in the cached part of the prompt (the cached content).""", + ) + + +class CountTokensResponseDict(TypedDict, total=False): + """Response for counting tokens.""" + + total_tokens: Optional[int] + """Total number of tokens.""" + + cached_content_token_count: Optional[int] + """Number of tokens in the cached part of the prompt (the cached content).""" + + +CountTokensResponseOrDict = Union[CountTokensResponse, CountTokensResponseDict] + + +class ComputeTokensConfig(_common.BaseModel): + """Optional parameters for computing tokens.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class ComputeTokensConfigDict(TypedDict, total=False): + """Optional parameters for computing tokens.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + +ComputeTokensConfigOrDict = Union[ComputeTokensConfig, ComputeTokensConfigDict] + + +class _ComputeTokensParameters(_common.BaseModel): + """Parameters for computing tokens.""" + + model: Optional[str] = Field( + default=None, + description="""ID of the model to use. For a list of models, see `Google models + <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models>`_.""", + ) + contents: Optional[ContentListUnion] = Field( + default=None, description="""Input content.""" + ) + config: Optional[ComputeTokensConfig] = Field( + default=None, + description="""Optional parameters for the request. + """, + ) + + +class _ComputeTokensParametersDict(TypedDict, total=False): + """Parameters for computing tokens.""" + + model: Optional[str] + """ID of the model to use. For a list of models, see `Google models + <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models>`_.""" + + contents: Optional[ContentListUnionDict] + """Input content.""" + + config: Optional[ComputeTokensConfigDict] + """Optional parameters for the request. + """ + + +_ComputeTokensParametersOrDict = Union[ + _ComputeTokensParameters, _ComputeTokensParametersDict +] + + +class TokensInfo(_common.BaseModel): + """Tokens info with a list of tokens and the corresponding list of token ids.""" + + role: Optional[str] = Field( + default=None, + description="""Optional. Optional fields for the role from the corresponding Content.""", + ) + token_ids: Optional[list[int]] = Field( + default=None, description="""A list of token ids from the input.""" + ) + tokens: Optional[list[bytes]] = Field( + default=None, description="""A list of tokens from the input.""" + ) + + +class TokensInfoDict(TypedDict, total=False): + """Tokens info with a list of tokens and the corresponding list of token ids.""" + + role: Optional[str] + """Optional. Optional fields for the role from the corresponding Content.""" + + token_ids: Optional[list[int]] + """A list of token ids from the input.""" + + tokens: Optional[list[bytes]] + """A list of tokens from the input.""" + + +TokensInfoOrDict = Union[TokensInfo, TokensInfoDict] + + +class ComputeTokensResponse(_common.BaseModel): + """Response for computing tokens.""" + + tokens_info: Optional[list[TokensInfo]] = Field( + default=None, + description="""Lists of tokens info from the input. A ComputeTokensRequest could have multiple instances with a prompt in each instance. We also need to return lists of tokens info for the request with multiple instances.""", + ) + + +class ComputeTokensResponseDict(TypedDict, total=False): + """Response for computing tokens.""" + + tokens_info: Optional[list[TokensInfoDict]] + """Lists of tokens info from the input. A ComputeTokensRequest could have multiple instances with a prompt in each instance. We also need to return lists of tokens info for the request with multiple instances.""" + + +ComputeTokensResponseOrDict = Union[ + ComputeTokensResponse, ComputeTokensResponseDict +] + + +class GetTuningJobConfig(_common.BaseModel): + """Optional parameters for tunings.get method.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetTuningJobConfigDict(TypedDict, total=False): + """Optional parameters for tunings.get method.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + +GetTuningJobConfigOrDict = Union[GetTuningJobConfig, GetTuningJobConfigDict] + + +class _GetTuningJobParameters(_common.BaseModel): + """Parameters for the get method.""" + + name: Optional[str] = Field(default=None, description="""""") + config: Optional[GetTuningJobConfig] = Field( + default=None, description="""Optional parameters for the request.""" + ) + + +class _GetTuningJobParametersDict(TypedDict, total=False): + """Parameters for the get method.""" + + name: Optional[str] + """""" + + config: Optional[GetTuningJobConfigDict] + """Optional parameters for the request.""" + + +_GetTuningJobParametersOrDict = Union[ + _GetTuningJobParameters, _GetTuningJobParametersDict +] + + +class TunedModel(_common.BaseModel): + + model: Optional[str] = Field( + default=None, + description="""Output only. The resource name of the TunedModel. Format: `projects/{project}/locations/{location}/models/{model}`.""", + ) + endpoint: Optional[str] = Field( + default=None, + description="""Output only. A resource name of an Endpoint. Format: `projects/{project}/locations/{location}/endpoints/{endpoint}`.""", + ) + + +class TunedModelDict(TypedDict, total=False): + + model: Optional[str] + """Output only. The resource name of the TunedModel. Format: `projects/{project}/locations/{location}/models/{model}`.""" + + endpoint: Optional[str] + """Output only. A resource name of an Endpoint. Format: `projects/{project}/locations/{location}/endpoints/{endpoint}`.""" + + +TunedModelOrDict = Union[TunedModel, TunedModelDict] + + +class GoogleRpcStatus(_common.BaseModel): + """The `Status` type defines a logical error model that is suitable for different programming environments, including REST APIs and RPC APIs. + + It is used by [gRPC](https://github.com/grpc). Each `Status` message contains + three pieces of data: error code, error message, and error details. You can + find out more about this error model and how to work with it in the [API + Design Guide](https://cloud.google.com/apis/design/errors). + """ + + code: Optional[int] = Field( + default=None, + description="""The status code, which should be an enum value of google.rpc.Code.""", + ) + details: Optional[list[dict[str, Any]]] = Field( + default=None, + description="""A list of messages that carry the error details. There is a common set of message types for APIs to use.""", + ) + message: Optional[str] = Field( + default=None, + description="""A developer-facing error message, which should be in English. Any user-facing error message should be localized and sent in the google.rpc.Status.details field, or localized by the client.""", + ) + + +class GoogleRpcStatusDict(TypedDict, total=False): + """The `Status` type defines a logical error model that is suitable for different programming environments, including REST APIs and RPC APIs. + + It is used by [gRPC](https://github.com/grpc). Each `Status` message contains + three pieces of data: error code, error message, and error details. You can + find out more about this error model and how to work with it in the [API + Design Guide](https://cloud.google.com/apis/design/errors). + """ + + code: Optional[int] + """The status code, which should be an enum value of google.rpc.Code.""" + + details: Optional[list[dict[str, Any]]] + """A list of messages that carry the error details. There is a common set of message types for APIs to use.""" + + message: Optional[str] + """A developer-facing error message, which should be in English. Any user-facing error message should be localized and sent in the google.rpc.Status.details field, or localized by the client.""" + + +GoogleRpcStatusOrDict = Union[GoogleRpcStatus, GoogleRpcStatusDict] + + +class SupervisedHyperParameters(_common.BaseModel): + """Hyperparameters for SFT.""" + + adapter_size: Optional[AdapterSize] = Field( + default=None, description="""Optional. Adapter size for tuning.""" + ) + epoch_count: Optional[int] = Field( + default=None, + description="""Optional. Number of complete passes the model makes over the entire training dataset during training.""", + ) + learning_rate_multiplier: Optional[float] = Field( + default=None, + description="""Optional. Multiplier for adjusting the default learning rate.""", + ) + + +class SupervisedHyperParametersDict(TypedDict, total=False): + """Hyperparameters for SFT.""" + + adapter_size: Optional[AdapterSize] + """Optional. Adapter size for tuning.""" + + epoch_count: Optional[int] + """Optional. Number of complete passes the model makes over the entire training dataset during training.""" + + learning_rate_multiplier: Optional[float] + """Optional. Multiplier for adjusting the default learning rate.""" + + +SupervisedHyperParametersOrDict = Union[ + SupervisedHyperParameters, SupervisedHyperParametersDict +] + + +class SupervisedTuningSpec(_common.BaseModel): + """Tuning Spec for Supervised Tuning for first party models.""" + + hyper_parameters: Optional[SupervisedHyperParameters] = Field( + default=None, description="""Optional. Hyperparameters for SFT.""" + ) + training_dataset_uri: Optional[str] = Field( + default=None, + description="""Required. Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + validation_dataset_uri: Optional[str] = Field( + default=None, + description="""Optional. Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + + +class SupervisedTuningSpecDict(TypedDict, total=False): + """Tuning Spec for Supervised Tuning for first party models.""" + + hyper_parameters: Optional[SupervisedHyperParametersDict] + """Optional. Hyperparameters for SFT.""" + + training_dataset_uri: Optional[str] + """Required. Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""" + + validation_dataset_uri: Optional[str] + """Optional. Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file.""" + + +SupervisedTuningSpecOrDict = Union[ + SupervisedTuningSpec, SupervisedTuningSpecDict +] + + +class DatasetDistributionDistributionBucket(_common.BaseModel): + """Dataset bucket used to create a histogram for the distribution given a population of values.""" + + count: Optional[int] = Field( + default=None, + description="""Output only. Number of values in the bucket.""", + ) + left: Optional[float] = Field( + default=None, description="""Output only. Left bound of the bucket.""" + ) + right: Optional[float] = Field( + default=None, description="""Output only. Right bound of the bucket.""" + ) + + +class DatasetDistributionDistributionBucketDict(TypedDict, total=False): + """Dataset bucket used to create a histogram for the distribution given a population of values.""" + + count: Optional[int] + """Output only. Number of values in the bucket.""" + + left: Optional[float] + """Output only. Left bound of the bucket.""" + + right: Optional[float] + """Output only. Right bound of the bucket.""" + + +DatasetDistributionDistributionBucketOrDict = Union[ + DatasetDistributionDistributionBucket, + DatasetDistributionDistributionBucketDict, +] + + +class DatasetDistribution(_common.BaseModel): + """Distribution computed over a tuning dataset.""" + + buckets: Optional[list[DatasetDistributionDistributionBucket]] = Field( + default=None, description="""Output only. Defines the histogram bucket.""" + ) + max: Optional[float] = Field( + default=None, + description="""Output only. The maximum of the population values.""", + ) + mean: Optional[float] = Field( + default=None, + description="""Output only. The arithmetic mean of the values in the population.""", + ) + median: Optional[float] = Field( + default=None, + description="""Output only. The median of the values in the population.""", + ) + min: Optional[float] = Field( + default=None, + description="""Output only. The minimum of the population values.""", + ) + p5: Optional[float] = Field( + default=None, + description="""Output only. The 5th percentile of the values in the population.""", + ) + p95: Optional[float] = Field( + default=None, + description="""Output only. The 95th percentile of the values in the population.""", + ) + sum: Optional[float] = Field( + default=None, + description="""Output only. Sum of a given population of values.""", + ) + + +class DatasetDistributionDict(TypedDict, total=False): + """Distribution computed over a tuning dataset.""" + + buckets: Optional[list[DatasetDistributionDistributionBucketDict]] + """Output only. Defines the histogram bucket.""" + + max: Optional[float] + """Output only. The maximum of the population values.""" + + mean: Optional[float] + """Output only. The arithmetic mean of the values in the population.""" + + median: Optional[float] + """Output only. The median of the values in the population.""" + + min: Optional[float] + """Output only. The minimum of the population values.""" + + p5: Optional[float] + """Output only. The 5th percentile of the values in the population.""" + + p95: Optional[float] + """Output only. The 95th percentile of the values in the population.""" + + sum: Optional[float] + """Output only. Sum of a given population of values.""" + + +DatasetDistributionOrDict = Union[DatasetDistribution, DatasetDistributionDict] + + +class DatasetStats(_common.BaseModel): + """Statistics computed over a tuning dataset.""" + + total_billable_character_count: Optional[int] = Field( + default=None, + description="""Output only. Number of billable characters in the tuning dataset.""", + ) + total_tuning_character_count: Optional[int] = Field( + default=None, + description="""Output only. Number of tuning characters in the tuning dataset.""", + ) + tuning_dataset_example_count: Optional[int] = Field( + default=None, + description="""Output only. Number of examples in the tuning dataset.""", + ) + tuning_step_count: Optional[int] = Field( + default=None, + description="""Output only. Number of tuning steps for this Tuning Job.""", + ) + user_dataset_examples: Optional[list[Content]] = Field( + default=None, + description="""Output only. Sample user messages in the training dataset uri.""", + ) + user_input_token_distribution: Optional[DatasetDistribution] = Field( + default=None, + description="""Output only. Dataset distributions for the user input tokens.""", + ) + user_message_per_example_distribution: Optional[DatasetDistribution] = Field( + default=None, + description="""Output only. Dataset distributions for the messages per example.""", + ) + user_output_token_distribution: Optional[DatasetDistribution] = Field( + default=None, + description="""Output only. Dataset distributions for the user output tokens.""", + ) + + +class DatasetStatsDict(TypedDict, total=False): + """Statistics computed over a tuning dataset.""" + + total_billable_character_count: Optional[int] + """Output only. Number of billable characters in the tuning dataset.""" + + total_tuning_character_count: Optional[int] + """Output only. Number of tuning characters in the tuning dataset.""" + + tuning_dataset_example_count: Optional[int] + """Output only. Number of examples in the tuning dataset.""" + + tuning_step_count: Optional[int] + """Output only. Number of tuning steps for this Tuning Job.""" + + user_dataset_examples: Optional[list[ContentDict]] + """Output only. Sample user messages in the training dataset uri.""" + + user_input_token_distribution: Optional[DatasetDistributionDict] + """Output only. Dataset distributions for the user input tokens.""" + + user_message_per_example_distribution: Optional[DatasetDistributionDict] + """Output only. Dataset distributions for the messages per example.""" + + user_output_token_distribution: Optional[DatasetDistributionDict] + """Output only. Dataset distributions for the user output tokens.""" + + +DatasetStatsOrDict = Union[DatasetStats, DatasetStatsDict] + + +class DistillationDataStats(_common.BaseModel): + """Statistics computed for datasets used for distillation.""" + + training_dataset_stats: Optional[DatasetStats] = Field( + default=None, + description="""Output only. Statistics computed for the training dataset.""", + ) + + +class DistillationDataStatsDict(TypedDict, total=False): + """Statistics computed for datasets used for distillation.""" + + training_dataset_stats: Optional[DatasetStatsDict] + """Output only. Statistics computed for the training dataset.""" + + +DistillationDataStatsOrDict = Union[ + DistillationDataStats, DistillationDataStatsDict +] + + +class SupervisedTuningDatasetDistributionDatasetBucket(_common.BaseModel): + """Dataset bucket used to create a histogram for the distribution given a population of values.""" + + count: Optional[float] = Field( + default=None, + description="""Output only. Number of values in the bucket.""", + ) + left: Optional[float] = Field( + default=None, description="""Output only. Left bound of the bucket.""" + ) + right: Optional[float] = Field( + default=None, description="""Output only. Right bound of the bucket.""" + ) + + +class SupervisedTuningDatasetDistributionDatasetBucketDict( + TypedDict, total=False +): + """Dataset bucket used to create a histogram for the distribution given a population of values.""" + + count: Optional[float] + """Output only. Number of values in the bucket.""" + + left: Optional[float] + """Output only. Left bound of the bucket.""" + + right: Optional[float] + """Output only. Right bound of the bucket.""" + + +SupervisedTuningDatasetDistributionDatasetBucketOrDict = Union[ + SupervisedTuningDatasetDistributionDatasetBucket, + SupervisedTuningDatasetDistributionDatasetBucketDict, +] + + +class SupervisedTuningDatasetDistribution(_common.BaseModel): + """Dataset distribution for Supervised Tuning.""" + + billable_sum: Optional[int] = Field( + default=None, + description="""Output only. Sum of a given population of values that are billable.""", + ) + buckets: Optional[list[SupervisedTuningDatasetDistributionDatasetBucket]] = ( + Field( + default=None, + description="""Output only. Defines the histogram bucket.""", + ) + ) + max: Optional[float] = Field( + default=None, + description="""Output only. The maximum of the population values.""", + ) + mean: Optional[float] = Field( + default=None, + description="""Output only. The arithmetic mean of the values in the population.""", + ) + median: Optional[float] = Field( + default=None, + description="""Output only. The median of the values in the population.""", + ) + min: Optional[float] = Field( + default=None, + description="""Output only. The minimum of the population values.""", + ) + p5: Optional[float] = Field( + default=None, + description="""Output only. The 5th percentile of the values in the population.""", + ) + p95: Optional[float] = Field( + default=None, + description="""Output only. The 95th percentile of the values in the population.""", + ) + sum: Optional[int] = Field( + default=None, + description="""Output only. Sum of a given population of values.""", + ) + + +class SupervisedTuningDatasetDistributionDict(TypedDict, total=False): + """Dataset distribution for Supervised Tuning.""" + + billable_sum: Optional[int] + """Output only. Sum of a given population of values that are billable.""" + + buckets: Optional[list[SupervisedTuningDatasetDistributionDatasetBucketDict]] + """Output only. Defines the histogram bucket.""" + + max: Optional[float] + """Output only. The maximum of the population values.""" + + mean: Optional[float] + """Output only. The arithmetic mean of the values in the population.""" + + median: Optional[float] + """Output only. The median of the values in the population.""" + + min: Optional[float] + """Output only. The minimum of the population values.""" + + p5: Optional[float] + """Output only. The 5th percentile of the values in the population.""" + + p95: Optional[float] + """Output only. The 95th percentile of the values in the population.""" + + sum: Optional[int] + """Output only. Sum of a given population of values.""" + + +SupervisedTuningDatasetDistributionOrDict = Union[ + SupervisedTuningDatasetDistribution, SupervisedTuningDatasetDistributionDict +] + + +class SupervisedTuningDataStats(_common.BaseModel): + """Tuning data statistics for Supervised Tuning.""" + + total_billable_character_count: Optional[int] = Field( + default=None, + description="""Output only. Number of billable characters in the tuning dataset.""", + ) + total_billable_token_count: Optional[int] = Field( + default=None, + description="""Output only. Number of billable tokens in the tuning dataset.""", + ) + total_truncated_example_count: Optional[int] = Field( + default=None, + description="""The number of examples in the dataset that have been truncated by any amount.""", + ) + total_tuning_character_count: Optional[int] = Field( + default=None, + description="""Output only. Number of tuning characters in the tuning dataset.""", + ) + truncated_example_indices: Optional[list[int]] = Field( + default=None, + description="""A partial sample of the indices (starting from 1) of the truncated examples.""", + ) + tuning_dataset_example_count: Optional[int] = Field( + default=None, + description="""Output only. Number of examples in the tuning dataset.""", + ) + tuning_step_count: Optional[int] = Field( + default=None, + description="""Output only. Number of tuning steps for this Tuning Job.""", + ) + user_dataset_examples: Optional[list[Content]] = Field( + default=None, + description="""Output only. Sample user messages in the training dataset uri.""", + ) + user_input_token_distribution: Optional[ + SupervisedTuningDatasetDistribution + ] = Field( + default=None, + description="""Output only. Dataset distributions for the user input tokens.""", + ) + user_message_per_example_distribution: Optional[ + SupervisedTuningDatasetDistribution + ] = Field( + default=None, + description="""Output only. Dataset distributions for the messages per example.""", + ) + user_output_token_distribution: Optional[ + SupervisedTuningDatasetDistribution + ] = Field( + default=None, + description="""Output only. Dataset distributions for the user output tokens.""", + ) + + +class SupervisedTuningDataStatsDict(TypedDict, total=False): + """Tuning data statistics for Supervised Tuning.""" + + total_billable_character_count: Optional[int] + """Output only. Number of billable characters in the tuning dataset.""" + + total_billable_token_count: Optional[int] + """Output only. Number of billable tokens in the tuning dataset.""" + + total_truncated_example_count: Optional[int] + """The number of examples in the dataset that have been truncated by any amount.""" + + total_tuning_character_count: Optional[int] + """Output only. Number of tuning characters in the tuning dataset.""" + + truncated_example_indices: Optional[list[int]] + """A partial sample of the indices (starting from 1) of the truncated examples.""" + + tuning_dataset_example_count: Optional[int] + """Output only. Number of examples in the tuning dataset.""" + + tuning_step_count: Optional[int] + """Output only. Number of tuning steps for this Tuning Job.""" + + user_dataset_examples: Optional[list[ContentDict]] + """Output only. Sample user messages in the training dataset uri.""" + + user_input_token_distribution: Optional[ + SupervisedTuningDatasetDistributionDict + ] + """Output only. Dataset distributions for the user input tokens.""" + + user_message_per_example_distribution: Optional[ + SupervisedTuningDatasetDistributionDict + ] + """Output only. Dataset distributions for the messages per example.""" + + user_output_token_distribution: Optional[ + SupervisedTuningDatasetDistributionDict + ] + """Output only. Dataset distributions for the user output tokens.""" + + +SupervisedTuningDataStatsOrDict = Union[ + SupervisedTuningDataStats, SupervisedTuningDataStatsDict +] + + +class TuningDataStats(_common.BaseModel): + """The tuning data statistic values for TuningJob.""" + + distillation_data_stats: Optional[DistillationDataStats] = Field( + default=None, description="""Output only. Statistics for distillation.""" + ) + supervised_tuning_data_stats: Optional[SupervisedTuningDataStats] = Field( + default=None, description="""The SFT Tuning data stats.""" + ) + + +class TuningDataStatsDict(TypedDict, total=False): + """The tuning data statistic values for TuningJob.""" + + distillation_data_stats: Optional[DistillationDataStatsDict] + """Output only. Statistics for distillation.""" + + supervised_tuning_data_stats: Optional[SupervisedTuningDataStatsDict] + """The SFT Tuning data stats.""" + + +TuningDataStatsOrDict = Union[TuningDataStats, TuningDataStatsDict] + + +class EncryptionSpec(_common.BaseModel): + """Represents a customer-managed encryption key spec that can be applied to a top-level resource.""" + + kms_key_name: Optional[str] = Field( + default=None, + description="""Required. The Cloud KMS resource identifier of the customer managed encryption key used to protect a resource. Has the form: `projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`. The key needs to be in the same region as where the compute resource is created.""", + ) + + +class EncryptionSpecDict(TypedDict, total=False): + """Represents a customer-managed encryption key spec that can be applied to a top-level resource.""" + + kms_key_name: Optional[str] + """Required. The Cloud KMS resource identifier of the customer managed encryption key used to protect a resource. Has the form: `projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`. The key needs to be in the same region as where the compute resource is created.""" + + +EncryptionSpecOrDict = Union[EncryptionSpec, EncryptionSpecDict] + + +class DistillationHyperParameters(_common.BaseModel): + """Hyperparameters for Distillation.""" + + adapter_size: Optional[AdapterSize] = Field( + default=None, description="""Optional. Adapter size for distillation.""" + ) + epoch_count: Optional[int] = Field( + default=None, + description="""Optional. Number of complete passes the model makes over the entire training dataset during training.""", + ) + learning_rate_multiplier: Optional[float] = Field( + default=None, + description="""Optional. Multiplier for adjusting the default learning rate.""", + ) + + +class DistillationHyperParametersDict(TypedDict, total=False): + """Hyperparameters for Distillation.""" + + adapter_size: Optional[AdapterSize] + """Optional. Adapter size for distillation.""" + + epoch_count: Optional[int] + """Optional. Number of complete passes the model makes over the entire training dataset during training.""" + + learning_rate_multiplier: Optional[float] + """Optional. Multiplier for adjusting the default learning rate.""" + + +DistillationHyperParametersOrDict = Union[ + DistillationHyperParameters, DistillationHyperParametersDict +] + + +class DistillationSpec(_common.BaseModel): + """Tuning Spec for Distillation.""" + + base_teacher_model: Optional[str] = Field( + default=None, + description="""The base teacher model that is being distilled, e.g., "gemini-1.0-pro-002".""", + ) + hyper_parameters: Optional[DistillationHyperParameters] = Field( + default=None, + description="""Optional. Hyperparameters for Distillation.""", + ) + pipeline_root_directory: Optional[str] = Field( + default=None, + description="""Required. A path in a Cloud Storage bucket, which will be treated as the root output directory of the distillation pipeline. It is used by the system to generate the paths of output artifacts.""", + ) + student_model: Optional[str] = Field( + default=None, + description="""The student model that is being tuned, e.g., "google/gemma-2b-1.1-it".""", + ) + training_dataset_uri: Optional[str] = Field( + default=None, + description="""Required. Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + tuned_teacher_model_source: Optional[str] = Field( + default=None, + description="""The resource name of the Tuned teacher model. Format: `projects/{project}/locations/{location}/models/{model}`.""", + ) + validation_dataset_uri: Optional[str] = Field( + default=None, + description="""Optional. Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + + +class DistillationSpecDict(TypedDict, total=False): + """Tuning Spec for Distillation.""" + + base_teacher_model: Optional[str] + """The base teacher model that is being distilled, e.g., "gemini-1.0-pro-002".""" + + hyper_parameters: Optional[DistillationHyperParametersDict] + """Optional. Hyperparameters for Distillation.""" + + pipeline_root_directory: Optional[str] + """Required. A path in a Cloud Storage bucket, which will be treated as the root output directory of the distillation pipeline. It is used by the system to generate the paths of output artifacts.""" + + student_model: Optional[str] + """The student model that is being tuned, e.g., "google/gemma-2b-1.1-it".""" + + training_dataset_uri: Optional[str] + """Required. Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""" + + tuned_teacher_model_source: Optional[str] + """The resource name of the Tuned teacher model. Format: `projects/{project}/locations/{location}/models/{model}`.""" + + validation_dataset_uri: Optional[str] + """Optional. Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file.""" + + +DistillationSpecOrDict = Union[DistillationSpec, DistillationSpecDict] + + +class PartnerModelTuningSpec(_common.BaseModel): + """Tuning spec for Partner models.""" + + hyper_parameters: Optional[dict[str, Any]] = Field( + default=None, + description="""Hyperparameters for tuning. The accepted hyper_parameters and their valid range of values will differ depending on the base model.""", + ) + training_dataset_uri: Optional[str] = Field( + default=None, + description="""Required. Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + validation_dataset_uri: Optional[str] = Field( + default=None, + description="""Optional. Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + + +class PartnerModelTuningSpecDict(TypedDict, total=False): + """Tuning spec for Partner models.""" + + hyper_parameters: Optional[dict[str, Any]] + """Hyperparameters for tuning. The accepted hyper_parameters and their valid range of values will differ depending on the base model.""" + + training_dataset_uri: Optional[str] + """Required. Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""" + + validation_dataset_uri: Optional[str] + """Optional. Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file.""" + + +PartnerModelTuningSpecOrDict = Union[ + PartnerModelTuningSpec, PartnerModelTuningSpecDict +] + + +class TuningJob(_common.BaseModel): + """A tuning job.""" + + name: Optional[str] = Field( + default=None, + description="""Output only. Identifier. Resource name of a TuningJob. Format: `projects/{project}/locations/{location}/tuningJobs/{tuning_job}`""", + ) + state: Optional[JobState] = Field( + default=None, + description="""Output only. The detailed state of the job.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the TuningJob was created.""", + ) + start_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the TuningJob for the first time entered the `JOB_STATE_RUNNING` state.""", + ) + end_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the TuningJob entered any of the following JobStates: `JOB_STATE_SUCCEEDED`, `JOB_STATE_FAILED`, `JOB_STATE_CANCELLED`, `JOB_STATE_EXPIRED`.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the TuningJob was most recently updated.""", + ) + error: Optional[GoogleRpcStatus] = Field( + default=None, + description="""Output only. Only populated when job's state is `JOB_STATE_FAILED` or `JOB_STATE_CANCELLED`.""", + ) + description: Optional[str] = Field( + default=None, + description="""Optional. The description of the TuningJob.""", + ) + base_model: Optional[str] = Field( + default=None, + description="""The base model that is being tuned, e.g., "gemini-1.0-pro-002". .""", + ) + tuned_model: Optional[TunedModel] = Field( + default=None, + description="""Output only. The tuned model resources assiociated with this TuningJob.""", + ) + supervised_tuning_spec: Optional[SupervisedTuningSpec] = Field( + default=None, description="""Tuning Spec for Supervised Fine Tuning.""" + ) + tuning_data_stats: Optional[TuningDataStats] = Field( + default=None, + description="""Output only. The tuning data statistics associated with this TuningJob.""", + ) + encryption_spec: Optional[EncryptionSpec] = Field( + default=None, + description="""Customer-managed encryption key options for a TuningJob. If this is set, then all resources created by the TuningJob will be encrypted with the provided encryption key.""", + ) + distillation_spec: Optional[DistillationSpec] = Field( + default=None, description="""Tuning Spec for Distillation.""" + ) + partner_model_tuning_spec: Optional[PartnerModelTuningSpec] = Field( + default=None, + description="""Tuning Spec for open sourced and third party Partner models.""", + ) + pipeline_job: Optional[str] = Field( + default=None, + description="""Output only. The resource name of the PipelineJob associated with the TuningJob. Format: `projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`.""", + ) + experiment: Optional[str] = Field( + default=None, + description="""Output only. The Experiment associated with this TuningJob.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, + description="""Optional. The labels with user-defined metadata to organize TuningJob and generated resources such as Model and Endpoint. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""", + ) + tuned_model_display_name: Optional[str] = Field( + default=None, + description="""Optional. The display name of the TunedModel. The name can be up to 128 characters long and can consist of any UTF-8 characters.""", + ) + + @property + def has_ended(self) -> bool: + """Whether the tuning job has ended.""" + return self.state in JOB_STATES_ENDED + + @property + def has_succeeded(self) -> bool: + """Whether the tuning job has succeeded.""" + return self.state in JOB_STATES_SUCCEEDED + + +class TuningJobDict(TypedDict, total=False): + """A tuning job.""" + + name: Optional[str] + """Output only. Identifier. Resource name of a TuningJob. Format: `projects/{project}/locations/{location}/tuningJobs/{tuning_job}`""" + + state: Optional[JobState] + """Output only. The detailed state of the job.""" + + create_time: Optional[datetime.datetime] + """Output only. Time when the TuningJob was created.""" + + start_time: Optional[datetime.datetime] + """Output only. Time when the TuningJob for the first time entered the `JOB_STATE_RUNNING` state.""" + + end_time: Optional[datetime.datetime] + """Output only. Time when the TuningJob entered any of the following JobStates: `JOB_STATE_SUCCEEDED`, `JOB_STATE_FAILED`, `JOB_STATE_CANCELLED`, `JOB_STATE_EXPIRED`.""" + + update_time: Optional[datetime.datetime] + """Output only. Time when the TuningJob was most recently updated.""" + + error: Optional[GoogleRpcStatusDict] + """Output only. Only populated when job's state is `JOB_STATE_FAILED` or `JOB_STATE_CANCELLED`.""" + + description: Optional[str] + """Optional. The description of the TuningJob.""" + + base_model: Optional[str] + """The base model that is being tuned, e.g., "gemini-1.0-pro-002". .""" + + tuned_model: Optional[TunedModelDict] + """Output only. The tuned model resources assiociated with this TuningJob.""" + + supervised_tuning_spec: Optional[SupervisedTuningSpecDict] + """Tuning Spec for Supervised Fine Tuning.""" + + tuning_data_stats: Optional[TuningDataStatsDict] + """Output only. The tuning data statistics associated with this TuningJob.""" + + encryption_spec: Optional[EncryptionSpecDict] + """Customer-managed encryption key options for a TuningJob. If this is set, then all resources created by the TuningJob will be encrypted with the provided encryption key.""" + + distillation_spec: Optional[DistillationSpecDict] + """Tuning Spec for Distillation.""" + + partner_model_tuning_spec: Optional[PartnerModelTuningSpecDict] + """Tuning Spec for open sourced and third party Partner models.""" + + pipeline_job: Optional[str] + """Output only. The resource name of the PipelineJob associated with the TuningJob. Format: `projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`.""" + + experiment: Optional[str] + """Output only. The Experiment associated with this TuningJob.""" + + labels: Optional[dict[str, str]] + """Optional. The labels with user-defined metadata to organize TuningJob and generated resources such as Model and Endpoint. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""" + + tuned_model_display_name: Optional[str] + """Optional. The display name of the TunedModel. The name can be up to 128 characters long and can consist of any UTF-8 characters.""" + + +TuningJobOrDict = Union[TuningJob, TuningJobDict] + + +class ListTuningJobsConfig(_common.BaseModel): + """Configuration for the list tuning jobs method.""" + + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field(default=None, description="""""") + + +class ListTuningJobsConfigDict(TypedDict, total=False): + """Configuration for the list tuning jobs method.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """""" + + +ListTuningJobsConfigOrDict = Union[ + ListTuningJobsConfig, ListTuningJobsConfigDict +] + + +class _ListTuningJobsParameters(_common.BaseModel): + """Parameters for the list tuning jobs method.""" + + config: Optional[ListTuningJobsConfig] = Field( + default=None, description="""""" + ) + + +class _ListTuningJobsParametersDict(TypedDict, total=False): + """Parameters for the list tuning jobs method.""" + + config: Optional[ListTuningJobsConfigDict] + """""" + + +_ListTuningJobsParametersOrDict = Union[ + _ListTuningJobsParameters, _ListTuningJobsParametersDict +] + + +class ListTuningJobsResponse(_common.BaseModel): + """Response for the list tuning jobs method.""" + + next_page_token: Optional[str] = Field( + default=None, + description="""A token to retrieve the next page of results. Pass to ListTuningJobsRequest.page_token to obtain that page.""", + ) + tuning_jobs: Optional[list[TuningJob]] = Field( + default=None, description="""List of TuningJobs in the requested page.""" + ) + + +class ListTuningJobsResponseDict(TypedDict, total=False): + """Response for the list tuning jobs method.""" + + next_page_token: Optional[str] + """A token to retrieve the next page of results. Pass to ListTuningJobsRequest.page_token to obtain that page.""" + + tuning_jobs: Optional[list[TuningJobDict]] + """List of TuningJobs in the requested page.""" + + +ListTuningJobsResponseOrDict = Union[ + ListTuningJobsResponse, ListTuningJobsResponseDict +] + + +class TuningExample(_common.BaseModel): + + text_input: Optional[str] = Field( + default=None, description="""Text model input.""" + ) + output: Optional[str] = Field( + default=None, description="""The expected model output.""" + ) + + +class TuningExampleDict(TypedDict, total=False): + + text_input: Optional[str] + """Text model input.""" + + output: Optional[str] + """The expected model output.""" + + +TuningExampleOrDict = Union[TuningExample, TuningExampleDict] + + +class TuningDataset(_common.BaseModel): + """Supervised fune-tuning training dataset.""" + + gcs_uri: Optional[str] = Field( + default=None, + description="""GCS URI of the file containing training dataset in JSONL format.""", + ) + examples: Optional[list[TuningExample]] = Field( + default=None, + description="""Inline examples with simple input/output text.""", + ) + + +class TuningDatasetDict(TypedDict, total=False): + """Supervised fune-tuning training dataset.""" + + gcs_uri: Optional[str] + """GCS URI of the file containing training dataset in JSONL format.""" + + examples: Optional[list[TuningExampleDict]] + """Inline examples with simple input/output text.""" + + +TuningDatasetOrDict = Union[TuningDataset, TuningDatasetDict] + + +class TuningValidationDataset(_common.BaseModel): + + gcs_uri: Optional[str] = Field( + default=None, + description="""GCS URI of the file containing validation dataset in JSONL format.""", + ) + + +class TuningValidationDatasetDict(TypedDict, total=False): + + gcs_uri: Optional[str] + """GCS URI of the file containing validation dataset in JSONL format.""" + + +TuningValidationDatasetOrDict = Union[ + TuningValidationDataset, TuningValidationDatasetDict +] + + +class CreateTuningJobConfig(_common.BaseModel): + """Supervised fine-tuning job creation request - optional fields.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + validation_dataset: Optional[TuningValidationDataset] = Field( + default=None, + description="""Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + tuned_model_display_name: Optional[str] = Field( + default=None, + description="""The display name of the tuned Model. The name can be up to 128 characters long and can consist of any UTF-8 characters.""", + ) + description: Optional[str] = Field( + default=None, description="""The description of the TuningJob""" + ) + epoch_count: Optional[int] = Field( + default=None, + description="""Number of complete passes the model makes over the entire training dataset during training.""", + ) + learning_rate_multiplier: Optional[float] = Field( + default=None, + description="""Multiplier for adjusting the default learning rate.""", + ) + adapter_size: Optional[AdapterSize] = Field( + default=None, description="""Adapter size for tuning.""" + ) + batch_size: Optional[int] = Field( + default=None, + description="""The batch size hyperparameter for tuning. If not set, a default of 4 or 16 will be used based on the number of training examples.""", + ) + learning_rate: Optional[float] = Field( + default=None, + description="""The learning rate hyperparameter for tuning. If not set, a default of 0.001 or 0.0002 will be calculated based on the number of training examples.""", + ) + + +class CreateTuningJobConfigDict(TypedDict, total=False): + """Supervised fine-tuning job creation request - optional fields.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + validation_dataset: Optional[TuningValidationDatasetDict] + """Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""" + + tuned_model_display_name: Optional[str] + """The display name of the tuned Model. The name can be up to 128 characters long and can consist of any UTF-8 characters.""" + + description: Optional[str] + """The description of the TuningJob""" + + epoch_count: Optional[int] + """Number of complete passes the model makes over the entire training dataset during training.""" + + learning_rate_multiplier: Optional[float] + """Multiplier for adjusting the default learning rate.""" + + adapter_size: Optional[AdapterSize] + """Adapter size for tuning.""" + + batch_size: Optional[int] + """The batch size hyperparameter for tuning. If not set, a default of 4 or 16 will be used based on the number of training examples.""" + + learning_rate: Optional[float] + """The learning rate hyperparameter for tuning. If not set, a default of 0.001 or 0.0002 will be calculated based on the number of training examples.""" + + +CreateTuningJobConfigOrDict = Union[ + CreateTuningJobConfig, CreateTuningJobConfigDict +] + + +class _CreateTuningJobParameters(_common.BaseModel): + """Supervised fine-tuning job creation parameters - optional fields.""" + + base_model: Optional[str] = Field( + default=None, + description="""The base model that is being tuned, e.g., "gemini-1.0-pro-002".""", + ) + training_dataset: Optional[TuningDataset] = Field( + default=None, + description="""Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + config: Optional[CreateTuningJobConfig] = Field( + default=None, description="""Configuration for the tuning job.""" + ) + + +class _CreateTuningJobParametersDict(TypedDict, total=False): + """Supervised fine-tuning job creation parameters - optional fields.""" + + base_model: Optional[str] + """The base model that is being tuned, e.g., "gemini-1.0-pro-002".""" + + training_dataset: Optional[TuningDatasetDict] + """Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""" + + config: Optional[CreateTuningJobConfigDict] + """Configuration for the tuning job.""" + + +_CreateTuningJobParametersOrDict = Union[ + _CreateTuningJobParameters, _CreateTuningJobParametersDict +] + + +class TuningJobOrOperation(_common.BaseModel): + """A tuning job or an long-running-operation that resolves to a tuning job.""" + + tuning_job: Optional[TuningJob] = Field(default=None, description="""""") + + +class TuningJobOrOperationDict(TypedDict, total=False): + """A tuning job or an long-running-operation that resolves to a tuning job.""" + + tuning_job: Optional[TuningJobDict] + """""" + + +TuningJobOrOperationOrDict = Union[ + TuningJobOrOperation, TuningJobOrOperationDict +] + + +class DistillationDataset(_common.BaseModel): + """Training dataset.""" + + gcs_uri: Optional[str] = Field( + default=None, + description="""GCS URI of the file containing training dataset in JSONL format.""", + ) + + +class DistillationDatasetDict(TypedDict, total=False): + """Training dataset.""" + + gcs_uri: Optional[str] + """GCS URI of the file containing training dataset in JSONL format.""" + + +DistillationDatasetOrDict = Union[DistillationDataset, DistillationDatasetDict] + + +class DistillationValidationDataset(_common.BaseModel): + """Validation dataset.""" + + gcs_uri: Optional[str] = Field( + default=None, + description="""GCS URI of the file containing validation dataset in JSONL format.""", + ) + + +class DistillationValidationDatasetDict(TypedDict, total=False): + """Validation dataset.""" + + gcs_uri: Optional[str] + """GCS URI of the file containing validation dataset in JSONL format.""" + + +DistillationValidationDatasetOrDict = Union[ + DistillationValidationDataset, DistillationValidationDatasetDict +] + + +class CreateDistillationJobConfig(_common.BaseModel): + """Distillation job creation request - optional fields.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + validation_dataset: Optional[DistillationValidationDataset] = Field( + default=None, + description="""Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + tuned_model_display_name: Optional[str] = Field( + default=None, + description="""The display name of the tuned Model. The name can be up to 128 characters long and can consist of any UTF-8 characters.""", + ) + epoch_count: Optional[int] = Field( + default=None, + description="""Number of complete passes the model makes over the entire training dataset during training.""", + ) + learning_rate_multiplier: Optional[float] = Field( + default=None, + description="""Multiplier for adjusting the default learning rate.""", + ) + adapter_size: Optional[AdapterSize] = Field( + default=None, description="""Adapter size for tuning.""" + ) + pipeline_root_directory: Optional[str] = Field( + default=None, + description="""The resource name of the PipelineJob associated with the TuningJob. Format:`projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`.""", + ) + + +class CreateDistillationJobConfigDict(TypedDict, total=False): + """Distillation job creation request - optional fields.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + validation_dataset: Optional[DistillationValidationDatasetDict] + """Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file.""" + + tuned_model_display_name: Optional[str] + """The display name of the tuned Model. The name can be up to 128 characters long and can consist of any UTF-8 characters.""" + + epoch_count: Optional[int] + """Number of complete passes the model makes over the entire training dataset during training.""" + + learning_rate_multiplier: Optional[float] + """Multiplier for adjusting the default learning rate.""" + + adapter_size: Optional[AdapterSize] + """Adapter size for tuning.""" + + pipeline_root_directory: Optional[str] + """The resource name of the PipelineJob associated with the TuningJob. Format:`projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`.""" + + +CreateDistillationJobConfigOrDict = Union[ + CreateDistillationJobConfig, CreateDistillationJobConfigDict +] + + +class _CreateDistillationJobParameters(_common.BaseModel): + """Distillation job creation parameters - optional fields.""" + + student_model: Optional[str] = Field( + default=None, + description="""The student model that is being tuned, e.g. ``google/gemma-2b-1.1-it``.""", + ) + teacher_model: Optional[str] = Field( + default=None, + description="""The teacher model that is being distilled from, e.g. ``gemini-1.0-pro-002``.""", + ) + training_dataset: Optional[DistillationDataset] = Field( + default=None, + description="""Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + config: Optional[CreateDistillationJobConfig] = Field( + default=None, description="""Configuration for the distillation job.""" + ) + + +class _CreateDistillationJobParametersDict(TypedDict, total=False): + """Distillation job creation parameters - optional fields.""" + + student_model: Optional[str] + """The student model that is being tuned, e.g. ``google/gemma-2b-1.1-it``.""" + + teacher_model: Optional[str] + """The teacher model that is being distilled from, e.g. ``gemini-1.0-pro-002``.""" + + training_dataset: Optional[DistillationDatasetDict] + """Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""" + + config: Optional[CreateDistillationJobConfigDict] + """Configuration for the distillation job.""" + + +_CreateDistillationJobParametersOrDict = Union[ + _CreateDistillationJobParameters, _CreateDistillationJobParametersDict +] + + +class CreateCachedContentConfig(_common.BaseModel): + """Class for configuring optional cached content creation parameters.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + ttl: Optional[str] = Field( + default=None, + description="""The TTL for this resource. The expiration time is computed: now + TTL.""", + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Timestamp of when this resource is considered expired.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""The user-generated meaningful display name of the cached content. + """, + ) + contents: Optional[ContentListUnion] = Field( + default=None, + description="""The content to cache. + """, + ) + system_instruction: Optional[ContentUnion] = Field( + default=None, + description="""Developer set system instruction. + """, + ) + tools: Optional[list[Tool]] = Field( + default=None, + description="""A list of `Tools` the model may use to generate the next response. + """, + ) + tool_config: Optional[ToolConfig] = Field( + default=None, + description="""Configuration for the tools to use. This config is shared for all tools. + """, + ) + + +class CreateCachedContentConfigDict(TypedDict, total=False): + """Class for configuring optional cached content creation parameters.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + ttl: Optional[str] + """The TTL for this resource. The expiration time is computed: now + TTL.""" + + expire_time: Optional[datetime.datetime] + """Timestamp of when this resource is considered expired.""" + + display_name: Optional[str] + """The user-generated meaningful display name of the cached content. + """ + + contents: Optional[ContentListUnionDict] + """The content to cache. + """ + + system_instruction: Optional[ContentUnionDict] + """Developer set system instruction. + """ + + tools: Optional[list[ToolDict]] + """A list of `Tools` the model may use to generate the next response. + """ + + tool_config: Optional[ToolConfigDict] + """Configuration for the tools to use. This config is shared for all tools. + """ + + +CreateCachedContentConfigOrDict = Union[ + CreateCachedContentConfig, CreateCachedContentConfigDict +] + + +class _CreateCachedContentParameters(_common.BaseModel): + """Parameters for caches.create method.""" + + model: Optional[str] = Field( + default=None, + description="""ID of the model to use. Example: gemini-1.5-flash""", + ) + config: Optional[CreateCachedContentConfig] = Field( + default=None, + description="""Configuration that contains optional parameters. + """, + ) + + +class _CreateCachedContentParametersDict(TypedDict, total=False): + """Parameters for caches.create method.""" + + model: Optional[str] + """ID of the model to use. Example: gemini-1.5-flash""" + + config: Optional[CreateCachedContentConfigDict] + """Configuration that contains optional parameters. + """ + + +_CreateCachedContentParametersOrDict = Union[ + _CreateCachedContentParameters, _CreateCachedContentParametersDict +] + + +class CachedContentUsageMetadata(_common.BaseModel): + """Metadata on the usage of the cached content.""" + + audio_duration_seconds: Optional[int] = Field( + default=None, description="""Duration of audio in seconds.""" + ) + image_count: Optional[int] = Field( + default=None, description="""Number of images.""" + ) + text_count: Optional[int] = Field( + default=None, description="""Number of text characters.""" + ) + total_token_count: Optional[int] = Field( + default=None, + description="""Total number of tokens that the cached content consumes.""", + ) + video_duration_seconds: Optional[int] = Field( + default=None, description="""Duration of video in seconds.""" + ) + + +class CachedContentUsageMetadataDict(TypedDict, total=False): + """Metadata on the usage of the cached content.""" + + audio_duration_seconds: Optional[int] + """Duration of audio in seconds.""" + + image_count: Optional[int] + """Number of images.""" + + text_count: Optional[int] + """Number of text characters.""" + + total_token_count: Optional[int] + """Total number of tokens that the cached content consumes.""" + + video_duration_seconds: Optional[int] + """Duration of video in seconds.""" + + +CachedContentUsageMetadataOrDict = Union[ + CachedContentUsageMetadata, CachedContentUsageMetadataDict +] + + +class CachedContent(_common.BaseModel): + """A resource used in LLM queries for users to explicitly specify what to cache.""" + + name: Optional[str] = Field( + default=None, + description="""The server-generated resource name of the cached content.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""The user-generated meaningful display name of the cached content.""", + ) + model: Optional[str] = Field( + default=None, + description="""The name of the publisher model to use for cached content.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, description="""Creatation time of the cache entry.""" + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""When the cache entry was last updated in UTC time.""", + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, description="""Expiration time of the cached content.""" + ) + usage_metadata: Optional[CachedContentUsageMetadata] = Field( + default=None, + description="""Metadata on the usage of the cached content.""", + ) + + +class CachedContentDict(TypedDict, total=False): + """A resource used in LLM queries for users to explicitly specify what to cache.""" + + name: Optional[str] + """The server-generated resource name of the cached content.""" + + display_name: Optional[str] + """The user-generated meaningful display name of the cached content.""" + + model: Optional[str] + """The name of the publisher model to use for cached content.""" + + create_time: Optional[datetime.datetime] + """Creatation time of the cache entry.""" + + update_time: Optional[datetime.datetime] + """When the cache entry was last updated in UTC time.""" + + expire_time: Optional[datetime.datetime] + """Expiration time of the cached content.""" + + usage_metadata: Optional[CachedContentUsageMetadataDict] + """Metadata on the usage of the cached content.""" + + +CachedContentOrDict = Union[CachedContent, CachedContentDict] + + +class GetCachedContentConfig(_common.BaseModel): + """Optional parameters for caches.get method.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetCachedContentConfigDict(TypedDict, total=False): + """Optional parameters for caches.get method.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + +GetCachedContentConfigOrDict = Union[ + GetCachedContentConfig, GetCachedContentConfigDict +] + + +class _GetCachedContentParameters(_common.BaseModel): + """Parameters for caches.get method.""" + + name: Optional[str] = Field( + default=None, + description="""The server-generated resource name of the cached content. + """, + ) + config: Optional[GetCachedContentConfig] = Field( + default=None, + description="""Optional parameters for the request. + """, + ) + + +class _GetCachedContentParametersDict(TypedDict, total=False): + """Parameters for caches.get method.""" + + name: Optional[str] + """The server-generated resource name of the cached content. + """ + + config: Optional[GetCachedContentConfigDict] + """Optional parameters for the request. + """ + + +_GetCachedContentParametersOrDict = Union[ + _GetCachedContentParameters, _GetCachedContentParametersDict +] + + +class DeleteCachedContentConfig(_common.BaseModel): + """Optional parameters for caches.delete method.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DeleteCachedContentConfigDict(TypedDict, total=False): + """Optional parameters for caches.delete method.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + +DeleteCachedContentConfigOrDict = Union[ + DeleteCachedContentConfig, DeleteCachedContentConfigDict +] + + +class _DeleteCachedContentParameters(_common.BaseModel): + """Parameters for caches.delete method.""" + + name: Optional[str] = Field( + default=None, + description="""The server-generated resource name of the cached content. + """, + ) + config: Optional[DeleteCachedContentConfig] = Field( + default=None, + description="""Optional parameters for the request. + """, + ) + + +class _DeleteCachedContentParametersDict(TypedDict, total=False): + """Parameters for caches.delete method.""" + + name: Optional[str] + """The server-generated resource name of the cached content. + """ + + config: Optional[DeleteCachedContentConfigDict] + """Optional parameters for the request. + """ + + +_DeleteCachedContentParametersOrDict = Union[ + _DeleteCachedContentParameters, _DeleteCachedContentParametersDict +] + + +class DeleteCachedContentResponse(_common.BaseModel): + """Empty response for caches.delete method.""" + + pass + + +class DeleteCachedContentResponseDict(TypedDict, total=False): + """Empty response for caches.delete method.""" + + pass + + +DeleteCachedContentResponseOrDict = Union[ + DeleteCachedContentResponse, DeleteCachedContentResponseDict +] + + +class UpdateCachedContentConfig(_common.BaseModel): + """Optional parameters for caches.update method.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + ttl: Optional[str] = Field( + default=None, + description="""The TTL for this resource. The expiration time is computed: now + TTL.""", + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Timestamp of when this resource is considered expired.""", + ) + + +class UpdateCachedContentConfigDict(TypedDict, total=False): + """Optional parameters for caches.update method.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + ttl: Optional[str] + """The TTL for this resource. The expiration time is computed: now + TTL.""" + + expire_time: Optional[datetime.datetime] + """Timestamp of when this resource is considered expired.""" + + +UpdateCachedContentConfigOrDict = Union[ + UpdateCachedContentConfig, UpdateCachedContentConfigDict +] + + +class _UpdateCachedContentParameters(_common.BaseModel): + + name: Optional[str] = Field( + default=None, + description="""The server-generated resource name of the cached content. + """, + ) + config: Optional[UpdateCachedContentConfig] = Field( + default=None, + description="""Configuration that contains optional parameters. + """, + ) + + +class _UpdateCachedContentParametersDict(TypedDict, total=False): + + name: Optional[str] + """The server-generated resource name of the cached content. + """ + + config: Optional[UpdateCachedContentConfigDict] + """Configuration that contains optional parameters. + """ + + +_UpdateCachedContentParametersOrDict = Union[ + _UpdateCachedContentParameters, _UpdateCachedContentParametersDict +] + + +class ListCachedContentsConfig(_common.BaseModel): + """Config for caches.list method.""" + + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + + +class ListCachedContentsConfigDict(TypedDict, total=False): + """Config for caches.list method.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + +ListCachedContentsConfigOrDict = Union[ + ListCachedContentsConfig, ListCachedContentsConfigDict +] + + +class _ListCachedContentsParameters(_common.BaseModel): + """Parameters for caches.list method.""" + + config: Optional[ListCachedContentsConfig] = Field( + default=None, + description="""Configuration that contains optional parameters. + """, + ) + + +class _ListCachedContentsParametersDict(TypedDict, total=False): + """Parameters for caches.list method.""" + + config: Optional[ListCachedContentsConfigDict] + """Configuration that contains optional parameters. + """ + + +_ListCachedContentsParametersOrDict = Union[ + _ListCachedContentsParameters, _ListCachedContentsParametersDict +] + + +class ListCachedContentsResponse(_common.BaseModel): + + next_page_token: Optional[str] = Field(default=None, description="""""") + cached_contents: Optional[list[CachedContent]] = Field( + default=None, + description="""List of cached contents. + """, + ) + + +class ListCachedContentsResponseDict(TypedDict, total=False): + + next_page_token: Optional[str] + """""" + + cached_contents: Optional[list[CachedContentDict]] + """List of cached contents. + """ + + +ListCachedContentsResponseOrDict = Union[ + ListCachedContentsResponse, ListCachedContentsResponseDict +] + + +class ListFilesConfig(_common.BaseModel): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + + +class ListFilesConfigDict(TypedDict, total=False): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + +ListFilesConfigOrDict = Union[ListFilesConfig, ListFilesConfigDict] + + +class _ListFilesParameters(_common.BaseModel): + """Generates the parameters for the list method.""" + + config: Optional[ListFilesConfig] = Field( + default=None, + description="""Used to override the default configuration.""", + ) + + +class _ListFilesParametersDict(TypedDict, total=False): + """Generates the parameters for the list method.""" + + config: Optional[ListFilesConfigDict] + """Used to override the default configuration.""" + + +_ListFilesParametersOrDict = Union[ + _ListFilesParameters, _ListFilesParametersDict +] + + +class ListFilesResponse(_common.BaseModel): + """Response for the list files method.""" + + next_page_token: Optional[str] = Field( + default=None, description="""A token to retrieve next page of results.""" + ) + files: Optional[list[File]] = Field( + default=None, description="""The list of files.""" + ) + + +class ListFilesResponseDict(TypedDict, total=False): + """Response for the list files method.""" + + next_page_token: Optional[str] + """A token to retrieve next page of results.""" + + files: Optional[list[FileDict]] + """The list of files.""" + + +ListFilesResponseOrDict = Union[ListFilesResponse, ListFilesResponseDict] + + +class CreateFileConfig(_common.BaseModel): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CreateFileConfigDict(TypedDict, total=False): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + +CreateFileConfigOrDict = Union[CreateFileConfig, CreateFileConfigDict] + + +class _CreateFileParameters(_common.BaseModel): + """Generates the parameters for the private _create method.""" + + file: Optional[File] = Field( + default=None, + description="""The file to be uploaded. + mime_type: (Required) The MIME type of the file. Must be provided. + name: (Optional) The name of the file in the destination (e.g. + 'files/sample-image'). + display_name: (Optional) The display name of the file. + """, + ) + config: Optional[CreateFileConfig] = Field( + default=None, + description="""Used to override the default configuration.""", + ) + + +class _CreateFileParametersDict(TypedDict, total=False): + """Generates the parameters for the private _create method.""" + + file: Optional[FileDict] + """The file to be uploaded. + mime_type: (Required) The MIME type of the file. Must be provided. + name: (Optional) The name of the file in the destination (e.g. + 'files/sample-image'). + display_name: (Optional) The display name of the file. + """ + + config: Optional[CreateFileConfigDict] + """Used to override the default configuration.""" + + +_CreateFileParametersOrDict = Union[ + _CreateFileParameters, _CreateFileParametersDict +] + + +class CreateFileResponse(_common.BaseModel): + """Response for the create file method.""" + + pass + + +class CreateFileResponseDict(TypedDict, total=False): + """Response for the create file method.""" + + pass + + +CreateFileResponseOrDict = Union[CreateFileResponse, CreateFileResponseDict] + + +class GetFileConfig(_common.BaseModel): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetFileConfigDict(TypedDict, total=False): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + +GetFileConfigOrDict = Union[GetFileConfig, GetFileConfigDict] + + +class _GetFileParameters(_common.BaseModel): + """Generates the parameters for the get method.""" + + name: Optional[str] = Field( + default=None, + description="""The name identifier for the file to retrieve.""", + ) + config: Optional[GetFileConfig] = Field( + default=None, + description="""Used to override the default configuration.""", + ) + + +class _GetFileParametersDict(TypedDict, total=False): + """Generates the parameters for the get method.""" + + name: Optional[str] + """The name identifier for the file to retrieve.""" + + config: Optional[GetFileConfigDict] + """Used to override the default configuration.""" + + +_GetFileParametersOrDict = Union[_GetFileParameters, _GetFileParametersDict] + + +class DeleteFileConfig(_common.BaseModel): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DeleteFileConfigDict(TypedDict, total=False): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + +DeleteFileConfigOrDict = Union[DeleteFileConfig, DeleteFileConfigDict] + + +class _DeleteFileParameters(_common.BaseModel): + """Generates the parameters for the get method.""" + + name: Optional[str] = Field( + default=None, + description="""The name identifier for the file to be deleted.""", + ) + config: Optional[DeleteFileConfig] = Field( + default=None, + description="""Used to override the default configuration.""", + ) + + +class _DeleteFileParametersDict(TypedDict, total=False): + """Generates the parameters for the get method.""" + + name: Optional[str] + """The name identifier for the file to be deleted.""" + + config: Optional[DeleteFileConfigDict] + """Used to override the default configuration.""" + + +_DeleteFileParametersOrDict = Union[ + _DeleteFileParameters, _DeleteFileParametersDict +] + + +class DeleteFileResponse(_common.BaseModel): + """Response for the delete file method.""" + + pass + + +class DeleteFileResponseDict(TypedDict, total=False): + """Response for the delete file method.""" + + pass + + +DeleteFileResponseOrDict = Union[DeleteFileResponse, DeleteFileResponseDict] + + +class BatchJobSource(_common.BaseModel): + """Config class for `src` parameter.""" + + format: Optional[str] = Field( + default=None, + description="""Storage format of the input files. Must be one of: + 'jsonl', 'bigquery'. + """, + ) + gcs_uri: Optional[list[str]] = Field( + default=None, + description="""The Google Cloud Storage URIs to input files. + """, + ) + bigquery_uri: Optional[str] = Field( + default=None, + description="""The BigQuery URI to input table. + """, + ) + + +class BatchJobSourceDict(TypedDict, total=False): + """Config class for `src` parameter.""" + + format: Optional[str] + """Storage format of the input files. Must be one of: + 'jsonl', 'bigquery'. + """ + + gcs_uri: Optional[list[str]] + """The Google Cloud Storage URIs to input files. + """ + + bigquery_uri: Optional[str] + """The BigQuery URI to input table. + """ + + +BatchJobSourceOrDict = Union[BatchJobSource, BatchJobSourceDict] + + +class BatchJobDestination(_common.BaseModel): + """Config class for `des` parameter.""" + + format: Optional[str] = Field( + default=None, + description="""Storage format of the output files. Must be one of: + 'jsonl', 'bigquery'. + """, + ) + gcs_uri: Optional[str] = Field( + default=None, + description="""The Google Cloud Storage URI to the output file. + """, + ) + bigquery_uri: Optional[str] = Field( + default=None, + description="""The BigQuery URI to the output table. + """, + ) + + +class BatchJobDestinationDict(TypedDict, total=False): + """Config class for `des` parameter.""" + + format: Optional[str] + """Storage format of the output files. Must be one of: + 'jsonl', 'bigquery'. + """ + + gcs_uri: Optional[str] + """The Google Cloud Storage URI to the output file. + """ + + bigquery_uri: Optional[str] + """The BigQuery URI to the output table. + """ + + +BatchJobDestinationOrDict = Union[BatchJobDestination, BatchJobDestinationDict] + + +class CreateBatchJobConfig(_common.BaseModel): + """Config class for optional parameters.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + display_name: Optional[str] = Field( + default=None, + description="""The user-defined name of this BatchJob. + """, + ) + dest: Optional[str] = Field( + default=None, + description="""GCS or BigQuery URI prefix for the output predictions. Example: + "gs://path/to/output/data" or "bq://projectId.bqDatasetId.bqTableId". + """, + ) + + +class CreateBatchJobConfigDict(TypedDict, total=False): + """Config class for optional parameters.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + display_name: Optional[str] + """The user-defined name of this BatchJob. + """ + + dest: Optional[str] + """GCS or BigQuery URI prefix for the output predictions. Example: + "gs://path/to/output/data" or "bq://projectId.bqDatasetId.bqTableId". + """ + + +CreateBatchJobConfigOrDict = Union[ + CreateBatchJobConfig, CreateBatchJobConfigDict +] + + +class _CreateBatchJobParameters(_common.BaseModel): + """Config class for batches.create parameters.""" + + model: Optional[str] = Field( + default=None, + description="""The name of the model to produces the predictions via the BatchJob. + """, + ) + src: Optional[str] = Field( + default=None, + description="""GCS URI(-s) or BigQuery URI to your input data to run batch job. + Example: "gs://path/to/input/data" or "bq://projectId.bqDatasetId.bqTableId". + """, + ) + config: Optional[CreateBatchJobConfig] = Field( + default=None, + description="""Optional parameters for creating a BatchJob. + """, + ) + + +class _CreateBatchJobParametersDict(TypedDict, total=False): + """Config class for batches.create parameters.""" + + model: Optional[str] + """The name of the model to produces the predictions via the BatchJob. + """ + + src: Optional[str] + """GCS URI(-s) or BigQuery URI to your input data to run batch job. + Example: "gs://path/to/input/data" or "bq://projectId.bqDatasetId.bqTableId". + """ + + config: Optional[CreateBatchJobConfigDict] + """Optional parameters for creating a BatchJob. + """ + + +_CreateBatchJobParametersOrDict = Union[ + _CreateBatchJobParameters, _CreateBatchJobParametersDict +] + + +class JobError(_common.BaseModel): + """Config class for the job error.""" + + details: Optional[list[str]] = Field( + default=None, + description="""A list of messages that carry the error details. There is a common set of message types for APIs to use.""", + ) + code: Optional[int] = Field(default=None, description="""The status code.""") + message: Optional[str] = Field( + default=None, + description="""A developer-facing error message, which should be in English. Any user-facing error message should be localized and sent in the `details` field.""", + ) + + +class JobErrorDict(TypedDict, total=False): + """Config class for the job error.""" + + details: Optional[list[str]] + """A list of messages that carry the error details. There is a common set of message types for APIs to use.""" + + code: Optional[int] + """The status code.""" + + message: Optional[str] + """A developer-facing error message, which should be in English. Any user-facing error message should be localized and sent in the `details` field.""" + + +JobErrorOrDict = Union[JobError, JobErrorDict] + + +class BatchJob(_common.BaseModel): + """Config class for batches.create return value.""" + + name: Optional[str] = Field( + default=None, description="""Output only. Resource name of the Job.""" + ) + display_name: Optional[str] = Field( + default=None, description="""The user-defined name of this Job.""" + ) + state: Optional[JobState] = Field( + default=None, + description="""Output only. The detailed state of the job.""", + ) + error: Optional[JobError] = Field( + default=None, + description="""Output only. Only populated when the job's state is JOB_STATE_FAILED or JOB_STATE_CANCELLED.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the Job was created.""", + ) + start_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the Job for the first time entered the `JOB_STATE_RUNNING` state.""", + ) + end_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the Job entered any of the following states: `JOB_STATE_SUCCEEDED`, `JOB_STATE_FAILED`, `JOB_STATE_CANCELLED`.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the Job was most recently updated.""", + ) + model: Optional[str] = Field( + default=None, + description="""The name of the model that produces the predictions via the BatchJob. + """, + ) + src: Optional[BatchJobSource] = Field( + default=None, + description="""Configuration for the input data. + """, + ) + dest: Optional[BatchJobDestination] = Field( + default=None, + description="""Configuration for the output data. + """, + ) + + +class BatchJobDict(TypedDict, total=False): + """Config class for batches.create return value.""" + + name: Optional[str] + """Output only. Resource name of the Job.""" + + display_name: Optional[str] + """The user-defined name of this Job.""" + + state: Optional[JobState] + """Output only. The detailed state of the job.""" + + error: Optional[JobErrorDict] + """Output only. Only populated when the job's state is JOB_STATE_FAILED or JOB_STATE_CANCELLED.""" + + create_time: Optional[datetime.datetime] + """Output only. Time when the Job was created.""" + + start_time: Optional[datetime.datetime] + """Output only. Time when the Job for the first time entered the `JOB_STATE_RUNNING` state.""" + + end_time: Optional[datetime.datetime] + """Output only. Time when the Job entered any of the following states: `JOB_STATE_SUCCEEDED`, `JOB_STATE_FAILED`, `JOB_STATE_CANCELLED`.""" + + update_time: Optional[datetime.datetime] + """Output only. Time when the Job was most recently updated.""" + + model: Optional[str] + """The name of the model that produces the predictions via the BatchJob. + """ + + src: Optional[BatchJobSourceDict] + """Configuration for the input data. + """ + + dest: Optional[BatchJobDestinationDict] + """Configuration for the output data. + """ + + +BatchJobOrDict = Union[BatchJob, BatchJobDict] + + +class GetBatchJobConfig(_common.BaseModel): + """Optional parameters.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetBatchJobConfigDict(TypedDict, total=False): + """Optional parameters.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + +GetBatchJobConfigOrDict = Union[GetBatchJobConfig, GetBatchJobConfigDict] + + +class _GetBatchJobParameters(_common.BaseModel): + """Config class for batches.get parameters.""" + + name: Optional[str] = Field( + default=None, + description="""A fully-qualified BatchJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" + or "456" when project and location are initialized in the client. + """, + ) + config: Optional[GetBatchJobConfig] = Field( + default=None, description="""Optional parameters for the request.""" + ) + + +class _GetBatchJobParametersDict(TypedDict, total=False): + """Config class for batches.get parameters.""" + + name: Optional[str] + """A fully-qualified BatchJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" + or "456" when project and location are initialized in the client. + """ + + config: Optional[GetBatchJobConfigDict] + """Optional parameters for the request.""" + + +_GetBatchJobParametersOrDict = Union[ + _GetBatchJobParameters, _GetBatchJobParametersDict +] + + +class CancelBatchJobConfig(_common.BaseModel): + """Optional parameters.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CancelBatchJobConfigDict(TypedDict, total=False): + """Optional parameters.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + +CancelBatchJobConfigOrDict = Union[ + CancelBatchJobConfig, CancelBatchJobConfigDict +] + + +class _CancelBatchJobParameters(_common.BaseModel): + """Config class for batches.cancel parameters.""" + + name: Optional[str] = Field( + default=None, + description="""A fully-qualified BatchJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" + or "456" when project and location are initialized in the client. + """, + ) + config: Optional[CancelBatchJobConfig] = Field( + default=None, description="""Optional parameters for the request.""" + ) + + +class _CancelBatchJobParametersDict(TypedDict, total=False): + """Config class for batches.cancel parameters.""" + + name: Optional[str] + """A fully-qualified BatchJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" + or "456" when project and location are initialized in the client. + """ + + config: Optional[CancelBatchJobConfigDict] + """Optional parameters for the request.""" + + +_CancelBatchJobParametersOrDict = Union[ + _CancelBatchJobParameters, _CancelBatchJobParametersDict +] + + +class ListBatchJobConfig(_common.BaseModel): + """Config class for optional parameters.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field(default=None, description="""""") + + +class ListBatchJobConfigDict(TypedDict, total=False): + """Config class for optional parameters.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """""" + + +ListBatchJobConfigOrDict = Union[ListBatchJobConfig, ListBatchJobConfigDict] + + +class _ListBatchJobParameters(_common.BaseModel): + """Config class for batches.list parameters.""" + + config: Optional[ListBatchJobConfig] = Field(default=None, description="""""") + + +class _ListBatchJobParametersDict(TypedDict, total=False): + """Config class for batches.list parameters.""" + + config: Optional[ListBatchJobConfigDict] + """""" + + +_ListBatchJobParametersOrDict = Union[ + _ListBatchJobParameters, _ListBatchJobParametersDict +] + + +class ListBatchJobResponse(_common.BaseModel): + """Config class for batches.list return value.""" + + next_page_token: Optional[str] = Field(default=None, description="""""") + batch_jobs: Optional[list[BatchJob]] = Field(default=None, description="""""") + + +class ListBatchJobResponseDict(TypedDict, total=False): + """Config class for batches.list return value.""" + + next_page_token: Optional[str] + """""" + + batch_jobs: Optional[list[BatchJobDict]] + """""" + + +ListBatchJobResponseOrDict = Union[ + ListBatchJobResponse, ListBatchJobResponseDict +] + + +class _DeleteBatchJobParameters(_common.BaseModel): + """Config class for batches.delete parameters.""" + + name: Optional[str] = Field( + default=None, + description="""A fully-qualified BatchJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" + or "456" when project and location are initialized in the client. + """, + ) + + +class _DeleteBatchJobParametersDict(TypedDict, total=False): + """Config class for batches.delete parameters.""" + + name: Optional[str] + """A fully-qualified BatchJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" + or "456" when project and location are initialized in the client. + """ + + +_DeleteBatchJobParametersOrDict = Union[ + _DeleteBatchJobParameters, _DeleteBatchJobParametersDict +] + + +class DeleteResourceJob(_common.BaseModel): + """Config class for the return value of delete operation.""" + + name: Optional[str] = Field(default=None, description="""""") + done: Optional[bool] = Field(default=None, description="""""") + error: Optional[JobError] = Field(default=None, description="""""") + + +class DeleteResourceJobDict(TypedDict, total=False): + """Config class for the return value of delete operation.""" + + name: Optional[str] + """""" + + done: Optional[bool] + """""" + + error: Optional[JobErrorDict] + """""" + + +DeleteResourceJobOrDict = Union[DeleteResourceJob, DeleteResourceJobDict] + + +class TestTableItem(_common.BaseModel): + + name: Optional[str] = Field( + default=None, + description="""The name of the test. This is used to derive the replay id.""", + ) + parameters: Optional[dict[str, Any]] = Field( + default=None, + description="""The parameters to the test. Use pydantic models.""", + ) + exception_if_mldev: Optional[str] = Field( + default=None, + description="""Expects an exception for MLDev matching the string.""", + ) + exception_if_vertex: Optional[str] = Field( + default=None, + description="""Expects an exception for Vertex matching the string.""", + ) + override_replay_id: Optional[str] = Field( + default=None, + description="""Use if you don't want to use the default replay id which is derived from the test name.""", + ) + has_union: Optional[bool] = Field( + default=None, + description="""True if the parameters contain an unsupported union type. This test will be skipped for languages that do not support the union type.""", + ) + skip_in_api_mode: Optional[str] = Field( + default=None, + description="""When set to a reason string, this test will be skipped in the API mode. Use this flag for tests that can not be reproduced with the real API. E.g. a test that deletes a resource.""", + ) + + +class TestTableItemDict(TypedDict, total=False): + + name: Optional[str] + """The name of the test. This is used to derive the replay id.""" + + parameters: Optional[dict[str, Any]] + """The parameters to the test. Use pydantic models.""" + + exception_if_mldev: Optional[str] + """Expects an exception for MLDev matching the string.""" + + exception_if_vertex: Optional[str] + """Expects an exception for Vertex matching the string.""" + + override_replay_id: Optional[str] + """Use if you don't want to use the default replay id which is derived from the test name.""" + + has_union: Optional[bool] + """True if the parameters contain an unsupported union type. This test will be skipped for languages that do not support the union type.""" + + skip_in_api_mode: Optional[str] + """When set to a reason string, this test will be skipped in the API mode. Use this flag for tests that can not be reproduced with the real API. E.g. a test that deletes a resource.""" + + +TestTableItemOrDict = Union[TestTableItem, TestTableItemDict] + + +class TestTableFile(_common.BaseModel): + + comment: Optional[str] = Field(default=None, description="""""") + test_method: Optional[str] = Field(default=None, description="""""") + parameter_names: Optional[list[str]] = Field(default=None, description="""""") + test_table: Optional[list[TestTableItem]] = Field( + default=None, description="""""" + ) + + +class TestTableFileDict(TypedDict, total=False): + + comment: Optional[str] + """""" + + test_method: Optional[str] + """""" + + parameter_names: Optional[list[str]] + """""" + + test_table: Optional[list[TestTableItemDict]] + """""" + + +TestTableFileOrDict = Union[TestTableFile, TestTableFileDict] + + +class ReplayRequest(_common.BaseModel): + """Represents a single request in a replay.""" + + method: Optional[str] = Field(default=None, description="""""") + url: Optional[str] = Field(default=None, description="""""") + headers: Optional[dict[str, str]] = Field(default=None, description="""""") + body_segments: Optional[list[dict[str, Any]]] = Field( + default=None, description="""""" + ) + + +class ReplayRequestDict(TypedDict, total=False): + """Represents a single request in a replay.""" + + method: Optional[str] + """""" + + url: Optional[str] + """""" + + headers: Optional[dict[str, str]] + """""" + + body_segments: Optional[list[dict[str, Any]]] + """""" + + +ReplayRequestOrDict = Union[ReplayRequest, ReplayRequestDict] + + +class ReplayResponse(_common.BaseModel): + """Represents a single response in a replay.""" + + status_code: Optional[int] = Field(default=None, description="""""") + headers: Optional[dict[str, str]] = Field(default=None, description="""""") + body_segments: Optional[list[dict[str, Any]]] = Field( + default=None, description="""""" + ) + sdk_response_segments: Optional[list[dict[str, Any]]] = Field( + default=None, description="""""" + ) + + +class ReplayResponseDict(TypedDict, total=False): + """Represents a single response in a replay.""" + + status_code: Optional[int] + """""" + + headers: Optional[dict[str, str]] + """""" + + body_segments: Optional[list[dict[str, Any]]] + """""" + + sdk_response_segments: Optional[list[dict[str, Any]]] + """""" + + +ReplayResponseOrDict = Union[ReplayResponse, ReplayResponseDict] + + +class ReplayInteraction(_common.BaseModel): + """Represents a single interaction, request and response in a replay.""" + + request: Optional[ReplayRequest] = Field(default=None, description="""""") + response: Optional[ReplayResponse] = Field(default=None, description="""""") + + +class ReplayInteractionDict(TypedDict, total=False): + """Represents a single interaction, request and response in a replay.""" + + request: Optional[ReplayRequestDict] + """""" + + response: Optional[ReplayResponseDict] + """""" + + +ReplayInteractionOrDict = Union[ReplayInteraction, ReplayInteractionDict] + + +class ReplayFile(_common.BaseModel): + """Represents a recorded session.""" + + replay_id: Optional[str] = Field(default=None, description="""""") + interactions: Optional[list[ReplayInteraction]] = Field( + default=None, description="""""" + ) + + +class ReplayFileDict(TypedDict, total=False): + """Represents a recorded session.""" + + replay_id: Optional[str] + """""" + + interactions: Optional[list[ReplayInteractionDict]] + """""" + + +ReplayFileOrDict = Union[ReplayFile, ReplayFileDict] + + +class UploadFileConfig(_common.BaseModel): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + name: Optional[str] = Field( + default=None, + description="""The name of the file in the destination (e.g., 'files/sample-image'. If not provided one will be generated.""", + ) + mime_type: Optional[str] = Field( + default=None, + description="""mime_type: The MIME type of the file. If not provided, it will be inferred from the file extension.""", + ) + display_name: Optional[str] = Field( + default=None, description="""Optional display name of the file.""" + ) + + +class UploadFileConfigDict(TypedDict, total=False): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + name: Optional[str] + """The name of the file in the destination (e.g., 'files/sample-image'. If not provided one will be generated.""" + + mime_type: Optional[str] + """mime_type: The MIME type of the file. If not provided, it will be inferred from the file extension.""" + + display_name: Optional[str] + """Optional display name of the file.""" + + +UploadFileConfigOrDict = Union[UploadFileConfig, UploadFileConfigDict] + + +class DownloadFileConfig(_common.BaseModel): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DownloadFileConfigDict(TypedDict, total=False): + """Used to override the default configuration.""" + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + +DownloadFileConfigOrDict = Union[DownloadFileConfig, DownloadFileConfigDict] + + +class UpscaleImageConfig(_common.BaseModel): + """Configuration for upscaling an image. + + For more information on this configuration, refer to + the `Imagen API reference documentation + <https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api>`_. + """ + + http_options: Optional[dict[str, Any]] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + include_rai_reason: Optional[bool] = Field( + default=None, + description="""Whether to include a reason for filtered-out images in the + response.""", + ) + output_mime_type: Optional[str] = Field( + default=None, + description="""The image format that the output should be saved as.""", + ) + output_compression_quality: Optional[int] = Field( + default=None, + description="""The level of compression if the ``output_mime_type`` is + ``image/jpeg``.""", + ) + + +class UpscaleImageConfigDict(TypedDict, total=False): + """Configuration for upscaling an image. + + For more information on this configuration, refer to + the `Imagen API reference documentation + <https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api>`_. + """ + + http_options: Optional[dict[str, Any]] + """Used to override HTTP request options.""" + + include_rai_reason: Optional[bool] + """Whether to include a reason for filtered-out images in the + response.""" + + output_mime_type: Optional[str] + """The image format that the output should be saved as.""" + + output_compression_quality: Optional[int] + """The level of compression if the ``output_mime_type`` is + ``image/jpeg``.""" + + +UpscaleImageConfigOrDict = Union[UpscaleImageConfig, UpscaleImageConfigDict] + + +class UpscaleImageParameters(_common.BaseModel): + """User-facing config UpscaleImageParameters.""" + + model: Optional[str] = Field( + default=None, description="""The model to use.""" + ) + image: Optional[Image] = Field( + default=None, description="""The input image to upscale.""" + ) + upscale_factor: Optional[str] = Field( + default=None, + description="""The factor to upscale the image (x2 or x4).""", + ) + config: Optional[UpscaleImageConfig] = Field( + default=None, description="""Configuration for upscaling.""" + ) + + +class UpscaleImageParametersDict(TypedDict, total=False): + """User-facing config UpscaleImageParameters.""" + + model: Optional[str] + """The model to use.""" + + image: Optional[ImageDict] + """The input image to upscale.""" + + upscale_factor: Optional[str] + """The factor to upscale the image (x2 or x4).""" + + config: Optional[UpscaleImageConfigDict] + """Configuration for upscaling.""" + + +UpscaleImageParametersOrDict = Union[ + UpscaleImageParameters, UpscaleImageParametersDict +] + + +class RawReferenceImage(_common.BaseModel): + """Class that represents a Raw reference image. + + A raw reference image represents the base image to edit, provided by the user. + It can optionally be provided in addition to a mask reference image or + a style reference image. + """ + + reference_image: Optional[Image] = Field( + default=None, + description="""The reference image for the editing operation.""", + ) + reference_id: Optional[int] = Field( + default=None, description="""The id of the reference image.""" + ) + reference_type: Optional[str] = Field( + default=None, description="""The type of the reference image.""" + ) + + def __init__( + self, + reference_image: Optional[Image] = None, + reference_id: Optional[int] = None, + ): + super().__init__( + reference_image=reference_image, + reference_id=reference_id, + reference_type='REFERENCE_TYPE_RAW', + ) + + +class RawReferenceImageDict(TypedDict, total=False): + """Class that represents a Raw reference image. + + A raw reference image represents the base image to edit, provided by the user. + It can optionally be provided in addition to a mask reference image or + a style reference image. + """ + + reference_image: Optional[ImageDict] + """The reference image for the editing operation.""" + + reference_id: Optional[int] + """The id of the reference image.""" + + reference_type: Optional[str] + """The type of the reference image.""" + + +RawReferenceImageOrDict = Union[RawReferenceImage, RawReferenceImageDict] + + +class MaskReferenceImage(_common.BaseModel): + """Class that represents a Mask reference image. + + This encapsulates either a mask image provided by the user and configs for + the user provided mask, or only config parameters for the model to generate + a mask. + + A mask image is an image whose non-zero values indicate where to edit the base + image. If the user provides a mask image, the mask must be in the same + dimensions as the raw image. + """ + + reference_image: Optional[Image] = Field( + default=None, + description="""The reference image for the editing operation.""", + ) + reference_id: Optional[int] = Field( + default=None, description="""The id of the reference image.""" + ) + reference_type: Optional[str] = Field( + default=None, description="""The type of the reference image.""" + ) + config: Optional[MaskReferenceConfig] = Field( + default=None, + description="""Configuration for the mask reference image.""", + ) + """Re-map config to mask_reference_config to send to API.""" + mask_image_config: Optional['MaskReferenceConfig'] = Field( + default=None, description="""""" + ) + + def __init__( + self, + reference_image: Optional[Image] = None, + reference_id: Optional[int] = None, + config: Optional['MaskReferenceConfig'] = None, + ): + super().__init__( + reference_image=reference_image, + reference_id=reference_id, + reference_type='REFERENCE_TYPE_MASK', + ) + self.mask_image_config = config + + +class MaskReferenceImageDict(TypedDict, total=False): + """Class that represents a Mask reference image. + + This encapsulates either a mask image provided by the user and configs for + the user provided mask, or only config parameters for the model to generate + a mask. + + A mask image is an image whose non-zero values indicate where to edit the base + image. If the user provides a mask image, the mask must be in the same + dimensions as the raw image. + """ + + reference_image: Optional[ImageDict] + """The reference image for the editing operation.""" + + reference_id: Optional[int] + """The id of the reference image.""" + + reference_type: Optional[str] + """The type of the reference image.""" + + config: Optional[MaskReferenceConfigDict] + """Configuration for the mask reference image.""" + + +MaskReferenceImageOrDict = Union[MaskReferenceImage, MaskReferenceImageDict] + + +class ControlReferenceImage(_common.BaseModel): + """Class that represents a Control reference image. + + The image of the control reference image is either a control image provided + by the user, or a regular image which the backend will use to generate a + control image of. In the case of the latter, the + enable_control_image_computation field in the config should be set to True. + + A control image is an image that represents a sketch image of areas for the + model to fill in based on the prompt. + """ + + reference_image: Optional[Image] = Field( + default=None, + description="""The reference image for the editing operation.""", + ) + reference_id: Optional[int] = Field( + default=None, description="""The id of the reference image.""" + ) + reference_type: Optional[str] = Field( + default=None, description="""The type of the reference image.""" + ) + config: Optional[ControlReferenceConfig] = Field( + default=None, + description="""Configuration for the control reference image.""", + ) + """Re-map config to control_reference_config to send to API.""" + control_image_config: Optional['ControlReferenceConfig'] = Field( + default=None, description="""""" + ) + + def __init__( + self, + reference_image: Optional[Image] = None, + reference_id: Optional[int] = None, + config: Optional['ControlReferenceConfig'] = None, + ): + super().__init__( + reference_image=reference_image, + reference_id=reference_id, + reference_type='REFERENCE_TYPE_CONTROL', + ) + self.control_image_config = config + + +class ControlReferenceImageDict(TypedDict, total=False): + """Class that represents a Control reference image. + + The image of the control reference image is either a control image provided + by the user, or a regular image which the backend will use to generate a + control image of. In the case of the latter, the + enable_control_image_computation field in the config should be set to True. + + A control image is an image that represents a sketch image of areas for the + model to fill in based on the prompt. + """ + + reference_image: Optional[ImageDict] + """The reference image for the editing operation.""" + + reference_id: Optional[int] + """The id of the reference image.""" + + reference_type: Optional[str] + """The type of the reference image.""" + + config: Optional[ControlReferenceConfigDict] + """Configuration for the control reference image.""" + + +ControlReferenceImageOrDict = Union[ + ControlReferenceImage, ControlReferenceImageDict +] + + +class StyleReferenceImage(_common.BaseModel): + """Class that represents a Style reference image. + + This encapsulates a style reference image provided by the user, and + additionally optional config parameters for the style reference image. + + A raw reference image can also be provided as a destination for the style to + be applied to. + """ + + reference_image: Optional[Image] = Field( + default=None, + description="""The reference image for the editing operation.""", + ) + reference_id: Optional[int] = Field( + default=None, description="""The id of the reference image.""" + ) + reference_type: Optional[str] = Field( + default=None, description="""The type of the reference image.""" + ) + config: Optional[StyleReferenceConfig] = Field( + default=None, + description="""Configuration for the style reference image.""", + ) + """Re-map config to style_reference_config to send to API.""" + style_image_config: Optional['StyleReferenceConfig'] = Field( + default=None, description="""""" + ) + + def __init__( + self, + reference_image: Optional[Image] = None, + reference_id: Optional[int] = None, + config: Optional['StyleReferenceConfig'] = None, + ): + super().__init__( + reference_image=reference_image, + reference_id=reference_id, + reference_type='REFERENCE_TYPE_STYLE', + ) + self.style_image_config = config + + +class StyleReferenceImageDict(TypedDict, total=False): + """Class that represents a Style reference image. + + This encapsulates a style reference image provided by the user, and + additionally optional config parameters for the style reference image. + + A raw reference image can also be provided as a destination for the style to + be applied to. + """ + + reference_image: Optional[ImageDict] + """The reference image for the editing operation.""" + + reference_id: Optional[int] + """The id of the reference image.""" + + reference_type: Optional[str] + """The type of the reference image.""" + + config: Optional[StyleReferenceConfigDict] + """Configuration for the style reference image.""" + + +StyleReferenceImageOrDict = Union[StyleReferenceImage, StyleReferenceImageDict] + + +class SubjectReferenceImage(_common.BaseModel): + """Class that represents a Subject reference image. + + This encapsulates a subject reference image provided by the user, and + additionally optional config parameters for the subject reference image. + + A raw reference image can also be provided as a destination for the subject to + be applied to. + """ + + reference_image: Optional[Image] = Field( + default=None, + description="""The reference image for the editing operation.""", + ) + reference_id: Optional[int] = Field( + default=None, description="""The id of the reference image.""" + ) + reference_type: Optional[str] = Field( + default=None, description="""The type of the reference image.""" + ) + config: Optional[SubjectReferenceConfig] = Field( + default=None, + description="""Configuration for the subject reference image.""", + ) + """Re-map config to subject_reference_config to send to API.""" + subject_image_config: Optional['SubjectReferenceConfig'] = Field( + default=None, description="""""" + ) + + def __init__( + self, + reference_image: Optional[Image] = None, + reference_id: Optional[int] = None, + config: Optional['SubjectReferenceConfig'] = None, + ): + super().__init__( + reference_image=reference_image, + reference_id=reference_id, + reference_type='REFERENCE_TYPE_SUBJECT', + ) + self.subject_image_config = config + + +class SubjectReferenceImageDict(TypedDict, total=False): + """Class that represents a Subject reference image. + + This encapsulates a subject reference image provided by the user, and + additionally optional config parameters for the subject reference image. + + A raw reference image can also be provided as a destination for the subject to + be applied to. + """ + + reference_image: Optional[ImageDict] + """The reference image for the editing operation.""" + + reference_id: Optional[int] + """The id of the reference image.""" + + reference_type: Optional[str] + """The type of the reference image.""" + + config: Optional[SubjectReferenceConfigDict] + """Configuration for the subject reference image.""" + + +SubjectReferenceImageOrDict = Union[ + SubjectReferenceImage, SubjectReferenceImageDict +] + + +class LiveServerSetupComplete(_common.BaseModel): + """Sent in response to a `LiveGenerateContentSetup` message from the client.""" + + pass + + +class LiveServerSetupCompleteDict(TypedDict, total=False): + """Sent in response to a `LiveGenerateContentSetup` message from the client.""" + + pass + + +LiveServerSetupCompleteOrDict = Union[ + LiveServerSetupComplete, LiveServerSetupCompleteDict +] + + +class LiveServerContent(_common.BaseModel): + """Incremental server update generated by the model in response to client messages. + + Content is generated as quickly as possible, and not in real time. Clients + may choose to buffer and play it out in real time. + """ + + model_turn: Optional[Content] = Field( + default=None, + description="""The content that the model has generated as part of the current conversation with the user.""", + ) + turn_complete: Optional[bool] = Field( + default=None, + description="""If true, indicates that the model is done generating. Generation will only start in response to additional client messages. Can be set alongside `content`, indicating that the `content` is the last in the turn.""", + ) + interrupted: Optional[bool] = Field( + default=None, + description="""If true, indicates that a client message has interrupted current model generation. If the client is playing out the content in realtime, this is a good signal to stop and empty the current queue. If the client is playing out the content in realtime, this is a good signal to stop and empty the current playback queue.""", + ) + + +class LiveServerContentDict(TypedDict, total=False): + """Incremental server update generated by the model in response to client messages. + + Content is generated as quickly as possible, and not in real time. Clients + may choose to buffer and play it out in real time. + """ + + model_turn: Optional[ContentDict] + """The content that the model has generated as part of the current conversation with the user.""" + + turn_complete: Optional[bool] + """If true, indicates that the model is done generating. Generation will only start in response to additional client messages. Can be set alongside `content`, indicating that the `content` is the last in the turn.""" + + interrupted: Optional[bool] + """If true, indicates that a client message has interrupted current model generation. If the client is playing out the content in realtime, this is a good signal to stop and empty the current queue. If the client is playing out the content in realtime, this is a good signal to stop and empty the current playback queue.""" + + +LiveServerContentOrDict = Union[LiveServerContent, LiveServerContentDict] + + +class LiveServerToolCall(_common.BaseModel): + """Request for the client to execute the `function_calls` and return the responses with the matching `id`s.""" + + function_calls: Optional[list[FunctionCall]] = Field( + default=None, description="""The function call to be executed.""" + ) + + +class LiveServerToolCallDict(TypedDict, total=False): + """Request for the client to execute the `function_calls` and return the responses with the matching `id`s.""" + + function_calls: Optional[list[FunctionCallDict]] + """The function call to be executed.""" + + +LiveServerToolCallOrDict = Union[LiveServerToolCall, LiveServerToolCallDict] + + +class LiveServerToolCallCancellation(_common.BaseModel): + """Notification for the client that a previously issued `ToolCallMessage` with the specified `id`s should have been not executed and should be cancelled. + + If there were side-effects to those tool calls, clients may attempt to undo + the tool calls. This message occurs only in cases where the clients interrupt + server turns. + """ + + ids: Optional[list[str]] = Field( + default=None, description="""The ids of the tool calls to be cancelled.""" + ) + + +class LiveServerToolCallCancellationDict(TypedDict, total=False): + """Notification for the client that a previously issued `ToolCallMessage` with the specified `id`s should have been not executed and should be cancelled. + + If there were side-effects to those tool calls, clients may attempt to undo + the tool calls. This message occurs only in cases where the clients interrupt + server turns. + """ + + ids: Optional[list[str]] + """The ids of the tool calls to be cancelled.""" + + +LiveServerToolCallCancellationOrDict = Union[ + LiveServerToolCallCancellation, LiveServerToolCallCancellationDict +] + + +class LiveServerMessage(_common.BaseModel): + """Response message for API call.""" + + setup_complete: Optional[LiveServerSetupComplete] = Field( + default=None, + description="""Sent in response to a `LiveClientSetup` message from the client.""", + ) + server_content: Optional[LiveServerContent] = Field( + default=None, + description="""Content generated by the model in response to client messages.""", + ) + tool_call: Optional[LiveServerToolCall] = Field( + default=None, + description="""Request for the client to execute the `function_calls` and return the responses with the matching `id`s.""", + ) + tool_call_cancellation: Optional[LiveServerToolCallCancellation] = Field( + default=None, + description="""Notification for the client that a previously issued `ToolCallMessage` with the specified `id`s should have been not executed and should be cancelled.""", + ) + + @property + def text(self) -> Optional[str]: + """Returns the concatenation of all text parts in the response.""" + if ( + not self.server_content + or not self.server_content + or not self.server_content.model_turn + or not self.server_content.model_turn.parts + ): + return None + text = '' + for part in self.server_content.model_turn.parts: + if isinstance(part.text, str): + if isinstance(part.thought, bool) and part.thought: + continue + text += part.text + return text if text else None + + @property + def data(self) -> Optional[bytes]: + """Returns the concatenation of all inline data parts in the response.""" + if ( + not self.server_content + or not self.server_content + or not self.server_content.model_turn + or not self.server_content.model_turn.parts + ): + return None + concatenated_data = b'' + for part in self.server_content.model_turn.parts: + if part.inline_data and isinstance(part.inline_data.data, bytes): + concatenated_data += part.inline_data.data + return concatenated_data if len(concatenated_data) > 0 else None + + +class LiveServerMessageDict(TypedDict, total=False): + """Response message for API call.""" + + setup_complete: Optional[LiveServerSetupCompleteDict] + """Sent in response to a `LiveClientSetup` message from the client.""" + + server_content: Optional[LiveServerContentDict] + """Content generated by the model in response to client messages.""" + + tool_call: Optional[LiveServerToolCallDict] + """Request for the client to execute the `function_calls` and return the responses with the matching `id`s.""" + + tool_call_cancellation: Optional[LiveServerToolCallCancellationDict] + """Notification for the client that a previously issued `ToolCallMessage` with the specified `id`s should have been not executed and should be cancelled.""" + + +LiveServerMessageOrDict = Union[LiveServerMessage, LiveServerMessageDict] + + +class LiveClientSetup(_common.BaseModel): + """Message contains configuration that will apply for the duration of the streaming session.""" + + model: Optional[str] = Field( + default=None, + description=""" + The fully qualified name of the publisher model or tuned model endpoint to + use. + """, + ) + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="""The generation configuration for the session.""", + ) + system_instruction: Optional[Content] = Field( + default=None, + description="""The user provided system instructions for the model. + Note: only text should be used in parts and content in each part will be + in a separate paragraph.""", + ) + tools: Optional[list[Tool]] = Field( + default=None, + description=""" A list of `Tools` the model may use to generate the next response. + + A `Tool` is a piece of code that enables the system to interact with + external systems to perform an action, or set of actions, outside of + knowledge and scope of the model.""", + ) + + +class LiveClientSetupDict(TypedDict, total=False): + """Message contains configuration that will apply for the duration of the streaming session.""" + + model: Optional[str] + """ + The fully qualified name of the publisher model or tuned model endpoint to + use. + """ + + generation_config: Optional[GenerationConfigDict] + """The generation configuration for the session.""" + + system_instruction: Optional[ContentDict] + """The user provided system instructions for the model. + Note: only text should be used in parts and content in each part will be + in a separate paragraph.""" + + tools: Optional[list[ToolDict]] + """ A list of `Tools` the model may use to generate the next response. + + A `Tool` is a piece of code that enables the system to interact with + external systems to perform an action, or set of actions, outside of + knowledge and scope of the model.""" + + +LiveClientSetupOrDict = Union[LiveClientSetup, LiveClientSetupDict] + + +class LiveClientContent(_common.BaseModel): + """Incremental update of the current conversation delivered from the client. + + All the content here will unconditionally be appended to the conversation + history and used as part of the prompt to the model to generate content. + + A message here will interrupt any current model generation. + """ + + turns: Optional[list[Content]] = Field( + default=None, + description="""The content appended to the current conversation with the model. + + For single-turn queries, this is a single instance. For multi-turn + queries, this is a repeated field that contains conversation history + + latest request. + """, + ) + turn_complete: Optional[bool] = Field( + default=None, + description="""If true, indicates that the server content generation should start with + the currently accumulated prompt. Otherwise, the server will await + additional messages before starting generation.""", + ) + + +class LiveClientContentDict(TypedDict, total=False): + """Incremental update of the current conversation delivered from the client. + + All the content here will unconditionally be appended to the conversation + history and used as part of the prompt to the model to generate content. + + A message here will interrupt any current model generation. + """ + + turns: Optional[list[ContentDict]] + """The content appended to the current conversation with the model. + + For single-turn queries, this is a single instance. For multi-turn + queries, this is a repeated field that contains conversation history + + latest request. + """ + + turn_complete: Optional[bool] + """If true, indicates that the server content generation should start with + the currently accumulated prompt. Otherwise, the server will await + additional messages before starting generation.""" + + +LiveClientContentOrDict = Union[LiveClientContent, LiveClientContentDict] + + +class LiveClientRealtimeInput(_common.BaseModel): + """User input that is sent in real time. + + This is different from `ClientContentUpdate` in a few ways: + - Can be sent continuously without interruption the model generation. + - If there is a need to mix data interleaved across the + `ClientContentUpdate` and the `RealtimeUpdate`, server will attempt to + optimize for best response, but there are no guarantees. + - End of turn is not explicitly specified, but is rather derived from user + activity, e.g. end of speech. + - Even before the end of turn, the data will be processed incrementally + to optimize for a fast start of the response from the model. + - Is always assumed to be the user's input (cannot be used to populate + conversation history). + """ + + media_chunks: Optional[list[Blob]] = Field( + default=None, description="""Inlined bytes data for media input.""" + ) + + +class LiveClientRealtimeInputDict(TypedDict, total=False): + """User input that is sent in real time. + + This is different from `ClientContentUpdate` in a few ways: + - Can be sent continuously without interruption the model generation. + - If there is a need to mix data interleaved across the + `ClientContentUpdate` and the `RealtimeUpdate`, server will attempt to + optimize for best response, but there are no guarantees. + - End of turn is not explicitly specified, but is rather derived from user + activity, e.g. end of speech. + - Even before the end of turn, the data will be processed incrementally + to optimize for a fast start of the response from the model. + - Is always assumed to be the user's input (cannot be used to populate + conversation history). + """ + + media_chunks: Optional[list[BlobDict]] + """Inlined bytes data for media input.""" + + +LiveClientRealtimeInputOrDict = Union[ + LiveClientRealtimeInput, LiveClientRealtimeInputDict +] + + +class LiveClientToolResponse(_common.BaseModel): + """Client generated response to a `ToolCall` received from the server. + + Individual `FunctionResponse` objects are matched to the respective + `FunctionCall` objects by the `id` field. + + Note that in the unary and server-streaming GenerateContent APIs function + calling happens by exchanging the `Content` parts, while in the bidi + GenerateContent APIs function calling happens over this dedicated set of + messages. + """ + + function_responses: Optional[list[FunctionResponse]] = Field( + default=None, description="""The response to the function calls.""" + ) + + +class LiveClientToolResponseDict(TypedDict, total=False): + """Client generated response to a `ToolCall` received from the server. + + Individual `FunctionResponse` objects are matched to the respective + `FunctionCall` objects by the `id` field. + + Note that in the unary and server-streaming GenerateContent APIs function + calling happens by exchanging the `Content` parts, while in the bidi + GenerateContent APIs function calling happens over this dedicated set of + messages. + """ + + function_responses: Optional[list[FunctionResponseDict]] + """The response to the function calls.""" + + +LiveClientToolResponseOrDict = Union[ + LiveClientToolResponse, LiveClientToolResponseDict +] + + +class LiveClientMessage(_common.BaseModel): + """Messages sent by the client in the API call.""" + + setup: Optional[LiveClientSetup] = Field( + default=None, + description="""Message to be sent by the system when connecting to the API. SDK users should not send this message.""", + ) + client_content: Optional[LiveClientContent] = Field( + default=None, + description="""Incremental update of the current conversation delivered from the client.""", + ) + realtime_input: Optional[LiveClientRealtimeInput] = Field( + default=None, description="""User input that is sent in real time.""" + ) + tool_response: Optional[LiveClientToolResponse] = Field( + default=None, + description="""Response to a `ToolCallMessage` received from the server.""", + ) + + +class LiveClientMessageDict(TypedDict, total=False): + """Messages sent by the client in the API call.""" + + setup: Optional[LiveClientSetupDict] + """Message to be sent by the system when connecting to the API. SDK users should not send this message.""" + + client_content: Optional[LiveClientContentDict] + """Incremental update of the current conversation delivered from the client.""" + + realtime_input: Optional[LiveClientRealtimeInputDict] + """User input that is sent in real time.""" + + tool_response: Optional[LiveClientToolResponseDict] + """Response to a `ToolCallMessage` received from the server.""" + + +LiveClientMessageOrDict = Union[LiveClientMessage, LiveClientMessageDict] + + +class LiveConnectConfig(_common.BaseModel): + """Config class for the session.""" + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="""The generation configuration for the session.""", + ) + response_modalities: Optional[list[Modality]] = Field( + default=None, + description="""The requested modalities of the response. Represents the set of + modalities that the model can return. Defaults to AUDIO if not specified. + """, + ) + speech_config: Optional[SpeechConfig] = Field( + default=None, + description="""The speech generation configuration. + """, + ) + system_instruction: Optional[Content] = Field( + default=None, + description="""The user provided system instructions for the model. + Note: only text should be used in parts and content in each part will be + in a separate paragraph.""", + ) + tools: Optional[list[Tool]] = Field( + default=None, + description="""A list of `Tools` the model may use to generate the next response. + + A `Tool` is a piece of code that enables the system to interact with + external systems to perform an action, or set of actions, outside of + knowledge and scope of the model.""", + ) + + +class LiveConnectConfigDict(TypedDict, total=False): + """Config class for the session.""" + + generation_config: Optional[GenerationConfigDict] + """The generation configuration for the session.""" + + response_modalities: Optional[list[Modality]] + """The requested modalities of the response. Represents the set of + modalities that the model can return. Defaults to AUDIO if not specified. + """ + + speech_config: Optional[SpeechConfigDict] + """The speech generation configuration. + """ + + system_instruction: Optional[ContentDict] + """The user provided system instructions for the model. + Note: only text should be used in parts and content in each part will be + in a separate paragraph.""" + + tools: Optional[list[ToolDict]] + """A list of `Tools` the model may use to generate the next response. + + A `Tool` is a piece of code that enables the system to interact with + external systems to perform an action, or set of actions, outside of + knowledge and scope of the model.""" + + +LiveConnectConfigOrDict = Union[LiveConnectConfig, LiveConnectConfigDict] diff --git a/.venv/lib/python3.12/site-packages/google/genai/version.py b/.venv/lib/python3.12/site-packages/google/genai/version.py new file mode 100644 index 00000000..f69c999e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/version.py @@ -0,0 +1,16 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__version__ = '0.6.0' # x-release-please-version |