aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/google/genai/chats.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/google/genai/chats.py')
-rw-r--r--.venv/lib/python3.12/site-packages/google/genai/chats.py266
1 files changed, 266 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/google/genai/chats.py b/.venv/lib/python3.12/site-packages/google/genai/chats.py
new file mode 100644
index 00000000..707b7c8d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/google/genai/chats.py
@@ -0,0 +1,266 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional
+from typing import Union
+
+from . import _transformers as t
+from .models import AsyncModels, Models
+from .types import Content, ContentDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
+
+
+def _validate_response(response: GenerateContentResponse) -> bool:
+ if not response.candidates:
+ return False
+ if not response.candidates[0].content:
+ return False
+ if not response.candidates[0].content.parts:
+ return False
+ for part in response.candidates[0].content.parts:
+ if part == Part():
+ return False
+ if part.text is not None and part.text == "":
+ return False
+ return True
+
+
+class _BaseChat:
+ """Base chat session."""
+
+ def __init__(
+ self,
+ *,
+ modules: Union[Models, AsyncModels],
+ model: str,
+ config: GenerateContentConfigOrDict = None,
+ history: list[Content],
+ ):
+ self._modules = modules
+ self._model = model
+ self._config = config
+ self._curated_history = history
+
+
+class Chat(_BaseChat):
+ """Chat session."""
+
+ def send_message(
+ self, message: Union[list[PartUnionDict], PartUnionDict]
+ ) -> GenerateContentResponse:
+ """Sends the conversation history with the additional message and returns the model's response.
+
+ Args:
+ message: The message to send to the model.
+
+ Returns:
+ The model's response.
+
+ Usage:
+
+ .. code-block:: python
+
+ chat = client.chats.create(model='gemini-1.5-flash')
+ response = chat.send_message('tell me a story')
+ """
+
+ input_content = t.t_content(self._modules._api_client, message)
+ response = self._modules.generate_content(
+ model=self._model,
+ contents=self._curated_history + [input_content],
+ config=self._config,
+ )
+ if _validate_response(response):
+ if response.automatic_function_calling_history:
+ self._curated_history.extend(
+ response.automatic_function_calling_history
+ )
+ else:
+ self._curated_history.append(input_content)
+ self._curated_history.append(response.candidates[0].content)
+ return response
+
+ def send_message_stream(
+ self, message: Union[list[PartUnionDict], PartUnionDict]
+ ):
+ """Sends the conversation history with the additional message and yields the model's response in chunks.
+
+ Args:
+ message: The message to send to the model.
+
+ Yields:
+ The model's response in chunks.
+
+ Usage:
+
+ .. code-block:: python
+
+ chat = client.chats.create(model='gemini-1.5-flash')
+ for chunk in chat.send_message_stream('tell me a story'):
+ print(chunk.text)
+ """
+
+ input_content = t.t_content(self._modules._api_client, message)
+ output_contents = []
+ finish_reason = None
+ for chunk in self._modules.generate_content_stream(
+ model=self._model,
+ contents=self._curated_history + [input_content],
+ config=self._config,
+ ):
+ if _validate_response(chunk):
+ output_contents.append(chunk.candidates[0].content)
+ if chunk.candidates and chunk.candidates[0].finish_reason:
+ finish_reason = chunk.candidates[0].finish_reason
+ yield chunk
+ if output_contents and finish_reason:
+ self._curated_history.append(input_content)
+ self._curated_history.extend(output_contents)
+
+
+class Chats:
+ """A util class to create chat sessions."""
+
+ def __init__(self, modules: Models):
+ self._modules = modules
+
+ def create(
+ self,
+ *,
+ model: str,
+ config: GenerateContentConfigOrDict = None,
+ history: Optional[list[Content]] = None,
+ ) -> Chat:
+ """Creates a new chat session.
+
+ Args:
+ model: The model to use for the chat.
+ config: The configuration to use for the generate content request.
+ history: The history to use for the chat.
+
+ Returns:
+ A new chat session.
+ """
+ return Chat(
+ modules=self._modules,
+ model=model,
+ config=config,
+ history=history if history else [],
+ )
+
+
+class AsyncChat(_BaseChat):
+ """Async chat session."""
+
+ async def send_message(
+ self, message: Union[list[PartUnionDict], PartUnionDict]
+ ) -> GenerateContentResponse:
+ """Sends the conversation history with the additional message and returns model's response.
+
+ Args:
+ message: The message to send to the model.
+
+ Returns:
+ The model's response.
+
+ Usage:
+
+ .. code-block:: python
+
+ chat = client.aio.chats.create(model='gemini-1.5-flash')
+ response = await chat.send_message('tell me a story')
+ """
+
+ input_content = t.t_content(self._modules._api_client, message)
+ response = await self._modules.generate_content(
+ model=self._model,
+ contents=self._curated_history + [input_content],
+ config=self._config,
+ )
+ if _validate_response(response):
+ if response.automatic_function_calling_history:
+ self._curated_history.extend(
+ response.automatic_function_calling_history
+ )
+ else:
+ self._curated_history.append(input_content)
+ self._curated_history.append(response.candidates[0].content)
+ return response
+
+ async def send_message_stream(
+ self, message: Union[list[PartUnionDict], PartUnionDict]
+ ):
+ """Sends the conversation history with the additional message and yields the model's response in chunks.
+
+ Args:
+ message: The message to send to the model.
+
+ Yields:
+ The model's response in chunks.
+
+ Usage:
+
+ .. code-block:: python
+ chat = client.aio.chats.create(model='gemini-1.5-flash')
+ async for chunk in chat.send_message_stream('tell me a story'):
+ print(chunk.text)
+ """
+
+ input_content = t.t_content(self._modules._api_client, message)
+ output_contents = []
+ finish_reason = None
+ async for chunk in self._modules.generate_content_stream(
+ model=self._model,
+ contents=self._curated_history + [input_content],
+ config=self._config,
+ ):
+ if _validate_response(chunk):
+ output_contents.append(chunk.candidates[0].content)
+ if chunk.candidates and chunk.candidates[0].finish_reason:
+ finish_reason = chunk.candidates[0].finish_reason
+ yield chunk
+ if output_contents and finish_reason:
+ self._curated_history.append(input_content)
+ self._curated_history.extend(output_contents)
+
+
+class AsyncChats:
+ """A util class to create async chat sessions."""
+
+ def __init__(self, modules: AsyncModels):
+ self._modules = modules
+
+ def create(
+ self,
+ *,
+ model: str,
+ config: GenerateContentConfigOrDict = None,
+ history: Optional[list[Content]] = None,
+ ) -> AsyncChat:
+ """Creates a new chat session.
+
+ Args:
+ model: The model to use for the chat.
+ config: The configuration to use for the generate content request.
+ history: The history to use for the chat.
+
+ Returns:
+ A new chat session.
+ """
+ return AsyncChat(
+ modules=self._modules,
+ model=model,
+ config=config,
+ history=history if history else [],
+ )