aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/google/genai/live.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/google/genai/live.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/google/genai/live.py')
-rw-r--r--.venv/lib/python3.12/site-packages/google/genai/live.py696
1 files changed, 696 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/google/genai/live.py b/.venv/lib/python3.12/site-packages/google/genai/live.py
new file mode 100644
index 00000000..586ba0ce
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/google/genai/live.py
@@ -0,0 +1,696 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Live client."""
+
+import asyncio
+import base64
+import contextlib
+import json
+import logging
+from typing import AsyncIterator, Optional, Sequence, Union
+
+import google.auth
+from websockets import ConnectionClosed
+
+from . import _common
+from . import _transformers as t
+from . import client
+from . import types
+from ._api_client import ApiClient
+from ._common import get_value_by_path as getv
+from ._common import set_value_by_path as setv
+from .models import _Content_from_mldev
+from .models import _Content_from_vertex
+from .models import _Content_to_mldev
+from .models import _Content_to_vertex
+from .models import _GenerateContentConfig_to_mldev
+from .models import _GenerateContentConfig_to_vertex
+from .models import _SafetySetting_to_mldev
+from .models import _SafetySetting_to_vertex
+from .models import _SpeechConfig_to_mldev
+from .models import _SpeechConfig_to_vertex
+from .models import _Tool_to_mldev
+from .models import _Tool_to_vertex
+
+try:
+ from websockets.asyncio.client import ClientConnection
+ from websockets.asyncio.client import connect
+except ModuleNotFoundError:
+ from websockets.client import ClientConnection
+ from websockets.client import connect
+
+
+_FUNCTION_RESPONSE_REQUIRES_ID = (
+ 'FunctionResponse request must have an `id` field from the'
+ ' response of a ToolCall.FunctionalCalls in Google AI.'
+)
+
+
+class AsyncSession:
+ """AsyncSession."""
+
+ def __init__(self, api_client: client.ApiClient, websocket: ClientConnection):
+ self._api_client = api_client
+ self._ws = websocket
+
+ async def send(
+ self,
+ *,
+ input: Union[
+ types.ContentListUnion,
+ types.ContentListUnionDict,
+ types.LiveClientContentOrDict,
+ types.LiveClientRealtimeInputOrDict,
+ types.LiveClientRealtimeInputOrDict,
+ types.LiveClientToolResponseOrDict,
+ types.FunctionResponseOrDict,
+ Sequence[types.FunctionResponseOrDict],
+ ],
+ end_of_turn: Optional[bool] = False,
+ ):
+ """Send input to the model.
+
+ The method will send the input request to the server.
+
+ Args:
+ input: The input request to the model.
+ end_of_turn: Whether the input is the last message in a turn.
+
+ Example usage:
+
+ .. code-block:: python
+
+ client = genai.Client(api_key=API_KEY)
+
+ async with client.aio.live.connect(model='...') as session:
+ await session.send(input='Hello world!', end_of_turn=True)
+ async for message in session.receive():
+ print(message)
+ """
+ client_message = self._parse_client_message(input, end_of_turn)
+ await self._ws.send(json.dumps(client_message))
+
+ async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
+ """Receive model responses from the server.
+
+ The method will yield the model responses from the server. The returned
+ responses will represent a complete model turn. When the returned message
+ is function call, user must call `send` with the function response to
+ continue the turn.
+
+ Yields:
+ The model responses from the server.
+
+ Example usage:
+
+ .. code-block:: python
+
+ client = genai.Client(api_key=API_KEY)
+
+ async with client.aio.live.connect(model='...') as session:
+ await session.send(input='Hello world!', end_of_turn=True)
+ async for message in session.receive():
+ print(message)
+ """
+ # TODO(b/365983264) Handle intermittent issues for the user.
+ while result := await self._receive():
+ if result.server_content and result.server_content.turn_complete:
+ yield result
+ break
+ yield result
+
+ async def start_stream(
+ self, *, stream: AsyncIterator[bytes], mime_type: str
+ ) -> AsyncIterator[types.LiveServerMessage]:
+ """start a live session from a data stream.
+
+ The interaction terminates when the input stream is complete.
+ This method will start two async tasks. One task will be used to send the
+ input stream to the model and the other task will be used to receive the
+ responses from the model.
+
+ Args:
+ stream: An iterator that yields the model response.
+ mime_type: The MIME type of the data in the stream.
+
+ Yields:
+ The audio bytes received from the model and server response messages.
+
+ Example usage:
+
+ .. code-block:: python
+
+ client = genai.Client(api_key=API_KEY)
+ config = {'response_modalities': ['AUDIO']}
+ async def audio_stream():
+ stream = read_audio()
+ for data in stream:
+ yield data
+ async with client.aio.live.connect(model='...') as session:
+ for audio in session.start_stream(stream = audio_stream(),
+ mime_type = 'audio/pcm'):
+ play_audio_chunk(audio.data)
+ """
+ stop_event = asyncio.Event()
+ # Start the send loop. When stream is complete stop_event is set.
+ asyncio.create_task(self._send_loop(stream, mime_type, stop_event))
+ recv_task = None
+ while not stop_event.is_set():
+ try:
+ recv_task = asyncio.create_task(self._receive())
+ await asyncio.wait(
+ [
+ recv_task,
+ asyncio.create_task(stop_event.wait()),
+ ],
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+ if recv_task.done():
+ yield recv_task.result()
+ # Give a chance for the send loop to process requests.
+ await asyncio.sleep(10**-12)
+ except ConnectionClosed:
+ break
+ if recv_task is not None and not recv_task.done():
+ recv_task.cancel()
+ # Wait for the task to finish (cancelled or not)
+ try:
+ await recv_task
+ except asyncio.CancelledError:
+ pass
+
+ async def _receive(self) -> types.LiveServerMessage:
+ parameter_model = types.LiveServerMessage()
+ raw_response = await self._ws.recv(decode=False)
+ if raw_response:
+ try:
+ response = json.loads(raw_response)
+ except json.decoder.JSONDecodeError:
+ raise ValueError(f'Failed to parse response: {raw_response}')
+ else:
+ response = {}
+ if self._api_client.vertexai:
+ response_dict = self._LiveServerMessage_from_vertex(response)
+ else:
+ response_dict = self._LiveServerMessage_from_mldev(response)
+
+ return types.LiveServerMessage._from_response(
+ response_dict, parameter_model
+ )
+
+ async def _send_loop(
+ self,
+ data_stream: AsyncIterator[bytes],
+ mime_type: str,
+ stop_event: asyncio.Event,
+ ):
+ async for data in data_stream:
+ input = {'data': data, 'mimeType': mime_type}
+ await self.send(input=input)
+ # Give a chance for the receive loop to process responses.
+ await asyncio.sleep(10**-12)
+ # Give a chance for the receiver to process the last response.
+ stop_event.set()
+
+ def _LiveServerContent_from_mldev(
+ self,
+ from_object: Union[dict, object],
+ ) -> dict:
+ to_object = {}
+ if getv(from_object, ['modelTurn']) is not None:
+ setv(
+ to_object,
+ ['model_turn'],
+ _Content_from_mldev(
+ self._api_client,
+ getv(from_object, ['modelTurn']),
+ ),
+ )
+ if getv(from_object, ['turnComplete']) is not None:
+ setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
+ if getv(from_object, ['interrupted']) is not None:
+ setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
+ return to_object
+
+ def _LiveToolCall_from_mldev(
+ self,
+ from_object: Union[dict, object],
+ ) -> dict:
+ to_object = {}
+ if getv(from_object, ['functionCalls']) is not None:
+ setv(
+ to_object,
+ ['function_calls'],
+ getv(from_object, ['functionCalls']),
+ )
+ return to_object
+
+ def _LiveToolCall_from_vertex(
+ self,
+ from_object: Union[dict, object],
+ ) -> dict:
+ to_object = {}
+ if getv(from_object, ['functionCalls']) is not None:
+ setv(
+ to_object,
+ ['function_calls'],
+ getv(from_object, ['functionCalls']),
+ )
+ return to_object
+
+ def _LiveServerMessage_from_mldev(
+ self,
+ from_object: Union[dict, object],
+ ) -> dict:
+ to_object = {}
+ if getv(from_object, ['serverContent']) is not None:
+ setv(
+ to_object,
+ ['server_content'],
+ self._LiveServerContent_from_mldev(
+ getv(from_object, ['serverContent'])
+ ),
+ )
+ if getv(from_object, ['toolCall']) is not None:
+ setv(
+ to_object,
+ ['tool_call'],
+ self._LiveToolCall_from_mldev(getv(from_object, ['toolCall'])),
+ )
+ if getv(from_object, ['toolCallCancellation']) is not None:
+ setv(
+ to_object,
+ ['tool_call_cancellation'],
+ getv(from_object, ['toolCallCancellation']),
+ )
+ return to_object
+
+ def _LiveServerContent_from_vertex(
+ self,
+ from_object: Union[dict, object],
+ ) -> dict:
+ to_object = {}
+ if getv(from_object, ['modelTurn']) is not None:
+ setv(
+ to_object,
+ ['model_turn'],
+ _Content_from_vertex(
+ self._api_client,
+ getv(from_object, ['modelTurn']),
+ ),
+ )
+ if getv(from_object, ['turnComplete']) is not None:
+ setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
+ if getv(from_object, ['interrupted']) is not None:
+ setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
+ return to_object
+
+ def _LiveServerMessage_from_vertex(
+ self,
+ from_object: Union[dict, object],
+ ) -> dict:
+ to_object = {}
+ if getv(from_object, ['serverContent']) is not None:
+ setv(
+ to_object,
+ ['server_content'],
+ self._LiveServerContent_from_vertex(
+ getv(from_object, ['serverContent'])
+ ),
+ )
+
+ if getv(from_object, ['toolCall']) is not None:
+ setv(
+ to_object,
+ ['tool_call'],
+ self._LiveToolCall_from_vertex(getv(from_object, ['toolCall'])),
+ )
+ if getv(from_object, ['toolCallCancellation']) is not None:
+ setv(
+ to_object,
+ ['tool_call_cancellation'],
+ getv(from_object, ['toolCallCancellation']),
+ )
+ return to_object
+
+ def _parse_client_message(
+ self,
+ input: Union[
+ types.ContentListUnion,
+ types.ContentListUnionDict,
+ types.LiveClientContentOrDict,
+ types.LiveClientRealtimeInputOrDict,
+ types.LiveClientRealtimeInputOrDict,
+ types.LiveClientToolResponseOrDict,
+ types.FunctionResponseOrDict,
+ Sequence[types.FunctionResponseOrDict],
+ ],
+ end_of_turn: Optional[bool] = False,
+ ) -> dict:
+ if isinstance(input, str):
+ input = [input]
+ elif isinstance(input, dict) and 'data' in input:
+ if isinstance(input['data'], bytes):
+ decoded_data = base64.b64encode(input['data']).decode('utf-8')
+ input['data'] = decoded_data
+ input = [input]
+ elif isinstance(input, types.Blob):
+ input.data = base64.b64encode(input.data).decode('utf-8')
+ input = [input]
+ elif isinstance(input, dict) and 'name' in input and 'response' in input:
+ # ToolResponse.FunctionResponse
+ if not (self._api_client.vertexai) and 'id' not in input:
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
+ input = [input]
+
+ if isinstance(input, Sequence) and any(
+ isinstance(c, dict) and 'name' in c and 'response' in c for c in input
+ ):
+ # ToolResponse.FunctionResponse
+ if not (self._api_client.vertexai):
+ for item in input:
+ if 'id' not in item:
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
+ client_message = {'tool_response': {'function_responses': input}}
+ elif isinstance(input, Sequence) and any(isinstance(c, str) for c in input):
+ to_object = {}
+ if self._api_client.vertexai:
+ contents = [
+ _Content_to_vertex(self._api_client, item, to_object)
+ for item in t.t_contents(self._api_client, input)
+ ]
+ else:
+ contents = [
+ _Content_to_mldev(self._api_client, item, to_object)
+ for item in t.t_contents(self._api_client, input)
+ ]
+
+ client_message = {
+ 'client_content': {'turns': contents, 'turn_complete': end_of_turn}
+ }
+ elif isinstance(input, Sequence):
+ if any((isinstance(b, dict) and 'data' in b) for b in input):
+ pass
+ elif any(isinstance(b, types.Blob) for b in input):
+ input = [b.model_dump(exclude_none=True) for b in input]
+ else:
+ raise ValueError(
+ f'Unsupported input type "{type(input)}" or input content "{input}"'
+ )
+
+ client_message = {'realtime_input': {'media_chunks': input}}
+
+ elif isinstance(input, dict) and 'content' in input:
+ # TODO(b/365983264) Add validation checks for content_update input_dict.
+ client_message = {'client_content': input}
+ elif isinstance(input, types.LiveClientRealtimeInput):
+ client_message = {'realtime_input': input.model_dump(exclude_none=True)}
+ if isinstance(
+ client_message['realtime_input']['media_chunks'][0]['data'], bytes
+ ):
+ client_message['realtime_input']['media_chunks'] = [
+ {
+ 'data': base64.b64encode(item['data']).decode('utf-8'),
+ 'mime_type': item['mime_type'],
+ }
+ for item in client_message['realtime_input']['media_chunks']
+ ]
+
+ elif isinstance(input, types.LiveClientContent):
+ client_message = {'client_content': input.model_dump(exclude_none=True)}
+ elif isinstance(input, types.LiveClientToolResponse):
+ # ToolResponse.FunctionResponse
+ if not (self._api_client.vertexai) and not (
+ input.function_responses[0].id
+ ):
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
+ client_message = {'tool_response': input.model_dump(exclude_none=True)}
+ elif isinstance(input, types.FunctionResponse):
+ if not (self._api_client.vertexai) and not (input.id):
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
+ client_message = {
+ 'tool_response': {
+ 'function_responses': [input.model_dump(exclude_none=True)]
+ }
+ }
+ elif isinstance(input, Sequence) and isinstance(
+ input[0], types.FunctionResponse
+ ):
+ if not (self._api_client.vertexai) and not (input[0].id):
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
+ client_message = {
+ 'tool_response': {
+ 'function_responses': [
+ c.model_dump(exclude_none=True) for c in input
+ ]
+ }
+ }
+ else:
+ raise ValueError(
+ f'Unsupported input type "{type(input)}" or input content "{input}"'
+ )
+
+ return client_message
+
+ async def close(self):
+ # Close the websocket connection.
+ await self._ws.close()
+
+
+class AsyncLive(_common.BaseModule):
+ """AsyncLive."""
+
+ def _LiveSetup_to_mldev(
+ self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
+ ):
+ if isinstance(config, types.LiveConnectConfig):
+ from_object = config.model_dump(exclude_none=True)
+ else:
+ from_object = config
+
+ to_object = {}
+ if getv(from_object, ['generation_config']) is not None:
+ setv(
+ to_object,
+ ['generationConfig'],
+ _GenerateContentConfig_to_mldev(
+ self._api_client,
+ getv(from_object, ['generation_config']),
+ to_object,
+ ),
+ )
+ if getv(from_object, ['response_modalities']) is not None:
+ if getv(to_object, ['generationConfig']) is not None:
+ to_object['generationConfig']['responseModalities'] = from_object[
+ 'response_modalities'
+ ]
+ else:
+ to_object['generationConfig'] = {
+ 'responseModalities': from_object['response_modalities']
+ }
+ if getv(from_object, ['speech_config']) is not None:
+ if getv(to_object, ['generationConfig']) is not None:
+ to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev(
+ self._api_client,
+ t.t_speech_config(
+ self._api_client, getv(from_object, ['speech_config'])
+ ),
+ to_object,
+ )
+ else:
+ to_object['generationConfig'] = {
+ 'speechConfig': _SpeechConfig_to_mldev(
+ self._api_client,
+ t.t_speech_config(
+ self._api_client, getv(from_object, ['speech_config'])
+ ),
+ to_object,
+ )
+ }
+
+ if getv(from_object, ['system_instruction']) is not None:
+ setv(
+ to_object,
+ ['systemInstruction'],
+ _Content_to_mldev(
+ self._api_client,
+ t.t_content(
+ self._api_client, getv(from_object, ['system_instruction'])
+ ),
+ to_object,
+ ),
+ )
+ if getv(from_object, ['tools']) is not None:
+ setv(
+ to_object,
+ ['tools'],
+ [
+ _Tool_to_mldev(self._api_client, item, to_object)
+ for item in getv(from_object, ['tools'])
+ ],
+ )
+
+ return_value = {'setup': {'model': model}}
+ return_value['setup'].update(to_object)
+ return return_value
+
+ def _LiveSetup_to_vertex(
+ self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
+ ):
+ if isinstance(config, types.LiveConnectConfig):
+ from_object = config.model_dump(exclude_none=True)
+ else:
+ from_object = config
+
+ to_object = {}
+
+ if getv(from_object, ['generation_config']) is not None:
+ setv(
+ to_object,
+ ['generationConfig'],
+ _GenerateContentConfig_to_vertex(
+ self._api_client,
+ getv(from_object, ['generation_config']),
+ to_object,
+ ),
+ )
+ if getv(from_object, ['response_modalities']) is not None:
+ if getv(to_object, ['generationConfig']) is not None:
+ to_object['generationConfig']['responseModalities'] = from_object[
+ 'response_modalities'
+ ]
+ else:
+ to_object['generationConfig'] = {
+ 'responseModalities': from_object['response_modalities']
+ }
+ else:
+ # Set default to AUDIO to align with MLDev API.
+ if getv(to_object, ['generationConfig']) is not None:
+ to_object['generationConfig'].update({'responseModalities': ['AUDIO']})
+ else:
+ to_object.update(
+ {'generationConfig': {'responseModalities': ['AUDIO']}}
+ )
+ if getv(from_object, ['speech_config']) is not None:
+ if getv(to_object, ['generationConfig']) is not None:
+ to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex(
+ self._api_client,
+ t.t_speech_config(
+ self._api_client, getv(from_object, ['speech_config'])
+ ),
+ to_object,
+ )
+ else:
+ to_object['generationConfig'] = {
+ 'speechConfig': _SpeechConfig_to_vertex(
+ self._api_client,
+ t.t_speech_config(
+ self._api_client, getv(from_object, ['speech_config'])
+ ),
+ to_object,
+ )
+ }
+ if getv(from_object, ['system_instruction']) is not None:
+ setv(
+ to_object,
+ ['systemInstruction'],
+ _Content_to_vertex(
+ self._api_client,
+ t.t_content(
+ self._api_client, getv(from_object, ['system_instruction'])
+ ),
+ to_object,
+ ),
+ )
+ if getv(from_object, ['tools']) is not None:
+ setv(
+ to_object,
+ ['tools'],
+ [
+ _Tool_to_vertex(self._api_client, item, to_object)
+ for item in getv(from_object, ['tools'])
+ ],
+ )
+
+ return_value = {'setup': {'model': model}}
+ return_value['setup'].update(to_object)
+ return return_value
+
+ @contextlib.asynccontextmanager
+ async def connect(
+ self,
+ *,
+ model: str,
+ config: Optional[types.LiveConnectConfigOrDict] = None,
+ ) -> AsyncSession:
+ """Connect to the live server.
+
+ Usage:
+
+ .. code-block:: python
+
+ client = genai.Client(api_key=API_KEY)
+ config = {}
+ async with client.aio.live.connect(model='...', config=config) as session:
+ await session.send(input='Hello world!', end_of_turn=True)
+ async for message in session.receive():
+ print(message)
+ """
+ base_url = self._api_client._websocket_base_url()
+ if self._api_client.api_key:
+ api_key = self._api_client.api_key
+ version = self._api_client._http_options['api_version']
+ uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
+ headers = self._api_client._http_options['headers']
+
+ transformed_model = t.t_model(self._api_client, model)
+ request = json.dumps(
+ self._LiveSetup_to_mldev(model=transformed_model, config=config)
+ )
+ else:
+ # Get bearer token through Application Default Credentials.
+ creds, _ = google.auth.default(
+ scopes=['https://www.googleapis.com/auth/cloud-platform']
+ )
+
+ # creds.valid is False, and creds.token is None
+ # Need to refresh credentials to populate those
+ auth_req = google.auth.transport.requests.Request()
+ creds.refresh(auth_req)
+ bearer_token = creds.token
+ headers = {
+ 'Content-Type': 'application/json',
+ 'Authorization': 'Bearer {}'.format(bearer_token),
+ }
+ version = self._api_client._http_options['api_version']
+ uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
+ location = self._api_client.location
+ project = self._api_client.project
+ transformed_model = t.t_model(self._api_client, model)
+ if transformed_model.startswith('publishers/'):
+ transformed_model = (
+ f'projects/{project}/locations/{location}/' + transformed_model
+ )
+
+ request = json.dumps(
+ self._LiveSetup_to_vertex(model=transformed_model, config=config)
+ )
+
+ async with connect(uri, additional_headers=headers) as ws:
+ await ws.send(request)
+ logging.info(await ws.recv(decode=False))
+
+ yield AsyncSession(api_client=self._api_client, websocket=ws)