# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Extra utils depending on types that are shared between sync and async modules. """ import inspect import logging from typing import Any, Callable, Dict, get_args, get_origin, Optional, types as typing_types, Union import pydantic from . import _common from . import errors from . import types _DEFAULT_MAX_REMOTE_CALLS_AFC = 10 def format_destination( src: str, config: Optional[types.CreateBatchJobConfigOrDict] = None, ) -> types.CreateBatchJobConfig: """Formats the destination uri based on the source uri.""" config = ( types._CreateBatchJobParameters(config=config).config or types.CreateBatchJobConfig() ) unique_name = None if not config.display_name: unique_name = _common.timestamped_unique_name() config.display_name = f'genai_batch_job_{unique_name}' if not config.dest: if src.startswith('gs://') and src.endswith('.jsonl'): # If source uri is "gs://bucket/path/to/src.jsonl", then the destination # uri prefix will be "gs://bucket/path/to/src/dest". config.dest = f'{src[:-6]}/dest' elif src.startswith('bq://'): # If source uri is "bq://project.dataset.src", then the destination # uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID". unique_name = unique_name or _common.timestamped_unique_name() config.dest = f'{src}_dest_{unique_name}' else: raise ValueError(f'Unsupported source: {src}') return config def get_function_map( config: Optional[types.GenerateContentConfigOrDict] = None, ) -> dict[str, object]: """Returns a function map from the config.""" config_model = ( types.GenerateContentConfig(**config) if config and isinstance(config, dict) else config ) function_map = {} if not config_model: return function_map if config_model.tools: for tool in config_model.tools: if callable(tool): if inspect.iscoroutinefunction(tool): raise errors.UnsupportedFunctionError( f'Function {tool.__name__} is a coroutine function, which is not' ' supported for automatic function calling. Please manually invoke' f' {tool.__name__} to get the function response.' ) function_map[tool.__name__] = tool return function_map def convert_number_values_for_function_call_args( args: Union[dict[str, object], list[object], object], ) -> Union[dict[str, object], list[object], object]: """Converts float values with no decimal to integers.""" if isinstance(args, float) and args.is_integer(): return int(args) if isinstance(args, dict): return { key: convert_number_values_for_function_call_args(value) for key, value in args.items() } if isinstance(args, list): return [ convert_number_values_for_function_call_args(value) for value in args ] return args def _is_annotation_pydantic_model(annotation: Any) -> bool: return inspect.isclass(annotation) and issubclass( annotation, pydantic.BaseModel ) def convert_if_exist_pydantic_model( value: Any, annotation: Any, param_name: str, func_name: str ) -> Any: if isinstance(value, dict) and _is_annotation_pydantic_model(annotation): try: return annotation(**value) except pydantic.ValidationError as e: raise errors.UnknownFunctionCallArgumentError( f'Failed to parse parameter {param_name} for function' f' {func_name} from function call part because function call argument' f' value {value} is not compatible with parameter annotation' f' {annotation}, due to error {e}' ) if isinstance(value, list) and get_origin(annotation) == list: item_type = get_args(annotation)[0] return [ convert_if_exist_pydantic_model(item, item_type, param_name, func_name) for item in value ] if isinstance(value, dict) and get_origin(annotation) == dict: _, value_type = get_args(annotation) return { k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name) for k, v in value.items() } # example 1: typing.Union[int, float] # example 2: int | float equivalent to typing.types.UnionType[int, float] if get_origin(annotation) in (Union, typing_types.UnionType): for arg in get_args(annotation): if isinstance(value, arg) or ( isinstance(value, dict) and _is_annotation_pydantic_model(arg) ): try: return convert_if_exist_pydantic_model( value, arg, param_name, func_name ) # do not raise here because there could be multiple pydantic model types # in the union type. except pydantic.ValidationError: continue # if none of the union type is matched, raise error raise errors.UnknownFunctionCallArgumentError( f'Failed to parse parameter {param_name} for function' f' {func_name} from function call part because function call argument' f' value {value} cannot be converted to parameter annotation' f' {annotation}.' ) # the only exception for value and annotation type to be different is int and # float. see convert_number_values_for_function_call_args function for context if isinstance(value, int) and annotation is float: return value if not isinstance(value, annotation): raise errors.UnknownFunctionCallArgumentError( f'Failed to parse parameter {param_name} for function {func_name} from' f' function call part because function call argument value {value} is' f' not compatible with parameter annotation {annotation}.' ) return value def invoke_function_from_dict_args( args: Dict[str, Any], function_to_invoke: Callable ) -> Any: signature = inspect.signature(function_to_invoke) func_name = function_to_invoke.__name__ converted_args = {} for param_name, param in signature.parameters.items(): if param_name in args: converted_args[param_name] = convert_if_exist_pydantic_model( args[param_name], param.annotation, param_name, func_name, ) try: return function_to_invoke(**converted_args) except Exception as e: raise errors.FunctionInvocationError( f'Failed to invoke function {func_name} with converted arguments' f' {converted_args} from model returned function call argument' f' {args} because of error {e}' ) def get_function_response_parts( response: types.GenerateContentResponse, function_map: dict[str, object], ) -> list[types.Part]: """Returns the function response parts from the response.""" func_response_parts = [] for part in response.candidates[0].content.parts: if not part.function_call: continue func_name = part.function_call.name func = function_map[func_name] args = convert_number_values_for_function_call_args(part.function_call.args) try: response = {'result': invoke_function_from_dict_args(args, func)} except Exception as e: # pylint: disable=broad-except response = {'error': str(e)} func_response = types.Part.from_function_response(func_name, response) func_response_parts.append(func_response) return func_response_parts def should_disable_afc( config: Optional[types.GenerateContentConfigOrDict] = None, ) -> bool: """Returns whether automatic function calling is enabled.""" config_model = ( types.GenerateContentConfig(**config) if config and isinstance(config, dict) else config ) # If max_remote_calls is less or equal to 0, warn and disable AFC. if ( config_model and config_model.automatic_function_calling and config_model.automatic_function_calling.maximum_remote_calls is not None and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0 ): logging.warning( 'max_remote_calls in automatic_function_calling_config' f' {config_model.automatic_function_calling.maximum_remote_calls} is' ' less than or equal to 0. Disabling automatic function calling.' ' Please set max_remote_calls to a positive integer.' ) return True # Default to enable AFC if not specified. if ( not config_model or not config_model.automatic_function_calling or config_model.automatic_function_calling.disable is None ): return False if ( config_model.automatic_function_calling.disable and config_model.automatic_function_calling.maximum_remote_calls is not None and int(config_model.automatic_function_calling.maximum_remote_calls) > 0 ): logging.warning( '`automatic_function_calling.disable` is set to `True`. But' ' `automatic_function_calling.maximum_remote_calls` is set to be a' ' positive number' f' {config_model.automatic_function_calling.maximum_remote_calls}.' ' Disabling automatic function calling. If you want to enable' ' automatic function calling, please set' ' `automatic_function_calling.disable` to `False` or leave it unset,' ' and set `automatic_function_calling.maximum_remote_calls` to a' ' positive integer or leave' ' `automatic_function_calling.maximum_remote_calls` unset.' ) return config_model.automatic_function_calling.disable def get_max_remote_calls_afc( config: Optional[types.GenerateContentConfigOrDict] = None, ) -> int: """Returns the remaining remote calls for automatic function calling.""" if should_disable_afc(config): raise ValueError( 'automatic function calling is not enabled, but SDK is trying to get' ' max remote calls.' ) config_model = ( types.GenerateContentConfig(**config) if config and isinstance(config, dict) else config ) if ( not config_model or not config_model.automatic_function_calling or config_model.automatic_function_calling.maximum_remote_calls is None ): return _DEFAULT_MAX_REMOTE_CALLS_AFC return int(config_model.automatic_function_calling.maximum_remote_calls) def should_append_afc_history( config: Optional[types.GenerateContentConfigOrDict] = None, ) -> bool: config_model = ( types.GenerateContentConfig(**config) if config and isinstance(config, dict) else config ) if ( not config_model or not config_model.automatic_function_calling ): return True return not config_model.automatic_function_calling.ignore_call_history