aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/fastapi/dependencies
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/fastapi/dependencies')
-rw-r--r--.venv/lib/python3.12/site-packages/fastapi/dependencies/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/fastapi/dependencies/models.py37
-rw-r--r--.venv/lib/python3.12/site-packages/fastapi/dependencies/utils.py980
3 files changed, 1017 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/fastapi/dependencies/__init__.py b/.venv/lib/python3.12/site-packages/fastapi/dependencies/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/fastapi/dependencies/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/fastapi/dependencies/models.py b/.venv/lib/python3.12/site-packages/fastapi/dependencies/models.py
new file mode 100644
index 00000000..418c1172
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/fastapi/dependencies/models.py
@@ -0,0 +1,37 @@
+from dataclasses import dataclass, field
+from typing import Any, Callable, List, Optional, Sequence, Tuple
+
+from fastapi._compat import ModelField
+from fastapi.security.base import SecurityBase
+
+
+@dataclass
+class SecurityRequirement:
+ security_scheme: SecurityBase
+ scopes: Optional[Sequence[str]] = None
+
+
+@dataclass
+class Dependant:
+ path_params: List[ModelField] = field(default_factory=list)
+ query_params: List[ModelField] = field(default_factory=list)
+ header_params: List[ModelField] = field(default_factory=list)
+ cookie_params: List[ModelField] = field(default_factory=list)
+ body_params: List[ModelField] = field(default_factory=list)
+ dependencies: List["Dependant"] = field(default_factory=list)
+ security_requirements: List[SecurityRequirement] = field(default_factory=list)
+ name: Optional[str] = None
+ call: Optional[Callable[..., Any]] = None
+ request_param_name: Optional[str] = None
+ websocket_param_name: Optional[str] = None
+ http_connection_param_name: Optional[str] = None
+ response_param_name: Optional[str] = None
+ background_tasks_param_name: Optional[str] = None
+ security_scopes_param_name: Optional[str] = None
+ security_scopes: Optional[List[str]] = None
+ use_cache: bool = True
+ path: Optional[str] = None
+ cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
+
+ def __post_init__(self) -> None:
+ self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
diff --git a/.venv/lib/python3.12/site-packages/fastapi/dependencies/utils.py b/.venv/lib/python3.12/site-packages/fastapi/dependencies/utils.py
new file mode 100644
index 00000000..84dfa4d0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/fastapi/dependencies/utils.py
@@ -0,0 +1,980 @@
+import inspect
+from contextlib import AsyncExitStack, contextmanager
+from copy import copy, deepcopy
+from dataclasses import dataclass
+from typing import (
+ Any,
+ Callable,
+ Coroutine,
+ Dict,
+ ForwardRef,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+)
+
+import anyio
+from fastapi import params
+from fastapi._compat import (
+ PYDANTIC_V2,
+ ErrorWrapper,
+ ModelField,
+ RequiredParam,
+ Undefined,
+ _regenerate_error_with_loc,
+ copy_field_info,
+ create_body_model,
+ evaluate_forwardref,
+ field_annotation_is_scalar,
+ get_annotation_from_field_info,
+ get_cached_model_fields,
+ get_missing_field_error,
+ is_bytes_field,
+ is_bytes_sequence_field,
+ is_scalar_field,
+ is_scalar_sequence_field,
+ is_sequence_field,
+ is_uploadfile_or_nonable_uploadfile_annotation,
+ is_uploadfile_sequence_annotation,
+ lenient_issubclass,
+ sequence_types,
+ serialize_sequence_value,
+ value_is_sequence,
+)
+from fastapi.background import BackgroundTasks
+from fastapi.concurrency import (
+ asynccontextmanager,
+ contextmanager_in_threadpool,
+)
+from fastapi.dependencies.models import Dependant, SecurityRequirement
+from fastapi.logger import logger
+from fastapi.security.base import SecurityBase
+from fastapi.security.oauth2 import OAuth2, SecurityScopes
+from fastapi.security.open_id_connect_url import OpenIdConnect
+from fastapi.utils import create_model_field, get_path_param_names
+from pydantic import BaseModel
+from pydantic.fields import FieldInfo
+from starlette.background import BackgroundTasks as StarletteBackgroundTasks
+from starlette.concurrency import run_in_threadpool
+from starlette.datastructures import (
+ FormData,
+ Headers,
+ ImmutableMultiDict,
+ QueryParams,
+ UploadFile,
+)
+from starlette.requests import HTTPConnection, Request
+from starlette.responses import Response
+from starlette.websockets import WebSocket
+from typing_extensions import Annotated, get_args, get_origin
+
+multipart_not_installed_error = (
+ 'Form data requires "python-multipart" to be installed. \n'
+ 'You can install "python-multipart" with: \n\n'
+ "pip install python-multipart\n"
+)
+multipart_incorrect_install_error = (
+ 'Form data requires "python-multipart" to be installed. '
+ 'It seems you installed "multipart" instead. \n'
+ 'You can remove "multipart" with: \n\n'
+ "pip uninstall multipart\n\n"
+ 'And then install "python-multipart" with: \n\n'
+ "pip install python-multipart\n"
+)
+
+
+def ensure_multipart_is_installed() -> None:
+ try:
+ from python_multipart import __version__
+
+ # Import an attribute that can be mocked/deleted in testing
+ assert __version__ > "0.0.12"
+ except (ImportError, AssertionError):
+ try:
+ # __version__ is available in both multiparts, and can be mocked
+ from multipart import __version__ # type: ignore[no-redef,import-untyped]
+
+ assert __version__
+ try:
+ # parse_options_header is only available in the right multipart
+ from multipart.multipart import ( # type: ignore[import-untyped]
+ parse_options_header,
+ )
+
+ assert parse_options_header
+ except ImportError:
+ logger.error(multipart_incorrect_install_error)
+ raise RuntimeError(multipart_incorrect_install_error) from None
+ except ImportError:
+ logger.error(multipart_not_installed_error)
+ raise RuntimeError(multipart_not_installed_error) from None
+
+
+def get_param_sub_dependant(
+ *,
+ param_name: str,
+ depends: params.Depends,
+ path: str,
+ security_scopes: Optional[List[str]] = None,
+) -> Dependant:
+ assert depends.dependency
+ return get_sub_dependant(
+ depends=depends,
+ dependency=depends.dependency,
+ path=path,
+ name=param_name,
+ security_scopes=security_scopes,
+ )
+
+
+def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
+ assert callable(depends.dependency), (
+ "A parameter-less dependency must have a callable dependency"
+ )
+ return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
+
+
+def get_sub_dependant(
+ *,
+ depends: params.Depends,
+ dependency: Callable[..., Any],
+ path: str,
+ name: Optional[str] = None,
+ security_scopes: Optional[List[str]] = None,
+) -> Dependant:
+ security_requirement = None
+ security_scopes = security_scopes or []
+ if isinstance(depends, params.Security):
+ dependency_scopes = depends.scopes
+ security_scopes.extend(dependency_scopes)
+ if isinstance(dependency, SecurityBase):
+ use_scopes: List[str] = []
+ if isinstance(dependency, (OAuth2, OpenIdConnect)):
+ use_scopes = security_scopes
+ security_requirement = SecurityRequirement(
+ security_scheme=dependency, scopes=use_scopes
+ )
+ sub_dependant = get_dependant(
+ path=path,
+ call=dependency,
+ name=name,
+ security_scopes=security_scopes,
+ use_cache=depends.use_cache,
+ )
+ if security_requirement:
+ sub_dependant.security_requirements.append(security_requirement)
+ return sub_dependant
+
+
+CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
+
+
+def get_flat_dependant(
+ dependant: Dependant,
+ *,
+ skip_repeats: bool = False,
+ visited: Optional[List[CacheKey]] = None,
+) -> Dependant:
+ if visited is None:
+ visited = []
+ visited.append(dependant.cache_key)
+
+ flat_dependant = Dependant(
+ path_params=dependant.path_params.copy(),
+ query_params=dependant.query_params.copy(),
+ header_params=dependant.header_params.copy(),
+ cookie_params=dependant.cookie_params.copy(),
+ body_params=dependant.body_params.copy(),
+ security_requirements=dependant.security_requirements.copy(),
+ use_cache=dependant.use_cache,
+ path=dependant.path,
+ )
+ for sub_dependant in dependant.dependencies:
+ if skip_repeats and sub_dependant.cache_key in visited:
+ continue
+ flat_sub = get_flat_dependant(
+ sub_dependant, skip_repeats=skip_repeats, visited=visited
+ )
+ flat_dependant.path_params.extend(flat_sub.path_params)
+ flat_dependant.query_params.extend(flat_sub.query_params)
+ flat_dependant.header_params.extend(flat_sub.header_params)
+ flat_dependant.cookie_params.extend(flat_sub.cookie_params)
+ flat_dependant.body_params.extend(flat_sub.body_params)
+ flat_dependant.security_requirements.extend(flat_sub.security_requirements)
+ return flat_dependant
+
+
+def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
+ if not fields:
+ return fields
+ first_field = fields[0]
+ if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
+ fields_to_extract = get_cached_model_fields(first_field.type_)
+ return fields_to_extract
+ return fields
+
+
+def get_flat_params(dependant: Dependant) -> List[ModelField]:
+ flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
+ path_params = _get_flat_fields_from_params(flat_dependant.path_params)
+ query_params = _get_flat_fields_from_params(flat_dependant.query_params)
+ header_params = _get_flat_fields_from_params(flat_dependant.header_params)
+ cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
+ return path_params + query_params + header_params + cookie_params
+
+
+def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
+ signature = inspect.signature(call)
+ globalns = getattr(call, "__globals__", {})
+ typed_params = [
+ inspect.Parameter(
+ name=param.name,
+ kind=param.kind,
+ default=param.default,
+ annotation=get_typed_annotation(param.annotation, globalns),
+ )
+ for param in signature.parameters.values()
+ ]
+ typed_signature = inspect.Signature(typed_params)
+ return typed_signature
+
+
+def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
+ if isinstance(annotation, str):
+ annotation = ForwardRef(annotation)
+ annotation = evaluate_forwardref(annotation, globalns, globalns)
+ return annotation
+
+
+def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
+ signature = inspect.signature(call)
+ annotation = signature.return_annotation
+
+ if annotation is inspect.Signature.empty:
+ return None
+
+ globalns = getattr(call, "__globals__", {})
+ return get_typed_annotation(annotation, globalns)
+
+
+def get_dependant(
+ *,
+ path: str,
+ call: Callable[..., Any],
+ name: Optional[str] = None,
+ security_scopes: Optional[List[str]] = None,
+ use_cache: bool = True,
+) -> Dependant:
+ path_param_names = get_path_param_names(path)
+ endpoint_signature = get_typed_signature(call)
+ signature_params = endpoint_signature.parameters
+ dependant = Dependant(
+ call=call,
+ name=name,
+ path=path,
+ security_scopes=security_scopes,
+ use_cache=use_cache,
+ )
+ for param_name, param in signature_params.items():
+ is_path_param = param_name in path_param_names
+ param_details = analyze_param(
+ param_name=param_name,
+ annotation=param.annotation,
+ value=param.default,
+ is_path_param=is_path_param,
+ )
+ if param_details.depends is not None:
+ sub_dependant = get_param_sub_dependant(
+ param_name=param_name,
+ depends=param_details.depends,
+ path=path,
+ security_scopes=security_scopes,
+ )
+ dependant.dependencies.append(sub_dependant)
+ continue
+ if add_non_field_param_to_dependency(
+ param_name=param_name,
+ type_annotation=param_details.type_annotation,
+ dependant=dependant,
+ ):
+ assert param_details.field is None, (
+ f"Cannot specify multiple FastAPI annotations for {param_name!r}"
+ )
+ continue
+ assert param_details.field is not None
+ if isinstance(param_details.field.field_info, params.Body):
+ dependant.body_params.append(param_details.field)
+ else:
+ add_param_to_fields(field=param_details.field, dependant=dependant)
+ return dependant
+
+
+def add_non_field_param_to_dependency(
+ *, param_name: str, type_annotation: Any, dependant: Dependant
+) -> Optional[bool]:
+ if lenient_issubclass(type_annotation, Request):
+ dependant.request_param_name = param_name
+ return True
+ elif lenient_issubclass(type_annotation, WebSocket):
+ dependant.websocket_param_name = param_name
+ return True
+ elif lenient_issubclass(type_annotation, HTTPConnection):
+ dependant.http_connection_param_name = param_name
+ return True
+ elif lenient_issubclass(type_annotation, Response):
+ dependant.response_param_name = param_name
+ return True
+ elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
+ dependant.background_tasks_param_name = param_name
+ return True
+ elif lenient_issubclass(type_annotation, SecurityScopes):
+ dependant.security_scopes_param_name = param_name
+ return True
+ return None
+
+
+@dataclass
+class ParamDetails:
+ type_annotation: Any
+ depends: Optional[params.Depends]
+ field: Optional[ModelField]
+
+
+def analyze_param(
+ *,
+ param_name: str,
+ annotation: Any,
+ value: Any,
+ is_path_param: bool,
+) -> ParamDetails:
+ field_info = None
+ depends = None
+ type_annotation: Any = Any
+ use_annotation: Any = Any
+ if annotation is not inspect.Signature.empty:
+ use_annotation = annotation
+ type_annotation = annotation
+ # Extract Annotated info
+ if get_origin(use_annotation) is Annotated:
+ annotated_args = get_args(annotation)
+ type_annotation = annotated_args[0]
+ fastapi_annotations = [
+ arg
+ for arg in annotated_args[1:]
+ if isinstance(arg, (FieldInfo, params.Depends))
+ ]
+ fastapi_specific_annotations = [
+ arg
+ for arg in fastapi_annotations
+ if isinstance(arg, (params.Param, params.Body, params.Depends))
+ ]
+ if fastapi_specific_annotations:
+ fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
+ fastapi_specific_annotations[-1]
+ )
+ else:
+ fastapi_annotation = None
+ # Set default for Annotated FieldInfo
+ if isinstance(fastapi_annotation, FieldInfo):
+ # Copy `field_info` because we mutate `field_info.default` below.
+ field_info = copy_field_info(
+ field_info=fastapi_annotation, annotation=use_annotation
+ )
+ assert (
+ field_info.default is Undefined or field_info.default is RequiredParam
+ ), (
+ f"`{field_info.__class__.__name__}` default value cannot be set in"
+ f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
+ )
+ if value is not inspect.Signature.empty:
+ assert not is_path_param, "Path parameters cannot have default values"
+ field_info.default = value
+ else:
+ field_info.default = RequiredParam
+ # Get Annotated Depends
+ elif isinstance(fastapi_annotation, params.Depends):
+ depends = fastapi_annotation
+ # Get Depends from default value
+ if isinstance(value, params.Depends):
+ assert depends is None, (
+ "Cannot specify `Depends` in `Annotated` and default value"
+ f" together for {param_name!r}"
+ )
+ assert field_info is None, (
+ "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
+ f" default value together for {param_name!r}"
+ )
+ depends = value
+ # Get FieldInfo from default value
+ elif isinstance(value, FieldInfo):
+ assert field_info is None, (
+ "Cannot specify FastAPI annotations in `Annotated` and default value"
+ f" together for {param_name!r}"
+ )
+ field_info = value
+ if PYDANTIC_V2:
+ field_info.annotation = type_annotation
+
+ # Get Depends from type annotation
+ if depends is not None and depends.dependency is None:
+ # Copy `depends` before mutating it
+ depends = copy(depends)
+ depends.dependency = type_annotation
+
+ # Handle non-param type annotations like Request
+ if lenient_issubclass(
+ type_annotation,
+ (
+ Request,
+ WebSocket,
+ HTTPConnection,
+ Response,
+ StarletteBackgroundTasks,
+ SecurityScopes,
+ ),
+ ):
+ assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
+ assert field_info is None, (
+ f"Cannot specify FastAPI annotation for type {type_annotation!r}"
+ )
+ # Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
+ elif field_info is None and depends is None:
+ default_value = value if value is not inspect.Signature.empty else RequiredParam
+ if is_path_param:
+ # We might check here that `default_value is RequiredParam`, but the fact is that the same
+ # parameter might sometimes be a path parameter and sometimes not. See
+ # `tests/test_infer_param_optionality.py` for an example.
+ field_info = params.Path(annotation=use_annotation)
+ elif is_uploadfile_or_nonable_uploadfile_annotation(
+ type_annotation
+ ) or is_uploadfile_sequence_annotation(type_annotation):
+ field_info = params.File(annotation=use_annotation, default=default_value)
+ elif not field_annotation_is_scalar(annotation=type_annotation):
+ field_info = params.Body(annotation=use_annotation, default=default_value)
+ else:
+ field_info = params.Query(annotation=use_annotation, default=default_value)
+
+ field = None
+ # It's a field_info, not a dependency
+ if field_info is not None:
+ # Handle field_info.in_
+ if is_path_param:
+ assert isinstance(field_info, params.Path), (
+ f"Cannot use `{field_info.__class__.__name__}` for path param"
+ f" {param_name!r}"
+ )
+ elif (
+ isinstance(field_info, params.Param)
+ and getattr(field_info, "in_", None) is None
+ ):
+ field_info.in_ = params.ParamTypes.query
+ use_annotation_from_field_info = get_annotation_from_field_info(
+ use_annotation,
+ field_info,
+ param_name,
+ )
+ if isinstance(field_info, params.Form):
+ ensure_multipart_is_installed()
+ if not field_info.alias and getattr(field_info, "convert_underscores", None):
+ alias = param_name.replace("_", "-")
+ else:
+ alias = field_info.alias or param_name
+ field_info.alias = alias
+ field = create_model_field(
+ name=param_name,
+ type_=use_annotation_from_field_info,
+ default=field_info.default,
+ alias=alias,
+ required=field_info.default in (RequiredParam, Undefined),
+ field_info=field_info,
+ )
+ if is_path_param:
+ assert is_scalar_field(field=field), (
+ "Path params must be of one of the supported types"
+ )
+ elif isinstance(field_info, params.Query):
+ assert (
+ is_scalar_field(field)
+ or is_scalar_sequence_field(field)
+ or (
+ lenient_issubclass(field.type_, BaseModel)
+ # For Pydantic v1
+ and getattr(field, "shape", 1) == 1
+ )
+ )
+
+ return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
+
+
+def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
+ field_info = field.field_info
+ field_info_in = getattr(field_info, "in_", None)
+ if field_info_in == params.ParamTypes.path:
+ dependant.path_params.append(field)
+ elif field_info_in == params.ParamTypes.query:
+ dependant.query_params.append(field)
+ elif field_info_in == params.ParamTypes.header:
+ dependant.header_params.append(field)
+ else:
+ assert field_info_in == params.ParamTypes.cookie, (
+ f"non-body parameters must be in path, query, header or cookie: {field.name}"
+ )
+ dependant.cookie_params.append(field)
+
+
+def is_coroutine_callable(call: Callable[..., Any]) -> bool:
+ if inspect.isroutine(call):
+ return inspect.iscoroutinefunction(call)
+ if inspect.isclass(call):
+ return False
+ dunder_call = getattr(call, "__call__", None) # noqa: B004
+ return inspect.iscoroutinefunction(dunder_call)
+
+
+def is_async_gen_callable(call: Callable[..., Any]) -> bool:
+ if inspect.isasyncgenfunction(call):
+ return True
+ dunder_call = getattr(call, "__call__", None) # noqa: B004
+ return inspect.isasyncgenfunction(dunder_call)
+
+
+def is_gen_callable(call: Callable[..., Any]) -> bool:
+ if inspect.isgeneratorfunction(call):
+ return True
+ dunder_call = getattr(call, "__call__", None) # noqa: B004
+ return inspect.isgeneratorfunction(dunder_call)
+
+
+async def solve_generator(
+ *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
+) -> Any:
+ if is_gen_callable(call):
+ cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
+ elif is_async_gen_callable(call):
+ cm = asynccontextmanager(call)(**sub_values)
+ return await stack.enter_async_context(cm)
+
+
+@dataclass
+class SolvedDependency:
+ values: Dict[str, Any]
+ errors: List[Any]
+ background_tasks: Optional[StarletteBackgroundTasks]
+ response: Response
+ dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
+
+
+async def solve_dependencies(
+ *,
+ request: Union[Request, WebSocket],
+ dependant: Dependant,
+ body: Optional[Union[Dict[str, Any], FormData]] = None,
+ background_tasks: Optional[StarletteBackgroundTasks] = None,
+ response: Optional[Response] = None,
+ dependency_overrides_provider: Optional[Any] = None,
+ dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
+ async_exit_stack: AsyncExitStack,
+ embed_body_fields: bool,
+) -> SolvedDependency:
+ values: Dict[str, Any] = {}
+ errors: List[Any] = []
+ if response is None:
+ response = Response()
+ del response.headers["content-length"]
+ response.status_code = None # type: ignore
+ dependency_cache = dependency_cache or {}
+ sub_dependant: Dependant
+ for sub_dependant in dependant.dependencies:
+ sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
+ sub_dependant.cache_key = cast(
+ Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
+ )
+ call = sub_dependant.call
+ use_sub_dependant = sub_dependant
+ if (
+ dependency_overrides_provider
+ and dependency_overrides_provider.dependency_overrides
+ ):
+ original_call = sub_dependant.call
+ call = getattr(
+ dependency_overrides_provider, "dependency_overrides", {}
+ ).get(original_call, original_call)
+ use_path: str = sub_dependant.path # type: ignore
+ use_sub_dependant = get_dependant(
+ path=use_path,
+ call=call,
+ name=sub_dependant.name,
+ security_scopes=sub_dependant.security_scopes,
+ )
+
+ solved_result = await solve_dependencies(
+ request=request,
+ dependant=use_sub_dependant,
+ body=body,
+ background_tasks=background_tasks,
+ response=response,
+ dependency_overrides_provider=dependency_overrides_provider,
+ dependency_cache=dependency_cache,
+ async_exit_stack=async_exit_stack,
+ embed_body_fields=embed_body_fields,
+ )
+ background_tasks = solved_result.background_tasks
+ dependency_cache.update(solved_result.dependency_cache)
+ if solved_result.errors:
+ errors.extend(solved_result.errors)
+ continue
+ if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
+ solved = dependency_cache[sub_dependant.cache_key]
+ elif is_gen_callable(call) or is_async_gen_callable(call):
+ solved = await solve_generator(
+ call=call, stack=async_exit_stack, sub_values=solved_result.values
+ )
+ elif is_coroutine_callable(call):
+ solved = await call(**solved_result.values)
+ else:
+ solved = await run_in_threadpool(call, **solved_result.values)
+ if sub_dependant.name is not None:
+ values[sub_dependant.name] = solved
+ if sub_dependant.cache_key not in dependency_cache:
+ dependency_cache[sub_dependant.cache_key] = solved
+ path_values, path_errors = request_params_to_args(
+ dependant.path_params, request.path_params
+ )
+ query_values, query_errors = request_params_to_args(
+ dependant.query_params, request.query_params
+ )
+ header_values, header_errors = request_params_to_args(
+ dependant.header_params, request.headers
+ )
+ cookie_values, cookie_errors = request_params_to_args(
+ dependant.cookie_params, request.cookies
+ )
+ values.update(path_values)
+ values.update(query_values)
+ values.update(header_values)
+ values.update(cookie_values)
+ errors += path_errors + query_errors + header_errors + cookie_errors
+ if dependant.body_params:
+ (
+ body_values,
+ body_errors,
+ ) = await request_body_to_args( # body_params checked above
+ body_fields=dependant.body_params,
+ received_body=body,
+ embed_body_fields=embed_body_fields,
+ )
+ values.update(body_values)
+ errors.extend(body_errors)
+ if dependant.http_connection_param_name:
+ values[dependant.http_connection_param_name] = request
+ if dependant.request_param_name and isinstance(request, Request):
+ values[dependant.request_param_name] = request
+ elif dependant.websocket_param_name and isinstance(request, WebSocket):
+ values[dependant.websocket_param_name] = request
+ if dependant.background_tasks_param_name:
+ if background_tasks is None:
+ background_tasks = BackgroundTasks()
+ values[dependant.background_tasks_param_name] = background_tasks
+ if dependant.response_param_name:
+ values[dependant.response_param_name] = response
+ if dependant.security_scopes_param_name:
+ values[dependant.security_scopes_param_name] = SecurityScopes(
+ scopes=dependant.security_scopes
+ )
+ return SolvedDependency(
+ values=values,
+ errors=errors,
+ background_tasks=background_tasks,
+ response=response,
+ dependency_cache=dependency_cache,
+ )
+
+
+def _validate_value_with_model_field(
+ *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
+) -> Tuple[Any, List[Any]]:
+ if value is None:
+ if field.required:
+ return None, [get_missing_field_error(loc=loc)]
+ else:
+ return deepcopy(field.default), []
+ v_, errors_ = field.validate(value, values, loc=loc)
+ if isinstance(errors_, ErrorWrapper):
+ return None, [errors_]
+ elif isinstance(errors_, list):
+ new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
+ return None, new_errors
+ else:
+ return v_, []
+
+
+def _get_multidict_value(
+ field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
+) -> Any:
+ alias = alias or field.alias
+ if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
+ value = values.getlist(alias)
+ else:
+ value = values.get(alias, None)
+ if (
+ value is None
+ or (
+ isinstance(field.field_info, params.Form)
+ and isinstance(value, str) # For type checks
+ and value == ""
+ )
+ or (is_sequence_field(field) and len(value) == 0)
+ ):
+ if field.required:
+ return
+ else:
+ return deepcopy(field.default)
+ return value
+
+
+def request_params_to_args(
+ fields: Sequence[ModelField],
+ received_params: Union[Mapping[str, Any], QueryParams, Headers],
+) -> Tuple[Dict[str, Any], List[Any]]:
+ values: Dict[str, Any] = {}
+ errors: List[Dict[str, Any]] = []
+
+ if not fields:
+ return values, errors
+
+ first_field = fields[0]
+ fields_to_extract = fields
+ single_not_embedded_field = False
+ default_convert_underscores = True
+ if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
+ fields_to_extract = get_cached_model_fields(first_field.type_)
+ single_not_embedded_field = True
+ # If headers are in a Pydantic model, the way to disable convert_underscores
+ # would be with Header(convert_underscores=False) at the Pydantic model level
+ default_convert_underscores = getattr(
+ first_field.field_info, "convert_underscores", True
+ )
+
+ params_to_process: Dict[str, Any] = {}
+
+ processed_keys = set()
+
+ for field in fields_to_extract:
+ alias = None
+ if isinstance(received_params, Headers):
+ # Handle fields extracted from a Pydantic Model for a header, each field
+ # doesn't have a FieldInfo of type Header with the default convert_underscores=True
+ convert_underscores = getattr(
+ field.field_info, "convert_underscores", default_convert_underscores
+ )
+ if convert_underscores:
+ alias = (
+ field.alias
+ if field.alias != field.name
+ else field.name.replace("_", "-")
+ )
+ value = _get_multidict_value(field, received_params, alias=alias)
+ if value is not None:
+ params_to_process[field.name] = value
+ processed_keys.add(alias or field.alias)
+ processed_keys.add(field.name)
+
+ for key, value in received_params.items():
+ if key not in processed_keys:
+ params_to_process[key] = value
+
+ if single_not_embedded_field:
+ field_info = first_field.field_info
+ assert isinstance(field_info, params.Param), (
+ "Params must be subclasses of Param"
+ )
+ loc: Tuple[str, ...] = (field_info.in_.value,)
+ v_, errors_ = _validate_value_with_model_field(
+ field=first_field, value=params_to_process, values=values, loc=loc
+ )
+ return {first_field.name: v_}, errors_
+
+ for field in fields:
+ value = _get_multidict_value(field, received_params)
+ field_info = field.field_info
+ assert isinstance(field_info, params.Param), (
+ "Params must be subclasses of Param"
+ )
+ loc = (field_info.in_.value, field.alias)
+ v_, errors_ = _validate_value_with_model_field(
+ field=field, value=value, values=values, loc=loc
+ )
+ if errors_:
+ errors.extend(errors_)
+ else:
+ values[field.name] = v_
+ return values, errors
+
+
+def _should_embed_body_fields(fields: List[ModelField]) -> bool:
+ if not fields:
+ return False
+ # More than one dependency could have the same field, it would show up as multiple
+ # fields but it's the same one, so count them by name
+ body_param_names_set = {field.name for field in fields}
+ # A top level field has to be a single field, not multiple
+ if len(body_param_names_set) > 1:
+ return True
+ first_field = fields[0]
+ # If it explicitly specifies it is embedded, it has to be embedded
+ if getattr(first_field.field_info, "embed", None):
+ return True
+ # If it's a Form (or File) field, it has to be a BaseModel to be top level
+ # otherwise it has to be embedded, so that the key value pair can be extracted
+ if isinstance(first_field.field_info, params.Form) and not lenient_issubclass(
+ first_field.type_, BaseModel
+ ):
+ return True
+ return False
+
+
+async def _extract_form_body(
+ body_fields: List[ModelField],
+ received_body: FormData,
+) -> Dict[str, Any]:
+ values = {}
+ first_field = body_fields[0]
+ first_field_info = first_field.field_info
+
+ for field in body_fields:
+ value = _get_multidict_value(field, received_body)
+ if (
+ isinstance(first_field_info, params.File)
+ and is_bytes_field(field)
+ and isinstance(value, UploadFile)
+ ):
+ value = await value.read()
+ elif (
+ is_bytes_sequence_field(field)
+ and isinstance(first_field_info, params.File)
+ and value_is_sequence(value)
+ ):
+ # For types
+ assert isinstance(value, sequence_types) # type: ignore[arg-type]
+ results: List[Union[bytes, str]] = []
+
+ async def process_fn(
+ fn: Callable[[], Coroutine[Any, Any, Any]],
+ ) -> None:
+ result = await fn()
+ results.append(result) # noqa: B023
+
+ async with anyio.create_task_group() as tg:
+ for sub_value in value:
+ tg.start_soon(process_fn, sub_value.read)
+ value = serialize_sequence_value(field=field, value=results)
+ if value is not None:
+ values[field.alias] = value
+ for key, value in received_body.items():
+ if key not in values:
+ values[key] = value
+ return values
+
+
+async def request_body_to_args(
+ body_fields: List[ModelField],
+ received_body: Optional[Union[Dict[str, Any], FormData]],
+ embed_body_fields: bool,
+) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
+ values: Dict[str, Any] = {}
+ errors: List[Dict[str, Any]] = []
+ assert body_fields, "request_body_to_args() should be called with fields"
+ single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
+ first_field = body_fields[0]
+ body_to_process = received_body
+
+ fields_to_extract: List[ModelField] = body_fields
+
+ if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel):
+ fields_to_extract = get_cached_model_fields(first_field.type_)
+
+ if isinstance(received_body, FormData):
+ body_to_process = await _extract_form_body(fields_to_extract, received_body)
+
+ if single_not_embedded_field:
+ loc: Tuple[str, ...] = ("body",)
+ v_, errors_ = _validate_value_with_model_field(
+ field=first_field, value=body_to_process, values=values, loc=loc
+ )
+ return {first_field.name: v_}, errors_
+ for field in body_fields:
+ loc = ("body", field.alias)
+ value: Optional[Any] = None
+ if body_to_process is not None:
+ try:
+ value = body_to_process.get(field.alias)
+ # If the received body is a list, not a dict
+ except AttributeError:
+ errors.append(get_missing_field_error(loc))
+ continue
+ v_, errors_ = _validate_value_with_model_field(
+ field=field, value=value, values=values, loc=loc
+ )
+ if errors_:
+ errors.extend(errors_)
+ else:
+ values[field.name] = v_
+ return values, errors
+
+
+def get_body_field(
+ *, flat_dependant: Dependant, name: str, embed_body_fields: bool
+) -> Optional[ModelField]:
+ """
+ Get a ModelField representing the request body for a path operation, combining
+ all body parameters into a single field if necessary.
+
+ Used to check if it's form data (with `isinstance(body_field, params.Form)`)
+ or JSON and to generate the JSON Schema for a request body.
+
+ This is **not** used to validate/parse the request body, that's done with each
+ individual body parameter.
+ """
+ if not flat_dependant.body_params:
+ return None
+ first_param = flat_dependant.body_params[0]
+ if not embed_body_fields:
+ return first_param
+ model_name = "Body_" + name
+ BodyModel = create_body_model(
+ fields=flat_dependant.body_params, model_name=model_name
+ )
+ required = any(True for f in flat_dependant.body_params if f.required)
+ BodyFieldInfo_kwargs: Dict[str, Any] = {
+ "annotation": BodyModel,
+ "alias": "body",
+ }
+ if not required:
+ BodyFieldInfo_kwargs["default"] = None
+ if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
+ BodyFieldInfo: Type[params.Body] = params.File
+ elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
+ BodyFieldInfo = params.Form
+ else:
+ BodyFieldInfo = params.Body
+
+ body_param_media_types = [
+ f.field_info.media_type
+ for f in flat_dependant.body_params
+ if isinstance(f.field_info, params.Body)
+ ]
+ if len(set(body_param_media_types)) == 1:
+ BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
+ final_field = create_model_field(
+ name="body",
+ type_=BodyModel,
+ required=required,
+ alias="body",
+ field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
+ )
+ return final_field