aboutsummaryrefslogtreecommitdiff
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Optional
from typing import Union

from . import _transformers as t
from .models import AsyncModels, Models
from .types import Content, ContentDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict


def _validate_response(response: GenerateContentResponse) -> bool:
  if not response.candidates:
    return False
  if not response.candidates[0].content:
    return False
  if not response.candidates[0].content.parts:
    return False
  for part in response.candidates[0].content.parts:
    if part == Part():
      return False
    if part.text is not None and part.text == "":
      return False
  return True


class _BaseChat:
  """Base chat session."""

  def __init__(
      self,
      *,
      modules: Union[Models, AsyncModels],
      model: str,
      config: GenerateContentConfigOrDict = None,
      history: list[Content],
  ):
    self._modules = modules
    self._model = model
    self._config = config
    self._curated_history = history


class Chat(_BaseChat):
  """Chat session."""

  def send_message(
      self, message: Union[list[PartUnionDict], PartUnionDict]
  ) -> GenerateContentResponse:
    """Sends the conversation history with the additional message and returns the model's response.

    Args:
      message: The message to send to the model.

    Returns:
      The model's response.

    Usage:

    .. code-block:: python

      chat = client.chats.create(model='gemini-1.5-flash')
      response = chat.send_message('tell me a story')
    """

    input_content = t.t_content(self._modules._api_client, message)
    response = self._modules.generate_content(
        model=self._model,
        contents=self._curated_history + [input_content],
        config=self._config,
    )
    if _validate_response(response):
      if response.automatic_function_calling_history:
        self._curated_history.extend(
            response.automatic_function_calling_history
        )
      else:
        self._curated_history.append(input_content)
      self._curated_history.append(response.candidates[0].content)
    return response

  def send_message_stream(
      self, message: Union[list[PartUnionDict], PartUnionDict]
  ):
    """Sends the conversation history with the additional message and yields the model's response in chunks.

    Args:
      message: The message to send to the model.

    Yields:
      The model's response in chunks.

    Usage:

    .. code-block:: python

      chat = client.chats.create(model='gemini-1.5-flash')
      for chunk in chat.send_message_stream('tell me a story'):
        print(chunk.text)
    """

    input_content = t.t_content(self._modules._api_client, message)
    output_contents = []
    finish_reason = None
    for chunk in self._modules.generate_content_stream(
        model=self._model,
        contents=self._curated_history + [input_content],
        config=self._config,
    ):
      if _validate_response(chunk):
        output_contents.append(chunk.candidates[0].content)
      if chunk.candidates and chunk.candidates[0].finish_reason:
        finish_reason = chunk.candidates[0].finish_reason
      yield chunk
    if output_contents and finish_reason:
      self._curated_history.append(input_content)
      self._curated_history.extend(output_contents)


class Chats:
  """A util class to create chat sessions."""

  def __init__(self, modules: Models):
    self._modules = modules

  def create(
      self,
      *,
      model: str,
      config: GenerateContentConfigOrDict = None,
      history: Optional[list[Content]] = None,
  ) -> Chat:
    """Creates a new chat session.

    Args:
      model: The model to use for the chat.
      config: The configuration to use for the generate content request.
      history: The history to use for the chat.

    Returns:
      A new chat session.
    """
    return Chat(
        modules=self._modules,
        model=model,
        config=config,
        history=history if history else [],
    )


class AsyncChat(_BaseChat):
  """Async chat session."""

  async def send_message(
      self, message: Union[list[PartUnionDict], PartUnionDict]
  ) -> GenerateContentResponse:
    """Sends the conversation history with the additional message and returns model's response.

    Args:
      message: The message to send to the model.

    Returns:
      The model's response.

    Usage:

    .. code-block:: python

      chat = client.aio.chats.create(model='gemini-1.5-flash')
      response = await chat.send_message('tell me a story')
    """

    input_content = t.t_content(self._modules._api_client, message)
    response = await self._modules.generate_content(
        model=self._model,
        contents=self._curated_history + [input_content],
        config=self._config,
    )
    if _validate_response(response):
      if response.automatic_function_calling_history:
        self._curated_history.extend(
            response.automatic_function_calling_history
        )
      else:
        self._curated_history.append(input_content)
      self._curated_history.append(response.candidates[0].content)
    return response

  async def send_message_stream(
      self, message: Union[list[PartUnionDict], PartUnionDict]
  ):
    """Sends the conversation history with the additional message and yields the model's response in chunks.

    Args:
      message: The message to send to the model.

    Yields:
      The model's response in chunks.

    Usage:

    .. code-block:: python
      chat = client.aio.chats.create(model='gemini-1.5-flash')
      async for chunk in chat.send_message_stream('tell me a story'):
        print(chunk.text)
    """

    input_content = t.t_content(self._modules._api_client, message)
    output_contents = []
    finish_reason = None
    async for chunk in self._modules.generate_content_stream(
        model=self._model,
        contents=self._curated_history + [input_content],
        config=self._config,
    ):
      if _validate_response(chunk):
        output_contents.append(chunk.candidates[0].content)
      if chunk.candidates and chunk.candidates[0].finish_reason:
        finish_reason = chunk.candidates[0].finish_reason
      yield chunk
    if output_contents and finish_reason:
      self._curated_history.append(input_content)
      self._curated_history.extend(output_contents)


class AsyncChats:
  """A util class to create async chat sessions."""

  def __init__(self, modules: AsyncModels):
    self._modules = modules

  def create(
      self,
      *,
      model: str,
      config: GenerateContentConfigOrDict = None,
      history: Optional[list[Content]] = None,
  ) -> AsyncChat:
    """Creates a new chat session.

    Args:
      model: The model to use for the chat.
      config: The configuration to use for the generate content request.
      history: The history to use for the chat.

    Returns:
      A new chat session.
    """
    return AsyncChat(
        modules=self._modules,
        model=model,
        config=config,
        history=history if history else [],
    )