aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/aiohttp/connector.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiohttp/connector.py')
-rw-r--r--.venv/lib/python3.12/site-packages/aiohttp/connector.py1666
1 files changed, 1666 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiohttp/connector.py b/.venv/lib/python3.12/site-packages/aiohttp/connector.py
new file mode 100644
index 00000000..e5cf3674
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/aiohttp/connector.py
@@ -0,0 +1,1666 @@
+import asyncio
+import functools
+import random
+import socket
+import sys
+import traceback
+import warnings
+from collections import OrderedDict, defaultdict, deque
+from contextlib import suppress
+from http import HTTPStatus
+from itertools import chain, cycle, islice
+from time import monotonic
+from types import TracebackType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ DefaultDict,
+ Deque,
+ Dict,
+ Iterator,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ Union,
+ cast,
+)
+
+import aiohappyeyeballs
+
+from . import hdrs, helpers
+from .abc import AbstractResolver, ResolveResult
+from .client_exceptions import (
+ ClientConnectionError,
+ ClientConnectorCertificateError,
+ ClientConnectorDNSError,
+ ClientConnectorError,
+ ClientConnectorSSLError,
+ ClientHttpProxyError,
+ ClientProxyConnectionError,
+ ServerFingerprintMismatch,
+ UnixClientConnectorError,
+ cert_errors,
+ ssl_errors,
+)
+from .client_proto import ResponseHandler
+from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
+from .helpers import (
+ ceil_timeout,
+ is_ip_address,
+ noop,
+ sentinel,
+ set_exception,
+ set_result,
+)
+from .resolver import DefaultResolver
+
+if TYPE_CHECKING:
+ import ssl
+
+ SSLContext = ssl.SSLContext
+else:
+ try:
+ import ssl
+
+ SSLContext = ssl.SSLContext
+ except ImportError: # pragma: no cover
+ ssl = None # type: ignore[assignment]
+ SSLContext = object # type: ignore[misc,assignment]
+
+EMPTY_SCHEMA_SET = frozenset({""})
+HTTP_SCHEMA_SET = frozenset({"http", "https"})
+WS_SCHEMA_SET = frozenset({"ws", "wss"})
+
+HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET
+HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET
+
+NEEDS_CLEANUP_CLOSED = (3, 13, 0) <= sys.version_info < (
+ 3,
+ 13,
+ 1,
+) or sys.version_info < (3, 12, 7)
+# Cleanup closed is no longer needed after https://github.com/python/cpython/pull/118960
+# which first appeared in Python 3.12.7 and 3.13.1
+
+
+__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
+
+
+if TYPE_CHECKING:
+ from .client import ClientTimeout
+ from .client_reqrep import ConnectionKey
+ from .tracing import Trace
+
+
+class _DeprecationWaiter:
+ __slots__ = ("_awaitable", "_awaited")
+
+ def __init__(self, awaitable: Awaitable[Any]) -> None:
+ self._awaitable = awaitable
+ self._awaited = False
+
+ def __await__(self) -> Any:
+ self._awaited = True
+ return self._awaitable.__await__()
+
+ def __del__(self) -> None:
+ if not self._awaited:
+ warnings.warn(
+ "Connector.close() is a coroutine, "
+ "please use await connector.close()",
+ DeprecationWarning,
+ )
+
+
+class Connection:
+
+ _source_traceback = None
+
+ def __init__(
+ self,
+ connector: "BaseConnector",
+ key: "ConnectionKey",
+ protocol: ResponseHandler,
+ loop: asyncio.AbstractEventLoop,
+ ) -> None:
+ self._key = key
+ self._connector = connector
+ self._loop = loop
+ self._protocol: Optional[ResponseHandler] = protocol
+ self._callbacks: List[Callable[[], None]] = []
+
+ if loop.get_debug():
+ self._source_traceback = traceback.extract_stack(sys._getframe(1))
+
+ def __repr__(self) -> str:
+ return f"Connection<{self._key}>"
+
+ def __del__(self, _warnings: Any = warnings) -> None:
+ if self._protocol is not None:
+ kwargs = {"source": self}
+ _warnings.warn(f"Unclosed connection {self!r}", ResourceWarning, **kwargs)
+ if self._loop.is_closed():
+ return
+
+ self._connector._release(self._key, self._protocol, should_close=True)
+
+ context = {"client_connection": self, "message": "Unclosed connection"}
+ if self._source_traceback is not None:
+ context["source_traceback"] = self._source_traceback
+ self._loop.call_exception_handler(context)
+
+ def __bool__(self) -> Literal[True]:
+ """Force subclasses to not be falsy, to make checks simpler."""
+ return True
+
+ @property
+ def loop(self) -> asyncio.AbstractEventLoop:
+ warnings.warn(
+ "connector.loop property is deprecated", DeprecationWarning, stacklevel=2
+ )
+ return self._loop
+
+ @property
+ def transport(self) -> Optional[asyncio.Transport]:
+ if self._protocol is None:
+ return None
+ return self._protocol.transport
+
+ @property
+ def protocol(self) -> Optional[ResponseHandler]:
+ return self._protocol
+
+ def add_callback(self, callback: Callable[[], None]) -> None:
+ if callback is not None:
+ self._callbacks.append(callback)
+
+ def _notify_release(self) -> None:
+ callbacks, self._callbacks = self._callbacks[:], []
+
+ for cb in callbacks:
+ with suppress(Exception):
+ cb()
+
+ def close(self) -> None:
+ self._notify_release()
+
+ if self._protocol is not None:
+ self._connector._release(self._key, self._protocol, should_close=True)
+ self._protocol = None
+
+ def release(self) -> None:
+ self._notify_release()
+
+ if self._protocol is not None:
+ self._connector._release(self._key, self._protocol)
+ self._protocol = None
+
+ @property
+ def closed(self) -> bool:
+ return self._protocol is None or not self._protocol.is_connected()
+
+
+class _TransportPlaceholder:
+ """placeholder for BaseConnector.connect function"""
+
+ __slots__ = ()
+
+ def close(self) -> None:
+ """Close the placeholder transport."""
+
+
+class BaseConnector:
+ """Base connector class.
+
+ keepalive_timeout - (optional) Keep-alive timeout.
+ force_close - Set to True to force close and do reconnect
+ after each request (and between redirects).
+ limit - The total number of simultaneous connections.
+ limit_per_host - Number of simultaneous connections to one host.
+ enable_cleanup_closed - Enables clean-up closed ssl transports.
+ Disabled by default.
+ timeout_ceil_threshold - Trigger ceiling of timeout values when
+ it's above timeout_ceil_threshold.
+ loop - Optional event loop.
+ """
+
+ _closed = True # prevent AttributeError in __del__ if ctor was failed
+ _source_traceback = None
+
+ # abort transport after 2 seconds (cleanup broken connections)
+ _cleanup_closed_period = 2.0
+
+ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET
+
+ def __init__(
+ self,
+ *,
+ keepalive_timeout: Union[object, None, float] = sentinel,
+ force_close: bool = False,
+ limit: int = 100,
+ limit_per_host: int = 0,
+ enable_cleanup_closed: bool = False,
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+ timeout_ceil_threshold: float = 5,
+ ) -> None:
+
+ if force_close:
+ if keepalive_timeout is not None and keepalive_timeout is not sentinel:
+ raise ValueError(
+ "keepalive_timeout cannot be set if force_close is True"
+ )
+ else:
+ if keepalive_timeout is sentinel:
+ keepalive_timeout = 15.0
+
+ loop = loop or asyncio.get_running_loop()
+ self._timeout_ceil_threshold = timeout_ceil_threshold
+
+ self._closed = False
+ if loop.get_debug():
+ self._source_traceback = traceback.extract_stack(sys._getframe(1))
+
+ # Connection pool of reusable connections.
+ # We use a deque to store connections because it has O(1) popleft()
+ # and O(1) append() operations to implement a FIFO queue.
+ self._conns: DefaultDict[
+ ConnectionKey, Deque[Tuple[ResponseHandler, float]]
+ ] = defaultdict(deque)
+ self._limit = limit
+ self._limit_per_host = limit_per_host
+ self._acquired: Set[ResponseHandler] = set()
+ self._acquired_per_host: DefaultDict[ConnectionKey, Set[ResponseHandler]] = (
+ defaultdict(set)
+ )
+ self._keepalive_timeout = cast(float, keepalive_timeout)
+ self._force_close = force_close
+
+ # {host_key: FIFO list of waiters}
+ # The FIFO is implemented with an OrderedDict with None keys because
+ # python does not have an ordered set.
+ self._waiters: DefaultDict[
+ ConnectionKey, OrderedDict[asyncio.Future[None], None]
+ ] = defaultdict(OrderedDict)
+
+ self._loop = loop
+ self._factory = functools.partial(ResponseHandler, loop=loop)
+
+ # start keep-alive connection cleanup task
+ self._cleanup_handle: Optional[asyncio.TimerHandle] = None
+
+ # start cleanup closed transports task
+ self._cleanup_closed_handle: Optional[asyncio.TimerHandle] = None
+
+ if enable_cleanup_closed and not NEEDS_CLEANUP_CLOSED:
+ warnings.warn(
+ "enable_cleanup_closed ignored because "
+ "https://github.com/python/cpython/pull/118960 is fixed "
+ f"in Python version {sys.version_info}",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ enable_cleanup_closed = False
+
+ self._cleanup_closed_disabled = not enable_cleanup_closed
+ self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = []
+ self._cleanup_closed()
+
+ def __del__(self, _warnings: Any = warnings) -> None:
+ if self._closed:
+ return
+ if not self._conns:
+ return
+
+ conns = [repr(c) for c in self._conns.values()]
+
+ self._close()
+
+ kwargs = {"source": self}
+ _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, **kwargs)
+ context = {
+ "connector": self,
+ "connections": conns,
+ "message": "Unclosed connector",
+ }
+ if self._source_traceback is not None:
+ context["source_traceback"] = self._source_traceback
+ self._loop.call_exception_handler(context)
+
+ def __enter__(self) -> "BaseConnector":
+ warnings.warn(
+ '"with Connector():" is deprecated, '
+ 'use "async with Connector():" instead',
+ DeprecationWarning,
+ )
+ return self
+
+ def __exit__(self, *exc: Any) -> None:
+ self._close()
+
+ async def __aenter__(self) -> "BaseConnector":
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]] = None,
+ exc_value: Optional[BaseException] = None,
+ exc_traceback: Optional[TracebackType] = None,
+ ) -> None:
+ await self.close()
+
+ @property
+ def force_close(self) -> bool:
+ """Ultimately close connection on releasing if True."""
+ return self._force_close
+
+ @property
+ def limit(self) -> int:
+ """The total number for simultaneous connections.
+
+ If limit is 0 the connector has no limit.
+ The default limit size is 100.
+ """
+ return self._limit
+
+ @property
+ def limit_per_host(self) -> int:
+ """The limit for simultaneous connections to the same endpoint.
+
+ Endpoints are the same if they are have equal
+ (host, port, is_ssl) triple.
+ """
+ return self._limit_per_host
+
+ def _cleanup(self) -> None:
+ """Cleanup unused transports."""
+ if self._cleanup_handle:
+ self._cleanup_handle.cancel()
+ # _cleanup_handle should be unset, otherwise _release() will not
+ # recreate it ever!
+ self._cleanup_handle = None
+
+ now = monotonic()
+ timeout = self._keepalive_timeout
+
+ if self._conns:
+ connections = defaultdict(deque)
+ deadline = now - timeout
+ for key, conns in self._conns.items():
+ alive: Deque[Tuple[ResponseHandler, float]] = deque()
+ for proto, use_time in conns:
+ if proto.is_connected() and use_time - deadline >= 0:
+ alive.append((proto, use_time))
+ continue
+ transport = proto.transport
+ proto.close()
+ if not self._cleanup_closed_disabled and key.is_ssl:
+ self._cleanup_closed_transports.append(transport)
+
+ if alive:
+ connections[key] = alive
+
+ self._conns = connections
+
+ if self._conns:
+ self._cleanup_handle = helpers.weakref_handle(
+ self,
+ "_cleanup",
+ timeout,
+ self._loop,
+ timeout_ceil_threshold=self._timeout_ceil_threshold,
+ )
+
+ def _cleanup_closed(self) -> None:
+ """Double confirmation for transport close.
+
+ Some broken ssl servers may leave socket open without proper close.
+ """
+ if self._cleanup_closed_handle:
+ self._cleanup_closed_handle.cancel()
+
+ for transport in self._cleanup_closed_transports:
+ if transport is not None:
+ transport.abort()
+
+ self._cleanup_closed_transports = []
+
+ if not self._cleanup_closed_disabled:
+ self._cleanup_closed_handle = helpers.weakref_handle(
+ self,
+ "_cleanup_closed",
+ self._cleanup_closed_period,
+ self._loop,
+ timeout_ceil_threshold=self._timeout_ceil_threshold,
+ )
+
+ def close(self) -> Awaitable[None]:
+ """Close all opened transports."""
+ self._close()
+ return _DeprecationWaiter(noop())
+
+ def _close(self) -> None:
+ if self._closed:
+ return
+
+ self._closed = True
+
+ try:
+ if self._loop.is_closed():
+ return
+
+ # cancel cleanup task
+ if self._cleanup_handle:
+ self._cleanup_handle.cancel()
+
+ # cancel cleanup close task
+ if self._cleanup_closed_handle:
+ self._cleanup_closed_handle.cancel()
+
+ for data in self._conns.values():
+ for proto, t0 in data:
+ proto.close()
+
+ for proto in self._acquired:
+ proto.close()
+
+ for transport in self._cleanup_closed_transports:
+ if transport is not None:
+ transport.abort()
+
+ finally:
+ self._conns.clear()
+ self._acquired.clear()
+ for keyed_waiters in self._waiters.values():
+ for keyed_waiter in keyed_waiters:
+ keyed_waiter.cancel()
+ self._waiters.clear()
+ self._cleanup_handle = None
+ self._cleanup_closed_transports.clear()
+ self._cleanup_closed_handle = None
+
+ @property
+ def closed(self) -> bool:
+ """Is connector closed.
+
+ A readonly property.
+ """
+ return self._closed
+
+ def _available_connections(self, key: "ConnectionKey") -> int:
+ """
+ Return number of available connections.
+
+ The limit, limit_per_host and the connection key are taken into account.
+
+ If it returns less than 1 means that there are no connections
+ available.
+ """
+ # check total available connections
+ # If there are no limits, this will always return 1
+ total_remain = 1
+
+ if self._limit and (total_remain := self._limit - len(self._acquired)) <= 0:
+ return total_remain
+
+ # check limit per host
+ if host_remain := self._limit_per_host:
+ if acquired := self._acquired_per_host.get(key):
+ host_remain -= len(acquired)
+ if total_remain > host_remain:
+ return host_remain
+
+ return total_remain
+
+ async def connect(
+ self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
+ ) -> Connection:
+ """Get from pool or create new connection."""
+ key = req.connection_key
+ if (conn := await self._get(key, traces)) is not None:
+ # If we do not have to wait and we can get a connection from the pool
+ # we can avoid the timeout ceil logic and directly return the connection
+ return conn
+
+ async with ceil_timeout(timeout.connect, timeout.ceil_threshold):
+ if self._available_connections(key) <= 0:
+ await self._wait_for_available_connection(key, traces)
+ if (conn := await self._get(key, traces)) is not None:
+ return conn
+
+ placeholder = cast(ResponseHandler, _TransportPlaceholder())
+ self._acquired.add(placeholder)
+ if self._limit_per_host:
+ self._acquired_per_host[key].add(placeholder)
+
+ try:
+ # Traces are done inside the try block to ensure that the
+ # that the placeholder is still cleaned up if an exception
+ # is raised.
+ if traces:
+ for trace in traces:
+ await trace.send_connection_create_start()
+ proto = await self._create_connection(req, traces, timeout)
+ if traces:
+ for trace in traces:
+ await trace.send_connection_create_end()
+ except BaseException:
+ self._release_acquired(key, placeholder)
+ raise
+ else:
+ if self._closed:
+ proto.close()
+ raise ClientConnectionError("Connector is closed.")
+
+ # The connection was successfully created, drop the placeholder
+ # and add the real connection to the acquired set. There should
+ # be no awaits after the proto is added to the acquired set
+ # to ensure that the connection is not left in the acquired set
+ # on cancellation.
+ self._acquired.remove(placeholder)
+ self._acquired.add(proto)
+ if self._limit_per_host:
+ acquired_per_host = self._acquired_per_host[key]
+ acquired_per_host.remove(placeholder)
+ acquired_per_host.add(proto)
+ return Connection(self, key, proto, self._loop)
+
+ async def _wait_for_available_connection(
+ self, key: "ConnectionKey", traces: List["Trace"]
+ ) -> None:
+ """Wait for an available connection slot."""
+ # We loop here because there is a race between
+ # the connection limit check and the connection
+ # being acquired. If the connection is acquired
+ # between the check and the await statement, we
+ # need to loop again to check if the connection
+ # slot is still available.
+ attempts = 0
+ while True:
+ fut: asyncio.Future[None] = self._loop.create_future()
+ keyed_waiters = self._waiters[key]
+ keyed_waiters[fut] = None
+ if attempts:
+ # If we have waited before, we need to move the waiter
+ # to the front of the queue as otherwise we might get
+ # starved and hit the timeout.
+ keyed_waiters.move_to_end(fut, last=False)
+
+ try:
+ # Traces happen in the try block to ensure that the
+ # the waiter is still cleaned up if an exception is raised.
+ if traces:
+ for trace in traces:
+ await trace.send_connection_queued_start()
+ await fut
+ if traces:
+ for trace in traces:
+ await trace.send_connection_queued_end()
+ finally:
+ # pop the waiter from the queue if its still
+ # there and not already removed by _release_waiter
+ keyed_waiters.pop(fut, None)
+ if not self._waiters.get(key, True):
+ del self._waiters[key]
+
+ if self._available_connections(key) > 0:
+ break
+ attempts += 1
+
+ async def _get(
+ self, key: "ConnectionKey", traces: List["Trace"]
+ ) -> Optional[Connection]:
+ """Get next reusable connection for the key or None.
+
+ The connection will be marked as acquired.
+ """
+ if (conns := self._conns.get(key)) is None:
+ return None
+
+ t1 = monotonic()
+ while conns:
+ proto, t0 = conns.popleft()
+ # We will we reuse the connection if its connected and
+ # the keepalive timeout has not been exceeded
+ if proto.is_connected() and t1 - t0 <= self._keepalive_timeout:
+ if not conns:
+ # The very last connection was reclaimed: drop the key
+ del self._conns[key]
+ self._acquired.add(proto)
+ if self._limit_per_host:
+ self._acquired_per_host[key].add(proto)
+ if traces:
+ for trace in traces:
+ try:
+ await trace.send_connection_reuseconn()
+ except BaseException:
+ self._release_acquired(key, proto)
+ raise
+ return Connection(self, key, proto, self._loop)
+
+ # Connection cannot be reused, close it
+ transport = proto.transport
+ proto.close()
+ # only for SSL transports
+ if not self._cleanup_closed_disabled and key.is_ssl:
+ self._cleanup_closed_transports.append(transport)
+
+ # No more connections: drop the key
+ del self._conns[key]
+ return None
+
+ def _release_waiter(self) -> None:
+ """
+ Iterates over all waiters until one to be released is found.
+
+ The one to be released is not finished and
+ belongs to a host that has available connections.
+ """
+ if not self._waiters:
+ return
+
+ # Having the dict keys ordered this avoids to iterate
+ # at the same order at each call.
+ queues = list(self._waiters)
+ random.shuffle(queues)
+
+ for key in queues:
+ if self._available_connections(key) < 1:
+ continue
+
+ waiters = self._waiters[key]
+ while waiters:
+ waiter, _ = waiters.popitem(last=False)
+ if not waiter.done():
+ waiter.set_result(None)
+ return
+
+ def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None:
+ """Release acquired connection."""
+ if self._closed:
+ # acquired connection is already released on connector closing
+ return
+
+ self._acquired.discard(proto)
+ if self._limit_per_host and (conns := self._acquired_per_host.get(key)):
+ conns.discard(proto)
+ if not conns:
+ del self._acquired_per_host[key]
+ self._release_waiter()
+
+ def _release(
+ self,
+ key: "ConnectionKey",
+ protocol: ResponseHandler,
+ *,
+ should_close: bool = False,
+ ) -> None:
+ if self._closed:
+ # acquired connection is already released on connector closing
+ return
+
+ self._release_acquired(key, protocol)
+
+ if self._force_close or should_close or protocol.should_close:
+ transport = protocol.transport
+ protocol.close()
+
+ if key.is_ssl and not self._cleanup_closed_disabled:
+ self._cleanup_closed_transports.append(transport)
+ return
+
+ self._conns[key].append((protocol, monotonic()))
+
+ if self._cleanup_handle is None:
+ self._cleanup_handle = helpers.weakref_handle(
+ self,
+ "_cleanup",
+ self._keepalive_timeout,
+ self._loop,
+ timeout_ceil_threshold=self._timeout_ceil_threshold,
+ )
+
+ async def _create_connection(
+ self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
+ ) -> ResponseHandler:
+ raise NotImplementedError()
+
+
+class _DNSCacheTable:
+ def __init__(self, ttl: Optional[float] = None) -> None:
+ self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {}
+ self._timestamps: Dict[Tuple[str, int], float] = {}
+ self._ttl = ttl
+
+ def __contains__(self, host: object) -> bool:
+ return host in self._addrs_rr
+
+ def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None:
+ self._addrs_rr[key] = (cycle(addrs), len(addrs))
+
+ if self._ttl is not None:
+ self._timestamps[key] = monotonic()
+
+ def remove(self, key: Tuple[str, int]) -> None:
+ self._addrs_rr.pop(key, None)
+
+ if self._ttl is not None:
+ self._timestamps.pop(key, None)
+
+ def clear(self) -> None:
+ self._addrs_rr.clear()
+ self._timestamps.clear()
+
+ def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]:
+ loop, length = self._addrs_rr[key]
+ addrs = list(islice(loop, length))
+ # Consume one more element to shift internal state of `cycle`
+ next(loop)
+ return addrs
+
+ def expired(self, key: Tuple[str, int]) -> bool:
+ if self._ttl is None:
+ return False
+
+ return self._timestamps[key] + self._ttl < monotonic()
+
+
+def _make_ssl_context(verified: bool) -> SSLContext:
+ """Create SSL context.
+
+ This method is not async-friendly and should be called from a thread
+ because it will load certificates from disk and do other blocking I/O.
+ """
+ if ssl is None:
+ # No ssl support
+ return None
+ if verified:
+ sslcontext = ssl.create_default_context()
+ else:
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ sslcontext.options |= ssl.OP_NO_SSLv2
+ sslcontext.options |= ssl.OP_NO_SSLv3
+ sslcontext.check_hostname = False
+ sslcontext.verify_mode = ssl.CERT_NONE
+ sslcontext.options |= ssl.OP_NO_COMPRESSION
+ sslcontext.set_default_verify_paths()
+ sslcontext.set_alpn_protocols(("http/1.1",))
+ return sslcontext
+
+
+# The default SSLContext objects are created at import time
+# since they do blocking I/O to load certificates from disk,
+# and imports should always be done before the event loop starts
+# or in a thread.
+_SSL_CONTEXT_VERIFIED = _make_ssl_context(True)
+_SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False)
+
+
+class TCPConnector(BaseConnector):
+ """TCP connector.
+
+ verify_ssl - Set to True to check ssl certifications.
+ fingerprint - Pass the binary sha256
+ digest of the expected certificate in DER format to verify
+ that the certificate the server presents matches. See also
+ https://en.wikipedia.org/wiki/HTTP_Public_Key_Pinning
+ resolver - Enable DNS lookups and use this
+ resolver
+ use_dns_cache - Use memory cache for DNS lookups.
+ ttl_dns_cache - Max seconds having cached a DNS entry, None forever.
+ family - socket address family
+ local_addr - local tuple of (host, port) to bind socket to
+
+ keepalive_timeout - (optional) Keep-alive timeout.
+ force_close - Set to True to force close and do reconnect
+ after each request (and between redirects).
+ limit - The total number of simultaneous connections.
+ limit_per_host - Number of simultaneous connections to one host.
+ enable_cleanup_closed - Enables clean-up closed ssl transports.
+ Disabled by default.
+ happy_eyeballs_delay - This is the “Connection Attempt Delay”
+ as defined in RFC 8305. To disable
+ the happy eyeballs algorithm, set to None.
+ interleave - “First Address Family Count” as defined in RFC 8305
+ loop - Optional event loop.
+ """
+
+ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
+
+ def __init__(
+ self,
+ *,
+ verify_ssl: bool = True,
+ fingerprint: Optional[bytes] = None,
+ use_dns_cache: bool = True,
+ ttl_dns_cache: Optional[int] = 10,
+ family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC,
+ ssl_context: Optional[SSLContext] = None,
+ ssl: Union[bool, Fingerprint, SSLContext] = True,
+ local_addr: Optional[Tuple[str, int]] = None,
+ resolver: Optional[AbstractResolver] = None,
+ keepalive_timeout: Union[None, float, object] = sentinel,
+ force_close: bool = False,
+ limit: int = 100,
+ limit_per_host: int = 0,
+ enable_cleanup_closed: bool = False,
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+ timeout_ceil_threshold: float = 5,
+ happy_eyeballs_delay: Optional[float] = 0.25,
+ interleave: Optional[int] = None,
+ ):
+ super().__init__(
+ keepalive_timeout=keepalive_timeout,
+ force_close=force_close,
+ limit=limit,
+ limit_per_host=limit_per_host,
+ enable_cleanup_closed=enable_cleanup_closed,
+ loop=loop,
+ timeout_ceil_threshold=timeout_ceil_threshold,
+ )
+
+ self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)
+ if resolver is None:
+ resolver = DefaultResolver(loop=self._loop)
+ self._resolver = resolver
+
+ self._use_dns_cache = use_dns_cache
+ self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
+ self._throttle_dns_futures: Dict[
+ Tuple[str, int], Set["asyncio.Future[None]"]
+ ] = {}
+ self._family = family
+ self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr)
+ self._happy_eyeballs_delay = happy_eyeballs_delay
+ self._interleave = interleave
+ self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
+
+ def close(self) -> Awaitable[None]:
+ """Close all ongoing DNS calls."""
+ for fut in chain.from_iterable(self._throttle_dns_futures.values()):
+ fut.cancel()
+
+ for t in self._resolve_host_tasks:
+ t.cancel()
+
+ return super().close()
+
+ @property
+ def family(self) -> int:
+ """Socket family like AF_INET."""
+ return self._family
+
+ @property
+ def use_dns_cache(self) -> bool:
+ """True if local DNS caching is enabled."""
+ return self._use_dns_cache
+
+ def clear_dns_cache(
+ self, host: Optional[str] = None, port: Optional[int] = None
+ ) -> None:
+ """Remove specified host/port or clear all dns local cache."""
+ if host is not None and port is not None:
+ self._cached_hosts.remove((host, port))
+ elif host is not None or port is not None:
+ raise ValueError("either both host and port or none of them are allowed")
+ else:
+ self._cached_hosts.clear()
+
+ async def _resolve_host(
+ self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None
+ ) -> List[ResolveResult]:
+ """Resolve host and return list of addresses."""
+ if is_ip_address(host):
+ return [
+ {
+ "hostname": host,
+ "host": host,
+ "port": port,
+ "family": self._family,
+ "proto": 0,
+ "flags": 0,
+ }
+ ]
+
+ if not self._use_dns_cache:
+
+ if traces:
+ for trace in traces:
+ await trace.send_dns_resolvehost_start(host)
+
+ res = await self._resolver.resolve(host, port, family=self._family)
+
+ if traces:
+ for trace in traces:
+ await trace.send_dns_resolvehost_end(host)
+
+ return res
+
+ key = (host, port)
+ if key in self._cached_hosts and not self._cached_hosts.expired(key):
+ # get result early, before any await (#4014)
+ result = self._cached_hosts.next_addrs(key)
+
+ if traces:
+ for trace in traces:
+ await trace.send_dns_cache_hit(host)
+ return result
+
+ futures: Set["asyncio.Future[None]"]
+ #
+ # If multiple connectors are resolving the same host, we wait
+ # for the first one to resolve and then use the result for all of them.
+ # We use a throttle to ensure that we only resolve the host once
+ # and then use the result for all the waiters.
+ #
+ if key in self._throttle_dns_futures:
+ # get futures early, before any await (#4014)
+ futures = self._throttle_dns_futures[key]
+ future: asyncio.Future[None] = self._loop.create_future()
+ futures.add(future)
+ if traces:
+ for trace in traces:
+ await trace.send_dns_cache_hit(host)
+ try:
+ await future
+ finally:
+ futures.discard(future)
+ return self._cached_hosts.next_addrs(key)
+
+ # update dict early, before any await (#4014)
+ self._throttle_dns_futures[key] = futures = set()
+ # In this case we need to create a task to ensure that we can shield
+ # the task from cancellation as cancelling this lookup should not cancel
+ # the underlying lookup or else the cancel event will get broadcast to
+ # all the waiters across all connections.
+ #
+ coro = self._resolve_host_with_throttle(key, host, port, futures, traces)
+ loop = asyncio.get_running_loop()
+ if sys.version_info >= (3, 12):
+ # Optimization for Python 3.12, try to send immediately
+ resolved_host_task = asyncio.Task(coro, loop=loop, eager_start=True)
+ else:
+ resolved_host_task = loop.create_task(coro)
+
+ if not resolved_host_task.done():
+ self._resolve_host_tasks.add(resolved_host_task)
+ resolved_host_task.add_done_callback(self._resolve_host_tasks.discard)
+
+ try:
+ return await asyncio.shield(resolved_host_task)
+ except asyncio.CancelledError:
+
+ def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None:
+ with suppress(Exception, asyncio.CancelledError):
+ fut.result()
+
+ resolved_host_task.add_done_callback(drop_exception)
+ raise
+
+ async def _resolve_host_with_throttle(
+ self,
+ key: Tuple[str, int],
+ host: str,
+ port: int,
+ futures: Set["asyncio.Future[None]"],
+ traces: Optional[Sequence["Trace"]],
+ ) -> List[ResolveResult]:
+ """Resolve host and set result for all waiters.
+
+ This method must be run in a task and shielded from cancellation
+ to avoid cancelling the underlying lookup.
+ """
+ try:
+ if traces:
+ for trace in traces:
+ await trace.send_dns_cache_miss(host)
+
+ for trace in traces:
+ await trace.send_dns_resolvehost_start(host)
+
+ addrs = await self._resolver.resolve(host, port, family=self._family)
+ if traces:
+ for trace in traces:
+ await trace.send_dns_resolvehost_end(host)
+
+ self._cached_hosts.add(key, addrs)
+ for fut in futures:
+ set_result(fut, None)
+ except BaseException as e:
+ # any DNS exception is set for the waiters to raise the same exception.
+ # This coro is always run in task that is shielded from cancellation so
+ # we should never be propagating cancellation here.
+ for fut in futures:
+ set_exception(fut, e)
+ raise
+ finally:
+ self._throttle_dns_futures.pop(key)
+
+ return self._cached_hosts.next_addrs(key)
+
+ async def _create_connection(
+ self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
+ ) -> ResponseHandler:
+ """Create connection.
+
+ Has same keyword arguments as BaseEventLoop.create_connection.
+ """
+ if req.proxy:
+ _, proto = await self._create_proxy_connection(req, traces, timeout)
+ else:
+ _, proto = await self._create_direct_connection(req, traces, timeout)
+
+ return proto
+
+ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
+ """Logic to get the correct SSL context
+
+ 0. if req.ssl is false, return None
+
+ 1. if ssl_context is specified in req, use it
+ 2. if _ssl_context is specified in self, use it
+ 3. otherwise:
+ 1. if verify_ssl is not specified in req, use self.ssl_context
+ (will generate a default context according to self.verify_ssl)
+ 2. if verify_ssl is True in req, generate a default SSL context
+ 3. if verify_ssl is False in req, generate a SSL context that
+ won't verify
+ """
+ if not req.is_ssl():
+ return None
+
+ if ssl is None: # pragma: no cover
+ raise RuntimeError("SSL is not supported.")
+ sslcontext = req.ssl
+ if isinstance(sslcontext, ssl.SSLContext):
+ return sslcontext
+ if sslcontext is not True:
+ # not verified or fingerprinted
+ return _SSL_CONTEXT_UNVERIFIED
+ sslcontext = self._ssl
+ if isinstance(sslcontext, ssl.SSLContext):
+ return sslcontext
+ if sslcontext is not True:
+ # not verified or fingerprinted
+ return _SSL_CONTEXT_UNVERIFIED
+ return _SSL_CONTEXT_VERIFIED
+
+ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
+ ret = req.ssl
+ if isinstance(ret, Fingerprint):
+ return ret
+ ret = self._ssl
+ if isinstance(ret, Fingerprint):
+ return ret
+ return None
+
+ async def _wrap_create_connection(
+ self,
+ *args: Any,
+ addr_infos: List[aiohappyeyeballs.AddrInfoType],
+ req: ClientRequest,
+ timeout: "ClientTimeout",
+ client_error: Type[Exception] = ClientConnectorError,
+ **kwargs: Any,
+ ) -> Tuple[asyncio.Transport, ResponseHandler]:
+ sock: Union[socket.socket, None] = None
+ try:
+ async with ceil_timeout(
+ timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
+ ):
+ sock = await aiohappyeyeballs.start_connection(
+ addr_infos=addr_infos,
+ local_addr_infos=self._local_addr_infos,
+ happy_eyeballs_delay=self._happy_eyeballs_delay,
+ interleave=self._interleave,
+ loop=self._loop,
+ )
+ connection = await self._loop.create_connection(
+ *args, **kwargs, sock=sock
+ )
+ sock = None
+ return connection
+ except cert_errors as exc:
+ raise ClientConnectorCertificateError(req.connection_key, exc) from exc
+ except ssl_errors as exc:
+ raise ClientConnectorSSLError(req.connection_key, exc) from exc
+ except OSError as exc:
+ if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
+ raise
+ raise client_error(req.connection_key, exc) from exc
+ finally:
+ if sock is not None:
+ # Will be hit if an exception is thrown before the event loop takes the socket.
+ # In that case, proactively close the socket to guard against event loop leaks.
+ # For example, see https://github.com/MagicStack/uvloop/issues/653.
+ try:
+ sock.close()
+ except OSError as exc:
+ raise client_error(req.connection_key, exc) from exc
+
+ async def _wrap_existing_connection(
+ self,
+ *args: Any,
+ req: ClientRequest,
+ timeout: "ClientTimeout",
+ client_error: Type[Exception] = ClientConnectorError,
+ **kwargs: Any,
+ ) -> Tuple[asyncio.Transport, ResponseHandler]:
+ try:
+ async with ceil_timeout(
+ timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
+ ):
+ return await self._loop.create_connection(*args, **kwargs)
+ except cert_errors as exc:
+ raise ClientConnectorCertificateError(req.connection_key, exc) from exc
+ except ssl_errors as exc:
+ raise ClientConnectorSSLError(req.connection_key, exc) from exc
+ except OSError as exc:
+ if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
+ raise
+ raise client_error(req.connection_key, exc) from exc
+
+ def _fail_on_no_start_tls(self, req: "ClientRequest") -> None:
+ """Raise a :py:exc:`RuntimeError` on missing ``start_tls()``.
+
+ It is necessary for TLS-in-TLS so that it is possible to
+ send HTTPS queries through HTTPS proxies.
+
+ This doesn't affect regular HTTP requests, though.
+ """
+ if not req.is_ssl():
+ return
+
+ proxy_url = req.proxy
+ assert proxy_url is not None
+ if proxy_url.scheme != "https":
+ return
+
+ self._check_loop_for_start_tls()
+
+ def _check_loop_for_start_tls(self) -> None:
+ try:
+ self._loop.start_tls
+ except AttributeError as attr_exc:
+ raise RuntimeError(
+ "An HTTPS request is being sent through an HTTPS proxy. "
+ "This needs support for TLS in TLS but it is not implemented "
+ "in your runtime for the stdlib asyncio.\n\n"
+ "Please upgrade to Python 3.11 or higher. For more details, "
+ "please see:\n"
+ "* https://bugs.python.org/issue37179\n"
+ "* https://github.com/python/cpython/pull/28073\n"
+ "* https://docs.aiohttp.org/en/stable/"
+ "client_advanced.html#proxy-support\n"
+ "* https://github.com/aio-libs/aiohttp/discussions/6044\n",
+ ) from attr_exc
+
+ def _loop_supports_start_tls(self) -> bool:
+ try:
+ self._check_loop_for_start_tls()
+ except RuntimeError:
+ return False
+ else:
+ return True
+
+ def _warn_about_tls_in_tls(
+ self,
+ underlying_transport: asyncio.Transport,
+ req: ClientRequest,
+ ) -> None:
+ """Issue a warning if the requested URL has HTTPS scheme."""
+ if req.request_info.url.scheme != "https":
+ return
+
+ asyncio_supports_tls_in_tls = getattr(
+ underlying_transport,
+ "_start_tls_compatible",
+ False,
+ )
+
+ if asyncio_supports_tls_in_tls:
+ return
+
+ warnings.warn(
+ "An HTTPS request is being sent through an HTTPS proxy. "
+ "This support for TLS in TLS is known to be disabled "
+ "in the stdlib asyncio (Python <3.11). This is why you'll probably see "
+ "an error in the log below.\n\n"
+ "It is possible to enable it via monkeypatching. "
+ "For more details, see:\n"
+ "* https://bugs.python.org/issue37179\n"
+ "* https://github.com/python/cpython/pull/28073\n\n"
+ "You can temporarily patch this as follows:\n"
+ "* https://docs.aiohttp.org/en/stable/client_advanced.html#proxy-support\n"
+ "* https://github.com/aio-libs/aiohttp/discussions/6044\n",
+ RuntimeWarning,
+ source=self,
+ # Why `4`? At least 3 of the calls in the stack originate
+ # from the methods in this class.
+ stacklevel=3,
+ )
+
+ async def _start_tls_connection(
+ self,
+ underlying_transport: asyncio.Transport,
+ req: ClientRequest,
+ timeout: "ClientTimeout",
+ client_error: Type[Exception] = ClientConnectorError,
+ ) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
+ """Wrap the raw TCP transport with TLS."""
+ tls_proto = self._factory() # Create a brand new proto for TLS
+ sslcontext = self._get_ssl_context(req)
+ if TYPE_CHECKING:
+ # _start_tls_connection is unreachable in the current code path
+ # if sslcontext is None.
+ assert sslcontext is not None
+
+ try:
+ async with ceil_timeout(
+ timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
+ ):
+ try:
+ tls_transport = await self._loop.start_tls(
+ underlying_transport,
+ tls_proto,
+ sslcontext,
+ server_hostname=req.server_hostname or req.host,
+ ssl_handshake_timeout=timeout.total,
+ )
+ except BaseException:
+ # We need to close the underlying transport since
+ # `start_tls()` probably failed before it had a
+ # chance to do this:
+ underlying_transport.close()
+ raise
+ if isinstance(tls_transport, asyncio.Transport):
+ fingerprint = self._get_fingerprint(req)
+ if fingerprint:
+ try:
+ fingerprint.check(tls_transport)
+ except ServerFingerprintMismatch:
+ tls_transport.close()
+ if not self._cleanup_closed_disabled:
+ self._cleanup_closed_transports.append(tls_transport)
+ raise
+ except cert_errors as exc:
+ raise ClientConnectorCertificateError(req.connection_key, exc) from exc
+ except ssl_errors as exc:
+ raise ClientConnectorSSLError(req.connection_key, exc) from exc
+ except OSError as exc:
+ if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
+ raise
+ raise client_error(req.connection_key, exc) from exc
+ except TypeError as type_err:
+ # Example cause looks like this:
+ # TypeError: transport <asyncio.sslproto._SSLProtocolTransport
+ # object at 0x7f760615e460> is not supported by start_tls()
+
+ raise ClientConnectionError(
+ "Cannot initialize a TLS-in-TLS connection to host "
+ f"{req.host!s}:{req.port:d} through an underlying connection "
+ f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} "
+ f"[{type_err!s}]"
+ ) from type_err
+ else:
+ if tls_transport is None:
+ msg = "Failed to start TLS (possibly caused by closing transport)"
+ raise client_error(req.connection_key, OSError(msg))
+ tls_proto.connection_made(
+ tls_transport
+ ) # Kick the state machine of the new TLS protocol
+
+ return tls_transport, tls_proto
+
+ def _convert_hosts_to_addr_infos(
+ self, hosts: List[ResolveResult]
+ ) -> List[aiohappyeyeballs.AddrInfoType]:
+ """Converts the list of hosts to a list of addr_infos.
+
+ The list of hosts is the result of a DNS lookup. The list of
+ addr_infos is the result of a call to `socket.getaddrinfo()`.
+ """
+ addr_infos: List[aiohappyeyeballs.AddrInfoType] = []
+ for hinfo in hosts:
+ host = hinfo["host"]
+ is_ipv6 = ":" in host
+ family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
+ if self._family and self._family != family:
+ continue
+ addr = (host, hinfo["port"], 0, 0) if is_ipv6 else (host, hinfo["port"])
+ addr_infos.append(
+ (family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr)
+ )
+ return addr_infos
+
+ async def _create_direct_connection(
+ self,
+ req: ClientRequest,
+ traces: List["Trace"],
+ timeout: "ClientTimeout",
+ *,
+ client_error: Type[Exception] = ClientConnectorError,
+ ) -> Tuple[asyncio.Transport, ResponseHandler]:
+ sslcontext = self._get_ssl_context(req)
+ fingerprint = self._get_fingerprint(req)
+
+ host = req.url.raw_host
+ assert host is not None
+ # Replace multiple trailing dots with a single one.
+ # A trailing dot is only present for fully-qualified domain names.
+ # See https://github.com/aio-libs/aiohttp/pull/7364.
+ if host.endswith(".."):
+ host = host.rstrip(".") + "."
+ port = req.port
+ assert port is not None
+ try:
+ # Cancelling this lookup should not cancel the underlying lookup
+ # or else the cancel event will get broadcast to all the waiters
+ # across all connections.
+ hosts = await self._resolve_host(host, port, traces=traces)
+ except OSError as exc:
+ if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
+ raise
+ # in case of proxy it is not ClientProxyConnectionError
+ # it is problem of resolving proxy ip itself
+ raise ClientConnectorDNSError(req.connection_key, exc) from exc
+
+ last_exc: Optional[Exception] = None
+ addr_infos = self._convert_hosts_to_addr_infos(hosts)
+ while addr_infos:
+ # Strip trailing dots, certificates contain FQDN without dots.
+ # See https://github.com/aio-libs/aiohttp/issues/3636
+ server_hostname = (
+ (req.server_hostname or host).rstrip(".") if sslcontext else None
+ )
+
+ try:
+ transp, proto = await self._wrap_create_connection(
+ self._factory,
+ timeout=timeout,
+ ssl=sslcontext,
+ addr_infos=addr_infos,
+ server_hostname=server_hostname,
+ req=req,
+ client_error=client_error,
+ )
+ except (ClientConnectorError, asyncio.TimeoutError) as exc:
+ last_exc = exc
+ aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave)
+ continue
+
+ if req.is_ssl() and fingerprint:
+ try:
+ fingerprint.check(transp)
+ except ServerFingerprintMismatch as exc:
+ transp.close()
+ if not self._cleanup_closed_disabled:
+ self._cleanup_closed_transports.append(transp)
+ last_exc = exc
+ # Remove the bad peer from the list of addr_infos
+ sock: socket.socket = transp.get_extra_info("socket")
+ bad_peer = sock.getpeername()
+ aiohappyeyeballs.remove_addr_infos(addr_infos, bad_peer)
+ continue
+
+ return transp, proto
+ else:
+ assert last_exc is not None
+ raise last_exc
+
+ async def _create_proxy_connection(
+ self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
+ ) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
+ self._fail_on_no_start_tls(req)
+ runtime_has_start_tls = self._loop_supports_start_tls()
+
+ headers: Dict[str, str] = {}
+ if req.proxy_headers is not None:
+ headers = req.proxy_headers # type: ignore[assignment]
+ headers[hdrs.HOST] = req.headers[hdrs.HOST]
+
+ url = req.proxy
+ assert url is not None
+ proxy_req = ClientRequest(
+ hdrs.METH_GET,
+ url,
+ headers=headers,
+ auth=req.proxy_auth,
+ loop=self._loop,
+ ssl=req.ssl,
+ )
+
+ # create connection to proxy server
+ transport, proto = await self._create_direct_connection(
+ proxy_req, [], timeout, client_error=ClientProxyConnectionError
+ )
+
+ auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None)
+ if auth is not None:
+ if not req.is_ssl():
+ req.headers[hdrs.PROXY_AUTHORIZATION] = auth
+ else:
+ proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth
+
+ if req.is_ssl():
+ if runtime_has_start_tls:
+ self._warn_about_tls_in_tls(transport, req)
+
+ # For HTTPS requests over HTTP proxy
+ # we must notify proxy to tunnel connection
+ # so we send CONNECT command:
+ # CONNECT www.python.org:443 HTTP/1.1
+ # Host: www.python.org
+ #
+ # next we must do TLS handshake and so on
+ # to do this we must wrap raw socket into secure one
+ # asyncio handles this perfectly
+ proxy_req.method = hdrs.METH_CONNECT
+ proxy_req.url = req.url
+ key = req.connection_key._replace(
+ proxy=None, proxy_auth=None, proxy_headers_hash=None
+ )
+ conn = Connection(self, key, proto, self._loop)
+ proxy_resp = await proxy_req.send(conn)
+ try:
+ protocol = conn._protocol
+ assert protocol is not None
+
+ # read_until_eof=True will ensure the connection isn't closed
+ # once the response is received and processed allowing
+ # START_TLS to work on the connection below.
+ protocol.set_response_params(
+ read_until_eof=runtime_has_start_tls,
+ timeout_ceil_threshold=self._timeout_ceil_threshold,
+ )
+ resp = await proxy_resp.start(conn)
+ except BaseException:
+ proxy_resp.close()
+ conn.close()
+ raise
+ else:
+ conn._protocol = None
+ try:
+ if resp.status != 200:
+ message = resp.reason
+ if message is None:
+ message = HTTPStatus(resp.status).phrase
+ raise ClientHttpProxyError(
+ proxy_resp.request_info,
+ resp.history,
+ status=resp.status,
+ message=message,
+ headers=resp.headers,
+ )
+ if not runtime_has_start_tls:
+ rawsock = transport.get_extra_info("socket", default=None)
+ if rawsock is None:
+ raise RuntimeError(
+ "Transport does not expose socket instance"
+ )
+ # Duplicate the socket, so now we can close proxy transport
+ rawsock = rawsock.dup()
+ except BaseException:
+ # It shouldn't be closed in `finally` because it's fed to
+ # `loop.start_tls()` and the docs say not to touch it after
+ # passing there.
+ transport.close()
+ raise
+ finally:
+ if not runtime_has_start_tls:
+ transport.close()
+
+ if not runtime_has_start_tls:
+ # HTTP proxy with support for upgrade to HTTPS
+ sslcontext = self._get_ssl_context(req)
+ return await self._wrap_existing_connection(
+ self._factory,
+ timeout=timeout,
+ ssl=sslcontext,
+ sock=rawsock,
+ server_hostname=req.host,
+ req=req,
+ )
+
+ return await self._start_tls_connection(
+ # Access the old transport for the last time before it's
+ # closed and forgotten forever:
+ transport,
+ req=req,
+ timeout=timeout,
+ )
+ finally:
+ proxy_resp.close()
+
+ return transport, proto
+
+
+class UnixConnector(BaseConnector):
+ """Unix socket connector.
+
+ path - Unix socket path.
+ keepalive_timeout - (optional) Keep-alive timeout.
+ force_close - Set to True to force close and do reconnect
+ after each request (and between redirects).
+ limit - The total number of simultaneous connections.
+ limit_per_host - Number of simultaneous connections to one host.
+ loop - Optional event loop.
+ """
+
+ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"})
+
+ def __init__(
+ self,
+ path: str,
+ force_close: bool = False,
+ keepalive_timeout: Union[object, float, None] = sentinel,
+ limit: int = 100,
+ limit_per_host: int = 0,
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+ ) -> None:
+ super().__init__(
+ force_close=force_close,
+ keepalive_timeout=keepalive_timeout,
+ limit=limit,
+ limit_per_host=limit_per_host,
+ loop=loop,
+ )
+ self._path = path
+
+ @property
+ def path(self) -> str:
+ """Path to unix socket."""
+ return self._path
+
+ async def _create_connection(
+ self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
+ ) -> ResponseHandler:
+ try:
+ async with ceil_timeout(
+ timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
+ ):
+ _, proto = await self._loop.create_unix_connection(
+ self._factory, self._path
+ )
+ except OSError as exc:
+ if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
+ raise
+ raise UnixClientConnectorError(self.path, req.connection_key, exc) from exc
+
+ return proto
+
+
+class NamedPipeConnector(BaseConnector):
+ """Named pipe connector.
+
+ Only supported by the proactor event loop.
+ See also: https://docs.python.org/3/library/asyncio-eventloop.html
+
+ path - Windows named pipe path.
+ keepalive_timeout - (optional) Keep-alive timeout.
+ force_close - Set to True to force close and do reconnect
+ after each request (and between redirects).
+ limit - The total number of simultaneous connections.
+ limit_per_host - Number of simultaneous connections to one host.
+ loop - Optional event loop.
+ """
+
+ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"})
+
+ def __init__(
+ self,
+ path: str,
+ force_close: bool = False,
+ keepalive_timeout: Union[object, float, None] = sentinel,
+ limit: int = 100,
+ limit_per_host: int = 0,
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+ ) -> None:
+ super().__init__(
+ force_close=force_close,
+ keepalive_timeout=keepalive_timeout,
+ limit=limit,
+ limit_per_host=limit_per_host,
+ loop=loop,
+ )
+ if not isinstance(
+ self._loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
+ ):
+ raise RuntimeError(
+ "Named Pipes only available in proactor loop under windows"
+ )
+ self._path = path
+
+ @property
+ def path(self) -> str:
+ """Path to the named pipe."""
+ return self._path
+
+ async def _create_connection(
+ self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
+ ) -> ResponseHandler:
+ try:
+ async with ceil_timeout(
+ timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
+ ):
+ _, proto = await self._loop.create_pipe_connection( # type: ignore[attr-defined]
+ self._factory, self._path
+ )
+ # the drain is required so that the connection_made is called
+ # and transport is set otherwise it is not set before the
+ # `assert conn.transport is not None`
+ # in client.py's _request method
+ await asyncio.sleep(0)
+ # other option is to manually set transport like
+ # `proto.transport = trans`
+ except OSError as exc:
+ if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
+ raise
+ raise ClientConnectorError(req.connection_key, exc) from exc
+
+ return cast(ResponseHandler, proto)