# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Extra utils depending on types that are shared between sync and async modules.
"""
import inspect
import logging
from typing import Any, Callable, Dict, get_args, get_origin, Optional, types as typing_types, Union
import pydantic
from . import _common
from . import errors
from . import types
_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
def format_destination(
src: str,
config: Optional[types.CreateBatchJobConfigOrDict] = None,
) -> types.CreateBatchJobConfig:
"""Formats the destination uri based on the source uri."""
config = (
types._CreateBatchJobParameters(config=config).config
or types.CreateBatchJobConfig()
)
unique_name = None
if not config.display_name:
unique_name = _common.timestamped_unique_name()
config.display_name = f'genai_batch_job_{unique_name}'
if not config.dest:
if src.startswith('gs://') and src.endswith('.jsonl'):
# If source uri is "gs://bucket/path/to/src.jsonl", then the destination
# uri prefix will be "gs://bucket/path/to/src/dest".
config.dest = f'{src[:-6]}/dest'
elif src.startswith('bq://'):
# If source uri is "bq://project.dataset.src", then the destination
# uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID".
unique_name = unique_name or _common.timestamped_unique_name()
config.dest = f'{src}_dest_{unique_name}'
else:
raise ValueError(f'Unsupported source: {src}')
return config
def get_function_map(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> dict[str, object]:
"""Returns a function map from the config."""
config_model = (
types.GenerateContentConfig(**config)
if config and isinstance(config, dict)
else config
)
function_map = {}
if not config_model:
return function_map
if config_model.tools:
for tool in config_model.tools:
if callable(tool):
if inspect.iscoroutinefunction(tool):
raise errors.UnsupportedFunctionError(
f'Function {tool.__name__} is a coroutine function, which is not'
' supported for automatic function calling. Please manually invoke'
f' {tool.__name__} to get the function response.'
)
function_map[tool.__name__] = tool
return function_map
def convert_number_values_for_function_call_args(
args: Union[dict[str, object], list[object], object],
) -> Union[dict[str, object], list[object], object]:
"""Converts float values with no decimal to integers."""
if isinstance(args, float) and args.is_integer():
return int(args)
if isinstance(args, dict):
return {
key: convert_number_values_for_function_call_args(value)
for key, value in args.items()
}
if isinstance(args, list):
return [
convert_number_values_for_function_call_args(value) for value in args
]
return args
def _is_annotation_pydantic_model(annotation: Any) -> bool:
return inspect.isclass(annotation) and issubclass(
annotation, pydantic.BaseModel
)
def convert_if_exist_pydantic_model(
value: Any, annotation: Any, param_name: str, func_name: str
) -> Any:
if isinstance(value, dict) and _is_annotation_pydantic_model(annotation):
try:
return annotation(**value)
except pydantic.ValidationError as e:
raise errors.UnknownFunctionCallArgumentError(
f'Failed to parse parameter {param_name} for function'
f' {func_name} from function call part because function call argument'
f' value {value} is not compatible with parameter annotation'
f' {annotation}, due to error {e}'
)
if isinstance(value, list) and get_origin(annotation) == list:
item_type = get_args(annotation)[0]
return [
convert_if_exist_pydantic_model(item, item_type, param_name, func_name)
for item in value
]
if isinstance(value, dict) and get_origin(annotation) == dict:
_, value_type = get_args(annotation)
return {
k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name)
for k, v in value.items()
}
# example 1: typing.Union[int, float]
# example 2: int | float equivalent to typing.types.UnionType[int, float]
if get_origin(annotation) in (Union, typing_types.UnionType):
for arg in get_args(annotation):
if isinstance(value, arg) or (
isinstance(value, dict) and _is_annotation_pydantic_model(arg)
):
try:
return convert_if_exist_pydantic_model(
value, arg, param_name, func_name
)
# do not raise here because there could be multiple pydantic model types
# in the union type.
except pydantic.ValidationError:
continue
# if none of the union type is matched, raise error
raise errors.UnknownFunctionCallArgumentError(
f'Failed to parse parameter {param_name} for function'
f' {func_name} from function call part because function call argument'
f' value {value} cannot be converted to parameter annotation'
f' {annotation}.'
)
# the only exception for value and annotation type to be different is int and
# float. see convert_number_values_for_function_call_args function for context
if isinstance(value, int) and annotation is float:
return value
if not isinstance(value, annotation):
raise errors.UnknownFunctionCallArgumentError(
f'Failed to parse parameter {param_name} for function {func_name} from'
f' function call part because function call argument value {value} is'
f' not compatible with parameter annotation {annotation}.'
)
return value
def invoke_function_from_dict_args(
args: Dict[str, Any], function_to_invoke: Callable
) -> Any:
signature = inspect.signature(function_to_invoke)
func_name = function_to_invoke.__name__
converted_args = {}
for param_name, param in signature.parameters.items():
if param_name in args:
converted_args[param_name] = convert_if_exist_pydantic_model(
args[param_name],
param.annotation,
param_name,
func_name,
)
try:
return function_to_invoke(**converted_args)
except Exception as e:
raise errors.FunctionInvocationError(
f'Failed to invoke function {func_name} with converted arguments'
f' {converted_args} from model returned function call argument'
f' {args} because of error {e}'
)
def get_function_response_parts(
response: types.GenerateContentResponse,
function_map: dict[str, object],
) -> list[types.Part]:
"""Returns the function response parts from the response."""
func_response_parts = []
for part in response.candidates[0].content.parts:
if not part.function_call:
continue
func_name = part.function_call.name
func = function_map[func_name]
args = convert_number_values_for_function_call_args(part.function_call.args)
try:
response = {'result': invoke_function_from_dict_args(args, func)}
except Exception as e: # pylint: disable=broad-except
response = {'error': str(e)}
func_response = types.Part.from_function_response(func_name, response)
func_response_parts.append(func_response)
return func_response_parts
def should_disable_afc(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> bool:
"""Returns whether automatic function calling is enabled."""
config_model = (
types.GenerateContentConfig(**config)
if config and isinstance(config, dict)
else config
)
# If max_remote_calls is less or equal to 0, warn and disable AFC.
if (
config_model
and config_model.automatic_function_calling
and config_model.automatic_function_calling.maximum_remote_calls
is not None
and int(config_model.automatic_function_calling.maximum_remote_calls)
<= 0
):
logging.warning(
'max_remote_calls in automatic_function_calling_config'
f' {config_model.automatic_function_calling.maximum_remote_calls} is'
' less than or equal to 0. Disabling automatic function calling.'
' Please set max_remote_calls to a positive integer.'
)
return True
# Default to enable AFC if not specified.
if (
not config_model
or not config_model.automatic_function_calling
or config_model.automatic_function_calling.disable is None
):
return False
if (
config_model.automatic_function_calling.disable
and config_model.automatic_function_calling.maximum_remote_calls
is not None
and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
):
logging.warning(
'`automatic_function_calling.disable` is set to `True`. But'
' `automatic_function_calling.maximum_remote_calls` is set to be a'
' positive number'
f' {config_model.automatic_function_calling.maximum_remote_calls}.'
' Disabling automatic function calling. If you want to enable'
' automatic function calling, please set'
' `automatic_function_calling.disable` to `False` or leave it unset,'
' and set `automatic_function_calling.maximum_remote_calls` to a'
' positive integer or leave'
' `automatic_function_calling.maximum_remote_calls` unset.'
)
return config_model.automatic_function_calling.disable
def get_max_remote_calls_afc(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> int:
"""Returns the remaining remote calls for automatic function calling."""
if should_disable_afc(config):
raise ValueError(
'automatic function calling is not enabled, but SDK is trying to get'
' max remote calls.'
)
config_model = (
types.GenerateContentConfig(**config)
if config and isinstance(config, dict)
else config
)
if (
not config_model
or not config_model.automatic_function_calling
or config_model.automatic_function_calling.maximum_remote_calls is None
):
return _DEFAULT_MAX_REMOTE_CALLS_AFC
return int(config_model.automatic_function_calling.maximum_remote_calls)
def should_append_afc_history(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> bool:
config_model = (
types.GenerateContentConfig(**config)
if config and isinstance(config, dict)
else config
)
if (
not config_model
or not config_model.automatic_function_calling
):
return True
return not config_model.automatic_function_calling.ignore_call_history