diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/google/genai/_extra_utils.py | 310 |
1 files changed, 310 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py b/.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py new file mode 100644 index 00000000..db8d377b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py @@ -0,0 +1,310 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Extra utils depending on types that are shared between sync and async modules. +""" + +import inspect +import logging +from typing import Any, Callable, Dict, get_args, get_origin, Optional, types as typing_types, Union + +import pydantic + +from . import _common +from . import errors +from . import types + + +_DEFAULT_MAX_REMOTE_CALLS_AFC = 10 + + +def format_destination( + src: str, + config: Optional[types.CreateBatchJobConfigOrDict] = None, +) -> types.CreateBatchJobConfig: + """Formats the destination uri based on the source uri.""" + config = ( + types._CreateBatchJobParameters(config=config).config + or types.CreateBatchJobConfig() + ) + + unique_name = None + if not config.display_name: + unique_name = _common.timestamped_unique_name() + config.display_name = f'genai_batch_job_{unique_name}' + + if not config.dest: + if src.startswith('gs://') and src.endswith('.jsonl'): + # If source uri is "gs://bucket/path/to/src.jsonl", then the destination + # uri prefix will be "gs://bucket/path/to/src/dest". + config.dest = f'{src[:-6]}/dest' + elif src.startswith('bq://'): + # If source uri is "bq://project.dataset.src", then the destination + # uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID". + unique_name = unique_name or _common.timestamped_unique_name() + config.dest = f'{src}_dest_{unique_name}' + else: + raise ValueError(f'Unsupported source: {src}') + return config + + +def get_function_map( + config: Optional[types.GenerateContentConfigOrDict] = None, +) -> dict[str, object]: + """Returns a function map from the config.""" + config_model = ( + types.GenerateContentConfig(**config) + if config and isinstance(config, dict) + else config + ) + function_map = {} + if not config_model: + return function_map + if config_model.tools: + for tool in config_model.tools: + if callable(tool): + if inspect.iscoroutinefunction(tool): + raise errors.UnsupportedFunctionError( + f'Function {tool.__name__} is a coroutine function, which is not' + ' supported for automatic function calling. Please manually invoke' + f' {tool.__name__} to get the function response.' + ) + function_map[tool.__name__] = tool + return function_map + + +def convert_number_values_for_function_call_args( + args: Union[dict[str, object], list[object], object], +) -> Union[dict[str, object], list[object], object]: + """Converts float values with no decimal to integers.""" + if isinstance(args, float) and args.is_integer(): + return int(args) + if isinstance(args, dict): + return { + key: convert_number_values_for_function_call_args(value) + for key, value in args.items() + } + if isinstance(args, list): + return [ + convert_number_values_for_function_call_args(value) for value in args + ] + return args + + +def _is_annotation_pydantic_model(annotation: Any) -> bool: + return inspect.isclass(annotation) and issubclass( + annotation, pydantic.BaseModel + ) + + +def convert_if_exist_pydantic_model( + value: Any, annotation: Any, param_name: str, func_name: str +) -> Any: + if isinstance(value, dict) and _is_annotation_pydantic_model(annotation): + try: + return annotation(**value) + except pydantic.ValidationError as e: + raise errors.UnknownFunctionCallArgumentError( + f'Failed to parse parameter {param_name} for function' + f' {func_name} from function call part because function call argument' + f' value {value} is not compatible with parameter annotation' + f' {annotation}, due to error {e}' + ) + if isinstance(value, list) and get_origin(annotation) == list: + item_type = get_args(annotation)[0] + return [ + convert_if_exist_pydantic_model(item, item_type, param_name, func_name) + for item in value + ] + if isinstance(value, dict) and get_origin(annotation) == dict: + _, value_type = get_args(annotation) + return { + k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name) + for k, v in value.items() + } + # example 1: typing.Union[int, float] + # example 2: int | float equivalent to typing.types.UnionType[int, float] + if get_origin(annotation) in (Union, typing_types.UnionType): + for arg in get_args(annotation): + if isinstance(value, arg) or ( + isinstance(value, dict) and _is_annotation_pydantic_model(arg) + ): + try: + return convert_if_exist_pydantic_model( + value, arg, param_name, func_name + ) + # do not raise here because there could be multiple pydantic model types + # in the union type. + except pydantic.ValidationError: + continue + # if none of the union type is matched, raise error + raise errors.UnknownFunctionCallArgumentError( + f'Failed to parse parameter {param_name} for function' + f' {func_name} from function call part because function call argument' + f' value {value} cannot be converted to parameter annotation' + f' {annotation}.' + ) + # the only exception for value and annotation type to be different is int and + # float. see convert_number_values_for_function_call_args function for context + if isinstance(value, int) and annotation is float: + return value + if not isinstance(value, annotation): + raise errors.UnknownFunctionCallArgumentError( + f'Failed to parse parameter {param_name} for function {func_name} from' + f' function call part because function call argument value {value} is' + f' not compatible with parameter annotation {annotation}.' + ) + return value + + +def invoke_function_from_dict_args( + args: Dict[str, Any], function_to_invoke: Callable +) -> Any: + signature = inspect.signature(function_to_invoke) + func_name = function_to_invoke.__name__ + converted_args = {} + for param_name, param in signature.parameters.items(): + if param_name in args: + converted_args[param_name] = convert_if_exist_pydantic_model( + args[param_name], + param.annotation, + param_name, + func_name, + ) + try: + return function_to_invoke(**converted_args) + except Exception as e: + raise errors.FunctionInvocationError( + f'Failed to invoke function {func_name} with converted arguments' + f' {converted_args} from model returned function call argument' + f' {args} because of error {e}' + ) + + +def get_function_response_parts( + response: types.GenerateContentResponse, + function_map: dict[str, object], +) -> list[types.Part]: + """Returns the function response parts from the response.""" + func_response_parts = [] + for part in response.candidates[0].content.parts: + if not part.function_call: + continue + func_name = part.function_call.name + func = function_map[func_name] + args = convert_number_values_for_function_call_args(part.function_call.args) + try: + response = {'result': invoke_function_from_dict_args(args, func)} + except Exception as e: # pylint: disable=broad-except + response = {'error': str(e)} + func_response = types.Part.from_function_response(func_name, response) + + func_response_parts.append(func_response) + return func_response_parts + + +def should_disable_afc( + config: Optional[types.GenerateContentConfigOrDict] = None, +) -> bool: + """Returns whether automatic function calling is enabled.""" + config_model = ( + types.GenerateContentConfig(**config) + if config and isinstance(config, dict) + else config + ) + + # If max_remote_calls is less or equal to 0, warn and disable AFC. + if ( + config_model + and config_model.automatic_function_calling + and config_model.automatic_function_calling.maximum_remote_calls + is not None + and int(config_model.automatic_function_calling.maximum_remote_calls) + <= 0 + ): + logging.warning( + 'max_remote_calls in automatic_function_calling_config' + f' {config_model.automatic_function_calling.maximum_remote_calls} is' + ' less than or equal to 0. Disabling automatic function calling.' + ' Please set max_remote_calls to a positive integer.' + ) + return True + + # Default to enable AFC if not specified. + if ( + not config_model + or not config_model.automatic_function_calling + or config_model.automatic_function_calling.disable is None + ): + return False + + if ( + config_model.automatic_function_calling.disable + and config_model.automatic_function_calling.maximum_remote_calls + is not None + and int(config_model.automatic_function_calling.maximum_remote_calls) > 0 + ): + logging.warning( + '`automatic_function_calling.disable` is set to `True`. But' + ' `automatic_function_calling.maximum_remote_calls` is set to be a' + ' positive number' + f' {config_model.automatic_function_calling.maximum_remote_calls}.' + ' Disabling automatic function calling. If you want to enable' + ' automatic function calling, please set' + ' `automatic_function_calling.disable` to `False` or leave it unset,' + ' and set `automatic_function_calling.maximum_remote_calls` to a' + ' positive integer or leave' + ' `automatic_function_calling.maximum_remote_calls` unset.' + ) + + return config_model.automatic_function_calling.disable + + +def get_max_remote_calls_afc( + config: Optional[types.GenerateContentConfigOrDict] = None, +) -> int: + """Returns the remaining remote calls for automatic function calling.""" + if should_disable_afc(config): + raise ValueError( + 'automatic function calling is not enabled, but SDK is trying to get' + ' max remote calls.' + ) + config_model = ( + types.GenerateContentConfig(**config) + if config and isinstance(config, dict) + else config + ) + if ( + not config_model + or not config_model.automatic_function_calling + or config_model.automatic_function_calling.maximum_remote_calls is None + ): + return _DEFAULT_MAX_REMOTE_CALLS_AFC + return int(config_model.automatic_function_calling.maximum_remote_calls) + +def should_append_afc_history( + config: Optional[types.GenerateContentConfigOrDict] = None, +) -> bool: + config_model = ( + types.GenerateContentConfig(**config) + if config and isinstance(config, dict) + else config + ) + if ( + not config_model + or not config_model.automatic_function_calling + ): + return True + return not config_model.automatic_function_calling.ignore_call_history |