about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/anthropic/_utils/_transform.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/anthropic/_utils/_transform.py')
-rw-r--r--.venv/lib/python3.12/site-packages/anthropic/_utils/_transform.py402
1 files changed, 402 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/anthropic/_utils/_transform.py b/.venv/lib/python3.12/site-packages/anthropic/_utils/_transform.py
new file mode 100644
index 00000000..18afd9d8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/anthropic/_utils/_transform.py
@@ -0,0 +1,402 @@
+from __future__ import annotations
+
+import io
+import base64
+import pathlib
+from typing import Any, Mapping, TypeVar, cast
+from datetime import date, datetime
+from typing_extensions import Literal, get_args, override, get_type_hints
+
+import anyio
+import pydantic
+
+from ._utils import (
+    is_list,
+    is_mapping,
+    is_iterable,
+)
+from .._files import is_base64_file_input
+from ._typing import (
+    is_list_type,
+    is_union_type,
+    extract_type_arg,
+    is_iterable_type,
+    is_required_type,
+    is_annotated_type,
+    strip_annotated_type,
+)
+from .._compat import get_origin, model_dump, is_typeddict
+
+_T = TypeVar("_T")
+
+
+# TODO: support for drilling globals() and locals()
+# TODO: ensure works correctly with forward references in all cases
+
+
+PropertyFormat = Literal["iso8601", "base64", "custom"]
+
+
+class PropertyInfo:
+    """Metadata class to be used in Annotated types to provide information about a given type.
+
+    For example:
+
+    class MyParams(TypedDict):
+        account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
+
+    This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
+    """
+
+    alias: str | None
+    format: PropertyFormat | None
+    format_template: str | None
+    discriminator: str | None
+
+    def __init__(
+        self,
+        *,
+        alias: str | None = None,
+        format: PropertyFormat | None = None,
+        format_template: str | None = None,
+        discriminator: str | None = None,
+    ) -> None:
+        self.alias = alias
+        self.format = format
+        self.format_template = format_template
+        self.discriminator = discriminator
+
+    @override
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
+
+
+def maybe_transform(
+    data: object,
+    expected_type: object,
+) -> Any | None:
+    """Wrapper over `transform()` that allows `None` to be passed.
+
+    See `transform()` for more details.
+    """
+    if data is None:
+        return None
+    return transform(data, expected_type)
+
+
+# Wrapper over _transform_recursive providing fake types
+def transform(
+    data: _T,
+    expected_type: object,
+) -> _T:
+    """Transform dictionaries based off of type information from the given type, for example:
+
+    ```py
+    class Params(TypedDict, total=False):
+        card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
+
+
+    transformed = transform({"card_id": "<my card ID>"}, Params)
+    # {'cardID': '<my card ID>'}
+    ```
+
+    Any keys / data that does not have type information given will be included as is.
+
+    It should be noted that the transformations that this function does are not represented in the type system.
+    """
+    transformed = _transform_recursive(data, annotation=cast(type, expected_type))
+    return cast(_T, transformed)
+
+
+def _get_annotated_type(type_: type) -> type | None:
+    """If the given type is an `Annotated` type then it is returned, if not `None` is returned.
+
+    This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
+    """
+    if is_required_type(type_):
+        # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
+        type_ = get_args(type_)[0]
+
+    if is_annotated_type(type_):
+        return type_
+
+    return None
+
+
+def _maybe_transform_key(key: str, type_: type) -> str:
+    """Transform the given `data` based on the annotations provided in `type_`.
+
+    Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
+    """
+    annotated_type = _get_annotated_type(type_)
+    if annotated_type is None:
+        # no `Annotated` definition for this type, no transformation needed
+        return key
+
+    # ignore the first argument as it is the actual type
+    annotations = get_args(annotated_type)[1:]
+    for annotation in annotations:
+        if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
+            return annotation.alias
+
+    return key
+
+
+def _transform_recursive(
+    data: object,
+    *,
+    annotation: type,
+    inner_type: type | None = None,
+) -> object:
+    """Transform the given data against the expected type.
+
+    Args:
+        annotation: The direct type annotation given to the particular piece of data.
+            This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
+
+        inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
+            is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
+            the list can be transformed using the metadata from the container type.
+
+            Defaults to the same value as the `annotation` argument.
+    """
+    if inner_type is None:
+        inner_type = annotation
+
+    stripped_type = strip_annotated_type(inner_type)
+    origin = get_origin(stripped_type) or stripped_type
+    if is_typeddict(stripped_type) and is_mapping(data):
+        return _transform_typeddict(data, stripped_type)
+
+    if origin == dict and is_mapping(data):
+        items_type = get_args(stripped_type)[1]
+        return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
+
+    if (
+        # List[T]
+        (is_list_type(stripped_type) and is_list(data))
+        # Iterable[T]
+        or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
+    ):
+        # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
+        # intended as an iterable, so we don't transform it.
+        if isinstance(data, dict):
+            return cast(object, data)
+
+        inner_type = extract_type_arg(stripped_type, 0)
+        return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
+
+    if is_union_type(stripped_type):
+        # For union types we run the transformation against all subtypes to ensure that everything is transformed.
+        #
+        # TODO: there may be edge cases where the same normalized field name will transform to two different names
+        # in different subtypes.
+        for subtype in get_args(stripped_type):
+            data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
+        return data
+
+    if isinstance(data, pydantic.BaseModel):
+        return model_dump(data, exclude_unset=True, mode="json")
+
+    annotated_type = _get_annotated_type(annotation)
+    if annotated_type is None:
+        return data
+
+    # ignore the first argument as it is the actual type
+    annotations = get_args(annotated_type)[1:]
+    for annotation in annotations:
+        if isinstance(annotation, PropertyInfo) and annotation.format is not None:
+            return _format_data(data, annotation.format, annotation.format_template)
+
+    return data
+
+
+def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
+    if isinstance(data, (date, datetime)):
+        if format_ == "iso8601":
+            return data.isoformat()
+
+        if format_ == "custom" and format_template is not None:
+            return data.strftime(format_template)
+
+    if format_ == "base64" and is_base64_file_input(data):
+        binary: str | bytes | None = None
+
+        if isinstance(data, pathlib.Path):
+            binary = data.read_bytes()
+        elif isinstance(data, io.IOBase):
+            binary = data.read()
+
+            if isinstance(binary, str):  # type: ignore[unreachable]
+                binary = binary.encode()
+
+        if not isinstance(binary, bytes):
+            raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
+
+        return base64.b64encode(binary).decode("ascii")
+
+    return data
+
+
+def _transform_typeddict(
+    data: Mapping[str, object],
+    expected_type: type,
+) -> Mapping[str, object]:
+    result: dict[str, object] = {}
+    annotations = get_type_hints(expected_type, include_extras=True)
+    for key, value in data.items():
+        type_ = annotations.get(key)
+        if type_ is None:
+            # we do not have a type annotation for this field, leave it as is
+            result[key] = value
+        else:
+            result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
+    return result
+
+
+async def async_maybe_transform(
+    data: object,
+    expected_type: object,
+) -> Any | None:
+    """Wrapper over `async_transform()` that allows `None` to be passed.
+
+    See `async_transform()` for more details.
+    """
+    if data is None:
+        return None
+    return await async_transform(data, expected_type)
+
+
+async def async_transform(
+    data: _T,
+    expected_type: object,
+) -> _T:
+    """Transform dictionaries based off of type information from the given type, for example:
+
+    ```py
+    class Params(TypedDict, total=False):
+        card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
+
+
+    transformed = transform({"card_id": "<my card ID>"}, Params)
+    # {'cardID': '<my card ID>'}
+    ```
+
+    Any keys / data that does not have type information given will be included as is.
+
+    It should be noted that the transformations that this function does are not represented in the type system.
+    """
+    transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
+    return cast(_T, transformed)
+
+
+async def _async_transform_recursive(
+    data: object,
+    *,
+    annotation: type,
+    inner_type: type | None = None,
+) -> object:
+    """Transform the given data against the expected type.
+
+    Args:
+        annotation: The direct type annotation given to the particular piece of data.
+            This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
+
+        inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
+            is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
+            the list can be transformed using the metadata from the container type.
+
+            Defaults to the same value as the `annotation` argument.
+    """
+    if inner_type is None:
+        inner_type = annotation
+
+    stripped_type = strip_annotated_type(inner_type)
+    origin = get_origin(stripped_type) or stripped_type
+    if is_typeddict(stripped_type) and is_mapping(data):
+        return await _async_transform_typeddict(data, stripped_type)
+
+    if origin == dict and is_mapping(data):
+        items_type = get_args(stripped_type)[1]
+        return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
+
+    if (
+        # List[T]
+        (is_list_type(stripped_type) and is_list(data))
+        # Iterable[T]
+        or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
+    ):
+        # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
+        # intended as an iterable, so we don't transform it.
+        if isinstance(data, dict):
+            return cast(object, data)
+
+        inner_type = extract_type_arg(stripped_type, 0)
+        return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
+
+    if is_union_type(stripped_type):
+        # For union types we run the transformation against all subtypes to ensure that everything is transformed.
+        #
+        # TODO: there may be edge cases where the same normalized field name will transform to two different names
+        # in different subtypes.
+        for subtype in get_args(stripped_type):
+            data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
+        return data
+
+    if isinstance(data, pydantic.BaseModel):
+        return model_dump(data, exclude_unset=True, mode="json")
+
+    annotated_type = _get_annotated_type(annotation)
+    if annotated_type is None:
+        return data
+
+    # ignore the first argument as it is the actual type
+    annotations = get_args(annotated_type)[1:]
+    for annotation in annotations:
+        if isinstance(annotation, PropertyInfo) and annotation.format is not None:
+            return await _async_format_data(data, annotation.format, annotation.format_template)
+
+    return data
+
+
+async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
+    if isinstance(data, (date, datetime)):
+        if format_ == "iso8601":
+            return data.isoformat()
+
+        if format_ == "custom" and format_template is not None:
+            return data.strftime(format_template)
+
+    if format_ == "base64" and is_base64_file_input(data):
+        binary: str | bytes | None = None
+
+        if isinstance(data, pathlib.Path):
+            binary = await anyio.Path(data).read_bytes()
+        elif isinstance(data, io.IOBase):
+            binary = data.read()
+
+            if isinstance(binary, str):  # type: ignore[unreachable]
+                binary = binary.encode()
+
+        if not isinstance(binary, bytes):
+            raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
+
+        return base64.b64encode(binary).decode("ascii")
+
+    return data
+
+
+async def _async_transform_typeddict(
+    data: Mapping[str, object],
+    expected_type: type,
+) -> Mapping[str, object]:
+    result: dict[str, object] = {}
+    annotations = get_type_hints(expected_type, include_extras=True)
+    for key, value in data.items():
+        type_ = annotations.get(key)
+        if type_ is None:
+            # we do not have a type annotation for this field, leave it as is
+            result[key] = value
+        else:
+            result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
+    return result