aboutsummaryrefslogtreecommitdiff
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Pagers for the GenAI List APIs."""

# pylint: disable=protected-access

import copy
from typing import Any, AsyncIterator, Awaitable, Callable, Generic, Iterator, Literal, TypeVar

T = TypeVar('T')

PagedItem = Literal[
    'batch_jobs', 'models', 'tuning_jobs', 'files', 'cached_contents'
]


class _BasePager(Generic[T]):
  """Base pager class for iterating through paginated results."""

  def __init__(
      self,
      name: PagedItem,
      request: Callable[Any, Any],
      response: Any,
      config: Any,
  ):
    self._name = name
    self._request = request

    self._page = getattr(response, self._name) or []
    self._idx = 0

    if not config:
      request_config = {}
    elif isinstance(config, dict):
      request_config = copy.deepcopy(config)
    else:
      request_config = dict(config)
    request_config['page_token'] = getattr(response, 'next_page_token')
    self._config = request_config

    self._page_size = request_config.get('page_size', len(self._page))

  @property
  def page(self) -> list[T]:
    """Returns the current page, which is a list of items.

    The returned list of items is a subset of the entire list.

    Usage:

    .. code-block:: python

      batch_jobs_pager = client.batches.list(config={'page_size': 5})
      print(f"first page: {batch_jobs_pager.page}")
      # first page: [BatchJob(name='projects/./locations/./batchPredictionJobs/1
    """

    return self._page

  @property
  def name(self) -> str:
    """Returns the type of paged item (for example, ``batch_jobs``).

    Usage:

    .. code-block:: python

      batch_jobs_pager = client.batches.list(config={'page_size': 5})
      print(f"name: {batch_jobs_pager.name}")
      # name: batch_jobs
    """

    return self._name

  @property
  def page_size(self) -> int:
    """Returns the length of the page fetched each time by this pager.

    The number of items in the page is less than or equal to the page length.

    Usage:

    .. code-block:: python

      batch_jobs_pager = client.batches.list(config={'page_size': 5})
      print(f"page_size: {batch_jobs_pager.page_size}")
      # page_size: 5
    """

    return self._page_size

  @property
  def config(self) -> dict[str, Any]:
    """Returns the configuration when making the API request for the next page.

    A configuration is a set of optional parameters and arguments that can be
    used to customize the API request. For example, the ``page_token`` parameter
    contains the token to request the next page.

    Usage:

    .. code-block:: python

      batch_jobs_pager = client.batches.list(config={'page_size': 5})
      print(f"config: {batch_jobs_pager.config}")
      # config: {'page_size': 5, 'page_token': 'AMEw9yO5jnsGnZJLHSKDFHJJu'}
    """

    return self._config

  def __len__(self) -> int:
    """Returns the total number of items in the current page."""
    return len(self.page)

  def __getitem__(self, index: int) -> T:
    """Returns the item at the given index."""
    return self.page[index]

  def _init_next_page(self, response: Any) -> None:
    """Initializes the next page from the response.

    This is an internal method that should be called by subclasses after
    fetching the next page.

    Args:
      response: The response object from the API request.
    """
    self.__init__(self.name, self._request, response, self.config)


class Pager(_BasePager[T]):
  """Pager class for iterating through paginated results."""

  def __next__(self) -> T:
    """Returns the next item."""
    if self._idx >= len(self):
      try:
        self.next_page()
      except IndexError:
        raise StopIteration

    item = self.page[self._idx]
    self._idx += 1
    return item

  def __iter__(self) -> Iterator[T]:
    """Returns an iterator over the items."""
    self._idx = 0
    return self

  def next_page(self) -> list[T]:
    """Fetches the next page of items. This makes a new API request.

    Usage:

    .. code-block:: python

      batch_jobs_pager = client.batches.list(config={'page_size': 5})
      print(f"current page: {batch_jobs_pager.page}")
      batch_jobs_pager.next_page()
      print(f"next page: {batch_jobs_pager.page}")
      # current page: [BatchJob(name='projects/.../batchPredictionJobs/1
      # next page: [BatchJob(name='projects/.../batchPredictionJobs/6
    """

    if not self.config.get('page_token'):
      raise IndexError('No more pages to fetch.')

    response = self._request(config=self.config)
    self._init_next_page(response)
    return self.page


class AsyncPager(_BasePager[T]):
  """AsyncPager class for iterating through paginated results."""

  def __init__(
      self,
      name: PagedItem,
      request: Callable[Any, Awaitable[Any]],
      response: Any,
      config: Any,
  ):
    super().__init__(name, request, response, config)

  def __aiter__(self) -> AsyncIterator[T]:
    """Returns an async iterator over the items."""
    self._idx = 0
    return self

  async def __anext__(self) -> Awaitable[T]:
    """Returns the next item asynchronously."""
    if self._idx >= len(self):
      try:
        await self.next_page()
      except IndexError:
        raise StopAsyncIteration

    item = self.page[self._idx]
    self._idx += 1
    return item

  async def next_page(self) -> list[T]:
    """Fetches the next page of items asynchronously.

    This makes a new API request.

    Returns:
      The next page of items.

    Raises:
      IndexError: No more pages to fetch.

    Usage:

    .. code-block:: python

      batch_jobs_pager = await client.aio.batches.list(config={'page_size': 5})
      print(f"current page: {batch_jobs_pager.page}")
      await batch_jobs_pager.next_page()
      print(f"next page: {batch_jobs_pager.page}")
      # current page: [BatchJob(name='projects/.../batchPredictionJobs/1
      # next page: [BatchJob(name='projects/.../batchPredictionJobs/6
    """

    if not self.config.get('page_token'):
      raise IndexError('No more pages to fetch.')

    response = await self._request(config=self.config)
    self._init_next_page(response)
    return self.page