from __future__ import annotations import asyncio import http import logging from typing import Any, Callable, Literal, cast from urllib.parse import unquote import h11 from h11._connection import DEFAULT_MAX_INCOMPLETE_EVENT_SIZE from uvicorn._types import ( ASGI3Application, ASGIReceiveEvent, ASGISendEvent, HTTPRequestEvent, HTTPResponseBodyEvent, HTTPResponseStartEvent, HTTPScope, ) from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL from uvicorn.protocols.http.flow_control import ( CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable, ) from uvicorn.protocols.utils import ( get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl, ) from uvicorn.server import ServerState def _get_status_phrase(status_code: int) -> bytes: try: return http.HTTPStatus(status_code).phrase.encode() except ValueError: return b"" STATUS_PHRASES = { status_code: _get_status_phrase(status_code) for status_code in range(100, 600) } class H11Protocol(asyncio.Protocol): def __init__( self, config: Config, server_state: ServerState, app_state: dict[str, Any], _loop: asyncio.AbstractEventLoop | None = None, ) -> None: if not config.loaded: config.load() self.config = config self.app = config.loaded_app self.loop = _loop or asyncio.get_event_loop() self.logger = logging.getLogger("uvicorn.error") self.access_logger = logging.getLogger("uvicorn.access") self.access_log = self.access_logger.hasHandlers() self.conn = h11.Connection( h11.SERVER, config.h11_max_incomplete_event_size if config.h11_max_incomplete_event_size is not None else DEFAULT_MAX_INCOMPLETE_EVENT_SIZE, ) self.ws_protocol_class = config.ws_protocol_class self.root_path = config.root_path self.limit_concurrency = config.limit_concurrency self.app_state = app_state # Timeouts self.timeout_keep_alive_task: asyncio.TimerHandle | None = None self.timeout_keep_alive = config.timeout_keep_alive # Shared server state self.server_state = server_state self.connections = server_state.connections self.tasks = server_state.tasks # Per-connection state self.transport: asyncio.Transport = None # type: ignore[assignment] self.flow: FlowControl = None # type: ignore[assignment] self.server: tuple[str, int] | None = None self.client: tuple[str, int] | None = None self.scheme: Literal["http", "https"] | None = None # Per-request state self.scope: HTTPScope = None # type: ignore[assignment] self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment] self.cycle: RequestResponseCycle = None # type: ignore[assignment] # Protocol interface def connection_made( # type: ignore[override] self, transport: asyncio.Transport ) -> None: self.connections.add(self) self.transport = transport self.flow = FlowControl(transport) self.server = get_local_addr(transport) self.client = get_remote_addr(transport) self.scheme = "https" if is_ssl(transport) else "http" if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) def connection_lost(self, exc: Exception | None) -> None: self.connections.discard(self) if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix) if self.cycle and not self.cycle.response_complete: self.cycle.disconnected = True if self.conn.our_state != h11.ERROR: event = h11.ConnectionClosed() try: self.conn.send(event) except h11.LocalProtocolError: # Premature client disconnect pass if self.cycle is not None: self.cycle.message_event.set() if self.flow is not None: self.flow.resume_writing() if exc is None: self.transport.close() self._unset_keepalive_if_required() def eof_received(self) -> None: pass def _unset_keepalive_if_required(self) -> None: if self.timeout_keep_alive_task is not None: self.timeout_keep_alive_task.cancel() self.timeout_keep_alive_task = None def _get_upgrade(self) -> bytes | None: connection = [] upgrade = None for name, value in self.headers: if name == b"connection": connection = [token.lower().strip() for token in value.split(b",")] if name == b"upgrade": upgrade = value.lower() if b"upgrade" in connection: return upgrade return None def _should_upgrade_to_ws(self) -> bool: if self.ws_protocol_class is None: if self.config.ws == "auto": msg = "Unsupported upgrade request." self.logger.warning(msg) msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501 self.logger.warning(msg) return False return True def data_received(self, data: bytes) -> None: self._unset_keepalive_if_required() self.conn.receive_data(data) self.handle_events() def handle_events(self) -> None: while True: try: event = self.conn.next_event() except h11.RemoteProtocolError: msg = "Invalid HTTP request received." self.logger.warning(msg) self.send_400_response(msg) return if event is h11.NEED_DATA: break elif event is h11.PAUSED: # This case can occur in HTTP pipelining, so we need to # stop reading any more data, and ensure that at the end # of the active request/response cycle we handle any # events that have been buffered up. self.flow.pause_reading() break elif isinstance(event, h11.Request): self.headers = [(key.lower(), value) for key, value in event.headers] raw_path, _, query_string = event.target.partition(b"?") path = unquote(raw_path.decode("ascii")) full_path = self.root_path + path full_raw_path = self.root_path.encode("ascii") + raw_path self.scope = { "type": "http", "asgi": { "version": self.config.asgi_version, "spec_version": "2.3", }, "http_version": event.http_version.decode("ascii"), "server": self.server, "client": self.client, "scheme": self.scheme, # type: ignore[typeddict-item] "method": event.method.decode("ascii"), "root_path": self.root_path, "path": full_path, "raw_path": full_raw_path, "query_string": query_string, "headers": self.headers, "state": self.app_state.copy(), } upgrade = self._get_upgrade() if upgrade == b"websocket" and self._should_upgrade_to_ws(): self.handle_websocket_upgrade(event) return # Handle 503 responses when 'limit_concurrency' is exceeded. if self.limit_concurrency is not None and ( len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency ): app = service_unavailable message = "Exceeded concurrency limit." self.logger.warning(message) else: app = self.app # When starting to process a request, disable the keep-alive # timeout. Normally we disable this when receiving data from # client and set back when finishing processing its request. # However, for pipelined requests processing finishes after # already receiving the next request and thus the timer may # be set here, which we don't want. self._unset_keepalive_if_required() self.cycle = RequestResponseCycle( scope=self.scope, conn=self.conn, transport=self.transport, flow=self.flow, logger=self.logger, access_logger=self.access_logger, access_log=self.access_log, default_headers=self.server_state.default_headers, message_event=asyncio.Event(), on_response=self.on_response_complete, ) task = self.loop.create_task(self.cycle.run_asgi(app)) task.add_done_callback(self.tasks.discard) self.tasks.add(task) elif isinstance(event, h11.Data): if self.conn.our_state is h11.DONE: continue self.cycle.body += event.data if len(self.cycle.body) > HIGH_WATER_LIMIT: self.flow.pause_reading() self.cycle.message_event.set() elif isinstance(event, h11.EndOfMessage): if self.conn.our_state is h11.DONE: self.transport.resume_reading() self.conn.start_next_cycle() continue self.cycle.more_body = False self.cycle.message_event.set() def handle_websocket_upgrade(self, event: h11.Request) -> None: if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix) self.connections.discard(self) output = [event.method, b" ", event.target, b" HTTP/1.1\r\n"] for name, value in self.headers: output += [name, b": ", value, b"\r\n"] output.append(b"\r\n") protocol = self.ws_protocol_class( # type: ignore[call-arg, misc] config=self.config, server_state=self.server_state, app_state=self.app_state, ) protocol.connection_made(self.transport) protocol.data_received(b"".join(output)) self.transport.set_protocol(protocol) def send_400_response(self, msg: str) -> None: reason = STATUS_PHRASES[400] headers: list[tuple[bytes, bytes]] = [ (b"content-type", b"text/plain; charset=utf-8"), (b"connection", b"close"), ] event = h11.Response(status_code=400, headers=headers, reason=reason) output = self.conn.send(event) self.transport.write(output) output = self.conn.send(event=h11.Data(data=msg.encode("ascii"))) self.transport.write(output) output = self.conn.send(event=h11.EndOfMessage()) self.transport.write(output) self.transport.close() def on_response_complete(self) -> None: self.server_state.total_requests += 1 if self.transport.is_closing(): return # Set a short Keep-Alive timeout. self._unset_keepalive_if_required() self.timeout_keep_alive_task = self.loop.call_later( self.timeout_keep_alive, self.timeout_keep_alive_handler ) # Unpause data reads if needed. self.flow.resume_reading() # Unblock any pipelined events. if self.conn.our_state is h11.DONE and self.conn.their_state is h11.DONE: self.conn.start_next_cycle() self.handle_events() def shutdown(self) -> None: """ Called by the server to commence a graceful shutdown. """ if self.cycle is None or self.cycle.response_complete: event = h11.ConnectionClosed() self.conn.send(event) self.transport.close() else: self.cycle.keep_alive = False def pause_writing(self) -> None: """ Called by the transport when the write buffer exceeds the high water mark. """ self.flow.pause_writing() def resume_writing(self) -> None: """ Called by the transport when the write buffer drops below the low water mark. """ self.flow.resume_writing() def timeout_keep_alive_handler(self) -> None: """ Called on a keep-alive connection if no new data is received after a short delay. """ if not self.transport.is_closing(): event = h11.ConnectionClosed() self.conn.send(event) self.transport.close() class RequestResponseCycle: def __init__( self, scope: "HTTPScope", conn: h11.Connection, transport: asyncio.Transport, flow: FlowControl, logger: logging.Logger, access_logger: logging.Logger, access_log: bool, default_headers: list[tuple[bytes, bytes]], message_event: asyncio.Event, on_response: Callable[..., None], ) -> None: self.scope = scope self.conn = conn self.transport = transport self.flow = flow self.logger = logger self.access_logger = access_logger self.access_log = access_log self.default_headers = default_headers self.message_event = message_event self.on_response = on_response # Connection state self.disconnected = False self.keep_alive = True self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue # Request state self.body = b"" self.more_body = True # Response state self.response_started = False self.response_complete = False # ASGI exception wrapper async def run_asgi(self, app: "ASGI3Application") -> None: try: result = await app( # type: ignore[func-returns-value] self.scope, self.receive, self.send ) except BaseException as exc: msg = "Exception in ASGI application\n" self.logger.error(msg, exc_info=exc) if not self.response_started: await self.send_500_response() else: self.transport.close() else: if result is not None: msg = "ASGI callable should return None, but returned '%s'." self.logger.error(msg, result) self.transport.close() elif not self.response_started and not self.disconnected: msg = "ASGI callable returned without starting response." self.logger.error(msg) await self.send_500_response() elif not self.response_complete and not self.disconnected: msg = "ASGI callable returned without completing response." self.logger.error(msg) self.transport.close() finally: self.on_response = lambda: None async def send_500_response(self) -> None: response_start_event: "HTTPResponseStartEvent" = { "type": "http.response.start", "status": 500, "headers": [ (b"content-type", b"text/plain; charset=utf-8"), (b"connection", b"close"), ], } await self.send(response_start_event) response_body_event: "HTTPResponseBodyEvent" = { "type": "http.response.body", "body": b"Internal Server Error", "more_body": False, } await self.send(response_body_event) # ASGI interface async def send(self, message: "ASGISendEvent") -> None: message_type = message["type"] if self.flow.write_paused and not self.disconnected: await self.flow.drain() if self.disconnected: return if not self.response_started: # Sending response status line and headers if message_type != "http.response.start": msg = "Expected ASGI message 'http.response.start', but got '%s'." raise RuntimeError(msg % message_type) message = cast("HTTPResponseStartEvent", message) self.response_started = True self.waiting_for_100_continue = False status = message["status"] headers = self.default_headers + list(message.get("headers", [])) if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers: headers = headers + [CLOSE_HEADER] if self.access_log: self.access_logger.info( '%s - "%s %s HTTP/%s" %d', get_client_addr(self.scope), self.scope["method"], get_path_with_query_string(self.scope), self.scope["http_version"], status, ) # Write response status line and headers reason = STATUS_PHRASES[status] response = h11.Response(status_code=status, headers=headers, reason=reason) output = self.conn.send(event=response) self.transport.write(output) elif not self.response_complete: # Sending response body if message_type != "http.response.body": msg = "Expected ASGI message 'http.response.body', but got '%s'." raise RuntimeError(msg % message_type) message = cast("HTTPResponseBodyEvent", message) body = message.get("body", b"") more_body = message.get("more_body", False) # Write response body data = b"" if self.scope["method"] == "HEAD" else body output = self.conn.send(event=h11.Data(data=data)) self.transport.write(output) # Handle response completion if not more_body: self.response_complete = True self.message_event.set() output = self.conn.send(event=h11.EndOfMessage()) self.transport.write(output) else: # Response already sent msg = "Unexpected ASGI message '%s' sent, after response already completed." raise RuntimeError(msg % message_type) if self.response_complete: if self.conn.our_state is h11.MUST_CLOSE or not self.keep_alive: self.conn.send(event=h11.ConnectionClosed()) self.transport.close() self.on_response() async def receive(self) -> "ASGIReceiveEvent": if self.waiting_for_100_continue and not self.transport.is_closing(): headers: list[tuple[str, str]] = [] event = h11.InformationalResponse( status_code=100, headers=headers, reason="Continue" ) output = self.conn.send(event=event) self.transport.write(output) self.waiting_for_100_continue = False if not self.disconnected and not self.response_complete: self.flow.resume_reading() await self.message_event.wait() self.message_event.clear() if self.disconnected or self.response_complete: return {"type": "http.disconnect"} message: "HTTPRequestEvent" = { "type": "http.request", "body": self.body, "more_body": self.more_body, } self.body = b"" return message