1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
|
"""Http related parsers and protocol."""
import asyncio
import sys
import zlib
from typing import ( # noqa
Any,
Awaitable,
Callable,
Iterable,
List,
NamedTuple,
Optional,
Union,
)
from multidict import CIMultiDict
from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor
from .helpers import NO_EXTENSIONS
__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
MIN_PAYLOAD_FOR_WRITELINES = 2048
IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
# writelines is not safe for use
# on Python 3.12+ until 3.12.9
# on Python 3.13+ until 3.13.2
# and on older versions it not any faster than write
# CVE-2024-12254: https://github.com/python/cpython/pull/127656
class HttpVersion(NamedTuple):
major: int
minor: int
HttpVersion10 = HttpVersion(1, 0)
HttpVersion11 = HttpVersion(1, 1)
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
class StreamWriter(AbstractStreamWriter):
length: Optional[int] = None
chunked: bool = False
_eof: bool = False
_compress: Optional[ZLibCompressor] = None
def __init__(
self,
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
on_chunk_sent: _T_OnChunkSent = None,
on_headers_sent: _T_OnHeadersSent = None,
) -> None:
self._protocol = protocol
self.loop = loop
self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
self._on_headers_sent: _T_OnHeadersSent = on_headers_sent
@property
def transport(self) -> Optional[asyncio.Transport]:
return self._protocol.transport
@property
def protocol(self) -> BaseProtocol:
return self._protocol
def enable_chunking(self) -> None:
self.chunked = True
def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
) -> None:
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)
def _writelines(self, chunks: Iterable[bytes]) -> None:
size = 0
for chunk in chunks:
size += len(chunk)
self.buffer_size += size
self.output_size += size
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
transport.write(b"".join(chunks))
else:
transport.writelines(chunks)
async def write(
self,
chunk: Union[bytes, bytearray, memoryview],
*,
drain: bool = True,
LIMIT: int = 0x10000,
) -> None:
"""Writes chunk of data to a stream.
write_eof() indicates end of stream.
writer can't be used after write_eof() method being called.
write() return drain future.
"""
if self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
if isinstance(chunk, memoryview):
if chunk.nbytes != len(chunk):
# just reshape it
chunk = chunk.cast("c")
if self._compress is not None:
chunk = await self._compress.compress(chunk)
if not chunk:
return
if self.length is not None:
chunk_len = len(chunk)
if self.length >= chunk_len:
self.length = self.length - chunk_len
else:
chunk = chunk[: self.length]
self.length = 0
if not chunk:
return
if chunk:
if self.chunked:
self._writelines(
(f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n")
)
else:
self._write(chunk)
if self.buffer_size > LIMIT and drain:
self.buffer_size = 0
await self.drain()
async def write_headers(
self, status_line: str, headers: "CIMultiDict[str]"
) -> None:
"""Write request/response status and headers."""
if self._on_headers_sent is not None:
await self._on_headers_sent(headers)
# status + headers
buf = _serialize_headers(status_line, headers)
self._write(buf)
def set_eof(self) -> None:
"""Indicate that the message is complete."""
self._eof = True
async def write_eof(self, chunk: bytes = b"") -> None:
if self._eof:
return
if chunk and self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
if self._compress:
chunks: List[bytes] = []
chunks_len = 0
if chunk and (compressed_chunk := await self._compress.compress(chunk)):
chunks_len = len(compressed_chunk)
chunks.append(compressed_chunk)
flush_chunk = self._compress.flush()
chunks_len += len(flush_chunk)
chunks.append(flush_chunk)
assert chunks_len
if self.chunked:
chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
elif len(chunks) > 1:
self._writelines(chunks)
else:
self._write(chunks[0])
elif self.chunked:
if chunk:
chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
else:
self._write(b"0\r\n\r\n")
elif chunk:
self._write(chunk)
await self.drain()
self._eof = True
async def drain(self) -> None:
"""Flush the write buffer.
The intended use is to write
await w.write(data)
await w.drain()
"""
protocol = self._protocol
if protocol.transport is not None and protocol._paused:
await protocol._drain_helper()
def _safe_header(string: str) -> str:
if "\r" in string or "\n" in string:
raise ValueError(
"Newline or carriage return detected in headers. "
"Potential header injection attack."
)
return string
def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
return line.encode("utf-8")
_serialize_headers = _py_serialize_headers
try:
import aiohttp._http_writer as _http_writer # type: ignore[import-not-found]
_c_serialize_headers = _http_writer._serialize_headers
if not NO_EXTENSIONS:
_serialize_headers = _c_serialize_headers
except ImportError:
pass
|