aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/openai/_utils/_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/openai/_utils/_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/openai/_utils/_utils.py430
1 files changed, 430 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/openai/_utils/_utils.py b/.venv/lib/python3.12/site-packages/openai/_utils/_utils.py
new file mode 100644
index 00000000..d6734e6b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/openai/_utils/_utils.py
@@ -0,0 +1,430 @@
+from __future__ import annotations
+
+import os
+import re
+import inspect
+import functools
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Tuple,
+ Mapping,
+ TypeVar,
+ Callable,
+ Iterable,
+ Sequence,
+ cast,
+ overload,
+)
+from pathlib import Path
+from datetime import date, datetime
+from typing_extensions import TypeGuard
+
+import sniffio
+
+from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike
+from .._compat import parse_date as parse_date, parse_datetime as parse_datetime
+
+_T = TypeVar("_T")
+_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
+_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
+_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
+CallableT = TypeVar("CallableT", bound=Callable[..., Any])
+
+if TYPE_CHECKING:
+ from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI
+
+
+def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
+ return [item for sublist in t for item in sublist]
+
+
+def extract_files(
+ # TODO: this needs to take Dict but variance issues.....
+ # create protocol type ?
+ query: Mapping[str, object],
+ *,
+ paths: Sequence[Sequence[str]],
+) -> list[tuple[str, FileTypes]]:
+ """Recursively extract files from the given dictionary based on specified paths.
+
+ A path may look like this ['foo', 'files', '<array>', 'data'].
+
+ Note: this mutates the given dictionary.
+ """
+ files: list[tuple[str, FileTypes]] = []
+ for path in paths:
+ files.extend(_extract_items(query, path, index=0, flattened_key=None))
+ return files
+
+
+def _extract_items(
+ obj: object,
+ path: Sequence[str],
+ *,
+ index: int,
+ flattened_key: str | None,
+) -> list[tuple[str, FileTypes]]:
+ try:
+ key = path[index]
+ except IndexError:
+ if isinstance(obj, NotGiven):
+ # no value was provided - we can safely ignore
+ return []
+
+ # cyclical import
+ from .._files import assert_is_file_content
+
+ # We have exhausted the path, return the entry we found.
+ assert_is_file_content(obj, key=flattened_key)
+ assert flattened_key is not None
+ return [(flattened_key, cast(FileTypes, obj))]
+
+ index += 1
+ if is_dict(obj):
+ try:
+ # We are at the last entry in the path so we must remove the field
+ if (len(path)) == index:
+ item = obj.pop(key)
+ else:
+ item = obj[key]
+ except KeyError:
+ # Key was not present in the dictionary, this is not indicative of an error
+ # as the given path may not point to a required field. We also do not want
+ # to enforce required fields as the API may differ from the spec in some cases.
+ return []
+ if flattened_key is None:
+ flattened_key = key
+ else:
+ flattened_key += f"[{key}]"
+ return _extract_items(
+ item,
+ path,
+ index=index,
+ flattened_key=flattened_key,
+ )
+ elif is_list(obj):
+ if key != "<array>":
+ return []
+
+ return flatten(
+ [
+ _extract_items(
+ item,
+ path,
+ index=index,
+ flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
+ )
+ for item in obj
+ ]
+ )
+
+ # Something unexpected was passed, just ignore it.
+ return []
+
+
+def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
+ return not isinstance(obj, NotGiven)
+
+
+# Type safe methods for narrowing types with TypeVars.
+# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
+# however this cause Pyright to rightfully report errors. As we know we don't
+# care about the contained types we can safely use `object` in it's place.
+#
+# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
+# `is_*` is for when you're dealing with an unknown input
+# `is_*_t` is for when you're narrowing a known union type to a specific subset
+
+
+def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
+ return isinstance(obj, tuple)
+
+
+def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
+ return isinstance(obj, tuple)
+
+
+def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
+ return isinstance(obj, Sequence)
+
+
+def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
+ return isinstance(obj, Sequence)
+
+
+def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
+ return isinstance(obj, Mapping)
+
+
+def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
+ return isinstance(obj, Mapping)
+
+
+def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
+ return isinstance(obj, dict)
+
+
+def is_list(obj: object) -> TypeGuard[list[object]]:
+ return isinstance(obj, list)
+
+
+def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
+ return isinstance(obj, Iterable)
+
+
+def deepcopy_minimal(item: _T) -> _T:
+ """Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
+
+ - mappings, e.g. `dict`
+ - list
+
+ This is done for performance reasons.
+ """
+ if is_mapping(item):
+ return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
+ if is_list(item):
+ return cast(_T, [deepcopy_minimal(entry) for entry in item])
+ return item
+
+
+# copied from https://github.com/Rapptz/RoboDanny
+def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
+ size = len(seq)
+ if size == 0:
+ return ""
+
+ if size == 1:
+ return seq[0]
+
+ if size == 2:
+ return f"{seq[0]} {final} {seq[1]}"
+
+ return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
+
+
+def quote(string: str) -> str:
+ """Add single quotation marks around the given string. Does *not* do any escaping."""
+ return f"'{string}'"
+
+
+def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
+ """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
+
+ Useful for enforcing runtime validation of overloaded functions.
+
+ Example usage:
+ ```py
+ @overload
+ def foo(*, a: str) -> str: ...
+
+
+ @overload
+ def foo(*, b: bool) -> str: ...
+
+
+ # This enforces the same constraints that a static type checker would
+ # i.e. that either a or b must be passed to the function
+ @required_args(["a"], ["b"])
+ def foo(*, a: str | None = None, b: bool | None = None) -> str: ...
+ ```
+ """
+
+ def inner(func: CallableT) -> CallableT:
+ params = inspect.signature(func).parameters
+ positional = [
+ name
+ for name, param in params.items()
+ if param.kind
+ in {
+ param.POSITIONAL_ONLY,
+ param.POSITIONAL_OR_KEYWORD,
+ }
+ ]
+
+ @functools.wraps(func)
+ def wrapper(*args: object, **kwargs: object) -> object:
+ given_params: set[str] = set()
+ for i, _ in enumerate(args):
+ try:
+ given_params.add(positional[i])
+ except IndexError:
+ raise TypeError(
+ f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
+ ) from None
+
+ for key in kwargs.keys():
+ given_params.add(key)
+
+ for variant in variants:
+ matches = all((param in given_params for param in variant))
+ if matches:
+ break
+ else: # no break
+ if len(variants) > 1:
+ variations = human_join(
+ ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
+ )
+ msg = f"Missing required arguments; Expected either {variations} arguments to be given"
+ else:
+ assert len(variants) > 0
+
+ # TODO: this error message is not deterministic
+ missing = list(set(variants[0]) - given_params)
+ if len(missing) > 1:
+ msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
+ else:
+ msg = f"Missing required argument: {quote(missing[0])}"
+ raise TypeError(msg)
+ return func(*args, **kwargs)
+
+ return wrapper # type: ignore
+
+ return inner
+
+
+_K = TypeVar("_K")
+_V = TypeVar("_V")
+
+
+@overload
+def strip_not_given(obj: None) -> None: ...
+
+
+@overload
+def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
+
+
+@overload
+def strip_not_given(obj: object) -> object: ...
+
+
+def strip_not_given(obj: object | None) -> object:
+ """Remove all top-level keys where their values are instances of `NotGiven`"""
+ if obj is None:
+ return None
+
+ if not is_mapping(obj):
+ return obj
+
+ return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
+
+
+def coerce_integer(val: str) -> int:
+ return int(val, base=10)
+
+
+def coerce_float(val: str) -> float:
+ return float(val)
+
+
+def coerce_boolean(val: str) -> bool:
+ return val == "true" or val == "1" or val == "on"
+
+
+def maybe_coerce_integer(val: str | None) -> int | None:
+ if val is None:
+ return None
+ return coerce_integer(val)
+
+
+def maybe_coerce_float(val: str | None) -> float | None:
+ if val is None:
+ return None
+ return coerce_float(val)
+
+
+def maybe_coerce_boolean(val: str | None) -> bool | None:
+ if val is None:
+ return None
+ return coerce_boolean(val)
+
+
+def removeprefix(string: str, prefix: str) -> str:
+ """Remove a prefix from a string.
+
+ Backport of `str.removeprefix` for Python < 3.9
+ """
+ if string.startswith(prefix):
+ return string[len(prefix) :]
+ return string
+
+
+def removesuffix(string: str, suffix: str) -> str:
+ """Remove a suffix from a string.
+
+ Backport of `str.removesuffix` for Python < 3.9
+ """
+ if string.endswith(suffix):
+ return string[: -len(suffix)]
+ return string
+
+
+def file_from_path(path: str) -> FileTypes:
+ contents = Path(path).read_bytes()
+ file_name = os.path.basename(path)
+ return (file_name, contents)
+
+
+def get_required_header(headers: HeadersLike, header: str) -> str:
+ lower_header = header.lower()
+ if is_mapping_t(headers):
+ # mypy doesn't understand the type narrowing here
+ for k, v in headers.items(): # type: ignore
+ if k.lower() == lower_header and isinstance(v, str):
+ return v
+
+ # to deal with the case where the header looks like Stainless-Event-Id
+ intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
+
+ for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
+ value = headers.get(normalized_header)
+ if value:
+ return value
+
+ raise ValueError(f"Could not find {header} header")
+
+
+def get_async_library() -> str:
+ try:
+ return sniffio.current_async_library()
+ except Exception:
+ return "false"
+
+
+def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
+ """A version of functools.lru_cache that retains the type signature
+ for the wrapped function arguments.
+ """
+ wrapper = functools.lru_cache( # noqa: TID251
+ maxsize=maxsize,
+ )
+ return cast(Any, wrapper) # type: ignore[no-any-return]
+
+
+def json_safe(data: object) -> object:
+ """Translates a mapping / sequence recursively in the same fashion
+ as `pydantic` v2's `model_dump(mode="json")`.
+ """
+ if is_mapping(data):
+ return {json_safe(key): json_safe(value) for key, value in data.items()}
+
+ if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
+ return [json_safe(item) for item in data]
+
+ if isinstance(data, (datetime, date)):
+ return data.isoformat()
+
+ return data
+
+
+def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
+ from ..lib.azure import AzureOpenAI
+
+ return isinstance(client, AzureOpenAI)
+
+
+def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
+ from ..lib.azure import AsyncAzureOpenAI
+
+ return isinstance(client, AsyncAzureOpenAI)