Bug 1934094 - Update aiohttp from 3.8.5 to 3.10.11 r=firefox-build-system-reviewers,mach-reviewers,frontend-codestyle-reviewers,glandium,ahal

`3.8.5` cannot build with python 3.12, and `3.10.11` still builds on
python 3.8, so `3.10.11` gives better coverage. (There are newer
versions, but they aren't compatible with python 3.8)

Differential Revision: https://phabricator.services.mozilla.com/D230869
This commit is contained in:
ahochheiden
2025-01-21 18:19:43 +00:00
parent e3b6d7a64d
commit 8ec74b0a9b
249 changed files with 75890 additions and 27761 deletions

View File

@@ -10,6 +10,7 @@ vendored:testing/web-platform/tests/tools/wptrunner
vendored:testing/web-platform/tests/tools/wptserve
vendored:third_party/python/Jinja2
vendored:third_party/python/PyYAML/lib/
vendored:third_party/python/aiohappyeyeballs
vendored:third_party/python/aiohttp
vendored:third_party/python/aiosignal
vendored:third_party/python/appdirs
@@ -43,6 +44,7 @@ vendored:third_party/python/pathspec
vendored:third_party/python/platformdirs
vendored:third_party/python/ply
vendored:third_party/python/polib
vendored:third_party/python/propcache
vendored:third_party/python/pyasn1
vendored:third_party/python/pyasn1_modules
vendored:third_party/python/pygments

View File

@@ -9,6 +9,7 @@ vendored:testing/web-platform/tests/tools/wptrunner
vendored:testing/web-platform/tests/tools/wptserve
vendored:third_party/python/MarkupSafe/src
vendored:third_party/python/PyYAML/lib/
vendored:third_party/python/aiohappyeyeballs
vendored:third_party/python/aiohttp
vendored:third_party/python/aiosignal
vendored:third_party/python/appdirs
@@ -40,6 +41,7 @@ vendored:third_party/python/multidict
vendored:third_party/python/pathspec
vendored:third_party/python/platformdirs
vendored:third_party/python/ply
vendored:third_party/python/propcache
vendored:third_party/python/pyasn1
vendored:third_party/python/pyasn1_modules
vendored:third_party/python/pygments

View File

@@ -23,6 +23,7 @@ pypi:pywatchman==1.4.1
vendored:third_party/python/Jinja2
vendored:third_party/python/MarkupSafe/src
vendored:third_party/python/PyYAML/lib/
vendored:third_party/python/aiohappyeyeballs
vendored:third_party/python/aiohttp
vendored:third_party/python/aiosignal
vendored:third_party/python/appdirs
@@ -44,6 +45,7 @@ vendored:third_party/python/markdown_it_py
vendored:third_party/python/mdurl
vendored:third_party/python/mohawk
vendored:third_party/python/mozilla_repo_urls
vendored:third_party/python/propcache
vendored:third_party/python/pygments
vendored:third_party/python/pylru
vendored:third_party/python/redo

View File

@@ -1,5 +1,6 @@
requires-python:>=3.8
vendored:third_party/python/PyYAML/lib/
vendored:third_party/python/aiohappyeyeballs
vendored:third_party/python/aiohttp
vendored:third_party/python/aiosignal
vendored:third_party/python/appdirs
@@ -21,6 +22,7 @@ vendored:third_party/python/mohawk
vendored:third_party/python/mozilla_repo_urls
vendored:third_party/python/multidict
vendored:third_party/python/pathspec
vendored:third_party/python/propcache
vendored:third_party/python/pygments
vendored:third_party/python/python_dateutil
vendored:third_party/python/python_slugify

View File

@@ -17,6 +17,7 @@ vendored:testing/web-platform/tests/tools/wptrunner
vendored:testing/web-platform/tests/tools/wptserve
vendored:third_party/python/MarkupSafe/src
vendored:third_party/python/PyYAML/lib/
vendored:third_party/python/aiohappyeyeballs
vendored:third_party/python/aiohttp
vendored:third_party/python/aiosignal
vendored:third_party/python/appdirs
@@ -48,6 +49,7 @@ vendored:third_party/python/multidict
vendored:third_party/python/pathspec
vendored:third_party/python/platformdirs
vendored:third_party/python/ply
vendored:third_party/python/propcache
vendored:third_party/python/pyasn1
vendored:third_party/python/pyasn1_modules
vendored:third_party/python/pygments

View File

@@ -9,6 +9,7 @@ vendored:testing/web-platform/tests/tools/wptrunner
vendored:testing/web-platform/tests/tools/wptserve
vendored:third_party/python/MarkupSafe/src
vendored:third_party/python/PyYAML/lib/
vendored:third_party/python/aiohappyeyeballs
vendored:third_party/python/aiohttp
vendored:third_party/python/aiosignal
vendored:third_party/python/appdirs
@@ -42,6 +43,7 @@ vendored:third_party/python/pathspec
vendored:third_party/python/pkgutil_resolve_name
vendored:third_party/python/platformdirs
vendored:third_party/python/ply
vendored:third_party/python/propcache
vendored:third_party/python/pyasn1
vendored:third_party/python/pyasn1_modules
vendored:third_party/python/pygments

View File

@@ -0,0 +1,279 @@
A. HISTORY OF THE SOFTWARE
==========================
Python was created in the early 1990s by Guido van Rossum at Stichting
Mathematisch Centrum (CWI, see https://www.cwi.nl) in the Netherlands
as a successor of a language called ABC. Guido remains Python's
principal author, although it includes many contributions from others.
In 1995, Guido continued his work on Python at the Corporation for
National Research Initiatives (CNRI, see https://www.cnri.reston.va.us)
in Reston, Virginia where he released several versions of the
software.
In May 2000, Guido and the Python core development team moved to
BeOpen.com to form the BeOpen PythonLabs team. In October of the same
year, the PythonLabs team moved to Digital Creations, which became
Zope Corporation. In 2001, the Python Software Foundation (PSF, see
https://www.python.org/psf/) was formed, a non-profit organization
created specifically to own Python-related Intellectual Property.
Zope Corporation was a sponsoring member of the PSF.
All Python releases are Open Source (see https://opensource.org for
the Open Source Definition). Historically, most, but not all, Python
releases have also been GPL-compatible; the table below summarizes
the various releases.
Release Derived Year Owner GPL-
from compatible? (1)
0.9.0 thru 1.2 1991-1995 CWI yes
1.3 thru 1.5.2 1.2 1995-1999 CNRI yes
1.6 1.5.2 2000 CNRI no
2.0 1.6 2000 BeOpen.com no
1.6.1 1.6 2001 CNRI yes (2)
2.1 2.0+1.6.1 2001 PSF no
2.0.1 2.0+1.6.1 2001 PSF yes
2.1.1 2.1+2.0.1 2001 PSF yes
2.1.2 2.1.1 2002 PSF yes
2.1.3 2.1.2 2002 PSF yes
2.2 and above 2.1.1 2001-now PSF yes
Footnotes:
(1) GPL-compatible doesn't mean that we're distributing Python under
the GPL. All Python licenses, unlike the GPL, let you distribute
a modified version without making your changes open source. The
GPL-compatible licenses make it possible to combine Python with
other software that is released under the GPL; the others don't.
(2) According to Richard Stallman, 1.6.1 is not GPL-compatible,
because its license has a choice of law clause. According to
CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1
is "not incompatible" with the GPL.
Thanks to the many outside volunteers who have worked under Guido's
direction to make these releases possible.
B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON
===============================================================
Python software and documentation are licensed under the
Python Software Foundation License Version 2.
Starting with Python 3.8.6, examples, recipes, and other code in
the documentation are dual licensed under the PSF License Version 2
and the Zero-Clause BSD license.
Some software incorporated into Python is under different licenses.
The licenses are listed with code falling under that license.
PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
--------------------------------------------
1. This LICENSE AGREEMENT is between the Python Software Foundation
("PSF"), and the Individual or Organization ("Licensee") accessing and
otherwise using this software ("Python") in source or binary form and
its associated documentation.
2. Subject to the terms and conditions of this License Agreement, PSF hereby
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
analyze, test, perform and/or display publicly, prepare derivative works,
distribute, and otherwise use Python alone or in any derivative version,
provided, however, that PSF's License Agreement and PSF's notice of copyright,
i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023 Python Software Foundation;
All Rights Reserved" are retained in Python alone or in any derivative version
prepared by Licensee.
3. In the event Licensee prepares a derivative work that is based on
or incorporates Python or any part thereof, and wants to make
the derivative work available to others as provided herein, then
Licensee hereby agrees to include in any such work a brief summary of
the changes made to Python.
4. PSF is making Python available to Licensee on an "AS IS"
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
INFRINGE ANY THIRD PARTY RIGHTS.
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
6. This License Agreement will automatically terminate upon a material
breach of its terms and conditions.
7. Nothing in this License Agreement shall be deemed to create any
relationship of agency, partnership, or joint venture between PSF and
Licensee. This License Agreement does not grant permission to use PSF
trademarks or trade name in a trademark sense to endorse or promote
products or services of Licensee, or any third party.
8. By copying, installing or otherwise using Python, Licensee
agrees to be bound by the terms and conditions of this License
Agreement.
BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0
-------------------------------------------
BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1
1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an
office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the
Individual or Organization ("Licensee") accessing and otherwise using
this software in source or binary form and its associated
documentation ("the Software").
2. Subject to the terms and conditions of this BeOpen Python License
Agreement, BeOpen hereby grants Licensee a non-exclusive,
royalty-free, world-wide license to reproduce, analyze, test, perform
and/or display publicly, prepare derivative works, distribute, and
otherwise use the Software alone or in any derivative version,
provided, however, that the BeOpen Python License is retained in the
Software, alone or in any derivative version prepared by Licensee.
3. BeOpen is making the Software available to Licensee on an "AS IS"
basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT
INFRINGE ANY THIRD PARTY RIGHTS.
4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE
SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS
AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY
DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
5. This License Agreement will automatically terminate upon a material
breach of its terms and conditions.
6. This License Agreement shall be governed by and interpreted in all
respects by the law of the State of California, excluding conflict of
law provisions. Nothing in this License Agreement shall be deemed to
create any relationship of agency, partnership, or joint venture
between BeOpen and Licensee. This License Agreement does not grant
permission to use BeOpen trademarks or trade names in a trademark
sense to endorse or promote products or services of Licensee, or any
third party. As an exception, the "BeOpen Python" logos available at
http://www.pythonlabs.com/logos.html may be used according to the
permissions granted on that web page.
7. By copying, installing or otherwise using the software, Licensee
agrees to be bound by the terms and conditions of this License
Agreement.
CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1
---------------------------------------
1. This LICENSE AGREEMENT is between the Corporation for National
Research Initiatives, having an office at 1895 Preston White Drive,
Reston, VA 20191 ("CNRI"), and the Individual or Organization
("Licensee") accessing and otherwise using Python 1.6.1 software in
source or binary form and its associated documentation.
2. Subject to the terms and conditions of this License Agreement, CNRI
hereby grants Licensee a nonexclusive, royalty-free, world-wide
license to reproduce, analyze, test, perform and/or display publicly,
prepare derivative works, distribute, and otherwise use Python 1.6.1
alone or in any derivative version, provided, however, that CNRI's
License Agreement and CNRI's notice of copyright, i.e., "Copyright (c)
1995-2001 Corporation for National Research Initiatives; All Rights
Reserved" are retained in Python 1.6.1 alone or in any derivative
version prepared by Licensee. Alternately, in lieu of CNRI's License
Agreement, Licensee may substitute the following text (omitting the
quotes): "Python 1.6.1 is made available subject to the terms and
conditions in CNRI's License Agreement. This Agreement together with
Python 1.6.1 may be located on the internet using the following
unique, persistent identifier (known as a handle): 1895.22/1013. This
Agreement may also be obtained from a proxy server on the internet
using the following URL: http://hdl.handle.net/1895.22/1013".
3. In the event Licensee prepares a derivative work that is based on
or incorporates Python 1.6.1 or any part thereof, and wants to make
the derivative work available to others as provided herein, then
Licensee hereby agrees to include in any such work a brief summary of
the changes made to Python 1.6.1.
4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS"
basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT
INFRINGE ANY THIRD PARTY RIGHTS.
5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1,
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
6. This License Agreement will automatically terminate upon a material
breach of its terms and conditions.
7. This License Agreement shall be governed by the federal
intellectual property law of the United States, including without
limitation the federal copyright law, and, to the extent such
U.S. federal law does not apply, by the law of the Commonwealth of
Virginia, excluding Virginia's conflict of law provisions.
Notwithstanding the foregoing, with regard to derivative works based
on Python 1.6.1 that incorporate non-separable material that was
previously distributed under the GNU General Public License (GPL), the
law of the Commonwealth of Virginia shall govern this License
Agreement only as to issues arising under or with respect to
Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this
License Agreement shall be deemed to create any relationship of
agency, partnership, or joint venture between CNRI and Licensee. This
License Agreement does not grant permission to use CNRI trademarks or
trade name in a trademark sense to endorse or promote products or
services of Licensee, or any third party.
8. By clicking on the "ACCEPT" button where indicated, or by copying,
installing or otherwise using Python 1.6.1, Licensee agrees to be
bound by the terms and conditions of this License Agreement.
ACCEPT
CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2
--------------------------------------------------
Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam,
The Netherlands. All rights reserved.
Permission to use, copy, modify, and distribute this software and its
documentation for any purpose and without fee is hereby granted,
provided that the above copyright notice appear in all copies and that
both that copyright notice and this permission notice appear in
supporting documentation, and that the name of Stichting Mathematisch
Centrum or CWI not be used in advertising or publicity pertaining to
distribution of the software without specific, written prior
permission.
STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO
THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND
FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE
FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
ZERO-CLAUSE BSD LICENSE FOR CODE IN THE PYTHON DOCUMENTATION
----------------------------------------------------------------------
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.

View File

@@ -0,0 +1,126 @@
Metadata-Version: 2.1
Name: aiohappyeyeballs
Version: 2.4.4
Summary: Happy Eyeballs for asyncio
Home-page: https://github.com/aio-libs/aiohappyeyeballs
License: PSF-2.0
Author: J. Nick Koston
Author-email: nick@koston.org
Requires-Python: >=3.8
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: Python Software Foundation License
Classifier: License :: Other/Proprietary License
Classifier: Natural Language :: English
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Software Development :: Libraries
Project-URL: Bug Tracker, https://github.com/aio-libs/aiohappyeyeballs/issues
Project-URL: Changelog, https://github.com/aio-libs/aiohappyeyeballs/blob/main/CHANGELOG.md
Project-URL: Documentation, https://aiohappyeyeballs.readthedocs.io
Project-URL: Repository, https://github.com/aio-libs/aiohappyeyeballs
Description-Content-Type: text/markdown
# aiohappyeyeballs
<p align="center">
<a href="https://github.com/aio-libs/aiohappyeyeballs/actions/workflows/ci.yml?query=branch%3Amain">
<img src="https://img.shields.io/github/actions/workflow/status/aio-libs/aiohappyeyeballs/ci-cd.yml?branch=main&label=CI&logo=github&style=flat-square" alt="CI Status" >
</a>
<a href="https://aiohappyeyeballs.readthedocs.io">
<img src="https://img.shields.io/readthedocs/aiohappyeyeballs.svg?logo=read-the-docs&logoColor=fff&style=flat-square" alt="Documentation Status">
</a>
<a href="https://codecov.io/gh/aio-libs/aiohappyeyeballs">
<img src="https://img.shields.io/codecov/c/github/aio-libs/aiohappyeyeballs.svg?logo=codecov&logoColor=fff&style=flat-square" alt="Test coverage percentage">
</a>
</p>
<p align="center">
<a href="https://python-poetry.org/">
<img src="https://img.shields.io/badge/packaging-poetry-299bd7?style=flat-square&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAA4AAAASCAYAAABrXO8xAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAJJSURBVHgBfZLPa1NBEMe/s7tNXoxW1KJQKaUHkXhQvHgW6UHQQ09CBS/6V3hKc/AP8CqCrUcpmop3Cx48eDB4yEECjVQrlZb80CRN8t6OM/teagVxYZi38+Yz853dJbzoMV3MM8cJUcLMSUKIE8AzQ2PieZzFxEJOHMOgMQQ+dUgSAckNXhapU/NMhDSWLs1B24A8sO1xrN4NECkcAC9ASkiIJc6k5TRiUDPhnyMMdhKc+Zx19l6SgyeW76BEONY9exVQMzKExGKwwPsCzza7KGSSWRWEQhyEaDXp6ZHEr416ygbiKYOd7TEWvvcQIeusHYMJGhTwF9y7sGnSwaWyFAiyoxzqW0PM/RjghPxF2pWReAowTEXnDh0xgcLs8l2YQmOrj3N7ByiqEoH0cARs4u78WgAVkoEDIDoOi3AkcLOHU60RIg5wC4ZuTC7FaHKQm8Hq1fQuSOBvX/sodmNJSB5geaF5CPIkUeecdMxieoRO5jz9bheL6/tXjrwCyX/UYBUcjCaWHljx1xiX6z9xEjkYAzbGVnB8pvLmyXm9ep+W8CmsSHQQY77Zx1zboxAV0w7ybMhQmfqdmmw3nEp1I0Z+FGO6M8LZdoyZnuzzBdjISicKRnpxzI9fPb+0oYXsNdyi+d3h9bm9MWYHFtPeIZfLwzmFDKy1ai3p+PDls1Llz4yyFpferxjnyjJDSEy9CaCx5m2cJPerq6Xm34eTrZt3PqxYO1XOwDYZrFlH1fWnpU38Y9HRze3lj0vOujZcXKuuXm3jP+s3KbZVra7y2EAAAAAASUVORK5CYII=" alt="Poetry">
</a>
<a href="https://github.com/astral-sh/ruff">
<img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json" alt="Ruff">
</a>
<a href="https://github.com/pre-commit/pre-commit">
<img src="https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white&style=flat-square" alt="pre-commit">
</a>
</p>
<p align="center">
<a href="https://pypi.org/project/aiohappyeyeballs/">
<img src="https://img.shields.io/pypi/v/aiohappyeyeballs.svg?logo=python&logoColor=fff&style=flat-square" alt="PyPI Version">
</a>
<img src="https://img.shields.io/pypi/pyversions/aiohappyeyeballs.svg?style=flat-square&logo=python&amp;logoColor=fff" alt="Supported Python versions">
<img src="https://img.shields.io/pypi/l/aiohappyeyeballs.svg?style=flat-square" alt="License">
</p>
---
**Documentation**: <a href="https://aiohappyeyeballs.readthedocs.io" target="_blank">https://aiohappyeyeballs.readthedocs.io </a>
**Source Code**: <a href="https://github.com/aio-libs/aiohappyeyeballs" target="_blank">https://github.com/aio-libs/aiohappyeyeballs </a>
---
[Happy Eyeballs](https://en.wikipedia.org/wiki/Happy_Eyeballs)
([RFC 8305](https://www.rfc-editor.org/rfc/rfc8305.html))
## Use case
This library exists to allow connecting with
[Happy Eyeballs](https://en.wikipedia.org/wiki/Happy_Eyeballs)
([RFC 8305](https://www.rfc-editor.org/rfc/rfc8305.html))
when you
already have a list of addrinfo and not a DNS name.
The stdlib version of `loop.create_connection()`
will only work when you pass in an unresolved name which
is not a good fit when using DNS caching or resolving
names via another method such as `zeroconf`.
## Installation
Install this via pip (or your favourite package manager):
`pip install aiohappyeyeballs`
## License
[aiohappyeyeballs is licensed under the same terms as cpython itself.](https://github.com/python/cpython/blob/main/LICENSE)
## Example usage
```python
addr_infos = await loop.getaddrinfo("example.org", 80)
socket = await start_connection(addr_infos)
socket = await start_connection(addr_infos, local_addr_infos=local_addr_infos, happy_eyeballs_delay=0.2)
transport, protocol = await loop.create_connection(
MyProtocol, sock=socket, ...)
# Remove the first address for each family from addr_info
pop_addr_infos_interleave(addr_info, 1)
# Remove all matching address from addr_info
remove_addr_infos(addr_info, "dead::beef::")
# Convert a local_addr to local_addr_infos
local_addr_infos = addr_to_addr_infos(("127.0.0.1",0))
```
## Credits
This package contains code from cpython and is licensed under the same terms as cpython itself.
This package was created with
[Copier](https://copier.readthedocs.io/) and the
[browniebroke/pypackage-template](https://github.com/browniebroke/pypackage-template)
project template.

View File

@@ -0,0 +1,10 @@
aiohappyeyeballs/__init__.py,sha256=64CUKZ1vpW6MnkJIyy-CHBU7o6c_TbKO7f6WAViSl9s,317
aiohappyeyeballs/_staggered.py,sha256=LbTGSjib2cb11QDE4RlSVQNUauK3X9p1avCR9YuJF7s,6723
aiohappyeyeballs/impl.py,sha256=qrAnR-7xaxh6W7mg0i-9ozAJpzvCm7w-P3Yhy_LaTaM,8109
aiohappyeyeballs/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
aiohappyeyeballs/types.py,sha256=iYPiBTl5J7YEjnIqEOVUTRPzz2DwqSHBRhvbAlM0zv0,234
aiohappyeyeballs/utils.py,sha256=on9GxIR0LhEfZu8P6Twi9hepX9zDanuZM20MWsb3xlQ,3028
aiohappyeyeballs-2.4.4.dist-info/LICENSE,sha256=Oy-B_iHRgcSZxZolbI4ZaEVdZonSaaqFNzv7avQdo78,13936
aiohappyeyeballs-2.4.4.dist-info/METADATA,sha256=CT9LuDMNOove0oCR6kFFKMoLkA-D_XuBVr_w4uCpcpY,6070
aiohappyeyeballs-2.4.4.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
aiohappyeyeballs-2.4.4.dist-info/RECORD,,

View File

@@ -0,0 +1,4 @@
Wheel-Version: 1.0
Generator: poetry-core 1.9.1
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,13 @@
__version__ = "2.4.4"
from .impl import start_connection
from .types import AddrInfoType
from .utils import addr_to_addr_infos, pop_addr_infos_interleave, remove_addr_infos
__all__ = (
"AddrInfoType",
"addr_to_addr_infos",
"pop_addr_infos_interleave",
"remove_addr_infos",
"start_connection",
)

View File

@@ -0,0 +1,202 @@
import asyncio
import contextlib
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
_T = TypeVar("_T")
def _set_result(wait_next: "asyncio.Future[None]") -> None:
"""Set the result of a future if it is not already done."""
if not wait_next.done():
wait_next.set_result(None)
async def _wait_one(
futures: "Iterable[asyncio.Future[Any]]",
loop: asyncio.AbstractEventLoop,
) -> _T:
"""Wait for the first future to complete."""
wait_next = loop.create_future()
def _on_completion(fut: "asyncio.Future[Any]") -> None:
if not wait_next.done():
wait_next.set_result(fut)
for f in futures:
f.add_done_callback(_on_completion)
try:
return await wait_next
finally:
for f in futures:
f.remove_done_callback(_on_completion)
async def staggered_race(
coro_fns: Iterable[Callable[[], Awaitable[_T]]],
delay: Optional[float],
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
"""
Run coroutines with staggered start times and take the first to finish.
This method takes an iterable of coroutine functions. The first one is
started immediately. From then on, whenever the immediately preceding one
fails (raises an exception), or when *delay* seconds has passed, the next
coroutine is started. This continues until one of the coroutines complete
successfully, in which case all others are cancelled, or until all
coroutines fail.
The coroutines provided should be well-behaved in the following way:
* They should only ``return`` if completed successfully.
* They should always raise an exception if they did not complete
successfully. In particular, if they handle cancellation, they should
probably reraise, like this::
try:
# do work
except asyncio.CancelledError:
# undo partially completed work
raise
Args:
----
coro_fns: an iterable of coroutine functions, i.e. callables that
return a coroutine object when called. Use ``functools.partial`` or
lambdas to pass arguments.
delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.
loop: the event loop to use. If ``None``, the running loop is used.
Returns:
-------
tuple *(winner_result, winner_index, exceptions)* where
- *winner_result*: the result of the winning coroutine, or ``None``
if no coroutines won.
- *winner_index*: the index of the winning coroutine in
``coro_fns``, or ``None`` if no coroutines won. If the winning
coroutine may return None on success, *winner_index* can be used
to definitively determine whether any coroutine won.
- *exceptions*: list of exceptions returned by the coroutines.
``len(exceptions)`` is equal to the number of coroutines actually
started, and the order is the same as in ``coro_fns``. The winning
coroutine's entry is ``None``.
"""
loop = loop or asyncio.get_running_loop()
exceptions: List[Optional[BaseException]] = []
tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()
async def run_one_coro(
coro_fn: Callable[[], Awaitable[_T]],
this_index: int,
start_next: "asyncio.Future[None]",
) -> Optional[Tuple[_T, int]]:
"""
Run a single coroutine.
If the coroutine fails, set the exception in the exceptions list and
start the next coroutine by setting the result of the start_next.
If the coroutine succeeds, return the result and the index of the
coroutine in the coro_fns list.
If SystemExit or KeyboardInterrupt is raised, re-raise it.
"""
try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as e:
exceptions[this_index] = e
_set_result(start_next) # Kickstart the next coroutine
return None
return result, this_index
start_next_timer: Optional[asyncio.TimerHandle] = None
start_next: Optional[asyncio.Future[None]]
task: asyncio.Task[Optional[Tuple[_T, int]]]
done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
coro_iter = iter(coro_fns)
this_index = -1
try:
while True:
if coro_fn := next(coro_iter, None):
this_index += 1
exceptions.append(None)
start_next = loop.create_future()
task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
tasks.add(task)
start_next_timer = (
loop.call_later(delay, _set_result, start_next) if delay else None
)
elif not tasks:
# We exhausted the coro_fns list and no tasks are running
# so we have no winner and all coroutines failed.
break
while tasks:
done = await _wait_one(
[*tasks, start_next] if start_next else tasks, loop
)
if done is start_next:
# The current task has failed or the timer has expired
# so we need to start the next task.
start_next = None
if start_next_timer:
start_next_timer.cancel()
start_next_timer = None
# Break out of the task waiting loop to start the next
# task.
break
if TYPE_CHECKING:
assert isinstance(done, asyncio.Task)
tasks.remove(done)
if winner := done.result():
return *winner, exceptions
finally:
# We either have:
# - a winner
# - all tasks failed
# - a KeyboardInterrupt or SystemExit.
#
# If the timer is still running, cancel it.
#
if start_next_timer:
start_next_timer.cancel()
#
# If there are any tasks left, cancel them and than
# wait them so they fill the exceptions list.
#
for task in tasks:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
return None, None, exceptions

View File

@@ -0,0 +1,221 @@
"""Base implementation."""
import asyncio
import collections
import functools
import itertools
import socket
import sys
from typing import List, Optional, Sequence, Union
from . import _staggered
from .types import AddrInfoType
if sys.version_info < (3, 8, 2): # noqa: UP036
# asyncio.staggered is broken in Python 3.8.0 and 3.8.1
# so it must be patched:
# https://github.com/aio-libs/aiohttp/issues/8556
# https://bugs.python.org/issue39129
# https://github.com/python/cpython/pull/17693
import asyncio.futures
asyncio.futures.TimeoutError = asyncio.TimeoutError # type: ignore[attr-defined]
async def start_connection(
addr_infos: Sequence[AddrInfoType],
*,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
happy_eyeballs_delay: Optional[float] = None,
interleave: Optional[int] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> socket.socket:
"""
Connect to a TCP server.
Create a socket connection to a specified destination. The
destination is specified as a list of AddrInfoType tuples as
returned from getaddrinfo().
The arguments are, in order:
* ``family``: the address family, e.g. ``socket.AF_INET`` or
``socket.AF_INET6``.
* ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or
``socket.SOCK_DGRAM``.
* ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or
``socket.IPPROTO_UDP``.
* ``canonname``: the canonical name of the address, e.g.
``"www.python.org"``.
* ``sockaddr``: the socket address
This method is a coroutine which will try to establish the connection
in the background. When successful, the coroutine returns a
socket.
The expected use case is to use this method in conjunction with
loop.create_connection() to establish a connection to a server::
socket = await start_connection(addr_infos)
transport, protocol = await loop.create_connection(
MyProtocol, sock=socket, ...)
"""
if not (current_loop := loop):
current_loop = asyncio.get_running_loop()
single_addr_info = len(addr_infos) == 1
if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family
interleave = 1
if interleave and not single_addr_info:
addr_infos = _interleave_addrinfos(addr_infos, interleave)
sock: Optional[socket.socket] = None
# uvloop can raise RuntimeError instead of OSError
exceptions: List[List[Union[OSError, RuntimeError]]] = []
if happy_eyeballs_delay is None or single_addr_info:
# not using happy eyeballs
for addrinfo in addr_infos:
try:
sock = await _connect_sock(
current_loop, exceptions, addrinfo, local_addr_infos
)
break
except (RuntimeError, OSError):
continue
else: # using happy eyeballs
sock, _, _ = await _staggered.staggered_race(
(
functools.partial(
_connect_sock, current_loop, exceptions, addrinfo, local_addr_infos
)
for addrinfo in addr_infos
),
happy_eyeballs_delay,
)
if sock is None:
all_exceptions = [exc for sub in exceptions for exc in sub]
try:
first_exception = all_exceptions[0]
if len(all_exceptions) == 1:
raise first_exception
else:
# If they all have the same str(), raise one.
model = str(first_exception)
if all(str(exc) == model for exc in all_exceptions):
raise first_exception
# Raise a combined exception so the user can see all
# the various error messages.
msg = "Multiple exceptions: {}".format(
", ".join(str(exc) for exc in all_exceptions)
)
# If the errno is the same for all exceptions, raise
# an OSError with that errno.
if isinstance(first_exception, OSError):
first_errno = first_exception.errno
if all(
isinstance(exc, OSError) and exc.errno == first_errno
for exc in all_exceptions
):
raise OSError(first_errno, msg)
elif isinstance(first_exception, RuntimeError) and all(
isinstance(exc, RuntimeError) for exc in all_exceptions
):
raise RuntimeError(msg)
# We have a mix of OSError and RuntimeError
# so we have to pick which one to raise.
# and we raise OSError for compatibility
raise OSError(msg)
finally:
all_exceptions = None # type: ignore[assignment]
exceptions = None # type: ignore[assignment]
return sock
async def _connect_sock(
loop: asyncio.AbstractEventLoop,
exceptions: List[List[Union[OSError, RuntimeError]]],
addr_info: AddrInfoType,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
) -> socket.socket:
"""Create, bind and connect one socket."""
my_exceptions: List[Union[OSError, RuntimeError]] = []
exceptions.append(my_exceptions)
family, type_, proto, _, address = addr_info
sock = None
try:
sock = socket.socket(family=family, type=type_, proto=proto)
sock.setblocking(False)
if local_addr_infos is not None:
for lfamily, _, _, _, laddr in local_addr_infos:
# skip local addresses of different family
if lfamily != family:
continue
try:
sock.bind(laddr)
break
except OSError as exc:
msg = (
f"error while attempting to bind on "
f"address {laddr!r}: "
f"{exc.strerror.lower()}"
)
exc = OSError(exc.errno, msg)
my_exceptions.append(exc)
else: # all bind attempts failed
if my_exceptions:
raise my_exceptions.pop()
else:
raise OSError(f"no matching local address with {family=} found")
await loop.sock_connect(sock, address)
return sock
except (RuntimeError, OSError) as exc:
my_exceptions.append(exc)
if sock is not None:
try:
sock.close()
except OSError as e:
my_exceptions.append(e)
raise
raise
except:
if sock is not None:
try:
sock.close()
except OSError as e:
my_exceptions.append(e)
raise
raise
finally:
exceptions = my_exceptions = None # type: ignore[assignment]
def _interleave_addrinfos(
addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1
) -> List[AddrInfoType]:
"""Interleave list of addrinfo tuples by family."""
# Group addresses by family
addrinfos_by_family: collections.OrderedDict[int, List[AddrInfoType]] = (
collections.OrderedDict()
)
for addr in addrinfos:
family = addr[0]
if family not in addrinfos_by_family:
addrinfos_by_family[family] = []
addrinfos_by_family[family].append(addr)
addrinfos_lists = list(addrinfos_by_family.values())
reordered: List[AddrInfoType] = []
if first_address_family_count > 1:
reordered.extend(addrinfos_lists[0][: first_address_family_count - 1])
del addrinfos_lists[0][: first_address_family_count - 1]
reordered.extend(
a
for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists))
if a is not None
)
return reordered

View File

@@ -0,0 +1,12 @@
"""Types for aiohappyeyeballs."""
import socket
from typing import Tuple, Union
AddrInfoType = Tuple[
Union[int, socket.AddressFamily],
Union[int, socket.SocketKind],
int,
str,
Tuple, # type: ignore[type-arg]
]

View File

@@ -0,0 +1,97 @@
"""Utility functions for aiohappyeyeballs."""
import ipaddress
import socket
from typing import Dict, List, Optional, Tuple, Union
from .types import AddrInfoType
def addr_to_addr_infos(
addr: Optional[
Union[Tuple[str, int, int, int], Tuple[str, int, int], Tuple[str, int]]
],
) -> Optional[List[AddrInfoType]]:
"""Convert an address tuple to a list of addr_info tuples."""
if addr is None:
return None
host = addr[0]
port = addr[1]
is_ipv6 = ":" in host
if is_ipv6:
flowinfo = 0
scopeid = 0
addr_len = len(addr)
if addr_len >= 4:
scopeid = addr[3] # type: ignore[misc]
if addr_len >= 3:
flowinfo = addr[2] # type: ignore[misc]
addr = (host, port, flowinfo, scopeid)
family = socket.AF_INET6
else:
addr = (host, port)
family = socket.AF_INET
return [(family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr)]
def pop_addr_infos_interleave(
addr_infos: List[AddrInfoType], interleave: Optional[int] = None
) -> None:
"""
Pop addr_info from the list of addr_infos by family up to interleave times.
The interleave parameter is used to know how many addr_infos for
each family should be popped of the top of the list.
"""
seen: Dict[int, int] = {}
if interleave is None:
interleave = 1
to_remove: List[AddrInfoType] = []
for addr_info in addr_infos:
family = addr_info[0]
if family not in seen:
seen[family] = 0
if seen[family] < interleave:
to_remove.append(addr_info)
seen[family] += 1
for addr_info in to_remove:
addr_infos.remove(addr_info)
def _addr_tuple_to_ip_address(
addr: Union[Tuple[str, int], Tuple[str, int, int, int]],
) -> Union[
Tuple[ipaddress.IPv4Address, int], Tuple[ipaddress.IPv6Address, int, int, int]
]:
"""Convert an address tuple to an IPv4Address."""
return (ipaddress.ip_address(addr[0]), *addr[1:])
def remove_addr_infos(
addr_infos: List[AddrInfoType],
addr: Union[Tuple[str, int], Tuple[str, int, int, int]],
) -> None:
"""
Remove an address from the list of addr_infos.
The addr value is typically the return value of
sock.getpeername().
"""
bad_addrs_infos: List[AddrInfoType] = []
for addr_info in addr_infos:
if addr_info[-1] == addr:
bad_addrs_infos.append(addr_info)
if bad_addrs_infos:
for bad_addr_info in bad_addrs_infos:
addr_infos.remove(bad_addr_info)
return
# Slow path in case addr is formatted differently
match_addr = _addr_tuple_to_ip_address(addr)
for addr_info in addr_infos:
if match_addr == _addr_tuple_to_ip_address(addr_info[-1]):
bad_addrs_infos.append(addr_info)
if bad_addrs_infos:
for bad_addr_info in bad_addrs_infos:
addr_infos.remove(bad_addr_info)
return
raise ValueError(f"Address {addr} not found in addr_infos")

File diff suppressed because it is too large Load Diff

View File

@@ -27,6 +27,7 @@ Alexander Shorin
Alexander Travov
Alexandru Mihai
Alexey Firsov
Alexey Nikitin
Alexey Popravka
Alexey Stepanov
Amin Etesamian
@@ -45,18 +46,21 @@ Anes Abismail
Antoine Pietri
Anton Kasyanov
Anton Zhdan-Pushkin
Arcadiy Ivanov
Arseny Timoniq
Artem Yushkovskiy
Arthur Darcet
Austin Scola
Ben Bader
Ben Greiner
Ben Kallus
Ben Timby
Benedikt Reinartz
Bob Haddleton
Boris Feld
Boyi Chen
Brett Cannon
Brett Higgins
Brian Bouterse
Brian C. Lane
Brian Muller
@@ -71,6 +75,7 @@ Chih-Yuan Chen
Chris AtLee
Chris Laws
Chris Moore
Chris Shucksmith
Christopher Schmitt
Claudiu Popa
Colin Dunklau
@@ -79,6 +84,7 @@ Damien Nadé
Dan King
Dan Xu
Daniel García
Daniel Golding
Daniel Grossmann-Kavanagh
Daniel Nelson
Danny Song
@@ -90,6 +96,7 @@ Denis Moshensky
Dennis Kliban
Dima Veselov
Dimitar Dimitrov
Diogo Dutra da Mata
Dmitriy Safonov
Dmitry Doroshev
Dmitry Erlikh
@@ -137,9 +144,11 @@ Hrishikesh Paranjape
Hu Bo
Hugh Young
Hugo Herter
Hugo Hromic
Hugo van Kemenade
Hynek Schlawack
Igor Alexandrov
Igor Bolshakov
Igor Davydenko
Igor Mozharovsky
Igor Pavlov
@@ -147,13 +156,19 @@ Illia Volochii
Ilya Chichak
Ilya Gruzinov
Ingmar Steen
Ivan Lakovic
Ivan Larin
J. Nick Koston
Jacob Champion
Jaesung Lee
Jake Davis
Jakob Ackermann
Jakub Wilk
Jan Buchar
Jan Gosmann
Jarno Elonen
Jashandeep Sohi
Jean-Baptiste Estival
Jens Steinhauser
Jeonghun Lee
Jeongkyu Shin
@@ -162,9 +177,11 @@ Jesus Cea
Jian Zeng
Jinkyu Yi
Joel Watts
John Parton
Jon Nabozny
Jonas Krüger Svensson
Jonas Obrist
Jonathan Ballet
Jonathan Wright
Jonny Tan
Joongi Kim
@@ -190,6 +207,7 @@ Krzysztof Blazewicz
Kyrylo Perevozchikov
Kyungmin Lee
Lars P. Søndergaard
Lee LieWhite
Liu Hua
Louis-Philippe Huberdeau
Loïc Lajeanne
@@ -198,20 +216,27 @@ Lubomir Gelo
Ludovic Gasc
Luis Pedrosa
Lukasz Marcin Dobrzanski
Lénárd Szolnoki
Makc Belousow
Manuel Miranda
Marat Sharafutdinov
Marc Mueller
Marco Paolini
Marcus Stojcevich
Mariano Anaya
Mariusz Masztalerczuk
Marko Kohtala
Martijn Pieters
Martin Melka
Martin Richard
Martin Sucha
Mathias Fröjdman
Mathieu Dugré
Matt VanEseltine
Matthias Marquardt
Matthieu Hauglustaine
Matthieu Rigal
Matvey Tingaev
Meet Mangukiya
Michael Ihnatenko
Michał Górny
@@ -235,8 +260,10 @@ Olaf Conradi
Pahaz Blinov
Panagiotis Kolokotronis
Pankaj Pandey
Parag Jain
Pau Freixes
Paul Colomiets
Paul J. Dorn
Paulius Šileikis
Paulus Schoutsen
Pavel Kamaev
@@ -247,15 +274,19 @@ Pawel Kowalski
Pawel Miech
Pepe Osca
Philipp A.
Pierre-Louis Peeters
Pieter van Beek
Qiao Han
Rafael Viotti
Rahul Nahata
Raphael Bialon
Raúl Cumplido
Required Field
Robert Lu
Robert Nikolich
Roman Podoliaka
Rong Zhang
Samir Akarioh
Samuel Colvin
Sean Hunt
Sebastian Acuna
@@ -277,6 +308,7 @@ Stepan Pletnev
Stephan Jaensch
Stephen Cirelli
Stephen Granade
Steve Repsher
Steven Seguin
Sunghyun Hwang
Sunit Deshpande
@@ -293,6 +325,7 @@ Tolga Tezel
Tomasz Trebski
Toshiaki Tanaka
Trinh Hoang Nhu
Tymofii Tsiapa
Vadim Suharnikov
Vaibhav Sagar
Vamsi Krishna Avula
@@ -324,6 +357,9 @@ Willem de Groot
William Grzybowski
William S.
Wilson Ong
wouter bolsterlee
Xavier Halloran
Xiang Li
Yang Zhou
Yannick Koechlin
Yannick Péroux
@@ -336,7 +372,10 @@ Yury Pliner
Yury Selivanov
Yusuke Tsutsumi
Yuval Ofir
Yuvi Panda
Zainab Lawal
Zeal Wierslee
Zlatan Sičanica
Łukasz Setla
Марк Коренберг
Семён Марьясин

View File

@@ -7,6 +7,7 @@ graft aiohttp
graft docs
graft examples
graft tests
graft requirements
recursive-include vendor *
global-include aiohttp *.pyi
global-exclude *.pyc

View File

@@ -25,7 +25,7 @@ FORCE:
# check_sum.py works perfectly fine but slow when called for every file from $(ALLS)
# (perhaps even several times for each file).
# That is why much less readable but faster solution exists
ifneq (, $(shell which sha256sum))
ifneq (, $(shell command -v sha256sum))
%.hash: FORCE
$(eval $@_ABS := $(abspath $@))
$(eval $@_NAME := $($@_ABS))
@@ -50,7 +50,7 @@ endif
@python -m pip install --upgrade pip
.install-cython: .update-pip $(call to-hash,requirements/cython.txt)
@python -m pip install -r requirements/cython.txt -c requirements/constraints.txt
@python -m pip install -r requirements/cython.in -c requirements/cython.txt
@touch .install-cython
aiohttp/_find_header.c: $(call to-hash,aiohttp/hdrs.py ./tools/gen.py)
@@ -58,10 +58,10 @@ aiohttp/_find_header.c: $(call to-hash,aiohttp/hdrs.py ./tools/gen.py)
# _find_headers generator creates _headers.pyi as well
aiohttp/%.c: aiohttp/%.pyx $(call to-hash,$(CYS)) aiohttp/_find_header.c
cython -3 -o $@ $< -I aiohttp
cython -3 -o $@ $< -I aiohttp -Werror
vendor/llhttp/node_modules: vendor/llhttp/package.json
cd vendor/llhttp; npm install
cd vendor/llhttp; npm ci
.llhttp-gen: vendor/llhttp/node_modules
$(MAKE) -C vendor/llhttp generate
@@ -74,7 +74,7 @@ generate-llhttp: .llhttp-gen
cythonize: .install-cython $(PYXS:.pyx=.c)
.install-deps: .install-cython $(PYXS:.pyx=.c) $(call to-hash,$(CYS) $(REQS))
@python -m pip install -r requirements/dev.txt -c requirements/constraints.txt
@python -m pip install -r requirements/dev.in -c requirements/dev.txt
@touch .install-deps
.PHONY: lint
@@ -89,7 +89,7 @@ mypy:
mypy
.develop: .install-deps generate-llhttp $(call to-hash,$(PYS) $(CYS) $(CS))
python -m pip install -e . -c requirements/constraints.txt
python -m pip install -e . -c requirements/runtime-deps.txt
@touch .develop
.PHONY: test
@@ -99,10 +99,12 @@ test: .develop
.PHONY: vtest
vtest: .develop
@pytest -s -v
@python -X dev -m pytest -s -v -m dev_mode
.PHONY: vvtest
vvtest: .develop
@pytest -vv
@python -X dev -m pytest -s -v -m dev_mode
define run_tests_in_docker
@@ -110,11 +112,7 @@ define run_tests_in_docker
docker run --rm -ti -v `pwd`:/src -w /src "aiohttp-test-$(1)-$(2)" $(TEST_SPEC)
endef
.PHONY: test-3.7-no-extensions test-3.7 test-3.8-no-extensions test-3.8 test-3.9-no-extensions test-3.9 test-3.10-no-extensions test-3.10
test-3.7-no-extensions:
$(call run_tests_in_docker,3.7,y)
test-3.7:
$(call run_tests_in_docker,3.7,n)
.PHONY: test-3.8-no-extensions test-3.8 test-3.9-no-extensions test
test-3.8-no-extensions:
$(call run_tests_in_docker,3.8,y)
test-3.8:
@@ -174,15 +172,14 @@ doc:
doc-spelling:
@make -C docs spelling SPHINXOPTS="-W --keep-going -n -E"
.PHONY: compile-deps
compile-deps: .update-pip $(REQS)
pip-compile --no-header --allow-unsafe -q --strip-extras \
-o requirements/constraints.txt \
requirements/constraints.in
.PHONY: install
install: .update-pip
@python -m pip install -r requirements/dev.txt -c requirements/constraints.txt
@python -m pip install -r requirements/dev.in -c requirements/dev.txt
.PHONY: install-dev
install-dev: .develop
.PHONY: sync-direct-runtime-deps
sync-direct-runtime-deps:
@echo Updating 'requirements/runtime-deps.in' from 'setup.cfg'... >&2
@python requirements/sync-direct-runtime-deps.py

View File

@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: aiohttp
Version: 3.8.5
Version: 3.10.11
Summary: Async http client/server framework (asyncio)
Home-page: https://github.com/aio-libs/aiohttp
Maintainer: aiohttp team <team@aiohttp.org>
@@ -23,16 +23,27 @@ Classifier: Operating System :: MacOS :: MacOS X
Classifier: Operating System :: Microsoft :: Windows
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Internet :: WWW/HTTP
Requires-Python: >=3.6
Requires-Python: >=3.8
Description-Content-Type: text/x-rst
Provides-Extra: speedups
License-File: LICENSE.txt
Requires-Dist: aiohappyeyeballs>=2.3.0
Requires-Dist: aiosignal>=1.1.2
Requires-Dist: async-timeout<6.0,>=4.0; python_version < "3.11"
Requires-Dist: attrs>=17.3.0
Requires-Dist: frozenlist>=1.1.1
Requires-Dist: multidict<7.0,>=4.5
Requires-Dist: yarl<2.0,>=1.12.0
Provides-Extra: speedups
Requires-Dist: aiodns>=3.2.0; (sys_platform == "linux" or sys_platform == "darwin") and extra == "speedups"
Requires-Dist: Brotli; platform_python_implementation == "CPython" and extra == "speedups"
Requires-Dist: brotlicffi; platform_python_implementation != "CPython" and extra == "speedups"
==================================
Async http client/server framework
@@ -53,6 +64,10 @@ Async http client/server framework
:target: https://codecov.io/gh/aio-libs/aiohttp
:alt: codecov.io status for master branch
.. image:: https://img.shields.io/endpoint?url=https://codspeed.io/badge.json
:target: https://codspeed.io/aio-libs/aiohttp
:alt: Codspeed.io status for aiohttp
.. image:: https://badge.fury.io/py/aiohttp.svg
:target: https://pypi.org/project/aiohttp
:alt: Latest PyPI package version
@@ -76,7 +91,7 @@ Key Features
- Supports both client and server side of HTTP protocol.
- Supports both client and server Web-Sockets out-of-the-box and avoids
Callback Hell.
- Provides Web-server with middlewares and plugable routing.
- Provides Web-server with middleware and pluggable routing.
Getting started
@@ -103,7 +118,7 @@ To get something from the web:
html = await response.text()
print("Body:", html[:15], "...")
asyncio.run(main())
asyncio.run(main())
This prints:
@@ -182,9 +197,9 @@ Feel free to make a Pull Request for adding your link to these pages!
Communication channels
======================
*aio-libs discourse group*: https://aio-libs.discourse.group
*aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions
*gitter chat* https://gitter.im/aio-libs/Lobby
*Matrix*: `#aio-libs:matrix.org <https://matrix.to/#/#aio-libs:matrix.org>`_
We support `Stack Overflow
<https://stackoverflow.com/questions/tagged/aiohttp>`_.
@@ -193,25 +208,19 @@ Please add *aiohttp* tag to your question there.
Requirements
============
- Python >= 3.6
- async-timeout_
- attrs_
- charset-normalizer_
- multidict_
- yarl_
- frozenlist_
Optionally you may install the cChardet_ and aiodns_ libraries (highly
recommended for sake of speed).
Optionally you may install the aiodns_ library (highly recommended for sake of speed).
.. _charset-normalizer: https://pypi.org/project/charset-normalizer
.. _aiodns: https://pypi.python.org/pypi/aiodns
.. _attrs: https://github.com/python-attrs/attrs
.. _multidict: https://pypi.python.org/pypi/multidict
.. _frozenlist: https://pypi.org/project/frozenlist/
.. _yarl: https://pypi.python.org/pypi/yarl
.. _async-timeout: https://pypi.python.org/pypi/async_timeout
.. _cChardet: https://pypi.python.org/pypi/cchardet
License
=======

View File

@@ -17,6 +17,10 @@ Async http client/server framework
:target: https://codecov.io/gh/aio-libs/aiohttp
:alt: codecov.io status for master branch
.. image:: https://img.shields.io/endpoint?url=https://codspeed.io/badge.json
:target: https://codspeed.io/aio-libs/aiohttp
:alt: Codspeed.io status for aiohttp
.. image:: https://badge.fury.io/py/aiohttp.svg
:target: https://pypi.org/project/aiohttp
:alt: Latest PyPI package version
@@ -40,7 +44,7 @@ Key Features
- Supports both client and server side of HTTP protocol.
- Supports both client and server Web-Sockets out-of-the-box and avoids
Callback Hell.
- Provides Web-server with middlewares and plugable routing.
- Provides Web-server with middleware and pluggable routing.
Getting started
@@ -67,7 +71,7 @@ To get something from the web:
html = await response.text()
print("Body:", html[:15], "...")
asyncio.run(main())
asyncio.run(main())
This prints:
@@ -146,9 +150,9 @@ Feel free to make a Pull Request for adding your link to these pages!
Communication channels
======================
*aio-libs discourse group*: https://aio-libs.discourse.group
*aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions
*gitter chat* https://gitter.im/aio-libs/Lobby
*Matrix*: `#aio-libs:matrix.org <https://matrix.to/#/#aio-libs:matrix.org>`_
We support `Stack Overflow
<https://stackoverflow.com/questions/tagged/aiohttp>`_.
@@ -157,25 +161,19 @@ Please add *aiohttp* tag to your question there.
Requirements
============
- Python >= 3.6
- async-timeout_
- attrs_
- charset-normalizer_
- multidict_
- yarl_
- frozenlist_
Optionally you may install the cChardet_ and aiodns_ libraries (highly
recommended for sake of speed).
Optionally you may install the aiodns_ library (highly recommended for sake of speed).
.. _charset-normalizer: https://pypi.org/project/charset-normalizer
.. _aiodns: https://pypi.python.org/pypi/aiodns
.. _attrs: https://github.com/python-attrs/attrs
.. _multidict: https://pypi.python.org/pypi/multidict
.. _frozenlist: https://pypi.org/project/frozenlist/
.. _yarl: https://pypi.python.org/pypi/yarl
.. _async-timeout: https://pypi.python.org/pypi/async_timeout
.. _cChardet: https://pypi.python.org/pypi/cchardet
License
=======

View File

@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: aiohttp
Version: 3.8.5
Version: 3.10.11
Summary: Async http client/server framework (asyncio)
Home-page: https://github.com/aio-libs/aiohttp
Maintainer: aiohttp team <team@aiohttp.org>
@@ -23,16 +23,27 @@ Classifier: Operating System :: MacOS :: MacOS X
Classifier: Operating System :: Microsoft :: Windows
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Internet :: WWW/HTTP
Requires-Python: >=3.6
Requires-Python: >=3.8
Description-Content-Type: text/x-rst
Provides-Extra: speedups
License-File: LICENSE.txt
Requires-Dist: aiohappyeyeballs>=2.3.0
Requires-Dist: aiosignal>=1.1.2
Requires-Dist: async-timeout<6.0,>=4.0; python_version < "3.11"
Requires-Dist: attrs>=17.3.0
Requires-Dist: frozenlist>=1.1.1
Requires-Dist: multidict<7.0,>=4.5
Requires-Dist: yarl<2.0,>=1.12.0
Provides-Extra: speedups
Requires-Dist: aiodns>=3.2.0; (sys_platform == "linux" or sys_platform == "darwin") and extra == "speedups"
Requires-Dist: Brotli; platform_python_implementation == "CPython" and extra == "speedups"
Requires-Dist: brotlicffi; platform_python_implementation != "CPython" and extra == "speedups"
==================================
Async http client/server framework
@@ -53,6 +64,10 @@ Async http client/server framework
:target: https://codecov.io/gh/aio-libs/aiohttp
:alt: codecov.io status for master branch
.. image:: https://img.shields.io/endpoint?url=https://codspeed.io/badge.json
:target: https://codspeed.io/aio-libs/aiohttp
:alt: Codspeed.io status for aiohttp
.. image:: https://badge.fury.io/py/aiohttp.svg
:target: https://pypi.org/project/aiohttp
:alt: Latest PyPI package version
@@ -76,7 +91,7 @@ Key Features
- Supports both client and server side of HTTP protocol.
- Supports both client and server Web-Sockets out-of-the-box and avoids
Callback Hell.
- Provides Web-server with middlewares and plugable routing.
- Provides Web-server with middleware and pluggable routing.
Getting started
@@ -103,7 +118,7 @@ To get something from the web:
html = await response.text()
print("Body:", html[:15], "...")
asyncio.run(main())
asyncio.run(main())
This prints:
@@ -182,9 +197,9 @@ Feel free to make a Pull Request for adding your link to these pages!
Communication channels
======================
*aio-libs discourse group*: https://aio-libs.discourse.group
*aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions
*gitter chat* https://gitter.im/aio-libs/Lobby
*Matrix*: `#aio-libs:matrix.org <https://matrix.to/#/#aio-libs:matrix.org>`_
We support `Stack Overflow
<https://stackoverflow.com/questions/tagged/aiohttp>`_.
@@ -193,25 +208,19 @@ Please add *aiohttp* tag to your question there.
Requirements
============
- Python >= 3.6
- async-timeout_
- attrs_
- charset-normalizer_
- multidict_
- yarl_
- frozenlist_
Optionally you may install the cChardet_ and aiodns_ libraries (highly
recommended for sake of speed).
Optionally you may install the aiodns_ library (highly recommended for sake of speed).
.. _charset-normalizer: https://pypi.org/project/charset-normalizer
.. _aiodns: https://pypi.python.org/pypi/aiodns
.. _attrs: https://github.com/python-attrs/attrs
.. _multidict: https://pypi.python.org/pypi/multidict
.. _frozenlist: https://pypi.org/project/frozenlist/
.. _yarl: https://pypi.python.org/pypi/yarl
.. _async-timeout: https://pypi.python.org/pypi/async_timeout
.. _cChardet: https://pypi.python.org/pypi/cchardet
License
=======

View File

@@ -29,6 +29,7 @@ aiohttp/client_exceptions.py
aiohttp/client_proto.py
aiohttp/client_reqrep.py
aiohttp/client_ws.py
aiohttp/compression_utils.py
aiohttp/connector.py
aiohttp/cookiejar.py
aiohttp/formdata.py
@@ -39,7 +40,6 @@ aiohttp/http_exceptions.py
aiohttp/http_parser.py
aiohttp/http_websocket.py
aiohttp/http_writer.py
aiohttp/locks.py
aiohttp/log.py
aiohttp/multipart.py
aiohttp/payload.py
@@ -92,6 +92,7 @@ docs/client_advanced.rst
docs/client_quickstart.rst
docs/client_reference.rst
docs/conf.py
docs/contributing-admins.rst
docs/contributing.rst
docs/deployment.rst
docs/essays.rst
@@ -126,8 +127,11 @@ docs/web_reference.rst
docs/websocket_utilities.rst
docs/whats_new_1_1.rst
docs/whats_new_3_0.rst
docs/_snippets/cchardet-unmaintained-admonition.rst
docs/_static/css/logo-adjustments.css
docs/_static/img/contributing-cov-comment.svg
docs/_static/img/contributing-cov-header.svg
docs/_static/img/contributing-cov-miss.svg
docs/_static/img/contributing-cov-partial.svg
examples/__init__.py
examples/background_tasks.py
examples/cli_app.py
@@ -150,15 +154,42 @@ examples/web_srv_route_deco.py
examples/web_srv_route_table.py
examples/web_ws.py
examples/websocket.html
requirements/base.in
requirements/base.txt
requirements/broken-projects.in
requirements/constraints.in
requirements/constraints.txt
requirements/cython.in
requirements/cython.txt
requirements/dev.in
requirements/dev.txt
requirements/doc-spelling.in
requirements/doc-spelling.txt
requirements/doc.in
requirements/doc.txt
requirements/lint.in
requirements/lint.txt
requirements/multidict.in
requirements/multidict.txt
requirements/runtime-deps.in
requirements/runtime-deps.txt
requirements/sync-direct-runtime-deps.py
requirements/test.in
requirements/test.txt
requirements/.hash/cython.txt.hash
tests/aiohttp.jpg
tests/aiohttp.png
tests/conftest.py
tests/data.unknown_mime_type
tests/data.zero_bytes
tests/hello.txt.gz
tests/sample.txt
tests/test___all__.py
tests/test_base_protocol.py
tests/test_benchmarks_client.py
tests/test_benchmarks_client_request.py
tests/test_benchmarks_client_ws.py
tests/test_benchmarks_cookiejar.py
tests/test_benchmarks_http_websocket.py
tests/test_benchmarks_http_writer.py
tests/test_circular_imports.py
tests/test_classbasedview.py
tests/test_client_connection.py
@@ -171,6 +202,7 @@ tests/test_client_response.py
tests/test_client_session.py
tests/test_client_ws.py
tests/test_client_ws_functional.py
tests/test_compression_utils.py
tests/test_connector.py
tests/test_cookiejar.py
tests/test_flowcontrol_streams.py
@@ -179,7 +211,7 @@ tests/test_helpers.py
tests/test_http_exceptions.py
tests/test_http_parser.py
tests/test_http_writer.py
tests/test_locks.py
tests/test_imports.py
tests/test_loop.py
tests/test_multipart.py
tests/test_multipart_helpers.py
@@ -236,12 +268,18 @@ vendor/llhttp/LICENSE-MIT
vendor/llhttp/Makefile
vendor/llhttp/README.md
vendor/llhttp/_config.yml
vendor/llhttp/eslint.json
vendor/llhttp/libllhttp.pc.in
vendor/llhttp/package-lock.json
vendor/llhttp/package.json
vendor/llhttp/tsconfig.base.json
vendor/llhttp/tsconfig.eslint.json
vendor/llhttp/tsconfig.json
vendor/llhttp/tslint.json
vendor/llhttp/.github/dependabot.yml
vendor/llhttp/.github/workflows/aiohttp.yml
vendor/llhttp/.github/workflows/ci.yaml
vendor/llhttp/.github/workflows/codeql.yml
vendor/llhttp/.github/workflows/scorecards.yml
vendor/llhttp/bench/index.ts
vendor/llhttp/bin/build_wasm.ts
vendor/llhttp/bin/generate.ts

View File

@@ -1,21 +1,20 @@
attrs>=17.3.0
charset-normalizer<4.0,>=2.0
multidict<7.0,>=4.5
async_timeout<5.0,>=4.0.0a3
yarl<2.0,>=1.0
frozenlist>=1.1.1
aiohappyeyeballs>=2.3.0
aiosignal>=1.1.2
attrs>=17.3.0
frozenlist>=1.1.1
multidict<7.0,>=4.5
yarl<2.0,>=1.12.0
[:python_version < "3.7"]
idna-ssl>=1.0
[:python_version < "3.8"]
asynctest==0.13.0
typing_extensions>=3.7.4
[:python_version < "3.11"]
async-timeout<6.0,>=4.0
[speedups]
aiodns
[speedups:platform_python_implementation != "CPython"]
brotlicffi
[speedups:platform_python_implementation == "CPython"]
Brotli
[speedups:python_version < "3.10"]
cchardet
[speedups:sys_platform == "linux" or sys_platform == "darwin"]
aiodns>=3.2.0

View File

@@ -1 +1 @@
e6d134d56d5f516ab2b5c3b295d0d440a3bef911f4384d506204018895a1f833 /home/runner/work/aiohttp/aiohttp/aiohttp/_cparser.pxd
f2318883e549f69de597009a914603b0f1b10381e265ef5d98af499ad973fb98 /home/runner/work/aiohttp/aiohttp/aiohttp/_cparser.pxd

View File

@@ -1 +1 @@
5de2db35fb795ffe227e2f1007c8ba4f2ad1b9aca28cc48edc80c779203cf6e3 /home/runner/work/aiohttp/aiohttp/aiohttp/_helpers.pyx
19d98f08efd55a40c99b2fc4c8341da7ee5cc143b1a59181014c3f43a3e95423 /home/runner/work/aiohttp/aiohttp/aiohttp/_helpers.pyx

View File

@@ -1 +1 @@
43bc2c42b9dbb09c19d0782c7aefd1a656a039b31c57a9fa809f82c2807eeaa9 /home/runner/work/aiohttp/aiohttp/aiohttp/_http_parser.pyx
e2d962e51a183b6e2723c1cb97b9f11c795bedc7093ae1eb038a7040dd8f4d70 /home/runner/work/aiohttp/aiohttp/aiohttp/_http_parser.pyx

View File

@@ -1 +1 @@
6881c0a7c838655e646c645d99971efaf5e310bc3633a7c62b226e39d81842ac /home/runner/work/aiohttp/aiohttp/aiohttp/_http_writer.pyx
6638cb235efe7f79472f050b3bc59d9763ecca457aad88c1ab82dbb045476e7c /home/runner/work/aiohttp/aiohttp/aiohttp/_http_writer.pyx

View File

@@ -1 +1 @@
a30351c34760a1d7835b2a1b0552e463cf1d2db90da0cdb473313dc66e34a031 /home/runner/work/aiohttp/aiohttp/aiohttp/hdrs.py
bb39f96a09ff8d789dda1fa4cba63464043c06b3de4c62c31abfb07a231cb6ca /home/runner/work/aiohttp/aiohttp/aiohttp/hdrs.py

View File

@@ -1,40 +1,49 @@
__version__ = "3.8.5"
__version__ = "3.10.11"
from typing import Tuple
from typing import TYPE_CHECKING, Tuple
from . import hdrs as hdrs
from .client import (
BaseConnector as BaseConnector,
ClientConnectionError as ClientConnectionError,
ClientConnectorCertificateError as ClientConnectorCertificateError,
ClientConnectorError as ClientConnectorError,
ClientConnectorSSLError as ClientConnectorSSLError,
ClientError as ClientError,
ClientHttpProxyError as ClientHttpProxyError,
ClientOSError as ClientOSError,
ClientPayloadError as ClientPayloadError,
ClientProxyConnectionError as ClientProxyConnectionError,
ClientRequest as ClientRequest,
ClientResponse as ClientResponse,
ClientResponseError as ClientResponseError,
ClientSession as ClientSession,
ClientSSLError as ClientSSLError,
ClientTimeout as ClientTimeout,
ClientWebSocketResponse as ClientWebSocketResponse,
ContentTypeError as ContentTypeError,
Fingerprint as Fingerprint,
InvalidURL as InvalidURL,
NamedPipeConnector as NamedPipeConnector,
RequestInfo as RequestInfo,
ServerConnectionError as ServerConnectionError,
ServerDisconnectedError as ServerDisconnectedError,
ServerFingerprintMismatch as ServerFingerprintMismatch,
ServerTimeoutError as ServerTimeoutError,
TCPConnector as TCPConnector,
TooManyRedirects as TooManyRedirects,
UnixConnector as UnixConnector,
WSServerHandshakeError as WSServerHandshakeError,
request as request,
BaseConnector,
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorDNSError,
ClientConnectorError,
ClientConnectorSSLError,
ClientError,
ClientHttpProxyError,
ClientOSError,
ClientPayloadError,
ClientProxyConnectionError,
ClientRequest,
ClientResponse,
ClientResponseError,
ClientSession,
ClientSSLError,
ClientTimeout,
ClientWebSocketResponse,
ConnectionTimeoutError,
ContentTypeError,
Fingerprint,
InvalidURL,
InvalidUrlClientError,
InvalidUrlRedirectClientError,
NamedPipeConnector,
NonHttpUrlClientError,
NonHttpUrlRedirectClientError,
RedirectClientError,
RequestInfo,
ServerConnectionError,
ServerDisconnectedError,
ServerFingerprintMismatch,
ServerTimeoutError,
SocketTimeoutError,
TCPConnector,
TooManyRedirects,
UnixConnector,
WSServerHandshakeError,
request,
)
from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar
from .formdata import FormData as FormData
@@ -99,17 +108,27 @@ from .tracing import (
TraceRequestChunkSentParams as TraceRequestChunkSentParams,
TraceRequestEndParams as TraceRequestEndParams,
TraceRequestExceptionParams as TraceRequestExceptionParams,
TraceRequestHeadersSentParams as TraceRequestHeadersSentParams,
TraceRequestRedirectParams as TraceRequestRedirectParams,
TraceRequestStartParams as TraceRequestStartParams,
TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams,
)
if TYPE_CHECKING:
# At runtime these are lazy-loaded at the bottom of the file.
from .worker import (
GunicornUVLoopWebWorker as GunicornUVLoopWebWorker,
GunicornWebWorker as GunicornWebWorker,
)
__all__: Tuple[str, ...] = (
"hdrs",
# client
"BaseConnector",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorDNSError",
"ClientConnectorError",
"ClientConnectorSSLError",
"ClientError",
@@ -124,14 +143,21 @@ __all__: Tuple[str, ...] = (
"ClientSession",
"ClientTimeout",
"ClientWebSocketResponse",
"ConnectionTimeoutError",
"ContentTypeError",
"Fingerprint",
"InvalidURL",
"InvalidUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlClientError",
"NonHttpUrlRedirectClientError",
"RedirectClientError",
"RequestInfo",
"ServerConnectionError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ServerTimeoutError",
"SocketTimeoutError",
"TCPConnector",
"TooManyRedirects",
"UnixConnector",
@@ -203,14 +229,32 @@ __all__: Tuple[str, ...] = (
"TraceRequestChunkSentParams",
"TraceRequestEndParams",
"TraceRequestExceptionParams",
"TraceRequestHeadersSentParams",
"TraceRequestRedirectParams",
"TraceRequestStartParams",
"TraceResponseChunkReceivedParams",
# workers (imported lazily with __getattr__)
"GunicornUVLoopWebWorker",
"GunicornWebWorker",
)
try:
from .worker import GunicornUVLoopWebWorker, GunicornWebWorker
__all__ += ("GunicornWebWorker", "GunicornUVLoopWebWorker")
except ImportError: # pragma: no cover
pass
def __dir__() -> Tuple[str, ...]:
return __all__ + ("__author__", "__doc__")
def __getattr__(name: str) -> object:
global GunicornUVLoopWebWorker, GunicornWebWorker
# Importing gunicorn takes a long time (>100ms), so only import if actually needed.
if name in ("GunicornUVLoopWebWorker", "GunicornWebWorker"):
try:
from .worker import GunicornUVLoopWebWorker as guv, GunicornWebWorker as gw
except ImportError:
return None
GunicornUVLoopWebWorker = guv # type: ignore[misc]
GunicornWebWorker = gw # type: ignore[misc]
return guv if name == "GunicornUVLoopWebWorker" else gw
raise AttributeError(f"module {__name__} has no attribute {name}")

View File

@@ -1,13 +1,4 @@
from libc.stdint cimport (
int8_t,
int16_t,
int32_t,
int64_t,
uint8_t,
uint16_t,
uint32_t,
uint64_t,
)
from libc.stdint cimport int32_t, uint8_t, uint16_t, uint64_t
cdef extern from "../vendor/llhttp/build/llhttp.h":
@@ -88,30 +79,14 @@ cdef extern from "../vendor/llhttp/build/llhttp.h":
ctypedef llhttp_errno llhttp_errno_t
enum llhttp_flags:
F_CONNECTION_KEEP_ALIVE,
F_CONNECTION_CLOSE,
F_CONNECTION_UPGRADE,
F_CHUNKED,
F_UPGRADE,
F_CONTENT_LENGTH,
F_SKIPBODY,
F_TRAILING,
F_TRANSFER_ENCODING
enum llhttp_lenient_flags:
LENIENT_HEADERS,
LENIENT_CHUNKED_LENGTH
F_CONTENT_LENGTH
enum llhttp_type:
HTTP_REQUEST,
HTTP_RESPONSE,
HTTP_BOTH
enum llhttp_finish_t:
HTTP_FINISH_SAFE,
HTTP_FINISH_SAFE_WITH_CB,
HTTP_FINISH_UNSAFE
enum llhttp_method:
HTTP_DELETE,
HTTP_GET,
@@ -167,24 +142,17 @@ cdef extern from "../vendor/llhttp/build/llhttp.h":
const llhttp_settings_t* settings)
llhttp_errno_t llhttp_execute(llhttp_t* parser, const char* data, size_t len)
llhttp_errno_t llhttp_finish(llhttp_t* parser)
int llhttp_message_needs_eof(const llhttp_t* parser)
int llhttp_should_keep_alive(const llhttp_t* parser)
void llhttp_pause(llhttp_t* parser)
void llhttp_resume(llhttp_t* parser)
void llhttp_resume_after_upgrade(llhttp_t* parser)
llhttp_errno_t llhttp_get_errno(const llhttp_t* parser)
const char* llhttp_get_error_reason(const llhttp_t* parser)
void llhttp_set_error_reason(llhttp_t* parser, const char* reason)
const char* llhttp_get_error_pos(const llhttp_t* parser)
const char* llhttp_errno_name(llhttp_errno_t err)
const char* llhttp_method_name(llhttp_method_t method)
void llhttp_set_lenient_headers(llhttp_t* parser, int enabled)
void llhttp_set_lenient_chunked_length(llhttp_t* parser, int enabled)
void llhttp_set_lenient_optional_cr_before_lf(llhttp_t* parser, int enabled)
void llhttp_set_lenient_spaces_after_chunk_size(llhttp_t* parser, int enabled)

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,6 @@
cdef _sentinel = object()
cdef class reify:
"""Use as a class method decorator. It operates almost exactly like
the Python `@property` decorator, but it puts the result of the
@@ -19,17 +22,14 @@ cdef class reify:
return self.wrapped.__doc__
def __get__(self, inst, owner):
try:
try:
return inst._cache[self.name]
except KeyError:
val = self.wrapped(inst)
inst._cache[self.name] = val
return val
except AttributeError:
if inst is None:
return self
raise
if inst is None:
return self
cdef dict cache = inst._cache
val = cache.get(self.name, _sentinel)
if val is _sentinel:
val = self.wrapped(inst)
cache[self.name] = val
return val
def __set__(self, inst, value):
raise AttributeError("reified property is read-only")

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,6 @@
#
# Based on https://github.com/MagicStack/httptools
#
from __future__ import absolute_import, print_function
from cpython cimport (
Py_buffer,
@@ -20,6 +19,7 @@ from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiD
from yarl import URL as _URL
from aiohttp import hdrs
from aiohttp.helpers import DEBUG, set_exception
from .http_exceptions import (
BadHttpMessage,
@@ -47,6 +47,7 @@ include "_headers.pxi"
from aiohttp cimport _find_header
ALLOWED_UPGRADES = frozenset({"websocket"})
DEF DEFAULT_FREELIST_SIZE = 250
cdef extern from "Python.h":
@@ -417,7 +418,6 @@ cdef class HttpParser:
cdef _on_headers_complete(self):
self._process_header()
method = http_method_str(self._cparser.method)
should_close = not cparser.llhttp_should_keep_alive(self._cparser)
upgrade = self._cparser.upgrade
chunked = self._cparser.flags & cparser.F_CHUNKED
@@ -425,8 +425,13 @@ cdef class HttpParser:
raw_headers = tuple(self._raw_headers)
headers = CIMultiDictProxy(self._headers)
if upgrade or self._cparser.method == cparser.HTTP_CONNECT:
self._upgraded = True
if self._cparser.type == cparser.HTTP_REQUEST:
allowed = upgrade and headers.get("upgrade", "").lower() in ALLOWED_UPGRADES
if allowed or self._cparser.method == cparser.HTTP_CONNECT:
self._upgraded = True
else:
if upgrade and self._cparser.status_code == 101:
self._upgraded = True
# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in headers:
@@ -441,6 +446,7 @@ cdef class HttpParser:
encoding = enc
if self._cparser.type == cparser.HTTP_REQUEST:
method = http_method_str(self._cparser.method)
msg = _new_request_message(
method, self._path,
self.http_version(), headers, raw_headers,
@@ -548,8 +554,8 @@ cdef class HttpParser:
else:
after = cparser.llhttp_get_error_pos(self._cparser)
before = data[:after - <char*>self.py_buf.buf]
after_b = after.split(b"\n", 1)[0]
before = before.rsplit(b"\n", 1)[-1]
after_b = after.split(b"\r\n", 1)[0]
before = before.rsplit(b"\r\n", 1)[-1]
data = before + after_b
pointer = " " * (len(repr(before))-1) + "^"
ex = parser_error_from_errno(self._cparser, data, pointer)
@@ -565,7 +571,7 @@ cdef class HttpParser:
if self._upgraded:
return messages, True, data[nb:]
else:
return messages, False, b''
return messages, False, b""
def set_upgraded(self, val):
self._upgraded = val
@@ -648,6 +654,11 @@ cdef class HttpResponseParser(HttpParser):
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof,
auto_decompress)
# Use strict parsing on dev mode, so users are warned about broken servers.
if not DEBUG:
cparser.llhttp_set_lenient_headers(self._cparser, 1)
cparser.llhttp_set_lenient_optional_cr_before_lf(self._cparser, 1)
cparser.llhttp_set_lenient_spaces_after_chunk_size(self._cparser, 1)
cdef object _on_status_complete(self):
if self._buf:
@@ -743,10 +754,7 @@ cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1:
pyparser._last_error = exc
return -1
else:
if (
pyparser._cparser.upgrade or
pyparser._cparser.method == cparser.HTTP_CONNECT
):
if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT:
return 2
else:
return 0
@@ -758,11 +766,13 @@ cdef int cb_on_body(cparser.llhttp_t* parser,
cdef bytes body = at[:length]
try:
pyparser._payload.feed_data(body, length)
except BaseException as exc:
except BaseException as underlying_exc:
reraised_exc = underlying_exc
if pyparser._payload_exception is not None:
pyparser._payload.set_exception(pyparser._payload_exception(str(exc)))
else:
pyparser._payload.set_exception(exc)
reraised_exc = pyparser._payload_exception(str(underlying_exc))
set_exception(pyparser._payload, reraised_exc, underlying_exc)
pyparser._payload_error = 1
return -1
else:
@@ -807,7 +817,9 @@ cdef parser_error_from_errno(cparser.llhttp_t* parser, data, pointer):
cdef cparser.llhttp_errno_t errno = cparser.llhttp_get_errno(parser)
cdef bytes desc = cparser.llhttp_get_error_reason(parser)
if errno in (cparser.HPE_CB_MESSAGE_BEGIN,
err_msg = "{}:\n\n {!r}\n {}".format(desc.decode("latin-1"), data, pointer)
if errno in {cparser.HPE_CB_MESSAGE_BEGIN,
cparser.HPE_CB_HEADERS_COMPLETE,
cparser.HPE_CB_MESSAGE_COMPLETE,
cparser.HPE_CB_CHUNK_HEADER,
@@ -817,22 +829,13 @@ cdef parser_error_from_errno(cparser.llhttp_t* parser, data, pointer):
cparser.HPE_INVALID_CONTENT_LENGTH,
cparser.HPE_INVALID_CHUNK_SIZE,
cparser.HPE_INVALID_EOF_STATE,
cparser.HPE_INVALID_TRANSFER_ENCODING):
cls = BadHttpMessage
elif errno == cparser.HPE_INVALID_STATUS:
cls = BadStatusLine
elif errno == cparser.HPE_INVALID_METHOD:
cls = BadStatusLine
elif errno == cparser.HPE_INVALID_VERSION:
cls = BadStatusLine
cparser.HPE_INVALID_TRANSFER_ENCODING}:
return BadHttpMessage(err_msg)
elif errno in {cparser.HPE_INVALID_STATUS,
cparser.HPE_INVALID_METHOD,
cparser.HPE_INVALID_VERSION}:
return BadStatusLine(error=err_msg)
elif errno == cparser.HPE_INVALID_URL:
cls = InvalidURLError
return InvalidURLError(err_msg)
else:
cls = BadHttpMessage
return cls("{}:\n\n {!r}\n {}".format(desc.decode("latin-1"), data, pointer))
return BadHttpMessage(err_msg)

File diff suppressed because it is too large Load Diff

View File

@@ -127,10 +127,6 @@ def _serialize_headers(str status_line, headers):
_init_writer(&writer)
for key, val in headers.items():
_safe_header(to_str(key))
_safe_header(to_str(val))
try:
if _write_str(&writer, status_line) < 0:
raise
@@ -140,6 +136,9 @@ def _serialize_headers(str status_line, headers):
raise
for key, val in headers.items():
_safe_header(to_str(key))
_safe_header(to_str(val))
if _write_str(&writer, to_str(key)) < 0:
raise
if _write_byte(&writer, b':') < 0:

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import socket
from abc import ABC, abstractmethod
from collections.abc import Sized
from http.cookies import BaseCookie, Morsel
@@ -14,15 +15,15 @@ from typing import (
List,
Optional,
Tuple,
TypedDict,
)
from multidict import CIMultiDict
from yarl import URL
from .helpers import get_running_loop
from .typedefs import LooseCookies
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from .web_app import Application
from .web_exceptions import HTTPException
from .web_request import BaseRequest, Request
@@ -65,7 +66,9 @@ class AbstractMatchInfo(ABC):
@property
@abstractmethod
def expect_handler(self) -> Callable[[Request], Awaitable[None]]:
def expect_handler(
self,
) -> Callable[[Request], Awaitable[Optional[StreamResponse]]]:
"""Expect handler for 100-continue processing"""
@property # pragma: no branch
@@ -117,11 +120,35 @@ class AbstractView(ABC):
"""Execute the view handler."""
class ResolveResult(TypedDict):
"""Resolve result.
This is the result returned from an AbstractResolver's
resolve method.
:param hostname: The hostname that was provided.
:param host: The IP address that was resolved.
:param port: The port that was resolved.
:param family: The address family that was resolved.
:param proto: The protocol that was resolved.
:param flags: The flags that were resolved.
"""
hostname: str
host: str
port: int
family: int
proto: int
flags: int
class AbstractResolver(ABC):
"""Abstract DNS resolver."""
@abstractmethod
async def resolve(self, host: str, port: int, family: int) -> List[Dict[str, Any]]:
async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> List[ResolveResult]:
"""Return IP address for given hostname"""
@abstractmethod
@@ -129,7 +156,7 @@ class AbstractResolver(ABC):
"""Release resolver"""
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
IterableBase = Iterable[Morsel[str]]
else:
IterableBase = Iterable
@@ -142,7 +169,7 @@ class AbstractCookieJar(Sized, IterableBase):
"""Abstract Cookie Jar."""
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = get_running_loop(loop)
self._loop = loop or asyncio.get_running_loop()
@abstractmethod
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:

View File

@@ -1,6 +1,8 @@
import asyncio
from typing import Optional, cast
from .client_exceptions import ClientConnectionResetError
from .helpers import set_exception
from .tcp_helpers import tcp_nodelay
@@ -76,11 +78,15 @@ class BaseProtocol(asyncio.Protocol):
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
set_exception(
waiter,
ConnectionError("Connection lost"),
exc,
)
async def _drain_helper(self) -> None:
if not self.connected:
raise ConnectionResetError("Connection lost")
raise ClientConnectionResetError("Connection lost")
if not self._paused:
return
waiter = self._drain_waiter

View File

@@ -9,12 +9,14 @@ import sys
import traceback
import warnings
from contextlib import suppress
from types import SimpleNamespace, TracebackType
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Coroutine,
Final,
FrozenSet,
Generator,
Generic,
@@ -25,6 +27,7 @@ from typing import (
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
)
@@ -36,25 +39,34 @@ from yarl import URL
from . import hdrs, http, payload
from .abc import AbstractCookieJar
from .client_exceptions import (
ClientConnectionError as ClientConnectionError,
ClientConnectorCertificateError as ClientConnectorCertificateError,
ClientConnectorError as ClientConnectorError,
ClientConnectorSSLError as ClientConnectorSSLError,
ClientError as ClientError,
ClientHttpProxyError as ClientHttpProxyError,
ClientOSError as ClientOSError,
ClientPayloadError as ClientPayloadError,
ClientProxyConnectionError as ClientProxyConnectionError,
ClientResponseError as ClientResponseError,
ClientSSLError as ClientSSLError,
ContentTypeError as ContentTypeError,
InvalidURL as InvalidURL,
ServerConnectionError as ServerConnectionError,
ServerDisconnectedError as ServerDisconnectedError,
ServerFingerprintMismatch as ServerFingerprintMismatch,
ServerTimeoutError as ServerTimeoutError,
TooManyRedirects as TooManyRedirects,
WSServerHandshakeError as WSServerHandshakeError,
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorDNSError,
ClientConnectorError,
ClientConnectorSSLError,
ClientError,
ClientHttpProxyError,
ClientOSError,
ClientPayloadError,
ClientProxyConnectionError,
ClientResponseError,
ClientSSLError,
ConnectionTimeoutError,
ContentTypeError,
InvalidURL,
InvalidUrlClientError,
InvalidUrlRedirectClientError,
NonHttpUrlClientError,
NonHttpUrlRedirectClientError,
RedirectClientError,
ServerConnectionError,
ServerDisconnectedError,
ServerFingerprintMismatch,
ServerTimeoutError,
SocketTimeoutError,
TooManyRedirects,
WSServerHandshakeError,
)
from .client_reqrep import (
ClientRequest as ClientRequest,
@@ -65,6 +77,7 @@ from .client_reqrep import (
)
from .client_ws import ClientWebSocketResponse as ClientWebSocketResponse
from .connector import (
HTTP_AND_EMPTY_SCHEMA_SET,
BaseConnector as BaseConnector,
NamedPipeConnector as NamedPipeConnector,
TCPConnector as TCPConnector,
@@ -72,13 +85,12 @@ from .connector import (
)
from .cookiejar import CookieJar
from .helpers import (
_SENTINEL,
DEBUG,
PY_36,
BasicAuth,
TimeoutHandle,
ceil_timeout,
get_env_proxy_for_url,
get_running_loop,
method_must_be_empty_body,
sentinel,
strip_auth_from_url,
)
@@ -86,12 +98,14 @@ from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter
from .http_websocket import WSHandshakeError, WSMessage, ws_ext_gen, ws_ext_parse
from .streams import FlowControlDataQueue
from .tracing import Trace, TraceConfig
from .typedefs import Final, JSONEncoder, LooseCookies, LooseHeaders, StrOrURL
from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, Query, StrOrURL
__all__ = (
# client_exceptions
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorDNSError",
"ClientConnectorError",
"ClientConnectorSSLError",
"ClientError",
@@ -101,12 +115,19 @@ __all__ = (
"ClientProxyConnectionError",
"ClientResponseError",
"ClientSSLError",
"ConnectionTimeoutError",
"ContentTypeError",
"InvalidURL",
"InvalidUrlClientError",
"RedirectClientError",
"NonHttpUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlRedirectClientError",
"ServerConnectionError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ServerTimeoutError",
"SocketTimeoutError",
"TooManyRedirects",
"WSServerHandshakeError",
# client_reqrep
@@ -128,10 +149,41 @@ __all__ = (
)
try:
if TYPE_CHECKING:
from ssl import SSLContext
except ImportError: # pragma: no cover
SSLContext = object # type: ignore[misc,assignment]
else:
SSLContext = None
if sys.version_info >= (3, 11) and TYPE_CHECKING:
from typing import Unpack
class _RequestOptions(TypedDict, total=False):
params: Query
data: Any
json: Any
cookies: Union[LooseCookies, None]
headers: Union[LooseHeaders, None]
skip_auto_headers: Union[Iterable[str], None]
auth: Union[BasicAuth, None]
allow_redirects: bool
max_redirects: int
compress: Union[str, bool, None]
chunked: Union[bool, None]
expect100: bool
raise_for_status: Union[None, bool, Callable[[ClientResponse], Awaitable[None]]]
read_until_eof: bool
proxy: Union[StrOrURL, None]
proxy_auth: Union[BasicAuth, None]
timeout: "Union[ClientTimeout, _SENTINEL, None]"
ssl: Union[SSLContext, bool, Fingerprint]
server_hostname: Union[str, None]
proxy_headers: Union[LooseHeaders, None]
trace_request_ctx: Union[Mapping[str, Any], None]
read_bufsize: Union[int, None]
auto_decompress: Union[bool, None]
max_line_size: Union[int, None]
max_field_size: Union[int, None]
@attr.s(auto_attribs=True, frozen=True, slots=True)
@@ -140,6 +192,7 @@ class ClientTimeout:
connect: Optional[float] = None
sock_read: Optional[float] = None
sock_connect: Optional[float] = None
ceil_threshold: float = 5
# pool_queue_timeout: Optional[float] = None
# dns_resolution_timeout: Optional[float] = None
@@ -156,9 +209,13 @@ class ClientTimeout:
# 5 Minute default read timeout
DEFAULT_TIMEOUT: Final[ClientTimeout] = ClientTimeout(total=5 * 60)
DEFAULT_TIMEOUT: Final[ClientTimeout] = ClientTimeout(total=5 * 60, sock_connect=30)
_RetType = TypeVar("_RetType")
# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})
_RetType = TypeVar("_RetType", ClientResponse, ClientWebSocketResponse)
_CharsetResolver = Callable[[ClientResponse, bytes], str]
class ClientSession:
@@ -188,11 +245,14 @@ class ClientSession:
"_ws_response_class",
"_trace_configs",
"_read_bufsize",
"_max_line_size",
"_max_field_size",
"_resolve_charset",
]
)
_source_traceback = None # type: Optional[traceback.StackSummary]
_connector = None # type: Optional[BaseConnector]
_source_traceback: Optional[traceback.StackSummary] = None
_connector: Optional[BaseConnector] = None
def __init__(
self,
@@ -211,8 +271,10 @@ class ClientSession:
version: HttpVersion = http.HttpVersion11,
cookie_jar: Optional[AbstractCookieJar] = None,
connector_owner: bool = True,
raise_for_status: bool = False,
read_timeout: Union[float, object] = sentinel,
raise_for_status: Union[
bool, Callable[[ClientResponse], Awaitable[None]]
] = False,
read_timeout: Union[float, _SENTINEL] = sentinel,
conn_timeout: Optional[float] = None,
timeout: Union[object, ClientTimeout] = sentinel,
auto_decompress: bool = True,
@@ -220,12 +282,19 @@ class ClientSession:
requote_redirect_url: bool = True,
trace_configs: Optional[List[TraceConfig]] = None,
read_bufsize: int = 2**16,
max_line_size: int = 8190,
max_field_size: int = 8190,
fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8",
) -> None:
# We initialise _connector to None immediately, as it's referenced in __del__()
# and could cause issues if an exception occurs during initialisation.
self._connector: Optional[BaseConnector] = None
if loop is None:
if connector is not None:
loop = connector._loop
loop = get_running_loop(loop)
loop = loop or asyncio.get_running_loop()
if base_url is None or isinstance(base_url, URL):
self._base_url: Optional[URL] = base_url
@@ -235,30 +304,7 @@ class ClientSession:
self._base_url.origin() == self._base_url
), "Only absolute URLs without path part are supported"
if connector is None:
connector = TCPConnector(loop=loop)
if connector._loop is not loop:
raise RuntimeError("Session and connector has to use same event loop")
self._loop = loop
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
if cookie_jar is None:
cookie_jar = CookieJar(loop=loop)
self._cookie_jar = cookie_jar
if cookies is not None:
self._cookie_jar.update_cookies(cookies)
self._connector = connector
self._connector_owner = connector_owner
self._default_auth = auth
self._version = version
self._json_serialize = json_serialize
if timeout is sentinel:
if timeout is sentinel or timeout is None:
self._timeout = DEFAULT_TIMEOUT
if read_timeout is not sentinel:
warnings.warn(
@@ -275,7 +321,12 @@ class ClientSession:
stacklevel=2,
)
else:
self._timeout = timeout # type: ignore[assignment]
if not isinstance(timeout, ClientTimeout):
raise ValueError(
f"timeout parameter cannot be of {type(timeout)} type, "
"please use 'timeout=ClientTimeout(...)'",
)
self._timeout = timeout
if read_timeout is not sentinel:
raise ValueError(
"read_timeout and timeout parameters "
@@ -288,11 +339,37 @@ class ClientSession:
"conflict, please setup "
"timeout.connect"
)
if connector is None:
connector = TCPConnector(loop=loop)
if connector._loop is not loop:
raise RuntimeError("Session and connector has to use same event loop")
self._loop = loop
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
if cookie_jar is None:
cookie_jar = CookieJar(loop=loop)
self._cookie_jar = cookie_jar
if cookies:
self._cookie_jar.update_cookies(cookies)
self._connector = connector
self._connector_owner = connector_owner
self._default_auth = auth
self._version = version
self._json_serialize = json_serialize
self._raise_for_status = raise_for_status
self._auto_decompress = auto_decompress
self._trust_env = trust_env
self._requote_redirect_url = requote_redirect_url
self._read_bufsize = read_bufsize
self._max_line_size = max_line_size
self._max_field_size = max_field_size
# Convert to list of tuples
if headers:
@@ -313,6 +390,8 @@ class ClientSession:
for trace_config in self._trace_configs:
trace_config.freeze()
self._resolve_charset = fallback_charset_resolver
def __init_subclass__(cls: Type["ClientSession"]) -> None:
warnings.warn(
"Inheritance class {} from ClientSession "
@@ -335,10 +414,7 @@ class ClientSession:
def __del__(self, _warnings: Any = warnings) -> None:
if not self.closed:
if PY_36:
kwargs = {"source": self}
else:
kwargs = {}
kwargs = {"source": self}
_warnings.warn(
f"Unclosed client session {self!r}", ResourceWarning, **kwargs
)
@@ -347,18 +423,29 @@ class ClientSession:
context["source_traceback"] = self._source_traceback
self._loop.call_exception_handler(context)
def request(
self, method: str, url: StrOrURL, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP request."""
return _RequestContextManager(self._request(method, url, **kwargs))
if sys.version_info >= (3, 11) and TYPE_CHECKING:
def request(
self,
method: str,
url: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> "_RequestContextManager": ...
else:
def request(
self, method: str, url: StrOrURL, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP request."""
return _RequestContextManager(self._request(method, url, **kwargs))
def _build_url(self, str_or_url: StrOrURL) -> URL:
url = URL(str_or_url)
if self._base_url is None:
return url
else:
assert not url.is_absolute() and url.path.startswith("/")
assert not url.absolute and url.path.startswith("/")
return self._base_url.join(url)
async def _request(
@@ -366,7 +453,7 @@ class ClientSession:
method: str,
str_or_url: StrOrURL,
*,
params: Optional[Mapping[str, str]] = None,
params: Query = None,
data: Any = None,
json: Any = None,
cookies: Optional[LooseCookies] = None,
@@ -375,21 +462,27 @@ class ClientSession:
auth: Optional[BasicAuth] = None,
allow_redirects: bool = True,
max_redirects: int = 10,
compress: Optional[str] = None,
compress: Union[str, bool, None] = None,
chunked: Optional[bool] = None,
expect100: bool = False,
raise_for_status: Optional[bool] = None,
raise_for_status: Union[
None, bool, Callable[[ClientResponse], Awaitable[None]]
] = None,
read_until_eof: bool = True,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
timeout: Union[ClientTimeout, object] = sentinel,
timeout: Union[ClientTimeout, _SENTINEL] = sentinel,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
ssl: Optional[Union[SSLContext, bool, Fingerprint]] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
trace_request_ctx: Optional[Mapping[str, Any]] = None,
read_bufsize: Optional[int] = None,
auto_decompress: Optional[bool] = None,
max_line_size: Optional[int] = None,
max_field_size: Optional[int] = None,
) -> ClientResponse:
# NOTE: timeout clamps existing connect and read timeouts. We cannot
@@ -412,24 +505,31 @@ class ClientSession:
warnings.warn("Chunk size is deprecated #1615", DeprecationWarning)
redirects = 0
history = []
history: List[ClientResponse] = []
version = self._version
params = params or {}
# Merge with default headers and transform to CIMultiDict
headers = self._prepare_headers(headers)
proxy_headers = self._prepare_headers(proxy_headers)
try:
url = self._build_url(str_or_url)
except ValueError as e:
raise InvalidURL(str_or_url) from e
raise InvalidUrlClientError(str_or_url) from e
assert self._connector is not None
if url.scheme not in self._connector.allowed_protocol_schema_set:
raise NonHttpUrlClientError(url)
skip_headers = set(self._skip_auto_headers)
if skip_auto_headers is not None:
for i in skip_auto_headers:
skip_headers.add(istr(i))
if proxy is not None:
if proxy is None:
proxy_headers = None
else:
proxy_headers = self._prepare_headers(proxy_headers)
try:
proxy = URL(proxy)
except ValueError as e:
@@ -439,17 +539,28 @@ class ClientSession:
real_timeout: ClientTimeout = self._timeout
else:
if not isinstance(timeout, ClientTimeout):
real_timeout = ClientTimeout(total=timeout) # type: ignore[arg-type]
real_timeout = ClientTimeout(total=timeout)
else:
real_timeout = timeout
# timeout is cumulative for all request operations
# (request, redirects, responses, data consuming)
tm = TimeoutHandle(self._loop, real_timeout.total)
tm = TimeoutHandle(
self._loop, real_timeout.total, ceil_threshold=real_timeout.ceil_threshold
)
handle = tm.start()
if read_bufsize is None:
read_bufsize = self._read_bufsize
if auto_decompress is None:
auto_decompress = self._auto_decompress
if max_line_size is None:
max_line_size = self._max_line_size
if max_field_size is None:
max_field_size = self._max_field_size
traces = [
Trace(
self,
@@ -465,15 +576,31 @@ class ClientSession:
timer = tm.timer()
try:
with timer:
# https://www.rfc-editor.org/rfc/rfc9112.html#name-retrying-requests
retry_persistent_connection = method in IDEMPOTENT_METHODS
while True:
url, auth_from_url = strip_auth_from_url(url)
if auth and auth_from_url:
if not url.raw_host:
# NOTE: Bail early, otherwise, causes `InvalidURL` through
# NOTE: `self._request_class()` below.
err_exc_cls = (
InvalidUrlRedirectClientError
if redirects
else InvalidUrlClientError
)
raise err_exc_cls(url)
# If `auth` was passed for an already authenticated URL,
# disallow only if this is the initial URL; this is to avoid issues
# with sketchy redirects that are not the caller's responsibility
if not history and (auth and auth_from_url):
raise ValueError(
"Cannot combine AUTH argument with "
"credentials encoded in URL"
)
if auth is None:
# Override the auth with the one from the URL only if we
# have no auth, or if we got an auth from a redirect URL
if auth is None or (history and auth_from_url is not None):
auth = auth_from_url
if auth is None:
auth = self._default_auth
@@ -510,7 +637,7 @@ class ClientSession:
url,
params=params,
headers=headers,
skip_auto_headers=skip_headers,
skip_auto_headers=skip_headers if skip_headers else None,
data=data,
cookies=all_cookies,
auth=auth,
@@ -524,21 +651,21 @@ class ClientSession:
proxy_auth=proxy_auth,
timer=timer,
session=self,
ssl=ssl,
ssl=ssl if ssl is not None else True,
server_hostname=server_hostname,
proxy_headers=proxy_headers,
traces=traces,
trust_env=self.trust_env,
)
# connection timeout
try:
async with ceil_timeout(real_timeout.connect):
assert self._connector is not None
conn = await self._connector.connect(
req, traces=traces, timeout=real_timeout
)
conn = await self._connector.connect(
req, traces=traces, timeout=real_timeout
)
except asyncio.TimeoutError as exc:
raise ServerTimeoutError(
"Connection timeout " "to host {}".format(url)
raise ConnectionTimeoutError(
f"Connection timeout to host {url}"
) from exc
assert conn.transport is not None
@@ -546,11 +673,14 @@ class ClientSession:
assert conn.protocol is not None
conn.protocol.set_response_params(
timer=timer,
skip_payload=method.upper() == "HEAD",
skip_payload=method_must_be_empty_body(method),
read_until_eof=read_until_eof,
auto_decompress=self._auto_decompress,
auto_decompress=auto_decompress,
read_timeout=real_timeout.sock_read,
read_bufsize=read_bufsize,
timeout_ceil_threshold=self._connector._timeout_ceil_threshold,
max_line_size=max_line_size,
max_field_size=max_field_size,
)
try:
@@ -564,6 +694,11 @@ class ClientSession:
except BaseException:
conn.close()
raise
except (ClientOSError, ServerDisconnectedError):
if retry_persistent_connection:
retry_persistent_connection = False
continue
raise
except ClientError:
raise
except OSError as exc:
@@ -571,7 +706,8 @@ class ClientSession:
raise
raise ClientOSError(*exc.args) from exc
self._cookie_jar.update_cookies(resp.cookies, resp.url)
if cookies := resp.cookies:
self._cookie_jar.update_cookies(cookies, resp.url)
# redirects
if resp.status in (301, 302, 303, 307, 308) and allow_redirects:
@@ -611,26 +747,36 @@ class ClientSession:
resp.release()
try:
parsed_url = URL(
parsed_redirect_url = URL(
r_url, encoded=not self._requote_redirect_url
)
except ValueError as e:
raise InvalidURL(r_url) from e
raise InvalidUrlRedirectClientError(
r_url,
"Server attempted redirecting to a location that does not look like a URL",
) from e
scheme = parsed_url.scheme
if scheme not in ("http", "https", ""):
scheme = parsed_redirect_url.scheme
if scheme not in HTTP_AND_EMPTY_SCHEMA_SET:
resp.close()
raise ValueError("Can redirect only to http or https")
raise NonHttpUrlRedirectClientError(r_url)
elif not scheme:
parsed_url = url.join(parsed_url)
parsed_redirect_url = url.join(parsed_redirect_url)
if url.origin() != parsed_url.origin():
try:
redirect_origin = parsed_redirect_url.origin()
except ValueError as origin_val_err:
raise InvalidUrlRedirectClientError(
parsed_redirect_url,
"Invalid redirect URL origin",
) from origin_val_err
if url.origin() != redirect_origin:
auth = None
headers.pop(hdrs.AUTHORIZATION, None)
url = parsed_url
params = None
url = parsed_redirect_url
params = {}
resp.release()
continue
@@ -639,7 +785,12 @@ class ClientSession:
# check response status
if raise_for_status is None:
raise_for_status = self._raise_for_status
if raise_for_status:
if raise_for_status is None:
pass
elif callable(raise_for_status):
await raise_for_status(resp)
elif raise_for_status:
resp.raise_for_status()
# register connection
@@ -683,11 +834,11 @@ class ClientSession:
heartbeat: Optional[float] = None,
auth: Optional[BasicAuth] = None,
origin: Optional[str] = None,
params: Optional[Mapping[str, str]] = None,
params: Query = None,
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, None, Fingerprint] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
@@ -735,11 +886,11 @@ class ClientSession:
heartbeat: Optional[float] = None,
auth: Optional[BasicAuth] = None,
origin: Optional[str] = None,
params: Optional[Mapping[str, str]] = None,
params: Query = None,
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, None, Fingerprint] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
@@ -755,7 +906,7 @@ class ClientSession:
default_headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.CONNECTION: "Upgrade",
hdrs.SEC_WEBSOCKET_VERSION: "13",
}
@@ -773,6 +924,14 @@ class ClientSession:
extstr = ws_ext_gen(compress=compress)
real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr
# For the sake of backward compatibility, if user passes in None, convert it to True
if ssl is None:
warnings.warn(
"ssl=None is deprecated, please use ssl=True",
DeprecationWarning,
stacklevel=2,
)
ssl = True
ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)
# send request
@@ -866,6 +1025,16 @@ class ClientSession:
assert conn is not None
conn_proto = conn.protocol
assert conn_proto is not None
# For WS connection the read_timeout must be either receive_timeout or greater
# None == no timeout, i.e. infinite timeout, so None is the max timeout possible
if receive_timeout is None:
# Reset regardless
conn_proto.read_timeout = receive_timeout
elif conn_proto.read_timeout is not None:
# If read_timeout was set check which wins
conn_proto.read_timeout = max(receive_timeout, conn_proto.read_timeout)
transport = conn.transport
assert transport is not None
reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue(
@@ -914,61 +1083,111 @@ class ClientSession:
added_names.add(key)
return result
def get(
self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP GET request."""
return _RequestContextManager(
self._request(hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs)
)
if sys.version_info >= (3, 11) and TYPE_CHECKING:
def options(
self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP OPTIONS request."""
return _RequestContextManager(
self._request(
hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs
def get(
self,
url: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> "_RequestContextManager": ...
def options(
self,
url: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> "_RequestContextManager": ...
def head(
self,
url: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> "_RequestContextManager": ...
def post(
self,
url: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> "_RequestContextManager": ...
def put(
self,
url: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> "_RequestContextManager": ...
def patch(
self,
url: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> "_RequestContextManager": ...
def delete(
self,
url: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> "_RequestContextManager": ...
else:
def get(
self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP GET request."""
return _RequestContextManager(
self._request(
hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs
)
)
)
def head(
self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP HEAD request."""
return _RequestContextManager(
self._request(
hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs
def options(
self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP OPTIONS request."""
return _RequestContextManager(
self._request(
hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs
)
)
)
def post(
self, url: StrOrURL, *, data: Any = None, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP POST request."""
return _RequestContextManager(
self._request(hdrs.METH_POST, url, data=data, **kwargs)
)
def head(
self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP HEAD request."""
return _RequestContextManager(
self._request(
hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs
)
)
def put(
self, url: StrOrURL, *, data: Any = None, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP PUT request."""
return _RequestContextManager(
self._request(hdrs.METH_PUT, url, data=data, **kwargs)
)
def post(
self, url: StrOrURL, *, data: Any = None, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP POST request."""
return _RequestContextManager(
self._request(hdrs.METH_POST, url, data=data, **kwargs)
)
def patch(
self, url: StrOrURL, *, data: Any = None, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP PATCH request."""
return _RequestContextManager(
self._request(hdrs.METH_PATCH, url, data=data, **kwargs)
)
def put(
self, url: StrOrURL, *, data: Any = None, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP PUT request."""
return _RequestContextManager(
self._request(hdrs.METH_PUT, url, data=data, **kwargs)
)
def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager":
"""Perform HTTP DELETE request."""
return _RequestContextManager(self._request(hdrs.METH_DELETE, url, **kwargs))
def patch(
self, url: StrOrURL, *, data: Any = None, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP PATCH request."""
return _RequestContextManager(
self._request(hdrs.METH_PATCH, url, data=data, **kwargs)
)
def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager":
"""Perform HTTP DELETE request."""
return _RequestContextManager(
self._request(hdrs.METH_DELETE, url, **kwargs)
)
async def close(self) -> None:
"""Close underlying connector.
@@ -1119,13 +1338,13 @@ class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType
__slots__ = ("_coro", "_resp")
def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None:
self._coro = coro
self._coro: Coroutine["asyncio.Future[Any]", None, _RetType] = coro
def send(self, arg: None) -> "asyncio.Future[Any]":
return self._coro.send(arg)
def throw(self, arg: BaseException) -> None: # type: ignore[arg-type,override]
self._coro.throw(arg)
def throw(self, *args: Any, **kwargs: Any) -> "asyncio.Future[Any]":
return self._coro.throw(*args, **kwargs)
def close(self) -> None:
return self._coro.close()
@@ -1138,12 +1357,8 @@ class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType
return self.__await__()
async def __aenter__(self) -> _RetType:
self._resp = await self._coro
return self._resp
class _RequestContextManager(_BaseRequestContextManager[ClientResponse]):
__slots__ = ()
self._resp: _RetType = await self._coro
return await self._resp.__aenter__()
async def __aexit__(
self,
@@ -1151,24 +1366,11 @@ class _RequestContextManager(_BaseRequestContextManager[ClientResponse]):
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
# We're basing behavior on the exception as it can be caused by
# user code unrelated to the status of the connection. If you
# would like to close a connection you must do that
# explicitly. Otherwise connection error handling should kick in
# and close/recycle the connection as required.
self._resp.release()
await self._resp.__aexit__(exc_type, exc, tb)
class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]):
__slots__ = ()
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
await self._resp.close()
_RequestContextManager = _BaseRequestContextManager[ClientResponse]
_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse]
class _SessionRequestContextManager:
@@ -1208,7 +1410,7 @@ def request(
method: str,
url: StrOrURL,
*,
params: Optional[Mapping[str, str]] = None,
params: Query = None,
data: Any = None,
json: Any = None,
headers: Optional[LooseHeaders] = None,
@@ -1229,6 +1431,8 @@ def request(
connector: Optional[BaseConnector] = None,
read_bufsize: Optional[int] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
max_line_size: int = 8190,
max_field_size: int = 8190,
) -> _SessionRequestContextManager:
"""Constructs and sends a request.
@@ -1300,6 +1504,8 @@ def request(
proxy=proxy,
proxy_auth=proxy_auth,
read_bufsize=read_bufsize,
max_line_size=max_line_size,
max_field_size=max_field_size,
),
session,
)

View File

@@ -2,10 +2,11 @@
import asyncio
import warnings
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union
from .http_parser import RawResponseMessage
from .typedefs import LooseHeaders
from multidict import MultiMapping
from .typedefs import StrOrURL
try:
import ssl
@@ -15,20 +16,25 @@ except ImportError: # pragma: no cover
ssl = SSLContext = None # type: ignore[assignment]
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo
from .http_parser import RawResponseMessage
else:
RequestInfo = ClientResponse = ConnectionKey = None
RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None
__all__ = (
"ClientError",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientOSError",
"ClientConnectorError",
"ClientProxyConnectionError",
"ClientSSLError",
"ClientConnectorDNSError",
"ClientConnectorSSLError",
"ClientConnectorCertificateError",
"ConnectionTimeoutError",
"SocketTimeoutError",
"ServerConnectionError",
"ServerTimeoutError",
"ServerDisconnectedError",
@@ -39,6 +45,11 @@ __all__ = (
"ContentTypeError",
"ClientPayloadError",
"InvalidURL",
"InvalidUrlClientError",
"RedirectClientError",
"NonHttpUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlRedirectClientError",
)
@@ -47,9 +58,13 @@ class ClientError(Exception):
class ClientResponseError(ClientError):
"""Connection error during reading response.
"""Base class for exceptions that occur after getting a response.
request_info: instance of RequestInfo
request_info: An instance of RequestInfo.
history: A sequence of responses, if redirects occurred.
status: HTTP status code.
message: Error message.
headers: Response headers.
"""
def __init__(
@@ -60,7 +75,7 @@ class ClientResponseError(ClientError):
code: Optional[int] = None,
status: Optional[int] = None,
message: str = "",
headers: Optional[LooseHeaders] = None,
headers: Optional[MultiMapping[str]] = None,
) -> None:
self.request_info = request_info
if code is not None:
@@ -89,7 +104,7 @@ class ClientResponseError(ClientError):
return "{}, message={!r}, url={!r}".format(
self.status,
self.message,
self.request_info.real_url,
str(self.request_info.real_url),
)
def __repr__(self) -> str:
@@ -146,6 +161,10 @@ class ClientConnectionError(ClientError):
"""Base class for client socket errors."""
class ClientConnectionResetError(ClientConnectionError, ConnectionResetError):
"""ConnectionResetError"""
class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""
@@ -176,18 +195,26 @@ class ClientConnectorError(ClientOSError):
return self._conn_key.port
@property
def ssl(self) -> Union[SSLContext, None, bool, "Fingerprint"]:
def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]:
return self._conn_key.ssl
def __str__(self) -> str:
return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format(
self, self.ssl if self.ssl is not None else "default", self.strerror
self, "default" if self.ssl is True else self.ssl, self.strerror
)
# OSError.__reduce__ does too much black magick
__reduce__ = BaseException.__reduce__
class ClientConnectorDNSError(ClientConnectorError):
"""DNS resolution failed during client connection.
Raised in :class:`aiohttp.connector.TCPConnector` if
DNS resolution fails.
"""
class ClientProxyConnectionError(ClientConnectorError):
"""Proxy connection error.
@@ -215,7 +242,7 @@ class UnixClientConnectorError(ClientConnectorError):
def __str__(self) -> str:
return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format(
self, self.ssl if self.ssl is not None else "default", self.strerror
self, "default" if self.ssl is True else self.ssl, self.strerror
)
@@ -238,6 +265,14 @@ class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError):
"""Server timeout error."""
class ConnectionTimeoutError(ServerTimeoutError):
"""Connection timeout error."""
class SocketTimeoutError(ServerTimeoutError):
"""Socket timeout error."""
class ServerFingerprintMismatch(ServerConnectionError):
"""SSL certificate does not match expected fingerprint."""
@@ -267,17 +302,52 @@ class InvalidURL(ClientError, ValueError):
# Derive from ValueError for backward compatibility
def __init__(self, url: Any) -> None:
def __init__(self, url: StrOrURL, description: Union[str, None] = None) -> None:
# The type of url is not yarl.URL because the exception can be raised
# on URL(url) call
super().__init__(url)
self._url = url
self._description = description
if description:
super().__init__(url, description)
else:
super().__init__(url)
@property
def url(self) -> Any:
return self.args[0]
def url(self) -> StrOrURL:
return self._url
@property
def description(self) -> "str | None":
return self._description
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self.url}>"
return f"<{self.__class__.__name__} {self}>"
def __str__(self) -> str:
if self._description:
return f"{self._url} - {self._description}"
return str(self._url)
class InvalidUrlClientError(InvalidURL):
"""Invalid URL client error."""
class RedirectClientError(ClientError):
"""Client redirect error."""
class NonHttpUrlClientError(ClientError):
"""Non http URL client error."""
class InvalidUrlRedirectClientError(InvalidUrlClientError, RedirectClientError):
"""Invalid URL redirect client error."""
class NonHttpUrlRedirectClientError(NonHttpUrlClientError, RedirectClientError):
"""Non http URL redirect client error."""
class ClientSSLError(ClientConnectorError):

View File

@@ -7,10 +7,16 @@ from .client_exceptions import (
ClientOSError,
ClientPayloadError,
ServerDisconnectedError,
ServerTimeoutError,
SocketTimeoutError,
)
from .helpers import (
_EXC_SENTINEL,
BaseTimerContext,
set_exception,
status_code_must_be_empty_body,
)
from .helpers import BaseTimerContext
from .http import HttpResponseParser, RawResponseMessage
from .http_exceptions import HttpProcessingError
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader
@@ -36,21 +42,21 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
self._read_timeout: Optional[float] = None
self._read_timeout_handle: Optional[asyncio.TimerHandle] = None
self._timeout_ceil_threshold: Optional[float] = 5
@property
def upgraded(self) -> bool:
return self._upgraded
@property
def should_close(self) -> bool:
if self._payload is not None and not self._payload.is_eof() or self._upgraded:
return True
return (
self._should_close
or (self._payload is not None and not self._payload.is_eof())
or self._upgraded
or self.exception() is not None
or self._exception is not None
or self._payload_parser is not None
or len(self) > 0
or bool(self._buffer)
or bool(self._tail)
)
@@ -71,28 +77,50 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._drop_timeout()
original_connection_error = exc
reraised_exc = original_connection_error
connection_closed_cleanly = original_connection_error is None
if self._payload_parser is not None:
with suppress(Exception):
with suppress(Exception): # FIXME: log this somehow?
self._payload_parser.feed_eof()
uncompleted = None
if self._parser is not None:
try:
uncompleted = self._parser.feed_eof()
except Exception:
except Exception as underlying_exc:
if self._payload is not None:
self._payload.set_exception(
ClientPayloadError("Response payload is not completed")
client_payload_exc_msg = (
f"Response payload is not completed: {underlying_exc !r}"
)
if not connection_closed_cleanly:
client_payload_exc_msg = (
f"{client_payload_exc_msg !s}. "
f"{original_connection_error !r}"
)
set_exception(
self._payload,
ClientPayloadError(client_payload_exc_msg),
underlying_exc,
)
if not self.is_eof():
if isinstance(exc, OSError):
exc = ClientOSError(*exc.args)
if exc is None:
exc = ServerDisconnectedError(uncompleted)
if isinstance(original_connection_error, OSError):
reraised_exc = ClientOSError(*original_connection_error.args)
if connection_closed_cleanly:
reraised_exc = ServerDisconnectedError(uncompleted)
# assigns self._should_close to True as side effect,
# we do it anyway below
self.set_exception(exc)
underlying_non_eof_exc = (
_EXC_SENTINEL
if connection_closed_cleanly
else original_connection_error
)
assert underlying_non_eof_exc is not None
assert reraised_exc is not None
self.set_exception(reraised_exc, underlying_non_eof_exc)
self._should_close = True
self._parser = None
@@ -100,7 +128,7 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
self._payload_parser = None
self._reading_paused = False
super().connection_lost(exc)
super().connection_lost(reraised_exc)
def eof_received(self) -> None:
# should call parser.feed_eof() most likely
@@ -114,10 +142,14 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
super().resume_reading()
self._reschedule_timeout()
def set_exception(self, exc: BaseException) -> None:
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._should_close = True
self._drop_timeout()
super().set_exception(exc)
super().set_exception(exc, exc_cause)
def set_parser(self, parser: Any, payload: Any) -> None:
# TODO: actual types are:
@@ -143,11 +175,15 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
auto_decompress: bool = True,
read_timeout: Optional[float] = None,
read_bufsize: int = 2**16,
timeout_ceil_threshold: float = 5,
max_line_size: int = 8190,
max_field_size: int = 8190,
) -> None:
self._skip_payload = skip_payload
self._read_timeout = read_timeout
self._reschedule_timeout()
self._timeout_ceil_threshold = timeout_ceil_threshold
self._parser = HttpResponseParser(
self,
@@ -158,6 +194,8 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
response_with_body=not skip_payload,
read_until_eof=read_until_eof,
auto_decompress=auto_decompress,
max_line_size=max_line_size,
max_field_size=max_field_size,
)
if self._tail:
@@ -181,11 +219,22 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
else:
self._read_timeout_handle = None
def start_timeout(self) -> None:
self._reschedule_timeout()
@property
def read_timeout(self) -> Optional[float]:
return self._read_timeout
@read_timeout.setter
def read_timeout(self, read_timeout: Optional[float]) -> None:
self._read_timeout = read_timeout
def _on_read_timeout(self) -> None:
exc = ServerTimeoutError("Timeout on reading data from socket")
exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
if self._payload is not None:
self._payload.set_exception(exc)
set_exception(self._payload, exc)
def data_received(self, data: bytes) -> None:
self._reschedule_timeout()
@@ -211,14 +260,22 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as exc:
except BaseException as underlying_exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
self.set_exception(exc)
if isinstance(underlying_exc, HttpProcessingError):
exc = HttpProcessingError(
code=underlying_exc.code,
message=underlying_exc.message,
headers=underlying_exc.headers,
)
else:
exc = HttpProcessingError()
self.set_exception(exc, underlying_exc)
return
self._upgraded = upgraded
@@ -230,7 +287,9 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
self._payload = payload
if self._skip_payload or message.code in (204, 304):
if self._skip_payload or status_code_must_be_empty_body(
message.code
):
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,13 @@
"""WebSocket client for asyncio."""
import asyncio
from typing import Any, Optional, cast
import sys
from types import TracebackType
from typing import Any, Optional, Type, cast
import async_timeout
from .client_exceptions import ClientError
from .client_exceptions import ClientError, ServerTimeoutError
from .client_reqrep import ClientResponse
from .helpers import call_later, set_result
from .helpers import calculate_timeout_when, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
@@ -25,6 +25,11 @@ from .typedefs import (
JSONEncoder,
)
if sys.version_info >= (3, 11):
import asyncio as async_timeout
else:
import async_timeout
class ClientWebSocketResponse:
def __init__(
@@ -58,53 +63,123 @@ class ClientWebSocketResponse:
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
self._heartbeat_when: float = 0.0
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._loop = loop
self._waiting: Optional[asyncio.Future[bool]] = None
self._waiting: bool = False
self._close_wait: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._compress = compress
self._client_notakeover = client_notakeover
self._ping_task: Optional[asyncio.Task[None]] = None
self._reset_heartbeat()
def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None
def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
def _reset_heartbeat(self) -> None:
self._cancel_heartbeat()
if self._heartbeat is not None:
self._heartbeat_cb = call_later(
self._send_heartbeat, self._heartbeat, self._loop
)
if self._heartbeat is None:
return
self._cancel_pong_response_cb()
loop = self._loop
assert loop is not None
conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
now = loop.time()
when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
self._heartbeat_when = when
if self._heartbeat_cb is None:
# We do not cancel the previous heartbeat_cb here because
# it generates a significant amount of TimerHandle churn
# which causes asyncio to rebuild the heap frequently.
# Instead _send_heartbeat() will reschedule the next
# heartbeat if it fires too early.
self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
def _send_heartbeat(self) -> None:
if self._heartbeat is not None and not self._closed:
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping())
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = call_later(
self._pong_not_received, self._pong_heartbeat, self._loop
self._heartbeat_cb = None
loop = self._loop
now = loop.time()
if now < self._heartbeat_when:
# Heartbeat fired too early, reschedule
self._heartbeat_cb = loop.call_at(
self._heartbeat_when, self._send_heartbeat
)
return
conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
else:
ping_task = loop.create_task(self._writer.ping())
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)
else:
self._ping_task_done(ping_task)
def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
if not task.cancelled() and (exc := task.exception()):
self._handle_ping_pong_exception(exc)
self._ping_task = None
def _pong_not_received(self) -> None:
if not self._closed:
self._closed = True
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = asyncio.TimeoutError()
self._response.close()
self._handle_ping_pong_exception(ServerTimeoutError())
def _handle_ping_pong_exception(self, exc: BaseException) -> None:
"""Handle exceptions raised during ping/pong processing."""
if self._closed:
return
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
if self._waiting and not self._closing:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None))
def _set_closed(self) -> None:
"""Set the connection to closed.
Cancel any heartbeat timers and set the closed flag.
"""
self._closed = True
self._cancel_heartbeat()
def _set_closing(self) -> None:
"""Set the connection to closing.
Cancel any heartbeat timers and set the closing flag.
"""
self._closing = True
self._cancel_heartbeat()
@property
def closed(self) -> bool:
@@ -167,15 +242,37 @@ class ClientWebSocketResponse:
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting is not None and not self._closed:
if self._waiting and not self._closing:
assert self._loop is not None
self._close_wait = self._loop.create_future()
self._set_closing()
self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._waiting
await self._close_wait
if not self._closed:
self._cancel_heartbeat()
self._closed = True
if self._closed:
return False
self._set_closed()
try:
await self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._response.close()
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
return True
if self._close_code:
self._response.close()
return True
while True:
try:
await self._writer.close(code, message)
async with async_timeout.timeout(self._timeout):
msg = await self._reader.read()
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._response.close()
@@ -186,34 +283,16 @@ class ClientWebSocketResponse:
self._response.close()
return True
if self._closing:
if msg.type is WSMsgType.CLOSE:
self._close_code = msg.data
self._response.close()
return True
while True:
try:
async with async_timeout.timeout(self._timeout):
msg = await self._reader.read()
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._response.close()
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
return True
if msg.type == WSMsgType.CLOSE:
self._close_code = msg.data
self._response.close()
return True
else:
return False
async def receive(self, timeout: Optional[float] = None) -> WSMessage:
receive_timeout = timeout or self._receive_timeout
while True:
if self._waiting is not None:
if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")
if self._closed:
@@ -223,15 +302,22 @@ class ClientWebSocketResponse:
return WS_CLOSED_MESSAGE
try:
self._waiting = self._loop.create_future()
self._waiting = True
try:
async with async_timeout.timeout(timeout or self._receive_timeout):
if receive_timeout:
# Entering the context manager and creating
# Timeout() object can take almost 50% of the
# run time in this loop so we avoid it if
# there is no read timeout.
async with async_timeout.timeout(receive_timeout):
msg = await self._reader.read()
else:
msg = await self._reader.read()
self._reset_heartbeat()
finally:
waiter = self._waiting
self._waiting = None
set_result(waiter, True)
self._waiting = False
if self._close_wait:
set_result(self._close_wait, None)
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
@@ -240,7 +326,8 @@ class ClientWebSocketResponse:
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except ClientError:
self._closed = True
# Likely ServerDisconnectedError when connection is lost
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
return WS_CLOSED_MESSAGE
except WebSocketError as exc:
@@ -249,35 +336,35 @@ class ClientWebSocketResponse:
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._set_closing()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type == WSMsgType.CLOSE:
self._closing = True
if msg.type is WSMsgType.CLOSE:
self._set_closing()
self._close_code = msg.data
if not self._closed and self._autoclose:
await self.close()
elif msg.type == WSMsgType.CLOSING:
self._closing = True
elif msg.type == WSMsgType.PING and self._autoping:
elif msg.type is WSMsgType.CLOSING:
self._set_closing()
elif msg.type is WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
elif msg.type == WSMsgType.PONG and self._autoping:
elif msg.type is WSMsgType.PONG and self._autoping:
continue
return msg
async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg = await self.receive(timeout)
if msg.type != WSMsgType.TEXT:
if msg.type is not WSMsgType.TEXT:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not str")
return cast(str, msg.data)
async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
if msg.type != WSMsgType.BINARY:
if msg.type is not WSMsgType.BINARY:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
return cast(bytes, msg.data)
@@ -298,3 +385,14 @@ class ClientWebSocketResponse:
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
raise StopAsyncIteration
return msg
async def __aenter__(self) -> "ClientWebSocketResponse":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()

View File

@@ -0,0 +1,173 @@
import asyncio
import zlib
from concurrent.futures import Executor
from typing import Optional, cast
try:
try:
import brotlicffi as brotli
except ImportError:
import brotli
HAS_BROTLI = True
except ImportError: # pragma: no cover
HAS_BROTLI = False
MAX_SYNC_CHUNK_SIZE = 1024
def encoding_to_mode(
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
) -> int:
if encoding == "gzip":
return 16 + zlib.MAX_WBITS
return -zlib.MAX_WBITS if suppress_deflate_header else zlib.MAX_WBITS
class ZlibBaseHandler:
def __init__(
self,
mode: int,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
self._mode = mode
self._executor = executor
self._max_sync_chunk_size = max_sync_chunk_size
class ZLibCompressor(ZlibBaseHandler):
def __init__(
self,
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
level: Optional[int] = None,
wbits: Optional[int] = None,
strategy: int = zlib.Z_DEFAULT_STRATEGY,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(
mode=(
encoding_to_mode(encoding, suppress_deflate_header)
if wbits is None
else wbits
),
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
if level is None:
self._compressor = zlib.compressobj(wbits=self._mode, strategy=strategy)
else:
self._compressor = zlib.compressobj(
wbits=self._mode, strategy=strategy, level=level
)
self._compress_lock = asyncio.Lock()
def compress_sync(self, data: bytes) -> bytes:
return self._compressor.compress(data)
async def compress(self, data: bytes) -> bytes:
"""Compress the data and returned the compressed bytes.
Note that flush() must be called after the last call to compress()
If the data size is large than the max_sync_chunk_size, the compression
will be done in the executor. Otherwise, the compression will be done
in the event loop.
"""
async with self._compress_lock:
# To ensure the stream is consistent in the event
# there are multiple writers, we need to lock
# the compressor so that only one writer can
# compress at a time.
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._compressor.compress, data
)
return self.compress_sync(data)
def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
return self._compressor.flush(mode)
class ZLibDecompressor(ZlibBaseHandler):
def __init__(
self,
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(
mode=encoding_to_mode(encoding, suppress_deflate_header),
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
self._decompressor = zlib.decompressobj(wbits=self._mode)
def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
return self._decompressor.decompress(data, max_length)
async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
"""Decompress the data and return the decompressed bytes.
If the data size is large than the max_sync_chunk_size, the decompression
will be done in the executor. Otherwise, the decompression will be done
in the event loop.
"""
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._decompressor.decompress, data, max_length
)
return self.decompress_sync(data, max_length)
def flush(self, length: int = 0) -> bytes:
return (
self._decompressor.flush(length)
if length > 0
else self._decompressor.flush()
)
@property
def eof(self) -> bool:
return self._decompressor.eof
@property
def unconsumed_tail(self) -> bytes:
return self._decompressor.unconsumed_tail
@property
def unused_data(self) -> bytes:
return self._decompressor.unused_data
class BrotliDecompressor:
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(self) -> None:
if not HAS_BROTLI:
raise RuntimeError(
"The brotli decompression is not available. "
"Please install `Brotli` module"
)
self._obj = brotli.Decompressor()
def decompress_sync(self, data: bytes) -> bytes:
if hasattr(self._obj, "decompress"):
return cast(bytes, self._obj.decompress(data))
return cast(bytes, self._obj.process(data))
def flush(self) -> bytes:
if hasattr(self._obj, "flush"):
return cast(bytes, self._obj.flush())
return b""

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,17 @@
import asyncio
import calendar
import contextlib
import datetime
import heapq
import itertools
import os # noqa
import pathlib
import pickle
import re
import time
from collections import defaultdict
from http.cookies import BaseCookie, Morsel, SimpleCookie
from typing import ( # noqa
from typing import (
DefaultDict,
Dict,
Iterable,
@@ -24,7 +28,7 @@ from typing import ( # noqa
from yarl import URL
from .abc import AbstractCookieJar, ClearCookiePredicate
from .helpers import is_ip_address, next_whole_second
from .helpers import is_ip_address
from .typedefs import LooseCookies, PathLike, StrOrURL
__all__ = ("CookieJar", "DummyCookieJar")
@@ -32,6 +36,15 @@ __all__ = ("CookieJar", "DummyCookieJar")
CookieItem = Union[str, "Morsel[str]"]
# We cache these string methods here as their use is in performance critical code.
_FORMAT_PATH = "{}/{}".format
_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
# The minimum number of scheduled cookie expirations before we start cleaning up
# the expiration heap. This is a performance optimization to avoid cleaning up the
# heap too often when there are only a few scheduled expirations.
_MIN_SCHEDULED_COOKIE_EXPIRATION = 100
class CookieJar(AbstractCookieJar):
"""Implements cookie storage adhering to RFC 6265."""
@@ -52,9 +65,23 @@ class CookieJar(AbstractCookieJar):
DATE_YEAR_RE = re.compile(r"(\d{2,4})")
MAX_TIME = datetime.datetime.max.replace(tzinfo=datetime.timezone.utc)
MAX_32BIT_TIME = datetime.datetime.utcfromtimestamp(2**31 - 1)
# calendar.timegm() fails for timestamps after datetime.datetime.max
# Minus one as a loss of precision occurs when timestamp() is called.
MAX_TIME = (
int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
)
try:
calendar.timegm(time.gmtime(MAX_TIME))
except (OSError, ValueError):
# Hit the maximum representable time on Windows
# https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
# Throws ValueError on PyPy 3.8 and 3.9, OSError elsewhere
MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
except OverflowError:
# #4515: datetime.max may not be representable on 32-bit platforms
MAX_TIME = 2**31 - 1
# Avoid minuses in the future, 3x faster
SUB_MAX_TIME = MAX_TIME - 1
def __init__(
self,
@@ -65,9 +92,12 @@ class CookieJar(AbstractCookieJar):
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
super().__init__(loop=loop)
self._cookies: DefaultDict[Tuple[str, str], SimpleCookie[str]] = defaultdict(
self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
SimpleCookie
)
self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = (
defaultdict(dict)
)
self._host_only_cookies: Set[Tuple[str, str]] = set()
self._unsafe = unsafe
self._quote_cookie = quote_cookie
@@ -83,14 +113,8 @@ class CookieJar(AbstractCookieJar):
for url in treat_as_secure_origin
]
self._treat_as_secure_origin = treat_as_secure_origin
self._next_expiration = next_whole_second()
self._expirations: Dict[Tuple[str, str, str], datetime.datetime] = {}
# #4515: datetime.max may not be representable on 32-bit platforms
self._max_time = self.MAX_TIME
try:
self._max_time.timestamp()
except OverflowError:
self._max_time = self.MAX_32BIT_TIME
self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = []
self._expirations: Dict[Tuple[str, str, str], float] = {}
def save(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
@@ -104,36 +128,26 @@ class CookieJar(AbstractCookieJar):
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
if predicate is None:
self._next_expiration = next_whole_second()
self._expire_heap.clear()
self._cookies.clear()
self._morsel_cache.clear()
self._host_only_cookies.clear()
self._expirations.clear()
return
to_del = []
now = datetime.datetime.now(datetime.timezone.utc)
for (domain, path), cookie in self._cookies.items():
for name, morsel in cookie.items():
key = (domain, path, name)
if (
key in self._expirations and self._expirations[key] <= now
) or predicate(morsel):
to_del.append(key)
for domain, path, name in to_del:
self._host_only_cookies.discard((domain, name))
key = (domain, path, name)
if key in self._expirations:
del self._expirations[(domain, path, name)]
self._cookies[(domain, path)].pop(name, None)
next_expiration = min(self._expirations.values(), default=self._max_time)
try:
self._next_expiration = next_expiration.replace(
microsecond=0
) + datetime.timedelta(seconds=1)
except OverflowError:
self._next_expiration = self._max_time
now = time.time()
to_del = [
key
for (domain, path), cookie in self._cookies.items()
for name, morsel in cookie.items()
if (
(key := (domain, path, name)) in self._expirations
and self._expirations[key] <= now
)
or predicate(morsel)
]
if to_del:
self._delete_cookies(to_del)
def clear_domain(self, domain: str) -> None:
self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
@@ -144,16 +158,70 @@ class CookieJar(AbstractCookieJar):
yield from val.values()
def __len__(self) -> int:
return sum(1 for i in self)
"""Return number of cookies.
This function does not iterate self to avoid unnecessary expiration
checks.
"""
return sum(len(cookie.values()) for cookie in self._cookies.values())
def _do_expiration(self) -> None:
self.clear(lambda x: False)
"""Remove expired cookies."""
if not (expire_heap_len := len(self._expire_heap)):
return
def _expire_cookie(
self, when: datetime.datetime, domain: str, path: str, name: str
) -> None:
self._next_expiration = min(self._next_expiration, when)
self._expirations[(domain, path, name)] = when
# If the expiration heap grows larger than the number expirations
# times two, we clean it up to avoid keeping expired entries in
# the heap and consuming memory. We guard this with a minimum
# threshold to avoid cleaning up the heap too often when there are
# only a few scheduled expirations.
if (
expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
and expire_heap_len > len(self._expirations) * 2
):
# Remove any expired entries from the expiration heap
# that do not match the expiration time in the expirations
# as it means the cookie has been re-added to the heap
# with a different expiration time.
self._expire_heap = [
entry
for entry in self._expire_heap
if self._expirations.get(entry[1]) == entry[0]
]
heapq.heapify(self._expire_heap)
now = time.time()
to_del: List[Tuple[str, str, str]] = []
# Find any expired cookies and add them to the to-delete list
while self._expire_heap:
when, cookie_key = self._expire_heap[0]
if when > now:
break
heapq.heappop(self._expire_heap)
# Check if the cookie hasn't been re-added to the heap
# with a different expiration time as it will be removed
# later when it reaches the top of the heap and its
# expiration time is met.
if self._expirations.get(cookie_key) == when:
to_del.append(cookie_key)
if to_del:
self._delete_cookies(to_del)
def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None:
for domain, path, name in to_del:
self._host_only_cookies.discard((domain, name))
self._cookies[(domain, path)].pop(name, None)
self._morsel_cache[(domain, path)].pop(name, None)
self._expirations.pop((domain, path, name), None)
def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
cookie_key = (domain, path, name)
if self._expirations.get(cookie_key) == when:
# Avoid adding duplicates to the heap
return
heapq.heappush(self._expire_heap, (when, cookie_key))
self._expirations[cookie_key] = when
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
"""Update cookies."""
@@ -168,14 +236,14 @@ class CookieJar(AbstractCookieJar):
for name, cookie in cookies:
if not isinstance(cookie, Morsel):
tmp: SimpleCookie[str] = SimpleCookie()
tmp = SimpleCookie()
tmp[name] = cookie # type: ignore[assignment]
cookie = tmp[name]
domain = cookie["domain"]
# ignore domains with trailing dots
if domain.endswith("."):
if domain and domain[-1] == ".":
domain = ""
del cookie["domain"]
@@ -185,7 +253,7 @@ class CookieJar(AbstractCookieJar):
self._host_only_cookies.add((hostname, name))
domain = cookie["domain"] = hostname
if domain.startswith("."):
if domain and domain[0] == ".":
# Remove leading dot
domain = domain[1:]
cookie["domain"] = domain
@@ -195,7 +263,7 @@ class CookieJar(AbstractCookieJar):
continue
path = cookie["path"]
if not path or not path.startswith("/"):
if not path or path[0] != "/":
# Set the cookie's path to the response path
path = response_url.path
if not path.startswith("/"):
@@ -204,82 +272,99 @@ class CookieJar(AbstractCookieJar):
# Cut everything from the last slash to the end
path = "/" + path[1 : path.rfind("/")]
cookie["path"] = path
path = path.rstrip("/")
max_age = cookie["max-age"]
if max_age:
if max_age := cookie["max-age"]:
try:
delta_seconds = int(max_age)
try:
max_age_expiration = datetime.datetime.now(
datetime.timezone.utc
) + datetime.timedelta(seconds=delta_seconds)
except OverflowError:
max_age_expiration = self._max_time
max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
self._expire_cookie(max_age_expiration, domain, path, name)
except ValueError:
cookie["max-age"] = ""
else:
expires = cookie["expires"]
if expires:
expire_time = self._parse_date(expires)
if expire_time:
self._expire_cookie(expire_time, domain, path, name)
else:
cookie["expires"] = ""
elif expires := cookie["expires"]:
if expire_time := self._parse_date(expires):
self._expire_cookie(expire_time, domain, path, name)
else:
cookie["expires"] = ""
self._cookies[(domain, path)][name] = cookie
key = (domain, path)
if self._cookies[key].get(name) != cookie:
# Don't blow away the cache if the same
# cookie gets set again
self._cookies[key][name] = cookie
self._morsel_cache[key].pop(name, None)
self._do_expiration()
def filter_cookies(
self, request_url: URL = URL()
) -> Union["BaseCookie[str]", "SimpleCookie[str]"]:
def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
"""Returns this jar's cookies filtered by their attributes."""
self._do_expiration()
request_url = URL(request_url)
filtered: Union["SimpleCookie[str]", "BaseCookie[str]"] = (
filtered: Union[SimpleCookie, "BaseCookie[str]"] = (
SimpleCookie() if self._quote_cookie else BaseCookie()
)
if not self._cookies:
# Skip do_expiration() if there are no cookies.
return filtered
self._do_expiration()
if not self._cookies:
# Skip rest of function if no non-expired cookies.
return filtered
request_url = URL(request_url)
hostname = request_url.raw_host or ""
request_origin = URL()
with contextlib.suppress(ValueError):
request_origin = request_url.origin()
is_not_secure = (
request_url.scheme not in ("https", "wss")
and request_origin not in self._treat_as_secure_origin
)
is_not_secure = request_url.scheme not in ("https", "wss")
if is_not_secure and self._treat_as_secure_origin:
request_origin = URL()
with contextlib.suppress(ValueError):
request_origin = request_url.origin()
is_not_secure = request_origin not in self._treat_as_secure_origin
for cookie in self:
name = cookie.key
domain = cookie["domain"]
# Send shared cookie
for c in self._cookies[("", "")].values():
filtered[c.key] = c.value
# Send shared cookies
if not domain:
filtered[name] = cookie.value
continue
if is_ip_address(hostname):
if not self._unsafe:
return filtered
domains: Iterable[str] = (hostname,)
else:
# Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
domains = itertools.accumulate(
reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
)
if not self._unsafe and is_ip_address(hostname):
continue
# Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
# Create every combination of (domain, path) pairs.
pairs = itertools.product(domains, paths)
if (domain, name) in self._host_only_cookies:
if domain != hostname:
path_len = len(request_url.path)
# Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
for p in pairs:
for name, cookie in self._cookies[p].items():
domain = cookie["domain"]
if (domain, name) in self._host_only_cookies and domain != hostname:
continue
elif not self._is_domain_match(domain, hostname):
continue
if not self._is_path_match(request_url.path, cookie["path"]):
continue
# Skip edge case when the cookie has a trailing slash but request doesn't.
if len(cookie["path"]) > path_len:
continue
if is_not_secure and cookie["secure"]:
continue
if is_not_secure and cookie["secure"]:
continue
# It's critical we use the Morsel so the coded_value
# (based on cookie version) is preserved
mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
filtered[name] = mrsl_val
# We already built the Morsel so reuse it here
if name in self._morsel_cache[p]:
filtered[name] = self._morsel_cache[p][name]
continue
# It's critical we use the Morsel so the coded_value
# (based on cookie version) is preserved
mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
self._morsel_cache[p][name] = mrsl_val
filtered[name] = mrsl_val
return filtered
@@ -299,27 +384,8 @@ class CookieJar(AbstractCookieJar):
return not is_ip_address(hostname)
@staticmethod
def _is_path_match(req_path: str, cookie_path: str) -> bool:
"""Implements path matching adhering to RFC 6265."""
if not req_path.startswith("/"):
req_path = "/"
if req_path == cookie_path:
return True
if not req_path.startswith(cookie_path):
return False
if cookie_path.endswith("/"):
return True
non_matching = req_path[len(cookie_path) :]
return non_matching.startswith("/")
@classmethod
def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]:
def _parse_date(cls, date_str: str) -> Optional[int]:
"""Implements date string parsing adhering to RFC 6265."""
if not date_str:
return None
@@ -380,9 +446,7 @@ class CookieJar(AbstractCookieJar):
if year < 1601 or hour > 23 or minute > 59 or second > 59:
return None
return datetime.datetime(
year, month, day, hour, minute, second, tzinfo=datetime.timezone.utc
)
return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
class DummyCookieJar(AbstractCookieJar):

View File

@@ -1,4 +1,5 @@
import io
import warnings
from typing import Any, Iterable, List, Optional
from urllib.parse import urlencode
@@ -53,7 +54,12 @@ class FormData:
if isinstance(value, io.IOBase):
self._is_multipart = True
elif isinstance(value, (bytes, bytearray, memoryview)):
msg = (
"In v4, passing bytes will no longer create a file field. "
"Please explicitly use the filename parameter or pass a BytesIO object."
)
if filename is None and content_transfer_encoding is None:
warnings.warn(msg, DeprecationWarning)
filename = name
type_options: MultiDict[str] = MultiDict({"name": name})
@@ -81,7 +87,11 @@ class FormData:
"content_transfer_encoding must be an instance"
" of str. Got: %s" % content_transfer_encoding
)
headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding
msg = (
"content_transfer_encoding is deprecated. "
"To maintain compatibility with v4 please pass a BytesPayload."
)
warnings.warn(msg, DeprecationWarning)
self._is_multipart = True
self._fields.append((type_options, headers, value))

View File

@@ -2,16 +2,10 @@
# After changing the file content call ./tools/gen.py
# to regenerate the headers parser
import sys
from typing import Set
from typing import Final, Set
from multidict import istr
if sys.version_info >= (3, 8):
from typing import Final
else:
from typing_extensions import Final
METH_ANY: Final[str] = "*"
METH_CONNECT: Final[str] = "CONNECT"
METH_HEAD: Final[str] = "HEAD"

View File

@@ -3,7 +3,9 @@
import asyncio
import base64
import binascii
import contextlib
import datetime
import enum
import functools
import inspect
import netrc
@@ -12,7 +14,6 @@ import platform
import re
import sys
import time
import warnings
import weakref
from collections import namedtuple
from contextlib import suppress
@@ -33,62 +34,47 @@ from typing import (
List,
Mapping,
Optional,
Pattern,
Set,
Protocol,
Tuple,
Type,
TypeVar,
Union,
cast,
get_args,
overload,
)
from urllib.parse import quote
from urllib.request import getproxies, proxy_bypass
import async_timeout
import attr
from multidict import MultiDict, MultiDictProxy
from multidict import MultiDict, MultiDictProxy, MultiMapping
from yarl import URL
from . import hdrs
from .log import client_logger, internal_logger
from .typedefs import PathLike, Protocol # noqa
from .log import client_logger
if sys.version_info >= (3, 11):
import asyncio as async_timeout
else:
import async_timeout
__all__ = ("BasicAuth", "ChainMapProxy", "ETag")
IS_MACOS = platform.system() == "Darwin"
IS_WINDOWS = platform.system() == "Windows"
PY_36 = sys.version_info >= (3, 6)
PY_37 = sys.version_info >= (3, 7)
PY_38 = sys.version_info >= (3, 8)
PY_310 = sys.version_info >= (3, 10)
PY_311 = sys.version_info >= (3, 11)
if sys.version_info < (3, 7):
import idna_ssl
idna_ssl.patch_match_hostname()
def all_tasks(
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Set["asyncio.Task[Any]"]:
tasks = list(asyncio.Task.all_tasks(loop))
return {t for t in tasks if not t.done()}
else:
all_tasks = asyncio.all_tasks
_T = TypeVar("_T")
_S = TypeVar("_S")
_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
sentinel = _SENTINEL.sentinel
sentinel: Any = object()
NO_EXTENSIONS: bool = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
# N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr
# for compatibility with older versions
DEBUG: bool = getattr(sys.flags, "dev_mode", False) or (
DEBUG = sys.flags.dev_mode or (
not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
)
@@ -177,9 +163,11 @@ class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
"""Create BasicAuth from url."""
if not isinstance(url, URL):
raise TypeError("url should be yarl.URL instance")
if url.user is None:
# Check raw_user and raw_password first as yarl is likely
# to already have these values parsed from the netloc in the cache.
if url.raw_user is None and url.raw_password is None:
return None
return cls(url.user, url.password or "", encoding=encoding)
return cls(url.user or "", url.password or "", encoding=encoding)
def encode(self) -> str:
"""Encode credentials."""
@@ -188,11 +176,12 @@ class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
auth = BasicAuth.from_url(url)
if auth is None:
"""Remove user and password from URL if present and return BasicAuth object."""
# Check raw_user and raw_password first as yarl is likely
# to already have these values parsed from the netloc in the cache.
if url.raw_user is None and url.raw_password is None:
return url, None
else:
return url.with_user(None), auth
return url.with_user(None), BasicAuth(url.user or "", url.password or "")
def netrc_from_env() -> Optional[netrc.netrc]:
@@ -226,8 +215,11 @@ def netrc_from_env() -> Optional[netrc.netrc]:
except netrc.NetrcParseError as e:
client_logger.warning("Could not parse .netrc file: %s", e)
except OSError as e:
netrc_exists = False
with contextlib.suppress(OSError):
netrc_exists = netrc_path.is_file()
# we couldn't read the file (doesn't exist, permissions, etc.)
if netrc_env or netrc_path.is_file():
if netrc_env or netrc_exists:
# only warn if the environment wanted us to load it,
# or it appears like the default file does actually exist
client_logger.warning("Could not read .netrc file: %s", e)
@@ -241,6 +233,35 @@ class ProxyInfo:
proxy_auth: Optional[BasicAuth]
def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth:
"""
Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``.
:raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
entry is found for the ``host``.
"""
if netrc_obj is None:
raise LookupError("No .netrc file found")
auth_from_netrc = netrc_obj.authenticators(host)
if auth_from_netrc is None:
raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
login, account, password = auth_from_netrc
# TODO(PY311): username = login or account
# Up to python 3.10, account could be None if not specified,
# and login will be empty string if not specified. From 3.11,
# login and account will be empty string if not specified.
username = login if (login or account is None) else account
# TODO(PY311): Remove this, as password will be empty string
# if not specified
if password is None:
password = ""
return BasicAuth(username, password)
def proxies_from_env() -> Dict[str, ProxyInfo]:
proxy_urls = {
k: URL(v)
@@ -258,55 +279,15 @@ def proxies_from_env() -> Dict[str, ProxyInfo]:
)
continue
if netrc_obj and auth is None:
auth_from_netrc = None
if proxy.host is not None:
auth_from_netrc = netrc_obj.authenticators(proxy.host)
if auth_from_netrc is not None:
# auth_from_netrc is a (`user`, `account`, `password`) tuple,
# `user` and `account` both can be username,
# if `user` is None, use `account`
*logins, password = auth_from_netrc
login = logins[0] if logins[0] else logins[-1]
auth = BasicAuth(cast(str, login), cast(str, password))
try:
auth = basicauth_from_netrc(netrc_obj, proxy.host)
except LookupError:
auth = None
ret[proto] = ProxyInfo(proxy, auth)
return ret
def current_task(
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> "Optional[asyncio.Task[Any]]":
if sys.version_info >= (3, 7):
return asyncio.current_task(loop=loop)
else:
return asyncio.Task.current_task(loop=loop)
def get_running_loop(
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> asyncio.AbstractEventLoop:
if loop is None:
loop = asyncio.get_event_loop()
if not loop.is_running():
warnings.warn(
"The object should be created within an async function",
DeprecationWarning,
stacklevel=3,
)
if loop.get_debug():
internal_logger.warning(
"The object should be created within an async function", stack_info=True
)
return loop
def isasyncgenfunction(obj: Any) -> bool:
func = getattr(inspect, "isasyncgenfunction", None)
if func is not None:
return func(obj) # type: ignore[no-any-return]
else:
return False
def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
"""Get a permitted proxy for the given URL from the env."""
if url.host is not None and proxy_bypass(url.host):
@@ -354,23 +335,15 @@ def parse_mimetype(mimetype: str) -> MimeType:
for item in parts[1:]:
if not item:
continue
key, value = cast(
Tuple[str, str], item.split("=", 1) if "=" in item else (item, "")
)
key, _, value = item.partition("=")
params.add(key.lower().strip(), value.strip(' "'))
fulltype = parts[0].strip().lower()
if fulltype == "*":
fulltype = "*/*"
mtype, stype = (
cast(Tuple[str, str], fulltype.split("/", 1))
if "/" in fulltype
else (fulltype, "")
)
stype, suffix = (
cast(Tuple[str, str], stype.split("+", 1)) if "+" in stype else (stype, "")
)
mtype, _, stype = fulltype.partition("/")
stype, _, suffix = stype.partition("+")
return MimeType(
type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
@@ -500,54 +473,54 @@ try:
except ImportError:
pass
_ipv4_pattern = (
r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}"
r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
)
_ipv6_pattern = (
r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}"
r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)"
r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})"
r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}"
r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}"
r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)"
r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}"
r":|:(:[A-F0-9]{1,4}){7})$"
)
_ipv4_regex = re.compile(_ipv4_pattern)
_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii"))
_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE)
def is_ipv4_address(host: Optional[Union[str, bytes]]) -> bool:
"""Check if host looks like an IPv4 address.
def _is_ip_address(
regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]]
) -> bool:
if host is None:
This function does not validate that the format is correct, only that
the host is a str or bytes, and its all numeric.
This check is only meant as a heuristic to ensure that
a host is not a domain name.
"""
if not host:
return False
# For a host to be an ipv4 address, it must be all numeric.
if isinstance(host, str):
return bool(regex.match(host))
elif isinstance(host, (bytes, bytearray, memoryview)):
return bool(regexb.match(host))
else:
raise TypeError(f"{host} [{type(host)}] is not a str or bytes")
return host.replace(".", "").isdigit()
if isinstance(host, (bytes, bytearray, memoryview)):
return host.decode("ascii").replace(".", "").isdigit()
raise TypeError(f"{host} [{type(host)}] is not a str or bytes")
is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb)
is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb)
def is_ipv6_address(host: Optional[Union[str, bytes]]) -> bool:
"""Check if host looks like an IPv6 address.
This function does not validate that the format is correct, only that
the host contains a colon and that it is a str or bytes.
This check is only meant as a heuristic to ensure that
a host is not a domain name.
"""
if not host:
return False
# The host must contain a colon to be an IPv6 address.
if isinstance(host, str):
return ":" in host
if isinstance(host, (bytes, bytearray, memoryview)):
return b":" in host
raise TypeError(f"{host} [{type(host)}] is not a str or bytes")
def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool:
"""Check if host looks like an IP Address.
This check is only meant as a heuristic to ensure that
a host is not a domain name.
"""
return is_ipv4_address(host) or is_ipv6_address(host)
def next_whole_second() -> datetime.datetime:
"""Return current time rounded up to the next whole second."""
return datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
) + datetime.timedelta(seconds=0)
_cached_current_datetime: Optional[int] = None
_cached_formatted_datetime = ""
@@ -601,11 +574,15 @@ def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None:
def weakref_handle(
ob: object, name: str, timeout: float, loop: asyncio.AbstractEventLoop
ob: object,
name: str,
timeout: float,
loop: asyncio.AbstractEventLoop,
timeout_ceil_threshold: float = 5,
) -> Optional[asyncio.TimerHandle]:
if timeout is not None and timeout > 0:
when = loop.time() + timeout
if timeout >= 5:
if timeout >= timeout_ceil_threshold:
when = ceil(when)
return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
@@ -613,24 +590,44 @@ def weakref_handle(
def call_later(
cb: Callable[[], Any], timeout: float, loop: asyncio.AbstractEventLoop
cb: Callable[[], Any],
timeout: float,
loop: asyncio.AbstractEventLoop,
timeout_ceil_threshold: float = 5,
) -> Optional[asyncio.TimerHandle]:
if timeout is not None and timeout > 0:
when = loop.time() + timeout
if timeout > 5:
when = ceil(when)
return loop.call_at(when, cb)
return None
if timeout is None or timeout <= 0:
return None
now = loop.time()
when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
return loop.call_at(when, cb)
def calculate_timeout_when(
loop_time: float,
timeout: float,
timeout_ceiling_threshold: float,
) -> float:
"""Calculate when to execute a timeout."""
when = loop_time + timeout
if timeout > timeout_ceiling_threshold:
return ceil(when)
return when
class TimeoutHandle:
"""Timeout handle"""
__slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks")
def __init__(
self, loop: asyncio.AbstractEventLoop, timeout: Optional[float]
self,
loop: asyncio.AbstractEventLoop,
timeout: Optional[float],
ceil_threshold: float = 5,
) -> None:
self._timeout = timeout
self._loop = loop
self._ceil_threshold = ceil_threshold
self._callbacks: List[
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
] = []
@@ -643,11 +640,11 @@ class TimeoutHandle:
def close(self) -> None:
self._callbacks.clear()
def start(self) -> Optional[asyncio.Handle]:
def start(self) -> Optional[asyncio.TimerHandle]:
timeout = self._timeout
if timeout is not None and timeout > 0:
when = self._loop.time() + timeout
if timeout >= 5:
if timeout >= self._ceil_threshold:
when = ceil(when)
return self._loop.call_at(when, self.__call__)
else:
@@ -670,10 +667,17 @@ class TimeoutHandle:
class BaseTimerContext(ContextManager["BaseTimerContext"]):
pass
__slots__ = ()
def assert_timeout(self) -> None:
"""Raise TimeoutError if timeout has been exceeded."""
class TimerNoop(BaseTimerContext):
__slots__ = ()
def __enter__(self) -> BaseTimerContext:
return self
@@ -689,19 +693,32 @@ class TimerNoop(BaseTimerContext):
class TimerContext(BaseTimerContext):
"""Low resolution timeout context manager"""
__slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling")
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._tasks: List[asyncio.Task[Any]] = []
self._cancelled = False
self._cancelling = 0
def assert_timeout(self) -> None:
"""Raise TimeoutError if timer has already been cancelled."""
if self._cancelled:
raise asyncio.TimeoutError from None
def __enter__(self) -> BaseTimerContext:
task = current_task(loop=self._loop)
task = asyncio.current_task(loop=self._loop)
if task is None:
raise RuntimeError(
"Timeout context manager should be used " "inside a task"
)
if sys.version_info >= (3, 11):
# Remember if the task was already cancelling
# so when we __exit__ we can decide if we should
# raise asyncio.TimeoutError or let the cancellation propagate
self._cancelling = task.cancelling()
if self._cancelled:
raise asyncio.TimeoutError from None
@@ -714,11 +731,22 @@ class TimerContext(BaseTimerContext):
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
enter_task: Optional[asyncio.Task[Any]] = None
if self._tasks:
self._tasks.pop()
enter_task = self._tasks.pop()
if exc_type is asyncio.CancelledError and self._cancelled:
raise asyncio.TimeoutError from None
assert enter_task is not None
# The timeout was hit, and the task was cancelled
# so we need to uncancel the last task that entered the context manager
# since the cancellation should not leak out of the context manager
if sys.version_info >= (3, 11):
# If the task was already cancelling don't raise
# asyncio.TimeoutError and instead return None
# to allow the cancellation to propagate
if enter_task.uncancel() > self._cancelling:
return None
raise asyncio.TimeoutError from exc_val
return None
def timeout(self) -> None:
@@ -729,27 +757,30 @@ class TimerContext(BaseTimerContext):
self._cancelled = True
def ceil_timeout(delay: Optional[float]) -> async_timeout.Timeout:
def ceil_timeout(
delay: Optional[float], ceil_threshold: float = 5
) -> async_timeout.Timeout:
if delay is None or delay <= 0:
return async_timeout.timeout(None)
loop = get_running_loop()
loop = asyncio.get_running_loop()
now = loop.time()
when = now + delay
if delay > 5:
if delay > ceil_threshold:
when = ceil(when)
return async_timeout.timeout_at(when)
class HeadersMixin:
ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"])
_headers: MultiMapping[str]
_content_type: Optional[str] = None
_content_dict: Optional[Dict[str, str]] = None
_stored_content_type = sentinel
_stored_content_type: Union[str, None, _SENTINEL] = sentinel
def _parse_content_type(self, raw: str) -> None:
def _parse_content_type(self, raw: Optional[str]) -> None:
self._stored_content_type = raw
if raw is None:
# default value according to RFC 2616
@@ -758,36 +789,32 @@ class HeadersMixin:
else:
msg = HeaderParser().parsestr("Content-Type: " + raw)
self._content_type = msg.get_content_type()
params = msg.get_params()
params = msg.get_params(())
self._content_dict = dict(params[1:]) # First element is content type again
@property
def content_type(self) -> str:
"""The value of content part for Content-Type HTTP header."""
raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore[attr-defined]
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
return self._content_type # type: ignore[return-value]
assert self._content_type is not None
return self._content_type
@property
def charset(self) -> Optional[str]:
"""The value of charset part for Content-Type HTTP header."""
raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore[attr-defined]
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
return self._content_dict.get("charset") # type: ignore[union-attr]
assert self._content_dict is not None
return self._content_dict.get("charset")
@property
def content_length(self) -> Optional[int]:
"""The value of Content-Length HTTP header."""
content_length = self._headers.get( # type: ignore[attr-defined]
hdrs.CONTENT_LENGTH
)
if content_length is not None:
return int(content_length)
else:
return None
content_length = self._headers.get(hdrs.CONTENT_LENGTH)
return None if content_length is None else int(content_length)
def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
@@ -795,15 +822,91 @@ def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
fut.set_result(result)
def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None:
if not fut.done():
fut.set_exception(exc)
_EXC_SENTINEL = BaseException()
class ChainMapProxy(Mapping[str, Any]):
class ErrorableProtocol(Protocol):
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = ...,
) -> None: ... # pragma: no cover
def set_exception(
fut: "asyncio.Future[_T] | ErrorableProtocol",
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
"""Set future exception.
If the future is marked as complete, this function is a no-op.
:param exc_cause: An exception that is a direct cause of ``exc``.
Only set if provided.
"""
if asyncio.isfuture(fut) and fut.done():
return
exc_is_sentinel = exc_cause is _EXC_SENTINEL
exc_causes_itself = exc is exc_cause
if not exc_is_sentinel and not exc_causes_itself:
exc.__cause__ = exc_cause
fut.set_exception(exc)
@functools.total_ordering
class AppKey(Generic[_T]):
"""Keys for static typing support in Application."""
__slots__ = ("_name", "_t", "__orig_class__")
# This may be set by Python when instantiating with a generic type. We need to
# support this, in order to support types that are not concrete classes,
# like Iterable, which can't be passed as the second parameter to __init__.
__orig_class__: Type[object]
def __init__(self, name: str, t: Optional[Type[_T]] = None):
# Prefix with module name to help deduplicate key names.
frame = inspect.currentframe()
while frame:
if frame.f_code.co_name == "<module>":
module: str = frame.f_globals["__name__"]
break
frame = frame.f_back
self._name = module + "." + name
self._t = t
def __lt__(self, other: object) -> bool:
if isinstance(other, AppKey):
return self._name < other._name
return True # Order AppKey above other types.
def __repr__(self) -> str:
t = self._t
if t is None:
with suppress(AttributeError):
# Set to type arg.
t = get_args(self.__orig_class__)[0]
if t is None:
t_repr = "<<Unknown>>"
elif isinstance(t, type):
if t.__module__ == "builtins":
t_repr = t.__qualname__
else:
t_repr = f"{t.__module__}.{t.__qualname__}"
else:
t_repr = repr(t)
return f"<AppKey({self._name}, type={t_repr})>"
class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]):
__slots__ = ("_maps",)
def __init__(self, maps: Iterable[Mapping[str, Any]]) -> None:
def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None:
self._maps = tuple(maps)
def __init_subclass__(cls) -> None:
@@ -812,7 +915,13 @@ class ChainMapProxy(Mapping[str, Any]):
"is forbidden".format(cls.__name__)
)
def __getitem__(self, key: str) -> Any:
@overload # type: ignore[override]
def __getitem__(self, key: AppKey[_T]) -> _T: ...
@overload
def __getitem__(self, key: str) -> Any: ...
def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
for mapping in self._maps:
try:
return mapping[key]
@@ -820,15 +929,27 @@ class ChainMapProxy(Mapping[str, Any]):
pass
raise KeyError(key)
def get(self, key: str, default: Any = None) -> Any:
return self[key] if key in self else default
@overload # type: ignore[override]
def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ...
@overload
def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
@overload
def get(self, key: str, default: Any = ...) -> Any: ...
def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
try:
return self[key]
except KeyError:
return default
def __len__(self) -> int:
# reuses stored hash values if possible
return len(set().union(*self._maps)) # type: ignore[arg-type]
return len(set().union(*self._maps))
def __iter__(self) -> Iterator[str]:
d: Dict[str, Any] = {}
def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
d: Dict[Union[str, AppKey[Any]], Any] = {}
for mapping in reversed(self._maps):
# reuses stored hash values if possible
d.update(mapping)
@@ -846,7 +967,7 @@ class ChainMapProxy(Mapping[str, Any]):
# https://tools.ietf.org/html/rfc7232#section-2.3
_ETAGC = r"[!#-}\x80-\xff]+"
_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
_ETAGC_RE = re.compile(_ETAGC)
_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
@@ -876,3 +997,40 @@ def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
with suppress(ValueError):
return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
return None
@functools.lru_cache
def must_be_empty_body(method: str, code: int) -> bool:
"""Check if a request must return an empty body."""
return (
status_code_must_be_empty_body(code)
or method_must_be_empty_body(method)
or (200 <= code < 300 and method.upper() == hdrs.METH_CONNECT)
)
def method_must_be_empty_body(method: str) -> bool:
"""Check if a method must return an empty body."""
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
return method.upper() == hdrs.METH_HEAD
def status_code_must_be_empty_body(code: int) -> bool:
"""Check if a status code must return an empty body."""
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
return code in {204, 304} or 100 <= code < 200
def should_remove_content_length(method: str, code: int) -> bool:
"""Check if a Content-Length header should be removed.
This should always be a subset of must_be_empty_body
"""
# https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
# https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
return (
code in {204, 304}
or 100 <= code < 200
or (200 <= code < 300 and method.upper() == hdrs.METH_CONNECT)
)

View File

@@ -1,5 +1,5 @@
import http.server
import sys
from http import HTTPStatus
from typing import Mapping, Tuple
from . import __version__
@@ -67,4 +67,6 @@ SERVER_SOFTWARE: str = "Python/{0[0]}.{0[1]} aiohttp/{1}".format(
sys.version_info, __version__
)
RESPONSES: Mapping[int, Tuple[str, str]] = http.server.BaseHTTPRequestHandler.responses
RESPONSES: Mapping[int, Tuple[str, str]] = {
v: (v.phrase, v.description) for v in HTTPStatus.__members__.values()
}

View File

@@ -1,6 +1,5 @@
"""Low-level http related exceptions."""
from textwrap import indent
from typing import Optional, Union
@@ -87,18 +86,17 @@ class LineTooLong(BadHttpMessage):
class InvalidHeader(BadHttpMessage):
def __init__(self, hdr: Union[bytes, str]) -> None:
if isinstance(hdr, bytes):
hdr = hdr.decode("utf-8", "surrogateescape")
super().__init__(f"Invalid HTTP Header: {hdr}")
self.hdr = hdr
hdr_s = hdr.decode(errors="backslashreplace") if isinstance(hdr, bytes) else hdr
super().__init__(f"Invalid HTTP header: {hdr!r}")
self.hdr = hdr_s
self.args = (hdr,)
class BadStatusLine(BadHttpMessage):
def __init__(self, line: str = "") -> None:
def __init__(self, line: str = "", error: Optional[str] = None) -> None:
if not isinstance(line, str):
line = repr(line)
super().__init__(f"Bad status line {line!r}")
super().__init__(error or f"Bad status line {line!r}")
self.args = (line,)
self.line = line

View File

@@ -1,15 +1,16 @@
import abc
import asyncio
import collections
import re
import string
import zlib
from contextlib import suppress
from enum import IntEnum
from typing import (
Any,
ClassVar,
Final,
Generic,
List,
Literal,
NamedTuple,
Optional,
Pattern,
@@ -18,7 +19,6 @@ from typing import (
Type,
TypeVar,
Union,
cast,
)
from multidict import CIMultiDict, CIMultiDictProxy, istr
@@ -26,28 +26,29 @@ from yarl import URL
from . import hdrs
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS, BaseTimerContext
from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor
from .helpers import (
_EXC_SENTINEL,
DEBUG,
NO_EXTENSIONS,
BaseTimerContext,
method_must_be_empty_body,
set_exception,
status_code_must_be_empty_body,
)
from .http_exceptions import (
BadHttpMessage,
BadStatusLine,
ContentEncodingError,
ContentLengthError,
InvalidHeader,
InvalidURLError,
LineTooLong,
TransferEncodingError,
)
from .http_writer import HttpVersion, HttpVersion10
from .log import internal_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import Final, RawHeaders
try:
import brotli
HAS_BROTLI = True
except ImportError: # pragma: no cover
HAS_BROTLI = False
from .typedefs import RawHeaders
__all__ = (
"HeadersParser",
@@ -58,18 +59,22 @@ __all__ = (
"RawResponseMessage",
)
_SEP = Literal[b"\r\n", b"\n"]
ASCIISET: Final[Set[str]] = set(string.printable)
# See https://tools.ietf.org/html/rfc7230#section-3.1.1
# and https://tools.ietf.org/html/rfc7230#appendix-B
# See https://www.rfc-editor.org/rfc/rfc9110.html#name-overview
# and https://www.rfc-editor.org/rfc/rfc9110.html#name-tokens
#
# method = token
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
# token = 1*tchar
METHRE: Final[Pattern[str]] = re.compile(r"[!#$%&'*+\-.^_`|~0-9A-Za-z]+")
VERSRE: Final[Pattern[str]] = re.compile(r"HTTP/(\d+).(\d+)")
HDRRE: Final[Pattern[bytes]] = re.compile(rb"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]")
_TCHAR_SPECIALS: Final[str] = re.escape("!#$%&'*+-.^_`|~")
TOKENRE: Final[Pattern[str]] = re.compile(f"[0-9A-Za-z{_TCHAR_SPECIALS}]+")
VERSRE: Final[Pattern[str]] = re.compile(r"HTTP/(\d)\.(\d)", re.ASCII)
DIGITS: Final[Pattern[str]] = re.compile(r"\d+", re.ASCII)
HEXDIGITS: Final[Pattern[bytes]] = re.compile(rb"[0-9a-fA-F]+")
class RawRequestMessage(NamedTuple):
@@ -85,20 +90,16 @@ class RawRequestMessage(NamedTuple):
url: URL
RawResponseMessage = collections.namedtuple(
"RawResponseMessage",
[
"version",
"code",
"reason",
"headers",
"raw_headers",
"should_close",
"compression",
"upgrade",
"chunked",
],
)
class RawResponseMessage(NamedTuple):
version: HttpVersion
code: int
reason: str
headers: CIMultiDictProxy[str]
raw_headers: RawHeaders
should_close: bool
compression: Optional[str]
upgrade: bool
chunked: bool
_MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage)
@@ -126,15 +127,18 @@ class HeadersParser:
max_line_size: int = 8190,
max_headers: int = 32768,
max_field_size: int = 8190,
lax: bool = False,
) -> None:
self.max_line_size = max_line_size
self.max_headers = max_headers
self.max_field_size = max_field_size
self._lax = lax
def parse_headers(
self, lines: List[bytes]
) -> Tuple["CIMultiDictProxy[str]", RawHeaders]:
headers: CIMultiDict[str] = CIMultiDict()
# note: "raw" does not mean inclusion of OWS before/after the field value
raw_headers = []
lines_idx = 1
@@ -148,18 +152,25 @@ class HeadersParser:
except ValueError:
raise InvalidHeader(line) from None
bname = bname.strip(b" \t")
bvalue = bvalue.lstrip()
if HDRRE.search(bname):
if len(bname) == 0:
raise InvalidHeader(bname)
# https://www.rfc-editor.org/rfc/rfc9112.html#section-5.1-2
if {bname[0], bname[-1]} & {32, 9}: # {" ", "\t"}
raise InvalidHeader(line)
bvalue = bvalue.lstrip(b" \t")
if len(bname) > self.max_field_size:
raise LineTooLong(
"request header name {}".format(
bname.decode("utf8", "xmlcharrefreplace")
bname.decode("utf8", "backslashreplace")
),
str(self.max_field_size),
str(len(bname)),
)
name = bname.decode("utf-8", "surrogateescape")
if not TOKENRE.fullmatch(name):
raise InvalidHeader(bname)
header_length = len(bvalue)
@@ -168,8 +179,9 @@ class HeadersParser:
line = lines[lines_idx]
# consume continuation lines
continuation = line and line[0] in (32, 9) # (' ', '\t')
continuation = self._lax and line and line[0] in (32, 9) # (' ', '\t')
# Deprecated: https://www.rfc-editor.org/rfc/rfc9112.html#name-obsolete-line-folding
if continuation:
bvalue_lst = [bvalue]
while continuation:
@@ -177,7 +189,7 @@ class HeadersParser:
if header_length > self.max_field_size:
raise LineTooLong(
"request header field {}".format(
bname.decode("utf8", "xmlcharrefreplace")
bname.decode("utf8", "backslashreplace")
),
str(self.max_field_size),
str(header_length),
@@ -198,23 +210,33 @@ class HeadersParser:
if header_length > self.max_field_size:
raise LineTooLong(
"request header field {}".format(
bname.decode("utf8", "xmlcharrefreplace")
bname.decode("utf8", "backslashreplace")
),
str(self.max_field_size),
str(header_length),
)
bvalue = bvalue.strip()
name = bname.decode("utf-8", "surrogateescape")
bvalue = bvalue.strip(b" \t")
value = bvalue.decode("utf-8", "surrogateescape")
# https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-5
if "\n" in value or "\r" in value or "\x00" in value:
raise InvalidHeader(bvalue)
headers.add(name, value)
raw_headers.append((bname, bvalue))
return (CIMultiDictProxy(headers), tuple(raw_headers))
def _is_supported_upgrade(headers: CIMultiDictProxy[str]) -> bool:
"""Check if the upgrade header is supported."""
return headers.get(hdrs.UPGRADE, "").lower() in {"tcp", "websocket"}
class HttpParser(abc.ABC, Generic[_MsgT]):
lax: ClassVar[bool] = False
def __init__(
self,
protocol: Optional[BaseProtocol] = None,
@@ -226,7 +248,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
timer: Optional[BaseTimerContext] = None,
code: Optional[int] = None,
method: Optional[str] = None,
readall: bool = False,
payload_exception: Optional[Type[BaseException]] = None,
response_with_body: bool = True,
read_until_eof: bool = False,
@@ -240,7 +261,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
self.timer = timer
self.code = code
self.method = method
self.readall = readall
self.payload_exception = payload_exception
self.response_with_body = response_with_body
self.read_until_eof = read_until_eof
@@ -252,11 +272,15 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
self._payload_parser: Optional[HttpPayloadParser] = None
self._auto_decompress = auto_decompress
self._limit = limit
self._headers_parser = HeadersParser(max_line_size, max_headers, max_field_size)
self._headers_parser = HeadersParser(
max_line_size, max_headers, max_field_size, self.lax
)
@abc.abstractmethod
def parse_message(self, lines: List[bytes]) -> _MsgT:
pass
def parse_message(self, lines: List[bytes]) -> _MsgT: ...
@abc.abstractmethod
def _is_chunked_te(self, te: str) -> bool: ...
def feed_eof(self) -> Optional[_MsgT]:
if self._payload_parser is not None:
@@ -277,7 +301,7 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
def feed_data(
self,
data: bytes,
SEP: bytes = b"\r\n",
SEP: _SEP = b"\r\n",
EMPTY: bytes = b"",
CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH,
METH_CONNECT: str = hdrs.METH_CONNECT,
@@ -293,6 +317,7 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
start_pos = 0
loop = self.loop
should_close = False
while start_pos < data_len:
# read HTTP message (request/response line + headers), \r\n\r\n
@@ -301,13 +326,19 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
pos = data.find(SEP, start_pos)
# consume \r\n
if pos == start_pos and not self._lines:
start_pos = pos + 2
start_pos = pos + len(SEP)
continue
if pos >= start_pos:
if should_close:
raise BadHttpMessage("Data after `Connection: close`")
# line found
self._lines.append(data[start_pos:pos])
start_pos = pos + 2
line = data[start_pos:pos]
if SEP == b"\n": # For lax response parsing
line = line.rstrip(b"\r")
self._lines.append(line)
start_pos = pos + len(SEP)
# \r\n\r\n found
if self._lines[-1] == EMPTY:
@@ -322,31 +353,35 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
if length_hdr is None:
return None
try:
length = int(length_hdr)
except ValueError:
# Shouldn't allow +/- or other number formats.
# https://www.rfc-editor.org/rfc/rfc9110#section-8.6-2
# msg.headers is already stripped of leading/trailing wsp
if not DIGITS.fullmatch(length_hdr):
raise InvalidHeader(CONTENT_LENGTH)
if length < 0:
raise InvalidHeader(CONTENT_LENGTH)
return length
return int(length_hdr)
length = get_content_length()
# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in msg.headers:
raise InvalidHeader(SEC_WEBSOCKET_KEY1)
self._upgraded = msg.upgrade
self._upgraded = msg.upgrade and _is_supported_upgrade(
msg.headers
)
method = getattr(msg, "method", self.method)
# code is only present on responses
code = getattr(msg, "code", 0)
assert self.protocol is not None
# calculate payload
if (
(length is not None and length > 0)
or msg.chunked
and not msg.upgrade
empty_body = status_code_must_be_empty_body(code) or bool(
method and method_must_be_empty_body(method)
)
if not empty_body and (
((length is not None and length > 0) or msg.chunked)
and not self._upgraded
):
payload = StreamReader(
self.protocol,
@@ -361,9 +396,9 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
method=method,
compression=msg.compression,
code=self.code,
readall=self.readall,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
@@ -380,38 +415,34 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
payload,
method=msg.method,
compression=msg.compression,
readall=True,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
elif not empty_body and length is None and self.read_until_eof:
payload = StreamReader(
self.protocol,
timer=self.timer,
loop=loop,
limit=self._limit,
)
payload_parser = HttpPayloadParser(
payload,
length=length,
chunked=msg.chunked,
method=method,
compression=msg.compression,
code=self.code,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
else:
if (
getattr(msg, "code", 100) >= 199
and length is None
and self.read_until_eof
):
payload = StreamReader(
self.protocol,
timer=self.timer,
loop=loop,
limit=self._limit,
)
payload_parser = HttpPayloadParser(
payload,
length=length,
chunked=msg.chunked,
method=method,
compression=msg.compression,
code=self.code,
readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
)
if not payload_parser.done:
self._payload_parser = payload_parser
else:
payload = EMPTY_PAYLOAD
payload = EMPTY_PAYLOAD
messages.append((msg, payload))
should_close = msg.should_close
else:
self._tail = data[start_pos:]
data = EMPTY
@@ -427,14 +458,17 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
assert not self._lines
assert self._payload_parser is not None
try:
eof, data = self._payload_parser.feed_data(data[start_pos:])
except BaseException as exc:
eof, data = self._payload_parser.feed_data(data[start_pos:], SEP)
except BaseException as underlying_exc:
reraised_exc = underlying_exc
if self.payload_exception is not None:
self._payload_parser.payload.set_exception(
self.payload_exception(str(exc))
)
else:
self._payload_parser.payload.set_exception(exc)
reraised_exc = self.payload_exception(str(underlying_exc))
set_exception(
self._payload_parser.payload,
reraised_exc,
underlying_exc,
)
eof = True
data = b""
@@ -470,6 +504,24 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
upgrade = False
chunked = False
# https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-6
# https://www.rfc-editor.org/rfc/rfc9110.html#name-collected-abnf
singletons = (
hdrs.CONTENT_LENGTH,
hdrs.CONTENT_LOCATION,
hdrs.CONTENT_RANGE,
hdrs.CONTENT_TYPE,
hdrs.ETAG,
hdrs.HOST,
hdrs.MAX_FORWARDS,
hdrs.SERVER,
hdrs.TRANSFER_ENCODING,
hdrs.USER_AGENT,
)
bad_hdr = next((h for h in singletons if len(headers.getall(h, ())) > 1), None)
if bad_hdr is not None:
raise BadHttpMessage(f"Duplicate '{bad_hdr}' header found.")
# keep-alive
conn = headers.get(hdrs.CONNECTION)
if conn:
@@ -478,7 +530,8 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
close_conn = True
elif v == "keep-alive":
close_conn = False
elif v == "upgrade":
# https://www.rfc-editor.org/rfc/rfc9110.html#name-101-switching-protocols
elif v == "upgrade" and headers.get(hdrs.UPGRADE):
upgrade = True
# encoding
@@ -491,14 +544,12 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
# chunking
te = headers.get(hdrs.TRANSFER_ENCODING)
if te is not None:
if "chunked" == te.lower():
if self._is_chunked_te(te):
chunked = True
else:
raise BadHttpMessage("Request has invalid `Transfer-Encoding`")
if hdrs.CONTENT_LENGTH in headers:
raise BadHttpMessage(
"Content-Length can't be present with Transfer-Encoding",
"Transfer-Encoding can't be present with Content-Length",
)
return (headers, raw_headers, close_conn, encoding, upgrade, chunked)
@@ -523,7 +574,7 @@ class HttpRequestParser(HttpParser[RawRequestMessage]):
# request line
line = lines[0].decode("utf-8", "surrogateescape")
try:
method, path, version = line.split(None, 2)
method, path, version = line.split(" ", maxsplit=2)
except ValueError:
raise BadStatusLine(line) from None
@@ -533,18 +584,14 @@ class HttpRequestParser(HttpParser[RawRequestMessage]):
)
# method
if not METHRE.match(method):
if not TOKENRE.fullmatch(method):
raise BadStatusLine(method)
# version
try:
if version.startswith("HTTP/"):
n1, n2 = version[5:].split(".", 1)
version_o = HttpVersion(int(n1), int(n2))
else:
raise BadStatusLine(version)
except Exception:
raise BadStatusLine(version)
match = VERSRE.fullmatch(version)
if match is None:
raise BadStatusLine(line)
version_o = HttpVersion(int(match.group(1)), int(match.group(2)))
if method == "CONNECT":
# authority-form,
@@ -566,10 +613,18 @@ class HttpRequestParser(HttpParser[RawRequestMessage]):
fragment=url_fragment,
encoded=True,
)
elif path == "*" and method == "OPTIONS":
# asterisk-form,
url = URL(path, encoded=True)
else:
# absolute-form for proxy maybe,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2
url = URL(path, encoded=True)
if url.scheme == "":
# not absolute-form
raise InvalidURLError(
path.encode(errors="surrogateescape").decode("latin1")
)
# read headers
(
@@ -600,6 +655,12 @@ class HttpRequestParser(HttpParser[RawRequestMessage]):
url,
)
def _is_chunked_te(self, te: str) -> bool:
if te.rsplit(",", maxsplit=1)[-1].strip(" \t").lower() == "chunked":
return True
# https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.3
raise BadHttpMessage("Request has invalid `Transfer-Encoding`")
class HttpResponseParser(HttpParser[RawResponseMessage]):
"""Read response status line and headers.
@@ -608,16 +669,31 @@ class HttpResponseParser(HttpParser[RawResponseMessage]):
Returns RawResponseMessage.
"""
# Lax mode should only be enabled on response parser.
lax = not DEBUG
def feed_data(
self,
data: bytes,
SEP: Optional[_SEP] = None,
*args: Any,
**kwargs: Any,
) -> Tuple[List[Tuple[RawResponseMessage, StreamReader]], bool, bytes]:
if SEP is None:
SEP = b"\r\n" if DEBUG else b"\n"
return super().feed_data(data, SEP, *args, **kwargs)
def parse_message(self, lines: List[bytes]) -> RawResponseMessage:
line = lines[0].decode("utf-8", "surrogateescape")
try:
version, status = line.split(None, 1)
version, status = line.split(maxsplit=1)
except ValueError:
raise BadStatusLine(line) from None
try:
status, reason = status.split(None, 1)
status, reason = status.split(maxsplit=1)
except ValueError:
status = status.strip()
reason = ""
if len(reason) > self.max_line_size:
@@ -626,19 +702,15 @@ class HttpResponseParser(HttpParser[RawResponseMessage]):
)
# version
match = VERSRE.match(version)
match = VERSRE.fullmatch(version)
if match is None:
raise BadStatusLine(line)
version_o = HttpVersion(int(match.group(1)), int(match.group(2)))
# The status code is a three-digit number
try:
status_i = int(status)
except ValueError:
raise BadStatusLine(line) from None
if status_i > 999:
# The status code is a three-digit ASCII number, no padding
if len(status) != 3 or not DIGITS.fullmatch(status):
raise BadStatusLine(line)
status_i = int(status)
# read headers
(
@@ -651,7 +723,16 @@ class HttpResponseParser(HttpParser[RawResponseMessage]):
) = self.parse_headers(lines)
if close is None:
close = version_o <= HttpVersion10
if version_o <= HttpVersion10:
close = True
# https://www.rfc-editor.org/rfc/rfc9112.html#name-message-body-length
elif 100 <= status_i < 200 or status_i in {204, 304}:
close = False
elif hdrs.CONTENT_LENGTH in headers or hdrs.TRANSFER_ENCODING in headers:
close = False
else:
# https://www.rfc-editor.org/rfc/rfc9112.html#section-6.3-2.8
close = True
return RawResponseMessage(
version_o,
@@ -665,6 +746,10 @@ class HttpResponseParser(HttpParser[RawResponseMessage]):
chunked,
)
def _is_chunked_te(self, te: str) -> bool:
# https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.2
return te.rsplit(",", maxsplit=1)[-1].strip(" \t").lower() == "chunked"
class HttpPayloadParser:
def __init__(
@@ -675,16 +760,17 @@ class HttpPayloadParser:
compression: Optional[str] = None,
code: Optional[int] = None,
method: Optional[str] = None,
readall: bool = False,
response_with_body: bool = True,
auto_decompress: bool = True,
lax: bool = False,
) -> None:
self._length = 0
self._type = ParseState.PARSE_NONE
self._type = ParseState.PARSE_UNTIL_EOF
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
self._chunk_size = 0
self._chunk_tail = b""
self._auto_decompress = auto_decompress
self._lax = lax
self.done = False
# payload decompression wrapper
@@ -701,7 +787,6 @@ class HttpPayloadParser:
self._type = ParseState.PARSE_NONE
real_payload.feed_eof()
self.done = True
elif chunked:
self._type = ParseState.PARSE_CHUNKED
elif length is not None:
@@ -710,16 +795,6 @@ class HttpPayloadParser:
if self._length == 0:
real_payload.feed_eof()
self.done = True
else:
if readall and code != 204:
self._type = ParseState.PARSE_UNTIL_EOF
elif method in ("PUT", "POST"):
internal_logger.warning( # pragma: no cover
"Content-Length or Transfer-Encoding header is required"
)
self._type = ParseState.PARSE_NONE
real_payload.feed_eof()
self.done = True
self.payload = real_payload
@@ -736,7 +811,7 @@ class HttpPayloadParser:
)
def feed_data(
self, chunk: bytes, SEP: bytes = b"\r\n", CHUNK_EXT: bytes = b";"
self, chunk: bytes, SEP: _SEP = b"\r\n", CHUNK_EXT: bytes = b";"
) -> Tuple[bool, bytes]:
# Read specified amount of bytes
if self._type == ParseState.PARSE_LENGTH:
@@ -770,21 +845,32 @@ class HttpPayloadParser:
i = chunk.find(CHUNK_EXT, 0, pos)
if i >= 0:
size_b = chunk[:i] # strip chunk-extensions
# Verify no LF in the chunk-extension
if b"\n" in (ext := chunk[i:pos]):
exc = BadHttpMessage(
f"Unexpected LF in chunk-extension: {ext!r}"
)
set_exception(self.payload, exc)
raise exc
else:
size_b = chunk[:pos]
try:
size = int(bytes(size_b), 16)
except ValueError:
if self._lax: # Allow whitespace in lax mode.
size_b = size_b.strip()
if not re.fullmatch(HEXDIGITS, size_b):
exc = TransferEncodingError(
chunk[:pos].decode("ascii", "surrogateescape")
)
self.payload.set_exception(exc)
raise exc from None
set_exception(self.payload, exc)
raise exc
size = int(bytes(size_b), 16)
chunk = chunk[pos + 2 :]
chunk = chunk[pos + len(SEP) :]
if size == 0: # eof marker
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
if self._lax and chunk.startswith(b"\r"):
chunk = chunk[1:]
else:
self._chunk = ChunkState.PARSE_CHUNKED_CHUNK
self._chunk_size = size
@@ -811,8 +897,10 @@ class HttpPayloadParser:
# toss the CRLF at the end of the chunk
if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF:
if chunk[:2] == SEP:
chunk = chunk[2:]
if self._lax and chunk.startswith(b"\r"):
chunk = chunk[1:]
if chunk[: len(SEP)] == SEP:
chunk = chunk[len(SEP) :]
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
else:
self._chunk_tail = chunk
@@ -820,13 +908,13 @@ class HttpPayloadParser:
# if stream does not contain trailer, after 0\r\n
# we should get another \r\n otherwise
# trailers needs to be skiped until \r\n\r\n
# trailers needs to be skipped until \r\n\r\n
if self._chunk == ChunkState.PARSE_MAYBE_TRAILERS:
head = chunk[:2]
head = chunk[: len(SEP)]
if head == SEP:
# end of stream
self.payload.feed_eof()
return True, chunk[2:]
return True, chunk[len(SEP) :]
# Both CR and LF, or only LF may not be received yet. It is
# expected that CRLF or LF will be shown at the very first
# byte next time, otherwise trailers should come. The last
@@ -844,7 +932,7 @@ class HttpPayloadParser:
if self._chunk == ChunkState.PARSE_TRAILERS:
pos = chunk.find(SEP)
if pos >= 0:
chunk = chunk[pos + 2 :]
chunk = chunk[pos + len(SEP) :]
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
else:
self._chunk_tail = chunk
@@ -868,37 +956,23 @@ class DeflateBuffer:
self.encoding = encoding
self._started_decoding = False
self.decompressor: Union[BrotliDecompressor, ZLibDecompressor]
if encoding == "br":
if not HAS_BROTLI: # pragma: no cover
raise ContentEncodingError(
"Can not decode content-encoding: brotli (br). "
"Please install `Brotli`"
)
class BrotliDecoder:
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(self) -> None:
self._obj = brotli.Decompressor()
def decompress(self, data: bytes) -> bytes:
if hasattr(self._obj, "decompress"):
return cast(bytes, self._obj.decompress(data))
return cast(bytes, self._obj.process(data))
def flush(self) -> bytes:
if hasattr(self._obj, "flush"):
return cast(bytes, self._obj.flush())
return b""
self.decompressor = BrotliDecoder()
self.decompressor = BrotliDecompressor()
else:
zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS
self.decompressor = zlib.decompressobj(wbits=zlib_mode)
self.decompressor = ZLibDecompressor(encoding=encoding)
def set_exception(self, exc: BaseException) -> None:
self.out.set_exception(exc)
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
set_exception(self.out, exc, exc_cause)
def feed_data(self, chunk: bytes, size: int) -> None:
if not size:
@@ -916,10 +990,12 @@ class DeflateBuffer:
):
# Change the decoder to decompress incorrectly compressed data
# Actually we should issue a warning about non-RFC-compliant data.
self.decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
self.decompressor = ZLibDecompressor(
encoding=self.encoding, suppress_deflate_header=True
)
try:
chunk = self.decompressor.decompress(chunk)
chunk = self.decompressor.decompress_sync(chunk)
except Exception:
raise ContentEncodingError(
"Can not decode content-encoding: %s" % self.encoding
@@ -954,7 +1030,7 @@ RawResponseMessagePy = RawResponseMessage
try:
if not NO_EXTENSIONS:
from ._http_parser import ( # type: ignore[import,no-redef]
from ._http_parser import ( # type: ignore[import-not-found,no-redef]
HttpRequestParser,
HttpResponseParser,
RawRequestMessage,

View File

@@ -1,20 +1,34 @@
"""WebSocket protocol versions 13 and 8."""
import asyncio
import collections
import functools
import json
import random
import re
import sys
import zlib
from enum import IntEnum
from functools import partial
from struct import Struct
from typing import Any, Callable, List, Optional, Pattern, Set, Tuple, Union, cast
from typing import (
Any,
Callable,
Final,
List,
NamedTuple,
Optional,
Pattern,
Set,
Tuple,
Union,
cast,
)
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor, ZLibDecompressor
from .helpers import NO_EXTENSIONS, set_exception
from .streams import DataQueue
from .typedefs import Final
__all__ = (
"WS_CLOSED_MESSAGE",
@@ -47,6 +61,15 @@ class WSCloseCode(IntEnum):
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
# For websockets, keeping latency low is extremely important as implementations
# generally expect to be able to send and receive messages quickly. We use a
# larger chunk size than the default to reduce the number of executor calls
# since the executor is a significant source of latency and overhead when
# the chunks are small. A size of 5KiB was chosen because it is also the
# same value python-zlib-ng choose to use as the threshold to release the GIL.
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
class WSMsgType(IntEnum):
# websocket spec types
@@ -72,6 +95,14 @@ class WSMsgType(IntEnum):
error = ERROR
MESSAGE_TYPES_WITH_CONTENT: Final = frozenset(
{
WSMsgType.BINARY,
WSMsgType.TEXT,
WSMsgType.CONTINUATION,
}
)
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
@@ -82,14 +113,18 @@ PACK_LEN1 = Struct("!BB").pack
PACK_LEN2 = Struct("!BBH").pack
PACK_LEN3 = Struct("!BBQ").pack
PACK_CLOSE_CODE = Struct("!H").pack
PACK_RANDBITS = Struct("!L").pack
MSG_SIZE: Final[int] = 2**14
DEFAULT_LIMIT: Final[int] = 2**16
MASK_LEN: Final[int] = 4
_WSMessageBase = collections.namedtuple("_WSMessageBase", ["type", "data", "extra"])
class WSMessage(NamedTuple):
type: WSMsgType
# To type correctly, this would need some kind of tagged union for each type.
data: Any
extra: Optional[str]
class WSMessage(_WSMessageBase):
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
"""Return parsed JSON data.
@@ -98,8 +133,12 @@ class WSMessage(_WSMessageBase):
return loads(self.data)
WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
# Constructing the tuple directly to avoid the overhead of
# the lambda and arg processing since NamedTuples are constructed
# with a run time built lambda
# https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441
WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None))
WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None))
class WebSocketError(Exception):
@@ -121,7 +160,9 @@ native_byteorder: Final[str] = sys.byteorder
# Used by _websocket_mask_python
_XOR_TABLE: Final[List[bytes]] = [bytes(a ^ b for a in range(256)) for b in range(256)]
@functools.lru_cache
def _xor_table() -> List[bytes]:
return [bytes(a ^ b for a in range(256)) for b in range(256)]
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
@@ -141,6 +182,7 @@ def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
assert len(mask) == 4, mask
if data:
_XOR_TABLE = _xor_table()
a, b, c, d = (_XOR_TABLE[n] for n in mask)
data[::4] = data[::4].translate(a)
data[1::4] = data[1::4].translate(b)
@@ -152,7 +194,7 @@ if NO_EXTENSIONS: # pragma: no cover
_websocket_mask = _websocket_mask_python
else:
try:
from ._websocket import _websocket_mask_cython # type: ignore[import]
from ._websocket import _websocket_mask_cython # type: ignore[import-not-found]
_websocket_mask = _websocket_mask_cython
except ImportError: # pragma: no cover
@@ -268,13 +310,13 @@ class WebSocketReader:
self._frame_opcode: Optional[int] = None
self._frame_payload = bytearray()
self._tail = b""
self._tail: bytes = b""
self._has_mask = False
self._frame_mask: Optional[bytes] = None
self._payload_length = 0
self._payload_length_flag = 0
self._compressed: Optional[bool] = None
self._decompressobj: Any = None # zlib.decompressobj actually
self._decompressobj: Optional[ZLibDecompressor] = None
self._compress = compress
def feed_eof(self) -> None:
@@ -285,17 +327,103 @@ class WebSocketReader:
return True, data
try:
return self._feed_data(data)
self._feed_data(data)
except Exception as exc:
self._exc = exc
self.queue.set_exception(exc)
set_exception(self.queue, exc)
return True, b""
def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
return False, b""
def _feed_data(self, data: bytes) -> None:
for fin, opcode, payload, compressed in self.parse_frame(data):
if compressed and not self._decompressobj:
self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
if opcode == WSMsgType.CLOSE:
if opcode in MESSAGE_TYPES_WITH_CONTENT:
# load text/binary
is_continuation = opcode == WSMsgType.CONTINUATION
if not fin:
# got partial frame payload
if not is_continuation:
self._opcode = opcode
self._partial += payload
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size
),
)
continue
has_partial = bool(self._partial)
if is_continuation:
if self._opcode is None:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Continuation frame for non started message",
)
opcode = self._opcode
self._opcode = None
# previous frame was non finished
# we should get continuation opcode
elif has_partial:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
"to be zero, got {!r}".format(opcode),
)
if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
else:
assembled_payload = payload
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(assembled_payload), self._max_msg_size
),
)
# Decompress process must to be done after all packets
# received.
if compressed:
if not self._decompressobj:
self._decompressobj = ZLibDecompressor(
suppress_deflate_header=True
)
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + _WS_DEFLATE_TRAILING, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Decompressed message size {} exceeds limit {}".format(
self._max_msg_size + left, self._max_msg_size
),
)
else:
payload_merged = bytes(assembled_payload)
if opcode == WSMsgType.TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
# tuple.__new__ is used to avoid the overhead of the lambda
msg = tuple.__new__(WSMessage, (WSMsgType.TEXT, text, ""))
self.queue.feed_data(msg, len(payload_merged))
continue
# tuple.__new__ is used to avoid the overhead of the lambda
msg = tuple.__new__(WSMessage, (WSMsgType.BINARY, payload_merged, ""))
self.queue.feed_data(msg, len(payload_merged))
elif opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
@@ -309,262 +437,167 @@ class WebSocketReader:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
msg = tuple.__new__(
WSMessage, (WSMsgType.CLOSE, close_code, close_message)
)
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close frame: {fin} {opcode} {payload!r}",
)
else:
msg = WSMessage(WSMsgType.CLOSE, 0, "")
msg = tuple.__new__(WSMessage, (WSMsgType.CLOSE, 0, ""))
self.queue.feed_data(msg, 0)
elif opcode == WSMsgType.PING:
self.queue.feed_data(
WSMessage(WSMsgType.PING, payload, ""), len(payload)
)
msg = tuple.__new__(WSMessage, (WSMsgType.PING, payload, ""))
self.queue.feed_data(msg, len(payload))
elif opcode == WSMsgType.PONG:
self.queue.feed_data(
WSMessage(WSMsgType.PONG, payload, ""), len(payload)
)
msg = tuple.__new__(WSMessage, (WSMsgType.PONG, payload, ""))
self.queue.feed_data(msg, len(payload))
elif (
opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
and self._opcode is None
):
else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
else:
# load text/binary
if not fin:
# got partial frame payload
if opcode != WSMsgType.CONTINUATION:
self._opcode = opcode
self._partial.extend(payload)
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size
),
)
else:
# previous frame was non finished
# we should get continuation opcode
if self._partial:
if opcode != WSMsgType.CONTINUATION:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
"to be zero, got {!r}".format(opcode),
)
if opcode == WSMsgType.CONTINUATION:
assert self._opcode is not None
opcode = self._opcode
self._opcode = None
self._partial.extend(payload)
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size
),
)
# Decompress process must to be done after all packets
# received.
if compressed:
self._partial.extend(_WS_DEFLATE_TRAILING)
payload_merged = self._decompressobj.decompress(
self._partial, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Decompressed message size {} exceeds limit {}".format(
self._max_msg_size + left, self._max_msg_size
),
)
else:
payload_merged = bytes(self._partial)
self._partial.clear()
if opcode == WSMsgType.TEXT:
try:
text = payload_merged.decode("utf-8")
self.queue.feed_data(
WSMessage(WSMsgType.TEXT, text, ""), len(text)
)
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
else:
self.queue.feed_data(
WSMessage(WSMsgType.BINARY, payload_merged, ""),
len(payload_merged),
)
return False, b""
def parse_frame(
self, buf: bytes
) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]:
"""Return the next frame from the socket."""
frames = []
frames: List[Tuple[bool, Optional[int], bytearray, Optional[bool]]] = []
if self._tail:
buf, self._tail = self._tail + buf, b""
start_pos = 0
start_pos: int = 0
buf_length = len(buf)
while True:
# read header
if self._state == WSParserState.READ_HEADER:
if buf_length - start_pos >= 2:
data = buf[start_pos : start_pos + 2]
start_pos += 2
first_byte, second_byte = data
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be " "larger than 125 bytes",
)
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed is None:
self._compressed = True if rsv1 else False
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_length_flag = length
self._state = WSParserState.READ_PAYLOAD_LENGTH
else:
if self._state is WSParserState.READ_HEADER:
if buf_length - start_pos < 2:
break
data = buf[start_pos : start_pos + 2]
start_pos += 2
first_byte, second_byte = data
# read payload length
if self._state == WSParserState.READ_PAYLOAD_LENGTH:
length = self._payload_length_flag
if length == 126:
if buf_length - start_pos >= 2:
data = buf[start_pos : start_pos + 2]
start_pos += 2
length = UNPACK_LEN2(data)[0]
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD
)
else:
break
elif length > 126:
if buf_length - start_pos >= 8:
data = buf[start_pos : start_pos + 8]
start_pos += 8
length = UNPACK_LEN3(data)[0]
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD
)
else:
break
else:
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
# read payload mask
if self._state == WSParserState.READ_PAYLOAD_MASK:
if buf_length - start_pos >= 4:
self._frame_mask = buf[start_pos : start_pos + 4]
start_pos += 4
self._state = WSParserState.READ_PAYLOAD
else:
break
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)
if self._state == WSParserState.READ_PAYLOAD:
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be " "larger than 125 bytes",
)
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed is None:
self._compressed = True if rsv1 else False
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_length_flag = length
self._state = WSParserState.READ_PAYLOAD_LENGTH
# read payload length
if self._state is WSParserState.READ_PAYLOAD_LENGTH:
length_flag = self._payload_length_flag
if length_flag == 126:
if buf_length - start_pos < 2:
break
data = buf[start_pos : start_pos + 2]
start_pos += 2
self._payload_length = UNPACK_LEN2(data)[0]
elif length_flag > 126:
if buf_length - start_pos < 8:
break
data = buf[start_pos : start_pos + 8]
start_pos += 8
self._payload_length = UNPACK_LEN3(data)[0]
else:
self._payload_length = length_flag
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD
)
# read payload mask
if self._state is WSParserState.READ_PAYLOAD_MASK:
if buf_length - start_pos < 4:
break
self._frame_mask = buf[start_pos : start_pos + 4]
start_pos += 4
self._state = WSParserState.READ_PAYLOAD
if self._state is WSParserState.READ_PAYLOAD:
length = self._payload_length
payload = self._frame_payload
chunk_len = buf_length - start_pos
if length >= chunk_len:
self._payload_length = length - chunk_len
payload.extend(buf[start_pos:])
payload += buf[start_pos:]
start_pos = buf_length
else:
self._payload_length = 0
payload.extend(buf[start_pos : start_pos + length])
payload += buf[start_pos : start_pos + length]
start_pos = start_pos + length
if self._payload_length == 0:
if self._has_mask:
assert self._frame_mask is not None
_websocket_mask(self._frame_mask, payload)
frames.append(
(self._frame_fin, self._frame_opcode, payload, self._compressed)
)
self._frame_payload = bytearray()
self._state = WSParserState.READ_HEADER
else:
if self._payload_length != 0:
break
if self._has_mask:
assert self._frame_mask is not None
_websocket_mask(self._frame_mask, payload)
frames.append(
(self._frame_fin, self._frame_opcode, payload, self._compressed)
)
self._frame_payload = bytearray()
self._state = WSParserState.READ_HEADER
self._tail = buf[start_pos:]
return frames
@@ -578,14 +611,14 @@ class WebSocketWriter:
*,
use_mask: bool = False,
limit: int = DEFAULT_LIMIT,
random: Any = random.Random(),
random: random.Random = random.Random(),
compress: int = 0,
notakeover: bool = False,
) -> None:
self.protocol = protocol
self.transport = transport
self.use_mask = use_mask
self.randrange = random.randrange
self.get_random_bits = partial(random.getrandbits, 32)
self.compress = compress
self.notakeover = notakeover
self._closing = False
@@ -598,78 +631,107 @@ class WebSocketWriter:
) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ConnectionResetError("Cannot write to closing transport")
raise ClientConnectionResetError("Cannot write to closing transport")
# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
rsv = 0
# Only compress larger packets (disabled)
# Does small packet needs to be compressed?
# if self.compress and opcode < 8 and len(message) > 124:
if (compress or self.compress) and opcode < 8:
# RSV1 (rsv = 0x40) is set for compressed frames
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
rsv = 0x40
if compress:
# Do not set self._compress if compressing is for this frame
compressobj = zlib.compressobj(level=zlib.Z_BEST_SPEED, wbits=-compress)
compressobj = self._make_compress_obj(compress)
else: # self.compress
if not self._compressobj:
self._compressobj = zlib.compressobj(
level=zlib.Z_BEST_SPEED, wbits=-self.compress
)
self._compressobj = self._make_compress_obj(self.compress)
compressobj = self._compressobj
message = compressobj.compress(message)
message = message + compressobj.flush(
message = await compressobj.compress(message)
# Its critical that we do not return control to the event
# loop until we have finished sending all the compressed
# data. Otherwise we could end up mixing compressed frames
# if there are multiple coroutines compressing data.
message += compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
)
if message.endswith(_WS_DEFLATE_TRAILING):
message = message[:-4]
rsv = rsv | 0x40
msg_length = len(message)
use_mask = self.use_mask
if use_mask:
mask_bit = 0x80
else:
mask_bit = 0
mask_bit = 0x80 if use_mask else 0
# Depending on the message length, the header is assembled differently.
# The first byte is reserved for the opcode and the RSV bits.
first_byte = 0x80 | rsv | opcode
if msg_length < 126:
header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
elif msg_length < (1 << 16):
header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
header = PACK_LEN1(first_byte, msg_length | mask_bit)
header_len = 2
elif msg_length < 65536:
header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
header_len = 4
else:
header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
header_len = 10
if self.transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
# If we are using a mask, we need to generate it randomly
# and apply it to the message before sending it. A mask is
# a 32-bit value that is applied to the message using a
# bitwise XOR operation. It is used to prevent certain types
# of attacks on the websocket protocol. The mask is only used
# when aiohttp is acting as a client. Servers do not use a mask.
if use_mask:
mask = self.randrange(0, 0xFFFFFFFF)
mask = mask.to_bytes(4, "big")
mask = PACK_RANDBITS(self.get_random_bits())
message = bytearray(message)
_websocket_mask(mask, message)
self._write(header + mask + message)
self._output_size += len(header) + len(mask) + len(message)
self.transport.write(header + mask + message)
self._output_size += MASK_LEN
elif msg_length > MSG_SIZE:
self.transport.write(header)
self.transport.write(message)
else:
if len(message) > MSG_SIZE:
self._write(header)
self._write(message)
else:
self._write(header + message)
self.transport.write(header + message)
self._output_size += len(header) + len(message)
self._output_size += header_len + msg_length
# It is safe to return control to the event loop when using compression
# after this point as we have already sent or buffered all the data.
# Once we have written output_size up to the limit, we call the
# drain helper which waits for the transport to be ready to accept
# more data. This is a flow control mechanism to prevent the buffer
# from growing too large. The drain helper will return right away
# if the writer is not paused.
if self._output_size > self._limit:
self._output_size = 0
await self.protocol._drain_helper()
def _write(self, data: bytes) -> None:
if self.transport is None or self.transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
self.transport.write(data)
def _make_compress_obj(self, compress: int) -> ZLibCompressor:
return ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
async def pong(self, message: bytes = b"") -> None:
async def pong(self, message: Union[bytes, str] = b"") -> None:
"""Send pong message."""
if isinstance(message, str):
message = message.encode("utf-8")
await self._send_frame(message, WSMsgType.PONG)
async def ping(self, message: bytes = b"") -> None:
async def ping(self, message: Union[bytes, str] = b"") -> None:
"""Send ping message."""
if isinstance(message, str):
message = message.encode("utf-8")
@@ -689,7 +751,7 @@ class WebSocketWriter:
else:
await self._send_frame(message, WSMsgType.TEXT, compress)
async def close(self, code: int = 1000, message: bytes = b"") -> None:
async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode("utf-8")

View File

@@ -8,6 +8,8 @@ from multidict import CIMultiDict
from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor
from .helpers import NO_EXTENSIONS
__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
@@ -43,7 +45,7 @@ class StreamWriter(AbstractStreamWriter):
self.output_size = 0
self._eof = False
self._compress: Any = None
self._compress: Optional[ZLibCompressor] = None
self._drain_waiter = None
self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
@@ -63,16 +65,15 @@ class StreamWriter(AbstractStreamWriter):
def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
) -> None:
zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS
self._compress = zlib.compressobj(wbits=zlib_mode, strategy=strategy)
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
def _write(self, chunk: bytes) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size
transport = self.transport
if not self._protocol.connected or transport is None or transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)
async def write(
@@ -93,7 +94,7 @@ class StreamWriter(AbstractStreamWriter):
chunk = chunk.cast("c")
if self._compress is not None:
chunk = self._compress.compress(chunk)
chunk = await self._compress.compress(chunk)
if not chunk:
return
@@ -138,9 +139,9 @@ class StreamWriter(AbstractStreamWriter):
if self._compress:
if chunk:
chunk = self._compress.compress(chunk)
chunk = await self._compress.compress(chunk)
chunk = chunk + self._compress.flush()
chunk += self._compress.flush()
if chunk and self.chunked:
chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
@@ -189,7 +190,7 @@ def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> byte
_serialize_headers = _py_serialize_headers
try:
import aiohttp._http_writer as _http_writer # type: ignore[import]
import aiohttp._http_writer as _http_writer # type: ignore[import-not-found]
_c_serialize_headers = _http_writer._serialize_headers
if not NO_EXTENSIONS:

View File

@@ -1,41 +0,0 @@
import asyncio
import collections
from typing import Any, Deque, Optional
class EventResultOrError:
"""Event asyncio lock helper class.
Wraps the Event asyncio lock allowing either to awake the
locked Tasks without any error or raising an exception.
thanks to @vorpalsmith for the simple design.
"""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._exc: Optional[BaseException] = None
self._event = asyncio.Event()
self._waiters: Deque[asyncio.Future[Any]] = collections.deque()
def set(self, exc: Optional[BaseException] = None) -> None:
self._exc = exc
self._event.set()
async def wait(self) -> Any:
waiter = self._loop.create_task(self._event.wait())
self._waiters.append(waiter)
try:
val = await waiter
finally:
self._waiters.remove(waiter)
if self._exc is not None:
raise self._exc
return val
def cancel(self) -> None:
"""Cancel all waiters"""
for waiter in self._waiters:
waiter.cancel()

View File

@@ -2,6 +2,7 @@ import base64
import binascii
import json
import re
import sys
import uuid
import warnings
import zlib
@@ -10,7 +11,6 @@ from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Deque,
Dict,
Iterator,
@@ -25,8 +25,9 @@ from typing import (
)
from urllib.parse import parse_qsl, unquote, urlencode
from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping
from multidict import CIMultiDict, CIMultiDictProxy
from .compression_utils import ZLibCompressor, ZLibDecompressor
from .hdrs import (
CONTENT_DISPOSITION,
CONTENT_ENCODING,
@@ -47,6 +48,13 @@ from .payload import (
)
from .streams import StreamReader
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing import TypeVar
Self = TypeVar("Self", bound="BodyPartReader")
__all__ = (
"MultipartReader",
"MultipartWriter",
@@ -58,7 +66,7 @@ __all__ = (
)
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from .client_reqrep import ClientResponse
@@ -255,23 +263,32 @@ class BodyPartReader:
chunk_size = 8192
def __init__(
self, boundary: bytes, headers: "CIMultiDictProxy[str]", content: StreamReader
self,
boundary: bytes,
headers: "CIMultiDictProxy[str]",
content: StreamReader,
*,
subtype: str = "mixed",
default_charset: Optional[str] = None,
) -> None:
self.headers = headers
self._boundary = boundary
self._boundary_len = len(boundary) + 2 # Boundary + \r\n
self._content = content
self._default_charset = default_charset
self._at_eof = False
length = self.headers.get(CONTENT_LENGTH, None)
self._is_form_data = subtype == "form-data"
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None)
self._length = int(length) if length is not None else None
self._read_bytes = 0
# TODO: typeing.Deque is not supported by Python 3.5
self._unread: Deque[bytes] = deque()
self._prev_chunk: Optional[bytes] = None
self._content_eof = 0
self._cache: Dict[str, Any] = {}
def __aiter__(self) -> AsyncIterator["BodyPartReader"]:
return self # type: ignore[return-value]
def __aiter__(self: Self) -> Self:
return self
async def __anext__(self) -> bytes:
part = await self.next()
@@ -313,6 +330,31 @@ class BodyPartReader:
else:
chunk = await self._read_chunk_from_stream(size)
# For the case of base64 data, we must read a fragment of size with a
# remainder of 0 by dividing by 4 for string without symbols \n or \r
encoding = self.headers.get(CONTENT_TRANSFER_ENCODING)
if encoding and encoding.lower() == "base64":
stripped_chunk = b"".join(chunk.split())
remainder = len(stripped_chunk) % 4
while remainder != 0 and not self.at_eof():
over_chunk_size = 4 - remainder
over_chunk = b""
if self._prev_chunk:
over_chunk = self._prev_chunk[:over_chunk_size]
self._prev_chunk = self._prev_chunk[len(over_chunk) :]
if len(over_chunk) != over_chunk_size:
over_chunk += await self._content.read(4 - len(over_chunk))
if not over_chunk:
self._at_eof = True
stripped_chunk += b"".join(over_chunk.split())
chunk += over_chunk
remainder = len(stripped_chunk) % 4
self._read_bytes += len(chunk)
if self._read_bytes == self._length:
self._at_eof = True
@@ -329,21 +371,33 @@ class BodyPartReader:
assert self._length is not None, "Content-Length required for chunked read"
chunk_size = min(size, self._length - self._read_bytes)
chunk = await self._content.read(chunk_size)
if self._content.at_eof():
self._at_eof = True
return chunk
async def _read_chunk_from_stream(self, size: int) -> bytes:
# Reads content chunk of body part with unknown length.
# The Content-Length header for body part is not necessary.
assert (
size >= len(self._boundary) + 2
size >= self._boundary_len
), "Chunk size must be greater or equal than boundary length + 2"
first_chunk = self._prev_chunk is None
if first_chunk:
self._prev_chunk = await self._content.read(size)
chunk = await self._content.read(size)
self._content_eof += int(self._content.at_eof())
assert self._content_eof < 3, "Reading after EOF"
chunk = b""
# content.read() may return less than size, so we need to loop to ensure
# we have enough data to detect the boundary.
while len(chunk) < self._boundary_len:
chunk += await self._content.read(size)
self._content_eof += int(self._content.at_eof())
assert self._content_eof < 3, "Reading after EOF"
if self._content_eof:
break
if len(chunk) > size:
self._content.unread_data(chunk[size:])
chunk = chunk[:size]
assert self._prev_chunk is not None
window = self._prev_chunk + chunk
sub = b"\r\n" + self._boundary
@@ -404,8 +458,8 @@ class BodyPartReader:
async def text(self, *, encoding: Optional[str] = None) -> str:
"""Like read(), but assumes that body part contains text data."""
data = await self.read(decode=True)
# see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm # NOQA
# and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send # NOQA
# see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm
# and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send
encoding = encoding or self.get_charset(default="utf-8")
return data.decode(encoding)
@@ -426,8 +480,13 @@ class BodyPartReader:
real_encoding = encoding
else:
real_encoding = self.get_charset(default="utf-8")
try:
decoded_data = data.rstrip().decode(real_encoding)
except UnicodeDecodeError:
raise ValueError("data cannot be decoded with %s encoding" % real_encoding)
return parse_qsl(
data.rstrip().decode(real_encoding),
decoded_data,
keep_blank_values=True,
encoding=real_encoding,
)
@@ -444,21 +503,22 @@ class BodyPartReader:
"""
if CONTENT_TRANSFER_ENCODING in self.headers:
data = self._decode_content_transfer(data)
if CONTENT_ENCODING in self.headers:
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
if not self._is_form_data and CONTENT_ENCODING in self.headers:
return self._decode_content(data)
return data
def _decode_content(self, data: bytes) -> bytes:
encoding = self.headers.get(CONTENT_ENCODING, "").lower()
if encoding == "deflate":
return zlib.decompress(data, -zlib.MAX_WBITS)
elif encoding == "gzip":
return zlib.decompress(data, 16 + zlib.MAX_WBITS)
elif encoding == "identity":
if encoding == "identity":
return data
else:
raise RuntimeError(f"unknown content encoding: {encoding}")
if encoding in {"deflate", "gzip"}:
return ZLibDecompressor(
encoding=encoding,
suppress_deflate_header=True,
).decompress_sync(data)
raise RuntimeError(f"unknown content encoding: {encoding}")
def _decode_content_transfer(self, data: bytes) -> bytes:
encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
@@ -478,7 +538,7 @@ class BodyPartReader:
"""Returns charset parameter from Content-Type header or default."""
ctype = self.headers.get(CONTENT_TYPE, "")
mimetype = parse_mimetype(ctype)
return mimetype.parameters.get("charset", default)
return mimetype.parameters.get("charset", self._default_charset or default)
@reify
def name(self) -> Optional[str]:
@@ -501,6 +561,8 @@ class BodyPartReader:
@payload_type(BodyPartReader, order=Order.try_first)
class BodyPartReaderPayload(Payload):
_value: BodyPartReader
def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
super().__init__(value, *args, **kwargs)
@@ -513,6 +575,9 @@ class BodyPartReaderPayload(Payload):
if params:
self.set_content_disposition("attachment", True, **params)
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
raise TypeError("Unable to decode.")
async def write(self, writer: Any) -> None:
field = self._value
chunk = await field.read_chunk(size=2**16)
@@ -528,23 +593,29 @@ class MultipartReader:
response_wrapper_cls = MultipartResponseWrapper
#: Multipart reader class, used to handle multipart/* body parts.
#: None points to type(self)
multipart_reader_cls = None
multipart_reader_cls: Optional[Type["MultipartReader"]] = None
#: Body part reader class for non multipart/* content types.
part_reader_cls = BodyPartReader
def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
assert self._mimetype.type == "multipart", "multipart/* content type expected"
if "boundary" not in self._mimetype.parameters:
raise ValueError(
"boundary missed for Content-Type: %s" % headers[CONTENT_TYPE]
)
self.headers = headers
self._boundary = ("--" + self._get_boundary()).encode()
self._content = content
self._default_charset: Optional[str] = None
self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None
self._at_eof = False
self._at_bof = True
self._unread: List[bytes] = []
def __aiter__(
self,
) -> AsyncIterator["BodyPartReader"]:
return self # type: ignore[return-value]
def __aiter__(self: Self) -> Self:
return self
async def __anext__(
self,
@@ -587,7 +658,24 @@ class MultipartReader:
await self._read_boundary()
if self._at_eof: # we just read the last boundary, nothing to do there
return None
self._last_part = await self.fetch_next_part()
part = await self.fetch_next_part()
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.6
if (
self._last_part is None
and self._mimetype.subtype == "form-data"
and isinstance(part, BodyPartReader)
):
_, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION))
if params.get("name") == "_charset_":
# Longest encoding in https://encoding.spec.whatwg.org/encodings.json
# is 19 characters, so 32 should be more than enough for any valid encoding.
charset = await part.read_chunk(32)
if len(charset) > 31:
raise RuntimeError("Invalid default charset")
self._default_charset = charset.strip().decode()
part = await self.fetch_next_part()
self._last_part = part
return self._last_part
async def release(self) -> None:
@@ -623,19 +711,16 @@ class MultipartReader:
return type(self)(headers, self._content)
return self.multipart_reader_cls(headers, self._content)
else:
return self.part_reader_cls(self._boundary, headers, self._content)
def _get_boundary(self) -> str:
mimetype = parse_mimetype(self.headers[CONTENT_TYPE])
assert mimetype.type == "multipart", "multipart/* content type expected"
if "boundary" not in mimetype.parameters:
raise ValueError(
"boundary missed for Content-Type: %s" % self.headers[CONTENT_TYPE]
return self.part_reader_cls(
self._boundary,
headers,
self._content,
subtype=self._mimetype.subtype,
default_charset=self._default_charset,
)
boundary = mimetype.parameters["boundary"]
def _get_boundary(self) -> str:
boundary = self._mimetype.parameters["boundary"]
if len(boundary) > 70:
raise ValueError("boundary %r is too long (70 chars max)" % boundary)
@@ -710,6 +795,8 @@ _Part = Tuple[Payload, str, str]
class MultipartWriter(Payload):
"""Multipart body writer."""
_value: None
def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None:
boundary = boundary if boundary is not None else uuid.uuid4().hex
# The underlying Payload API demands a str (utf-8), not bytes,
@@ -726,6 +813,7 @@ class MultipartWriter(Payload):
super().__init__(None, content_type=ctype)
self._parts: List[_Part] = []
self._is_form_data = subtype == "form-data"
def __enter__(self) -> "MultipartWriter":
return self
@@ -754,7 +842,7 @@ class MultipartWriter(Payload):
def _boundary_value(self) -> str:
"""Wrap boundary parameter value in quotes, if necessary.
Reads self.boundary and returns a unicode sting.
Reads self.boundary and returns a unicode string.
"""
# Refer to RFCs 7231, 7230, 5234.
#
@@ -786,7 +874,7 @@ class MultipartWriter(Payload):
def boundary(self) -> str:
return self._boundary.decode("ascii")
def append(self, obj: Any, headers: Optional[MultiMapping[str]] = None) -> Payload:
def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Payload:
if headers is None:
headers = CIMultiDict()
@@ -803,38 +891,44 @@ class MultipartWriter(Payload):
def append_payload(self, payload: Payload) -> Payload:
"""Adds a new body part to multipart writer."""
# compression
encoding: Optional[str] = payload.headers.get(
CONTENT_ENCODING,
"",
).lower()
if encoding and encoding not in ("deflate", "gzip", "identity"):
raise RuntimeError(f"unknown content encoding: {encoding}")
if encoding == "identity":
encoding = None
# te encoding
te_encoding: Optional[str] = payload.headers.get(
CONTENT_TRANSFER_ENCODING,
"",
).lower()
if te_encoding not in ("", "base64", "quoted-printable", "binary"):
raise RuntimeError(
"unknown content transfer encoding: {}" "".format(te_encoding)
encoding: Optional[str] = None
te_encoding: Optional[str] = None
if self._is_form_data:
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.7
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
assert (
not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING}
& payload.headers.keys()
)
if te_encoding == "binary":
te_encoding = None
# Set default Content-Disposition in case user doesn't create one
if CONTENT_DISPOSITION not in payload.headers:
name = f"section-{len(self._parts)}"
payload.set_content_disposition("form-data", name=name)
else:
# compression
encoding = payload.headers.get(CONTENT_ENCODING, "").lower()
if encoding and encoding not in ("deflate", "gzip", "identity"):
raise RuntimeError(f"unknown content encoding: {encoding}")
if encoding == "identity":
encoding = None
# size
size = payload.size
if size is not None and not (encoding or te_encoding):
payload.headers[CONTENT_LENGTH] = str(size)
# te encoding
te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
if te_encoding not in ("", "base64", "quoted-printable", "binary"):
raise RuntimeError(f"unknown content transfer encoding: {te_encoding}")
if te_encoding == "binary":
te_encoding = None
# size
size = payload.size
if size is not None and not (encoding or te_encoding):
payload.headers[CONTENT_LENGTH] = str(size)
self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
return payload
def append_json(
self, obj: Any, headers: Optional[MultiMapping[str]] = None
self, obj: Any, headers: Optional[Mapping[str, str]] = None
) -> Payload:
"""Helper to append JSON part."""
if headers is None:
@@ -845,7 +939,7 @@ class MultipartWriter(Payload):
def append_form(
self,
obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]],
headers: Optional[MultiMapping[str]] = None,
headers: Optional[Mapping[str, str]] = None,
) -> Payload:
"""Helper to append form urlencoded part."""
assert isinstance(obj, (Sequence, Mapping))
@@ -883,9 +977,24 @@ class MultipartWriter(Payload):
total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
return total
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return "".join(
"--"
+ self.boundary
+ "\n"
+ part._binary_headers.decode(encoding, errors)
+ part.decode()
for part, _e, _te in self._parts
)
async def write(self, writer: Any, close_boundary: bool = True) -> None:
"""Write body."""
for part, encoding, te_encoding in self._parts:
if self._is_form_data:
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.2
assert CONTENT_DISPOSITION in part.headers
assert "name=" in part.headers[CONTENT_DISPOSITION]
await writer.write(b"--" + self._boundary + b"\r\n")
await writer.write(part._binary_headers)
@@ -910,7 +1019,7 @@ class MultipartPayloadWriter:
def __init__(self, writer: Any) -> None:
self._writer = writer
self._encoding: Optional[str] = None
self._compress: Any = None
self._compress: Optional[ZLibCompressor] = None
self._encoding_buffer: Optional[bytearray] = None
def enable_encoding(self, encoding: str) -> None:
@@ -923,8 +1032,11 @@ class MultipartPayloadWriter:
def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
) -> None:
zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else -zlib.MAX_WBITS
self._compress = zlib.compressobj(wbits=zlib_mode, strategy=strategy)
self._compress = ZLibCompressor(
encoding=encoding,
suppress_deflate_header=True,
strategy=strategy,
)
async def write_eof(self) -> None:
if self._compress is not None:
@@ -940,7 +1052,7 @@ class MultipartPayloadWriter:
async def write(self, chunk: bytes) -> None:
if self._compress is not None:
if chunk:
chunk = self._compress.compress(chunk)
chunk = await self._compress.compress(chunk)
if not chunk:
return

View File

@@ -11,8 +11,8 @@ from typing import (
IO,
TYPE_CHECKING,
Any,
ByteString,
Dict,
Final,
Iterable,
Optional,
TextIO,
@@ -26,14 +26,14 @@ from multidict import CIMultiDict
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import (
PY_36,
_SENTINEL,
content_disposition_header,
guess_filename,
parse_mimetype,
sentinel,
)
from .streams import StreamReader
from .typedefs import Final, JSONEncoder, _CIMultiDict
from .typedefs import JSONEncoder, _CIMultiDict
__all__ = (
"PAYLOAD_REGISTRY",
@@ -53,7 +53,7 @@ __all__ = (
TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from typing import List
@@ -101,6 +101,7 @@ class PayloadRegistry:
self._first: List[_PayloadRegistryItem] = []
self._normal: List[_PayloadRegistryItem] = []
self._last: List[_PayloadRegistryItem] = []
self._normal_lookup: Dict[Any, PayloadType] = {}
def get(
self,
@@ -109,12 +110,20 @@ class PayloadRegistry:
_CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain,
**kwargs: Any,
) -> "Payload":
if self._first:
for factory, type_ in self._first:
if isinstance(data, type_):
return factory(data, *args, **kwargs)
# Try the fast lookup first
if lookup_factory := self._normal_lookup.get(type(data)):
return lookup_factory(data, *args, **kwargs)
# Bail early if its already a Payload
if isinstance(data, Payload):
return data
for factory, type in _CHAIN(self._first, self._normal, self._last):
if isinstance(data, type):
# Fallback to the slower linear search
for factory, type_ in _CHAIN(self._normal, self._last):
if isinstance(data, type_):
return factory(data, *args, **kwargs)
raise LookupError()
def register(
@@ -124,6 +133,11 @@ class PayloadRegistry:
self._first.append((factory, type))
elif order is Order.normal:
self._normal.append((factory, type))
if isinstance(type, Iterable):
for t in type:
self._normal_lookup[t] = factory
else:
self._normal_lookup[type] = factory
elif order is Order.try_last:
self._last.append((factory, type))
else:
@@ -141,7 +155,7 @@ class Payload(ABC):
headers: Optional[
Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]]
] = None,
content_type: Optional[str] = sentinel,
content_type: Union[str, None, _SENTINEL] = sentinel,
filename: Optional[str] = None,
encoding: Optional[str] = None,
**kwargs: Any,
@@ -159,7 +173,8 @@ class Payload(ABC):
self._headers[hdrs.CONTENT_TYPE] = content_type
else:
self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
self._headers.update(headers or {})
if headers:
self._headers.update(headers)
@property
def size(self) -> Optional[int]:
@@ -207,6 +222,13 @@ class Payload(ABC):
disptype, quote_fields=quote_fields, _charset=_charset, **params
)
@abstractmethod
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
"""Return string representation of the value.
This is named decode() to allow compatibility with bytes objects.
"""
@abstractmethod
async def write(self, writer: AbstractStreamWriter) -> None:
"""Write payload.
@@ -216,10 +238,11 @@ class Payload(ABC):
class BytesPayload(Payload):
def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None:
if not isinstance(value, (bytes, bytearray, memoryview)):
raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")
_value: bytes
def __init__(
self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any
) -> None:
if "content_type" not in kwargs:
kwargs["content_type"] = "application/octet-stream"
@@ -227,14 +250,13 @@ class BytesPayload(Payload):
if isinstance(value, memoryview):
self._size = value.nbytes
else:
elif isinstance(value, (bytes, bytearray)):
self._size = len(value)
else:
raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")
if self._size > TOO_LARGE_BYTES_BODY:
if PY_36:
kwargs = {"source": self}
else:
kwargs = {}
kwargs = {"source": self}
warnings.warn(
"Sending a large body directly with raw bytes might"
" lock the event loop. You should probably pass an "
@@ -243,6 +265,9 @@ class BytesPayload(Payload):
**kwargs,
)
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.decode(encoding, errors)
async def write(self, writer: AbstractStreamWriter) -> None:
await writer.write(self._value)
@@ -284,7 +309,7 @@ class StringIOPayload(StringPayload):
class IOBasePayload(Payload):
_value: IO[Any]
_value: io.IOBase
def __init__(
self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
@@ -308,9 +333,12 @@ class IOBasePayload(Payload):
finally:
await loop.run_in_executor(None, self._value.close)
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return "".join(r.decode(encoding, errors) for r in self._value.readlines())
class TextIOPayload(IOBasePayload):
_value: TextIO
_value: io.TextIOBase
def __init__(
self,
@@ -347,6 +375,9 @@ class TextIOPayload(IOBasePayload):
except OSError:
return None
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.read()
async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
@@ -364,6 +395,8 @@ class TextIOPayload(IOBasePayload):
class BytesIOPayload(IOBasePayload):
_value: io.BytesIO
@property
def size(self) -> int:
position = self._value.tell()
@@ -371,17 +404,27 @@ class BytesIOPayload(IOBasePayload):
self._value.seek(position)
return end - position
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.read().decode(encoding, errors)
class BufferedReaderPayload(IOBasePayload):
_value: io.BufferedIOBase
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
except OSError:
except (OSError, AttributeError):
# data.fileno() is not supported, e.g.
# io.BufferedReader(io.BytesIO(b'data'))
# For some file-like objects (e.g. tarfile), the fileno() attribute may
# not exist at all, and will instead raise an AttributeError.
return None
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.read().decode(encoding, errors)
class JsonPayload(BytesPayload):
def __init__(
@@ -403,7 +446,7 @@ class JsonPayload(BytesPayload):
)
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from typing import AsyncIterable, AsyncIterator
_AsyncIterator = AsyncIterator[bytes]
@@ -418,6 +461,7 @@ else:
class AsyncIterablePayload(Payload):
_iter: Optional[_AsyncIterator] = None
_value: _AsyncIterable
def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
if not isinstance(value, AsyncIterable):
@@ -445,6 +489,9 @@ class AsyncIterablePayload(Payload):
except StopAsyncIteration:
self._iter = None
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
raise TypeError("Unable to decode.")
class StreamReaderPayload(AsyncIterablePayload):
def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:

View File

@@ -1,5 +1,5 @@
"""
Payload implemenation for coroutines as data provider.
Payload implementation for coroutines as data provider.
As a simple case, you can upload data from file::
@@ -44,7 +44,7 @@ class _stream_wrapper:
self.kwargs = kwargs
async def __call__(self, writer: AbstractStreamWriter) -> None:
await self.coro(writer, *self.args, **self.kwargs) # type: ignore[operator]
await self.coro(writer, *self.args, **self.kwargs)
class streamer:
@@ -65,6 +65,9 @@ class StreamWrapperPayload(Payload):
async def write(self, writer: AbstractStreamWriter) -> None:
await self._value(writer)
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
raise TypeError("Unable to decode.")
@payload_type(streamer)
class StreamPayload(StreamWrapperPayload):

View File

@@ -1,14 +1,21 @@
import asyncio
import contextlib
import inspect
import warnings
from collections.abc import Callable
from typing import Any, Awaitable, Callable, Dict, Generator, Optional, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterator,
Optional,
Protocol,
Type,
Union,
)
import pytest
from aiohttp.helpers import PY_37, isasyncgenfunction
from aiohttp.web import Application
from .test_utils import (
BaseTestServer,
RawTestServer,
@@ -19,18 +26,35 @@ from .test_utils import (
teardown_test_loop,
unused_port as _unused_port,
)
from .web import Application
from .web_protocol import _RequestHandler
try:
import uvloop
except ImportError: # pragma: no cover
uvloop = None
uvloop = None # type: ignore[assignment]
try:
import tokio
except ImportError: # pragma: no cover
tokio = None
AiohttpClient = Callable[[Union[Application, BaseTestServer]], Awaitable[TestClient]]
class AiohttpClient(Protocol):
def __call__(
self,
__param: Union[Application, BaseTestServer],
*,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> Awaitable[TestClient]: ...
class AiohttpServer(Protocol):
def __call__(
self, app: Application, *, port: Optional[int] = None, **kwargs: Any
) -> Awaitable[TestServer]: ...
class AiohttpRawServer(Protocol):
def __call__(
self, handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
) -> Awaitable[RawTestServer]: ...
def pytest_addoption(parser): # type: ignore[no-untyped-def]
@@ -44,7 +68,7 @@ def pytest_addoption(parser): # type: ignore[no-untyped-def]
"--aiohttp-loop",
action="store",
default="pyloop",
help="run tests with specific loop: pyloop, uvloop, tokio or all",
help="run tests with specific loop: pyloop, uvloop or all",
)
parser.addoption(
"--aiohttp-enable-loop-debug",
@@ -61,7 +85,7 @@ def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def]
"""
func = fixturedef.func
if isasyncgenfunction(func):
if inspect.isasyncgenfunction(func):
# async generator fixture
is_async_gen = True
elif asyncio.iscoroutinefunction(func):
@@ -193,16 +217,14 @@ def pytest_generate_tests(metafunc): # type: ignore[no-untyped-def]
return
loops = metafunc.config.option.aiohttp_loop
avail_factories: Dict[str, Type[asyncio.AbstractEventLoopPolicy]]
avail_factories = {"pyloop": asyncio.DefaultEventLoopPolicy}
if uvloop is not None: # pragma: no cover
avail_factories["uvloop"] = uvloop.EventLoopPolicy
if tokio is not None: # pragma: no cover
avail_factories["tokio"] = tokio.EventLoopPolicy
if loops == "all":
loops = "pyloop,uvloop?,tokio?"
loops = "pyloop,uvloop?"
factories = {} # type: ignore[var-annotated]
for name in loops.split(","):
@@ -236,12 +258,8 @@ def loop(loop_factory, fast, loop_debug): # type: ignore[no-untyped-def]
@pytest.fixture
def proactor_loop(): # type: ignore[no-untyped-def]
if not PY_37:
policy = asyncio.get_event_loop_policy()
policy._loop_factory = asyncio.ProactorEventLoop # type: ignore[attr-defined]
else:
policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore[attr-defined]
asyncio.set_event_loop_policy(policy)
policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore[attr-defined]
asyncio.set_event_loop_policy(policy)
with loop_context(policy.new_event_loop) as _loop:
asyncio.set_event_loop(_loop)
@@ -249,7 +267,7 @@ def proactor_loop(): # type: ignore[no-untyped-def]
@pytest.fixture
def unused_port(aiohttp_unused_port): # type: ignore[no-untyped-def] # pragma: no cover
def unused_port(aiohttp_unused_port: Callable[[], int]) -> Callable[[], int]:
warnings.warn(
"Deprecated, use aiohttp_unused_port fixture instead",
DeprecationWarning,
@@ -259,20 +277,22 @@ def unused_port(aiohttp_unused_port): # type: ignore[no-untyped-def] # pragma:
@pytest.fixture
def aiohttp_unused_port(): # type: ignore[no-untyped-def]
def aiohttp_unused_port() -> Callable[[], int]:
"""Return a port that is unused on the current host."""
return _unused_port
@pytest.fixture
def aiohttp_server(loop): # type: ignore[no-untyped-def]
def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]:
"""Factory to create a TestServer instance, given an app.
aiohttp_server(app, **kwargs)
"""
servers = []
async def go(app, *, port=None, **kwargs): # type: ignore[no-untyped-def]
async def go(
app: Application, *, port: Optional[int] = None, **kwargs: Any
) -> TestServer:
server = TestServer(app, port=port)
await server.start_server(loop=loop, **kwargs)
servers.append(server)
@@ -298,14 +318,16 @@ def test_server(aiohttp_server): # type: ignore[no-untyped-def] # pragma: no c
@pytest.fixture
def aiohttp_raw_server(loop): # type: ignore[no-untyped-def]
def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawServer]:
"""Factory to create a RawTestServer instance, given a web handler.
aiohttp_raw_server(handler, **kwargs)
"""
servers = []
async def go(handler, *, port=None, **kwargs): # type: ignore[no-untyped-def]
async def go(
handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
) -> RawTestServer:
server = RawTestServer(handler, port=port)
await server.start_server(loop=loop, **kwargs)
servers.append(server)
@@ -335,7 +357,7 @@ def raw_test_server( # type: ignore[no-untyped-def] # pragma: no cover
@pytest.fixture
def aiohttp_client(
loop: asyncio.AbstractEventLoop,
) -> Generator[AiohttpClient, None, None]:
) -> Iterator[AiohttpClient]:
"""Factory to create a TestClient instance.
aiohttp_client(app, **kwargs)

View File

@@ -1,20 +1,25 @@
import asyncio
import socket
from typing import Any, Dict, List, Optional, Type, Union
import sys
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from .abc import AbstractResolver
from .helpers import get_running_loop
from .abc import AbstractResolver, ResolveResult
__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")
try:
import aiodns
# aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo")
except ImportError: # pragma: no cover
aiodns = None
aiodns = None # type: ignore[assignment]
aiodns_default = False
aiodns_default = False
_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
_NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0)
class ThreadedResolver(AbstractResolver):
@@ -25,48 +30,48 @@ class ThreadedResolver(AbstractResolver):
"""
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = get_running_loop(loop)
self._loop = loop or asyncio.get_running_loop()
async def resolve(
self, hostname: str, port: int = 0, family: int = socket.AF_INET
) -> List[Dict[str, Any]]:
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> List[ResolveResult]:
infos = await self._loop.getaddrinfo(
hostname,
host,
port,
type=socket.SOCK_STREAM,
family=family,
flags=socket.AI_ADDRCONFIG,
)
hosts = []
hosts: List[ResolveResult] = []
for family, _, proto, _, address in infos:
if family == socket.AF_INET6:
if len(address) < 3:
# IPv6 is not supported by Python build,
# or IPv6 is not enabled in the host
continue
if address[3]: # type: ignore[misc]
if address[3] and _SUPPORTS_SCOPE_ID:
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
host, _port = socket.getnameinfo(
address, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
resolved_host, _port = await self._loop.getnameinfo(
address, _NAME_SOCKET_FLAGS
)
port = int(_port)
else:
host, port = address[:2]
resolved_host, port = address[:2]
else: # IPv4
assert family == socket.AF_INET
host, port = address # type: ignore[misc]
resolved_host, port = address # type: ignore[misc]
hosts.append(
{
"hostname": hostname,
"host": host,
"port": port,
"family": family,
"proto": proto,
"flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
}
ResolveResult(
hostname=host,
host=resolved_host,
port=port,
family=family,
proto=proto,
flags=_NUMERIC_SOCKET_FLAGS,
)
)
return hosts
@@ -87,36 +92,60 @@ class AsyncResolver(AbstractResolver):
if aiodns is None:
raise RuntimeError("Resolver requires aiodns library")
self._loop = get_running_loop(loop)
self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs)
self._resolver = aiodns.DNSResolver(*args, **kwargs)
if not hasattr(self._resolver, "gethostbyname"):
# aiodns 1.1 is not available, fallback to DNSResolver.query
self.resolve = self._resolve_with_query # type: ignore
async def resolve(
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> List[Dict[str, Any]]:
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> List[ResolveResult]:
try:
resp = await self._resolver.gethostbyname(host, family)
resp = await self._resolver.getaddrinfo(
host,
port=port,
type=socket.SOCK_STREAM,
family=family,
flags=socket.AI_ADDRCONFIG,
)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(msg) from exc
hosts = []
for address in resp.addresses:
raise OSError(None, msg) from exc
hosts: List[ResolveResult] = []
for node in resp.nodes:
address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
family = node.family
if family == socket.AF_INET6:
if len(address) > 3 and address[3] and _SUPPORTS_SCOPE_ID:
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
result = await self._resolver.getnameinfo(
(address[0].decode("ascii"), *address[1:]),
_NAME_SOCKET_FLAGS,
)
resolved_host = result.node
else:
resolved_host = address[0].decode("ascii")
port = address[1]
else: # IPv4
assert family == socket.AF_INET
resolved_host = address[0].decode("ascii")
port = address[1]
hosts.append(
{
"hostname": host,
"host": address,
"port": port,
"family": family,
"proto": 0,
"flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
}
ResolveResult(
hostname=host,
host=resolved_host,
port=port,
family=family,
proto=0,
flags=_NUMERIC_SOCKET_FLAGS,
)
)
if not hosts:
raise OSError("DNS lookup failed")
raise OSError(None, "DNS lookup failed")
return hosts
@@ -132,7 +161,7 @@ class AsyncResolver(AbstractResolver):
resp = await self._resolver.query(host, qtype)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(msg) from exc
raise OSError(None, msg) from exc
hosts = []
for rr in resp:
@@ -148,7 +177,7 @@ class AsyncResolver(AbstractResolver):
)
if not hosts:
raise OSError("DNS lookup failed")
raise OSError(None, "DNS lookup failed")
return hosts

View File

@@ -1,12 +1,27 @@
import asyncio
import collections
import warnings
from typing import Awaitable, Callable, Deque, Generic, List, Optional, Tuple, TypeVar
from typing import (
Awaitable,
Callable,
Deque,
Final,
Generic,
List,
Optional,
Tuple,
TypeVar,
)
from .base_protocol import BaseProtocol
from .helpers import BaseTimerContext, set_exception, set_result
from .helpers import (
_EXC_SENTINEL,
BaseTimerContext,
TimerNoop,
set_exception,
set_result,
)
from .log import internal_logger
from .typedefs import Final
__all__ = (
"EMPTY_PAYLOAD",
@@ -59,19 +74,11 @@ class AsyncStreamReaderMixin:
return AsyncStreamIterator(self.readline) # type: ignore[attr-defined]
def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]:
"""Returns an asynchronous iterator that yields chunks of size n.
Python-3.5 available for Python 3.5+ only
"""
return AsyncStreamIterator(
lambda: self.read(n) # type: ignore[attr-defined,no-any-return]
)
"""Returns an asynchronous iterator that yields chunks of size n."""
return AsyncStreamIterator(lambda: self.read(n)) # type: ignore[attr-defined]
def iter_any(self) -> AsyncStreamIterator[bytes]:
"""Yield all available data as soon as it is received.
Python-3.5 available for Python 3.5+ only
"""
"""Yield all available data as soon as it is received."""
return AsyncStreamIterator(self.readany) # type: ignore[attr-defined]
def iter_chunks(self) -> ChunkTupleAsyncStreamIterator:
@@ -79,8 +86,6 @@ class AsyncStreamReaderMixin:
The yielded objects are tuples
of (bytes, bool) as returned by the StreamReader.readchunk method.
Python-3.5 available for Python 3.5+ only
"""
return ChunkTupleAsyncStreamIterator(self) # type: ignore[arg-type]
@@ -124,7 +129,7 @@ class StreamReader(AsyncStreamReaderMixin):
self._waiter: Optional[asyncio.Future[None]] = None
self._eof_waiter: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._timer = timer
self._timer = TimerNoop() if timer is None else timer
self._eof_callbacks: List[Callable[[], None]] = []
def __repr__(self) -> str:
@@ -147,19 +152,23 @@ class StreamReader(AsyncStreamReaderMixin):
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(self, exc: BaseException) -> None:
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._exception = exc
self._eof_callbacks.clear()
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_exception(waiter, exc)
set_exception(waiter, exc, exc_cause)
waiter = self._eof_waiter
if waiter is not None:
self._eof_waiter = None
set_exception(waiter, exc)
set_exception(waiter, exc, exc_cause)
def on_eof(self, callback: Callable[[], None]) -> None:
if self._eof:
@@ -236,9 +245,10 @@ class StreamReader(AsyncStreamReaderMixin):
if not data:
return
self._size += len(data)
data_len = len(data)
self._size += data_len
self._buffer.append(data)
self.total_bytes += len(data)
self.total_bytes += data_len
waiter = self._waiter
if waiter is not None:
@@ -266,7 +276,7 @@ class StreamReader(AsyncStreamReaderMixin):
# self._http_chunk_splits contains logical byte offsets from start of
# the body transfer. Each offset is the offset of the end of a chunk.
# "Logical" means bytes, accessible for a user.
# If no chunks containig logical data were received, current position
# If no chunks containing logical data were received, current position
# is difinitely zero.
pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0
@@ -287,6 +297,9 @@ class StreamReader(AsyncStreamReaderMixin):
set_result(waiter, None)
async def _wait(self, func_name: str) -> None:
if not self._protocol.connected:
raise RuntimeError("Connection closed.")
# StreamReader uses a future to link the protocol feed_data() method
# to a read coroutine. Running two read coroutines at the same time
# would have an unexpected behaviour. It would not possible to know
@@ -299,10 +312,7 @@ class StreamReader(AsyncStreamReaderMixin):
waiter = self._waiter = self._loop.create_future()
try:
if self._timer:
with self._timer:
await waiter
else:
with self._timer:
await waiter
finally:
self._waiter = None
@@ -327,7 +337,9 @@ class StreamReader(AsyncStreamReaderMixin):
offset = self._buffer_offset
ichar = self._buffer[0].find(separator, offset) + 1
# Read from current offset to found separator or to the end.
data = self._read_nowait_chunk(ichar - offset if ichar else -1)
data = self._read_nowait_chunk(
ichar - offset + seplen - 1 if ichar else -1
)
chunk += data
chunk_size += len(data)
if ichar:
@@ -491,8 +503,9 @@ class StreamReader(AsyncStreamReaderMixin):
def _read_nowait(self, n: int) -> bytes:
"""Read not more than n bytes, or whole buffer if n == -1"""
chunks = []
self._timer.assert_timeout()
chunks = []
while self._buffer:
chunk = self._read_nowait_chunk(n)
chunks.append(chunk)
@@ -506,12 +519,19 @@ class StreamReader(AsyncStreamReaderMixin):
class EmptyStreamReader(StreamReader): # lgtm [py/missing-call-to-init]
def __init__(self) -> None:
pass
self._read_eof_chunk = False
def __repr__(self) -> str:
return "<%s>" % self.__class__.__name__
def exception(self) -> Optional[BaseException]:
return None
def set_exception(self, exc: BaseException) -> None:
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
pass
def on_eof(self, callback: Callable[[], None]) -> None:
@@ -547,6 +567,10 @@ class EmptyStreamReader(StreamReader): # lgtm [py/missing-call-to-init]
return b""
async def readchunk(self) -> Tuple[bytes, bool]:
if not self._read_eof_chunk:
self._read_eof_chunk = True
return (b"", False)
return (b"", True)
async def readexactly(self, n: int) -> bytes:
@@ -582,14 +606,18 @@ class DataQueue(Generic[_T]):
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(self, exc: BaseException) -> None:
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._eof = True
self._exception = exc
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_exception(waiter, exc)
set_exception(waiter, exc, exc_cause)
def feed_data(self, data: _T, size: int = 0) -> None:
self._size += size

View File

@@ -11,32 +11,28 @@ import sys
import warnings
from abc import ABC, abstractmethod
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterator,
List,
Optional,
Type,
Union,
cast,
)
from unittest import mock
from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Type, cast
from unittest import IsolatedAsyncioTestCase, mock
from aiosignal import Signal
from multidict import CIMultiDict, CIMultiDictProxy
from yarl import URL
import aiohttp
from aiohttp.client import _RequestContextManager, _WSRequestContextManager
from aiohttp.client import (
_RequestContextManager,
_RequestOptions,
_WSRequestContextManager,
)
from . import ClientSession, hdrs
from .abc import AbstractCookieJar
from .client_reqrep import ClientResponse
from .client_ws import ClientWebSocketResponse
from .helpers import PY_38, sentinel
from .helpers import sentinel
from .http import HttpVersion, RawRequestMessage
from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import StrOrURL
from .web import (
Application,
AppRunner,
@@ -49,15 +45,13 @@ from .web import (
)
from .web_protocol import _RequestHandler
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from ssl import SSLContext
else:
SSLContext = None
if PY_38:
from unittest import IsolatedAsyncioTestCase as TestCase
else:
from asynctest import TestCase # type: ignore[no-redef]
if sys.version_info >= (3, 11) and TYPE_CHECKING:
from typing import Unpack
REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
@@ -94,7 +88,7 @@ class BaseTestServer(ABC):
def __init__(
self,
*,
scheme: Union[str, object] = sentinel,
scheme: str = "",
loop: Optional[asyncio.AbstractEventLoop] = None,
host: str = "127.0.0.1",
port: Optional[int] = None,
@@ -121,14 +115,17 @@ class BaseTestServer(ABC):
return
self._loop = loop
self._ssl = kwargs.pop("ssl", None)
self.runner = await self._make_runner(**kwargs)
self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
await self.runner.setup()
if not self.port:
self.port = 0
absolute_host = self.host
try:
version = ipaddress.ip_address(self.host).version
except ValueError:
version = 4
if version == 6:
absolute_host = f"[{self.host}]"
family = socket.AF_INET6 if version == 6 else socket.AF_INET
_sock = self.socket_factory(self.host, self.port, family)
self.host, self.port = _sock.getsockname()[:2]
@@ -136,29 +133,25 @@ class BaseTestServer(ABC):
await site.start()
server = site._server
assert server is not None
sockets = server.sockets
sockets = server.sockets # type: ignore[attr-defined]
assert sockets is not None
self.port = sockets[0].getsockname()[1]
if self.scheme is sentinel:
if self._ssl:
scheme = "https"
else:
scheme = "http"
self.scheme = scheme
self._root = URL(f"{self.scheme}://{self.host}:{self.port}")
if not self.scheme:
self.scheme = "https" if self._ssl else "http"
self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}")
@abstractmethod # pragma: no cover
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
pass
def make_url(self, path: str) -> URL:
def make_url(self, path: StrOrURL) -> URL:
assert self._root is not None
url = URL(path)
if not self.skip_url_asserts:
assert not url.is_absolute()
assert not url.absolute
return self._root.join(url)
else:
return URL(str(self._root) + path)
return URL(str(self._root) + str(path))
@property
def started(self) -> bool:
@@ -226,7 +219,7 @@ class TestServer(BaseTestServer):
self,
app: Application,
*,
scheme: Union[str, object] = sentinel,
scheme: str = "",
host: str = "127.0.0.1",
port: Optional[int] = None,
**kwargs: Any,
@@ -243,7 +236,7 @@ class RawTestServer(BaseTestServer):
self,
handler: _RequestHandler,
*,
scheme: Union[str, object] = sentinel,
scheme: str = "",
host: str = "127.0.0.1",
port: Optional[int] = None,
**kwargs: Any,
@@ -317,54 +310,114 @@ class TestClient:
"""
return self._session
def make_url(self, path: str) -> URL:
def make_url(self, path: StrOrURL) -> URL:
return self._server.make_url(path)
async def _request(self, method: str, path: str, **kwargs: Any) -> ClientResponse:
async def _request(
self, method: str, path: StrOrURL, **kwargs: Any
) -> ClientResponse:
resp = await self._session.request(method, self.make_url(path), **kwargs)
# save it to close later
self._responses.append(resp)
return resp
def request(self, method: str, path: str, **kwargs: Any) -> _RequestContextManager:
"""Routes a request to tested http server.
if sys.version_info >= (3, 11) and TYPE_CHECKING:
The interface is identical to aiohttp.ClientSession.request,
except the loop kwarg is overridden by the instance used by the
test server.
def request(
self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions]
) -> _RequestContextManager: ...
"""
return _RequestContextManager(self._request(method, path, **kwargs))
def get(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def get(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP GET request."""
return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
def options(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def post(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP POST request."""
return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
def head(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def options(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP OPTIONS request."""
return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs))
def post(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def head(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP HEAD request."""
return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
def put(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def put(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PUT request."""
return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
def patch(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def patch(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PATCH request."""
return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs))
def delete(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def delete(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PATCH request."""
return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs))
else:
def ws_connect(self, path: str, **kwargs: Any) -> _WSRequestContextManager:
def request(
self, method: str, path: StrOrURL, **kwargs: Any
) -> _RequestContextManager:
"""Routes a request to tested http server.
The interface is identical to aiohttp.ClientSession.request,
except the loop kwarg is overridden by the instance used by the
test server.
"""
return _RequestContextManager(self._request(method, path, **kwargs))
def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP GET request."""
return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP POST request."""
return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP OPTIONS request."""
return _RequestContextManager(
self._request(hdrs.METH_OPTIONS, path, **kwargs)
)
def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP HEAD request."""
return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PUT request."""
return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PATCH request."""
return _RequestContextManager(
self._request(hdrs.METH_PATCH, path, **kwargs)
)
def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PATCH request."""
return _RequestContextManager(
self._request(hdrs.METH_DELETE, path, **kwargs)
)
def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager:
"""Initiate websocket connection.
The api corresponds to aiohttp.ClientSession.ws_connect.
@@ -372,7 +425,9 @@ class TestClient:
"""
return _WSRequestContextManager(self._ws_connect(path, **kwargs))
async def _ws_connect(self, path: str, **kwargs: Any) -> ClientWebSocketResponse:
async def _ws_connect(
self, path: StrOrURL, **kwargs: Any
) -> ClientWebSocketResponse:
ws = await self._session.ws_connect(self.make_url(path), **kwargs)
self._websockets.append(ws)
return ws
@@ -423,7 +478,7 @@ class TestClient:
await self.close()
class AioHTTPTestCase(TestCase):
class AioHTTPTestCase(IsolatedAsyncioTestCase):
"""A base class to allow for unittest web applications using aiohttp.
Provides the following:
@@ -454,16 +509,8 @@ class AioHTTPTestCase(TestCase):
"""
raise RuntimeError("Did you forget to define get_application()?")
def setUp(self) -> None:
if not PY_38:
asyncio.get_event_loop().run_until_complete(self.asyncSetUp())
async def asyncSetUp(self) -> None:
try:
self.loop = asyncio.get_running_loop()
except (AttributeError, RuntimeError): # AttributeError->py36
self.loop = asyncio.get_event_loop_policy().get_event_loop()
self.loop = asyncio.get_running_loop()
return await self.setUpAsync()
async def setUpAsync(self) -> None:
@@ -473,10 +520,6 @@ class AioHTTPTestCase(TestCase):
await self.client.start_server()
def tearDown(self) -> None:
if not PY_38:
self.loop.run_until_complete(self.asyncTearDown())
async def asyncTearDown(self) -> None:
return await self.tearDownAsync()
@@ -531,28 +574,7 @@ def setup_test_loop(
once they are done with the loop.
"""
loop = loop_factory()
try:
module = loop.__class__.__module__
skip_watcher = "uvloop" in module
except AttributeError: # pragma: no cover
# Just in case
skip_watcher = True
asyncio.set_event_loop(loop)
if sys.platform != "win32" and not skip_watcher:
policy = asyncio.get_event_loop_policy()
watcher: asyncio.AbstractChildWatcher
try: # Python >= 3.8
# Refs:
# * https://github.com/pytest-dev/pytest-xdist/issues/620
# * https://stackoverflow.com/a/58614689/595220
# * https://bugs.python.org/issue35621
# * https://github.com/python/cpython/pull/14344
watcher = asyncio.ThreadedChildWatcher()
except AttributeError: # Python < 3.8
watcher = asyncio.SafeChildWatcher()
watcher.attach_loop(loop)
with contextlib.suppress(NotImplementedError):
policy.set_child_watcher(watcher)
return loop
@@ -613,7 +635,7 @@ def make_mocked_request(
writer: Any = sentinel,
protocol: Any = sentinel,
transport: Any = sentinel,
payload: Any = sentinel,
payload: StreamReader = EMPTY_PAYLOAD,
sslcontext: Optional[SSLContext] = None,
client_max_size: int = 1024**2,
loop: Any = ...,
@@ -625,8 +647,15 @@ def make_mocked_request(
"""
task = mock.Mock()
if loop is ...:
loop = mock.Mock()
loop.create_future.return_value = ()
# no loop passed, try to get the current one if
# its is running as we need a real loop to create
# executor jobs to be able to do testing
# with a real executor
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = mock.Mock()
loop.create_future.return_value = ()
if version < HttpVersion(1, 1):
closing = True
@@ -675,9 +704,6 @@ def make_mocked_request(
protocol.transport = transport
protocol.writer = writer
if payload is sentinel:
payload = mock.Mock()
req = Request(
message, payload, protocol, writer, task, loop, client_max_size=client_max_size
)

View File

@@ -1,5 +1,5 @@
from types import SimpleNamespace
from typing import TYPE_CHECKING, Awaitable, Optional, Type, TypeVar
from typing import TYPE_CHECKING, Awaitable, Mapping, Optional, Protocol, Type, TypeVar
import attr
from aiosignal import Signal
@@ -8,9 +8,8 @@ from yarl import URL
from .client_reqrep import ClientResponse
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from .client import ClientSession
from .typedefs import Protocol
_ParamT_contra = TypeVar("_ParamT_contra", contravariant=True)
@@ -20,8 +19,7 @@ if TYPE_CHECKING: # pragma: no cover
__client_session: ClientSession,
__trace_config_ctx: SimpleNamespace,
__params: _ParamT_contra,
) -> Awaitable[None]:
...
) -> Awaitable[None]: ...
__all__ = (
@@ -51,9 +49,9 @@ class TraceConfig:
def __init__(
self, trace_config_ctx_factory: Type[SimpleNamespace] = SimpleNamespace
) -> None:
self._on_request_start: Signal[
_SignalCallback[TraceRequestStartParams]
] = Signal(self)
self._on_request_start: Signal[_SignalCallback[TraceRequestStartParams]] = (
Signal(self)
)
self._on_request_chunk_sent: Signal[
_SignalCallback[TraceRequestChunkSentParams]
] = Signal(self)
@@ -90,12 +88,12 @@ class TraceConfig:
self._on_dns_resolvehost_end: Signal[
_SignalCallback[TraceDnsResolveHostEndParams]
] = Signal(self)
self._on_dns_cache_hit: Signal[
_SignalCallback[TraceDnsCacheHitParams]
] = Signal(self)
self._on_dns_cache_miss: Signal[
_SignalCallback[TraceDnsCacheMissParams]
] = Signal(self)
self._on_dns_cache_hit: Signal[_SignalCallback[TraceDnsCacheHitParams]] = (
Signal(self)
)
self._on_dns_cache_miss: Signal[_SignalCallback[TraceDnsCacheMissParams]] = (
Signal(self)
)
self._on_request_headers_sent: Signal[
_SignalCallback[TraceRequestHeadersSentParams]
] = Signal(self)
@@ -103,7 +101,7 @@ class TraceConfig:
self._trace_config_ctx_factory = trace_config_ctx_factory
def trace_config_ctx(
self, trace_request_ctx: Optional[SimpleNamespace] = None
self, trace_request_ctx: Optional[Mapping[str, str]] = None
) -> SimpleNamespace:
"""Return a new trace_config_ctx instance"""
return self._trace_config_ctx_factory(trace_request_ctx=trace_request_ctx)

View File

@@ -1,6 +1,5 @@
import json
import os
import sys
from typing import (
TYPE_CHECKING,
Any,
@@ -8,27 +7,20 @@ from typing import (
Callable,
Iterable,
Mapping,
Protocol,
Tuple,
Union,
)
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy, istr
from yarl import URL
from yarl import URL, Query as _Query
# These are for other modules to use (to avoid repeating the conditional import).
if sys.version_info >= (3, 8):
from typing import Final as Final, Protocol as Protocol, TypedDict as TypedDict
else:
from typing_extensions import ( # noqa: F401
Final,
Protocol as Protocol,
TypedDict as TypedDict,
)
Query = _Query
DEFAULT_JSON_ENCODER = json.dumps
DEFAULT_JSON_DECODER = json.loads
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
_CIMultiDict = CIMultiDict[str]
_CIMultiDictProxy = CIMultiDictProxy[str]
_MultiDict = MultiDict[str]
@@ -45,7 +37,13 @@ else:
Byteish = Union[bytes, bytearray, memoryview]
JSONEncoder = Callable[[Any], str]
JSONDecoder = Callable[[str], Any]
LooseHeaders = Union[Mapping[Union[str, istr], str], _CIMultiDict, _CIMultiDictProxy]
LooseHeaders = Union[
Mapping[str, str],
Mapping[istr, str],
_CIMultiDict,
_CIMultiDictProxy,
Iterable[Tuple[Union[str, istr], str]],
]
RawHeaders = Tuple[Tuple[bytes, bytes], ...]
StrOrURL = Union[str, URL]
@@ -61,4 +59,11 @@ LooseCookies = Union[
Handler = Callable[["Request"], Awaitable["StreamResponse"]]
class Middleware(Protocol):
def __call__(
self, request: "Request", handler: Handler
) -> Awaitable["StreamResponse"]: ...
PathLike = Union[str, "os.PathLike[str]"]

View File

@@ -1,9 +1,12 @@
import asyncio
import logging
import os
import socket
import sys
import warnings
from argparse import ArgumentParser
from collections.abc import Iterable
from contextlib import suppress
from importlib import import_module
from typing import (
Any,
@@ -19,8 +22,9 @@ from typing import (
)
from .abc import AbstractAccessLogger
from .helpers import all_tasks
from .helpers import AppKey as AppKey
from .log import access_logger
from .typedefs import PathLike
from .web_app import Application as Application, CleanupError as CleanupError
from .web_exceptions import (
HTTPAccepted as HTTPAccepted,
@@ -42,6 +46,7 @@ from .web_exceptions import (
HTTPLengthRequired as HTTPLengthRequired,
HTTPMethodNotAllowed as HTTPMethodNotAllowed,
HTTPMisdirectedRequest as HTTPMisdirectedRequest,
HTTPMove as HTTPMove,
HTTPMovedPermanently as HTTPMovedPermanently,
HTTPMultipleChoices as HTTPMultipleChoices,
HTTPNetworkAuthenticationRequired as HTTPNetworkAuthenticationRequired,
@@ -80,6 +85,7 @@ from .web_exceptions import (
HTTPUseProxy as HTTPUseProxy,
HTTPVariantAlsoNegotiates as HTTPVariantAlsoNegotiates,
HTTPVersionNotSupported as HTTPVersionNotSupported,
NotAppKeyWarning as NotAppKeyWarning,
)
from .web_fileresponse import FileResponse as FileResponse
from .web_log import AccessLogger
@@ -152,9 +158,11 @@ from .web_ws import (
__all__ = (
# web_app
"AppKey",
"Application",
"CleanupError",
# web_exceptions
"NotAppKeyWarning",
"HTTPAccepted",
"HTTPBadGateway",
"HTTPBadRequest",
@@ -174,6 +182,7 @@ __all__ = (
"HTTPLengthRequired",
"HTTPMethodNotAllowed",
"HTTPMisdirectedRequest",
"HTTPMove",
"HTTPMovedPermanently",
"HTTPMultipleChoices",
"HTTPNetworkAuthenticationRequired",
@@ -283,6 +292,9 @@ try:
except ImportError: # pragma: no cover
SSLContext = Any # type: ignore[misc,assignment]
# Only display warning when using -Wdefault, -We, -X dev or similar.
warnings.filterwarnings("ignore", category=NotAppKeyWarning, append=True)
HostSequence = TypingIterable[str]
@@ -291,12 +303,12 @@ async def _run_app(
*,
host: Optional[Union[str, HostSequence]] = None,
port: Optional[int] = None,
path: Optional[str] = None,
path: Union[PathLike, TypingIterable[PathLike], None] = None,
sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None,
shutdown_timeout: float = 60.0,
keepalive_timeout: float = 75.0,
ssl_context: Optional[SSLContext] = None,
print: Callable[..., None] = print,
print: Optional[Callable[..., None]] = print,
backlog: int = 128,
access_log_class: Type[AbstractAccessLogger] = AccessLogger,
access_log_format: str = AccessLogger.LOG_FORMAT,
@@ -304,10 +316,11 @@ async def _run_app(
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
) -> None:
# A internal functio to actually do all dirty job for application running
# An internal function to actually do all dirty job for application running
if asyncio.iscoroutine(app):
app = await app # type: ignore[misc]
app = await app
app = cast(Application, app)
@@ -318,6 +331,8 @@ async def _run_app(
access_log_format=access_log_format,
access_log=access_log,
keepalive_timeout=keepalive_timeout,
shutdown_timeout=shutdown_timeout,
handler_cancellation=handler_cancellation,
)
await runner.setup()
@@ -332,7 +347,6 @@ async def _run_app(
runner,
host,
port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
@@ -346,7 +360,6 @@ async def _run_app(
runner,
h,
port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
@@ -358,7 +371,6 @@ async def _run_app(
TCPSite(
runner,
port=port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
@@ -367,12 +379,11 @@ async def _run_app(
)
if path is not None:
if isinstance(path, (str, bytes, bytearray, memoryview)):
if isinstance(path, (str, os.PathLike)):
sites.append(
UnixSite(
runner,
path,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
@@ -383,7 +394,6 @@ async def _run_app(
UnixSite(
runner,
p,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
@@ -395,7 +405,6 @@ async def _run_app(
SockSite(
runner,
sock,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
@@ -406,7 +415,6 @@ async def _run_app(
SockSite(
runner,
s,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
@@ -422,15 +430,8 @@ async def _run_app(
)
# sleep forever by 1 hour intervals,
# on Windows before Python 3.8 wake up every 1 second to handle
# Ctrl+C smoothly
if sys.platform == "win32" and sys.version_info < (3, 8):
delay = 1
else:
delay = 3600
while True:
await asyncio.sleep(delay)
await asyncio.sleep(3600)
finally:
await runner.cleanup()
@@ -464,12 +465,12 @@ def run_app(
*,
host: Optional[Union[str, HostSequence]] = None,
port: Optional[int] = None,
path: Optional[str] = None,
path: Union[PathLike, TypingIterable[PathLike], None] = None,
sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None,
shutdown_timeout: float = 60.0,
keepalive_timeout: float = 75.0,
ssl_context: Optional[SSLContext] = None,
print: Callable[..., None] = print,
print: Optional[Callable[..., None]] = print,
backlog: int = 128,
access_log_class: Type[AbstractAccessLogger] = AccessLogger,
access_log_format: str = AccessLogger.LOG_FORMAT,
@@ -477,6 +478,7 @@ def run_app(
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Run an app locally"""
@@ -508,6 +510,7 @@ def run_app(
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
handler_cancellation=handler_cancellation,
)
)
@@ -517,10 +520,14 @@ def run_app(
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
finally:
_cancel_tasks({main_task}, loop)
_cancel_tasks(all_tasks(loop), loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
try:
main_task.cancel()
with suppress(asyncio.CancelledError):
loop.run_until_complete(main_task)
finally:
_cancel_tasks(asyncio.all_tasks(loop), loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
def main(argv: List[str]) -> None:

View File

@@ -1,7 +1,7 @@
import asyncio
import logging
import warnings
from functools import partial, update_wrapper
from functools import lru_cache, partial, update_wrapper
from typing import (
TYPE_CHECKING,
Any,
@@ -18,8 +18,10 @@ from typing import (
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
from aiosignal import Signal
@@ -32,10 +34,12 @@ from .abc import (
AbstractRouter,
AbstractStreamWriter,
)
from .helpers import DEBUG
from .helpers import DEBUG, AppKey
from .http_parser import RawRequestMessage
from .log import web_logger
from .streams import StreamReader
from .typedefs import Handler, Middleware
from .web_exceptions import NotAppKeyWarning
from .web_log import AccessLogger
from .web_middlewares import _fix_request_current_app
from .web_protocol import RequestHandler
@@ -50,35 +54,46 @@ from .web_urldispatcher import (
MaskDomain,
MatchedSubAppResource,
PrefixedSubAppResource,
SystemRoute,
UrlDispatcher,
)
__all__ = ("Application", "CleanupError")
if TYPE_CHECKING: # pragma: no cover
from .typedefs import Handler
if TYPE_CHECKING:
_AppSignal = Signal[Callable[["Application"], Awaitable[None]]]
_RespPrepareSignal = Signal[Callable[[Request, StreamResponse], Awaitable[None]]]
_Middleware = Union[
Callable[[Request, Handler], Awaitable[StreamResponse]],
Callable[["Application", Handler], Awaitable[Handler]], # old-style
]
_Middlewares = FrozenList[_Middleware]
_MiddlewaresHandlers = Optional[Sequence[Tuple[_Middleware, bool]]]
_Middlewares = FrozenList[Middleware]
_MiddlewaresHandlers = Optional[Sequence[Tuple[Middleware, bool]]]
_Subapps = List["Application"]
else:
# No type checker mode, skip types
_AppSignal = Signal
_RespPrepareSignal = Signal
_Middleware = Callable
_Middlewares = FrozenList
_MiddlewaresHandlers = Optional[Sequence]
_Subapps = List
_T = TypeVar("_T")
_U = TypeVar("_U")
_Resource = TypeVar("_Resource", bound=AbstractResource)
class Application(MutableMapping[str, Any]):
def _build_middlewares(
handler: Handler, apps: Tuple["Application", ...]
) -> Callable[[Request], Awaitable[StreamResponse]]:
"""Apply middlewares to handler."""
for app in apps[::-1]:
for m, _ in app._middlewares_handlers: # type: ignore[union-attr]
handler = update_wrapper(partial(m, handler=handler), handler) # type: ignore[misc]
return handler
_cached_build_middleware = lru_cache(maxsize=1024)(_build_middlewares)
class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
ATTRS = frozenset(
[
"logger",
@@ -88,6 +103,7 @@ class Application(MutableMapping[str, Any]):
"_handler_args",
"_middlewares",
"_middlewares_handlers",
"_has_legacy_middlewares",
"_run_middlewares",
"_state",
"_frozen",
@@ -107,7 +123,7 @@ class Application(MutableMapping[str, Any]):
*,
logger: logging.Logger = web_logger,
router: Optional[UrlDispatcher] = None,
middlewares: Iterable[_Middleware] = (),
middlewares: Iterable[Middleware] = (),
handler_args: Optional[Mapping[str, Any]] = None,
client_max_size: int = 1024**2,
loop: Optional[asyncio.AbstractEventLoop] = None,
@@ -142,8 +158,9 @@ class Application(MutableMapping[str, Any]):
self._middlewares_handlers: _MiddlewaresHandlers = None
# initialized on freezing
self._run_middlewares: Optional[bool] = None
self._has_legacy_middlewares: bool = True
self._state: Dict[str, Any] = {}
self._state: Dict[Union[AppKey[Any], str], object] = {}
self._frozen = False
self._pre_frozen = False
self._subapps: _Subapps = []
@@ -162,7 +179,7 @@ class Application(MutableMapping[str, Any]):
"Inheritance class {} from web.Application "
"is discouraged".format(cls.__name__),
DeprecationWarning,
stacklevel=2,
stacklevel=3,
)
if DEBUG: # pragma: no cover
@@ -182,7 +199,13 @@ class Application(MutableMapping[str, Any]):
def __eq__(self, other: object) -> bool:
return self is other
def __getitem__(self, key: str) -> Any:
@overload # type: ignore[override]
def __getitem__(self, key: AppKey[_T]) -> _T: ...
@overload
def __getitem__(self, key: str) -> Any: ...
def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
return self._state[key]
def _check_frozen(self) -> None:
@@ -193,26 +216,55 @@ class Application(MutableMapping[str, Any]):
stacklevel=3,
)
def __setitem__(self, key: str, value: Any) -> None:
@overload # type: ignore[override]
def __setitem__(self, key: AppKey[_T], value: _T) -> None: ...
@overload
def __setitem__(self, key: str, value: Any) -> None: ...
def __setitem__(self, key: Union[str, AppKey[_T]], value: Any) -> None:
self._check_frozen()
if not isinstance(key, AppKey):
warnings.warn(
"It is recommended to use web.AppKey instances for keys.\n"
+ "https://docs.aiohttp.org/en/stable/web_advanced.html"
+ "#application-s-config",
category=NotAppKeyWarning,
stacklevel=2,
)
self._state[key] = value
def __delitem__(self, key: str) -> None:
def __delitem__(self, key: Union[str, AppKey[_T]]) -> None:
self._check_frozen()
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[str]:
def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
return iter(self._state)
def __hash__(self) -> int:
return id(self)
@overload # type: ignore[override]
def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
@overload
def get(self, key: AppKey[_T], default: _U) -> Union[_T, _U]: ...
@overload
def get(self, key: str, default: Any = ...) -> Any: ...
def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
return self._state.get(key, default)
########
@property
def loop(self) -> asyncio.AbstractEventLoop:
# Technically the loop can be None
# but we mask it by explicit type cast
# to provide more convinient type annotation
# to provide more convenient type annotation
warnings.warn("loop property is deprecated", DeprecationWarning, stacklevel=2)
return cast(asyncio.AbstractEventLoop, self._loop)
@@ -251,6 +303,9 @@ class Application(MutableMapping[str, Any]):
self._on_shutdown.freeze()
self._on_cleanup.freeze()
self._middlewares_handlers = tuple(self._prepare_middleware())
self._has_legacy_middlewares = any(
not new_style for _, new_style in self._middlewares_handlers
)
# If current app and any subapp do not have middlewares avoid run all
# of the code footprint that it implies, which have a middleware
@@ -295,7 +350,7 @@ class Application(MutableMapping[str, Any]):
reg_handler("on_shutdown")
reg_handler("on_cleanup")
def add_subapp(self, prefix: str, subapp: "Application") -> AbstractResource:
def add_subapp(self, prefix: str, subapp: "Application") -> PrefixedSubAppResource:
if not isinstance(prefix, str):
raise TypeError("Prefix must be str")
prefix = prefix.rstrip("/")
@@ -305,8 +360,8 @@ class Application(MutableMapping[str, Any]):
return self._add_subapp(factory, subapp)
def _add_subapp(
self, resource_factory: Callable[[], AbstractResource], subapp: "Application"
) -> AbstractResource:
self, resource_factory: Callable[[], _Resource], subapp: "Application"
) -> _Resource:
if self.frozen:
raise RuntimeError("Cannot add sub application to frozen application")
if subapp.frozen:
@@ -320,7 +375,7 @@ class Application(MutableMapping[str, Any]):
subapp._set_loop(self._loop)
return resource
def add_domain(self, domain: str, subapp: "Application") -> AbstractResource:
def add_domain(self, domain: str, subapp: "Application") -> MatchedSubAppResource:
if not isinstance(domain, str):
raise TypeError("Domain must be str")
elif "*" in domain:
@@ -453,7 +508,7 @@ class Application(MutableMapping[str, Any]):
client_max_size=self._client_max_size,
)
def _prepare_middleware(self) -> Iterator[Tuple[_Middleware, bool]]:
def _prepare_middleware(self) -> Iterator[Tuple[Middleware, bool]]:
for m in reversed(self._middlewares):
if getattr(m, "__middleware_version__", None) == 1:
yield m, True
@@ -481,29 +536,35 @@ class Application(MutableMapping[str, Any]):
match_info.freeze()
resp = None
request._match_info = match_info
expect = request.headers.get(hdrs.EXPECT)
if expect:
if request.headers.get(hdrs.EXPECT):
resp = await match_info.expect_handler(request)
await request.writer.drain()
if resp is not None:
return resp
if resp is None:
handler = match_info.handler
handler = match_info.handler
if self._run_middlewares:
if self._run_middlewares:
# If its a SystemRoute, don't cache building the middlewares since
# they are constructed for every MatchInfoError as a new handler
# is made each time.
if not self._has_legacy_middlewares and not isinstance(
match_info.route, SystemRoute
):
handler = _cached_build_middleware(handler, match_info.apps)
else:
for app in match_info.apps[::-1]:
for m, new_style in app._middlewares_handlers: # type: ignore[union-attr] # noqa
for m, new_style in app._middlewares_handlers: # type: ignore[union-attr]
if new_style:
handler = update_wrapper(
partial(m, handler=handler), handler
partial(m, handler=handler), handler # type: ignore[misc]
)
else:
handler = await m(app, handler) # type: ignore[arg-type]
handler = await m(app, handler) # type: ignore[arg-type,assignment]
resp = await handler(request)
return resp
return await handler(request)
def __call__(self) -> "Application":
"""gunicorn compatibility"""
@@ -522,7 +583,7 @@ class CleanupError(RuntimeError):
return cast(List[BaseException], self.args[1])
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
_CleanupContextBase = FrozenList[Callable[[Application], AsyncIterator[None]]]
else:
_CleanupContextBase = FrozenList
@@ -546,7 +607,7 @@ class CleanupContext(_CleanupContextBase):
await it.__anext__()
except StopAsyncIteration:
pass
except Exception as exc:
except (Exception, asyncio.CancelledError) as exc:
errors.append(exc)
else:
errors.append(RuntimeError(f"{it!r} has more than one 'yield'"))

View File

@@ -18,6 +18,7 @@ __all__ = (
"HTTPNoContent",
"HTTPResetContent",
"HTTPPartialContent",
"HTTPMove",
"HTTPMultipleChoices",
"HTTPMovedPermanently",
"HTTPFound",
@@ -67,6 +68,10 @@ __all__ = (
)
class NotAppKeyWarning(UserWarning):
"""Warning when not using AppKey in Application."""
############################################################
# HTTP Exceptions
############################################################
@@ -160,7 +165,7 @@ class HTTPPartialContent(HTTPSuccessful):
############################################################
class _HTTPMove(HTTPRedirection):
class HTTPMove(HTTPRedirection):
def __init__(
self,
location: StrOrURL,
@@ -184,21 +189,21 @@ class _HTTPMove(HTTPRedirection):
self.location = location
class HTTPMultipleChoices(_HTTPMove):
class HTTPMultipleChoices(HTTPMove):
status_code = 300
class HTTPMovedPermanently(_HTTPMove):
class HTTPMovedPermanently(HTTPMove):
status_code = 301
class HTTPFound(_HTTPMove):
class HTTPFound(HTTPMove):
status_code = 302
# This one is safe after a POST (the redirected location will be
# retrieved with GET):
class HTTPSeeOther(_HTTPMove):
class HTTPSeeOther(HTTPMove):
status_code = 303
@@ -208,16 +213,16 @@ class HTTPNotModified(HTTPRedirection):
empty_body = True
class HTTPUseProxy(_HTTPMove):
class HTTPUseProxy(HTTPMove):
# Not a move, but looks a little like one
status_code = 305
class HTTPTemporaryRedirect(_HTTPMove):
class HTTPTemporaryRedirect(HTTPMove):
status_code = 307
class HTTPPermanentRedirect(_HTTPMove):
class HTTPPermanentRedirect(HTTPMove):
status_code = 308
@@ -366,7 +371,7 @@ class HTTPUnavailableForLegalReasons(HTTPClientError):
def __init__(
self,
link: str,
link: Optional[StrOrURL],
*,
headers: Optional[LooseHeaders] = None,
reason: Optional[str] = None,
@@ -381,8 +386,14 @@ class HTTPUnavailableForLegalReasons(HTTPClientError):
text=text,
content_type=content_type,
)
self.headers["Link"] = '<%s>; rel="blocked-by"' % link
self.link = link
self._link = None
if link:
self._link = URL(link)
self.headers["Link"] = f'<{str(self._link)}>; rel="blocked-by"'
@property
def link(self) -> Optional[URL]:
return self._link
############################################################

View File

@@ -1,14 +1,18 @@
import asyncio
import mimetypes
import os
import pathlib
import sys
from contextlib import suppress
from mimetypes import MimeTypes
from stat import S_ISREG
from types import MappingProxyType
from typing import ( # noqa
IO,
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Final,
Iterator,
List,
Optional,
@@ -19,9 +23,11 @@ from typing import ( # noqa
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import ETAG_ANY, ETag
from .typedefs import Final, LooseHeaders
from .helpers import ETAG_ANY, ETag, must_be_empty_body
from .typedefs import LooseHeaders, PathLike
from .web_exceptions import (
HTTPForbidden,
HTTPNotFound,
HTTPNotModified,
HTTPPartialContent,
HTTPPreconditionFailed,
@@ -31,7 +37,7 @@ from .web_response import StreamResponse
__all__ = ("FileResponse",)
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from .web_request import BaseRequest
@@ -40,13 +46,42 @@ _T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
NOSENDFILE: Final[bool] = bool(os.environ.get("AIOHTTP_NOSENDFILE"))
CONTENT_TYPES: Final[MimeTypes] = MimeTypes()
if sys.version_info < (3, 9):
CONTENT_TYPES.encodings_map[".br"] = "br"
# File extension to IANA encodings map that will be checked in the order defined.
ENCODING_EXTENSIONS = MappingProxyType(
{ext: CONTENT_TYPES.encodings_map[ext] for ext in (".br", ".gz")}
)
FALLBACK_CONTENT_TYPE = "application/octet-stream"
# Provide additional MIME type/extension pairs to be recognized.
# https://en.wikipedia.org/wiki/List_of_archive_formats#Compression_only
ADDITIONAL_CONTENT_TYPES = MappingProxyType(
{
"application/gzip": ".gz",
"application/x-brotli": ".br",
"application/x-bzip2": ".bz2",
"application/x-compress": ".Z",
"application/x-xz": ".xz",
}
)
# Add custom pairs and clear the encodings map so guess_type ignores them.
CONTENT_TYPES.encodings_map.clear()
for content_type, extension in ADDITIONAL_CONTENT_TYPES.items():
CONTENT_TYPES.add_type(content_type, extension) # type: ignore[attr-defined]
class FileResponse(StreamResponse):
"""A response object can be used to send files."""
def __init__(
self,
path: Union[str, pathlib.Path],
path: PathLike,
chunk_size: int = 256 * 1024,
status: int = 200,
reason: Optional[str] = None,
@@ -54,10 +89,7 @@ class FileResponse(StreamResponse):
) -> None:
super().__init__(status=status, reason=reason, headers=headers)
if isinstance(path, str):
path = pathlib.Path(path)
self._path = path
self._path = pathlib.Path(path)
self._chunk_size = chunk_size
async def _sendfile_fallback(
@@ -88,7 +120,7 @@ class FileResponse(StreamResponse):
writer = await super().prepare(request)
assert writer is not None
if NOSENDFILE or sys.version_info < (3, 7) or self.compression:
if NOSENDFILE or self.compression:
return await self._sendfile_fallback(writer, fobj, offset, count)
loop = request._loop
@@ -104,10 +136,12 @@ class FileResponse(StreamResponse):
return writer
@staticmethod
def _strong_etag_match(etag_value: str, etags: Tuple[ETag, ...]) -> bool:
def _etag_match(etag_value: str, etags: Tuple[ETag, ...], *, weak: bool) -> bool:
if len(etags) == 1 and etags[0].value == ETAG_ANY:
return True
return any(etag.value == etag_value for etag in etags if not etag.is_weak)
return any(
etag.value == etag_value for etag in etags if weak or not etag.is_weak
)
async def _not_modified(
self, request: "BaseRequest", etag_value: str, last_modified: float
@@ -127,26 +161,60 @@ class FileResponse(StreamResponse):
self.content_length = 0
return await super().prepare(request)
def _get_file_path_stat_encoding(
self, accept_encoding: str
) -> Tuple[pathlib.Path, os.stat_result, Optional[str]]:
"""Return the file path, stat result, and encoding.
If an uncompressed file is returned, the encoding is set to
:py:data:`None`.
This method should be called from a thread executor
since it calls os.stat which may block.
"""
file_path = self._path
for file_extension, file_encoding in ENCODING_EXTENSIONS.items():
if file_encoding not in accept_encoding:
continue
compressed_path = file_path.with_suffix(file_path.suffix + file_extension)
with suppress(OSError):
# Do not follow symlinks and ignore any non-regular files.
st = compressed_path.lstat()
if S_ISREG(st.st_mode):
return compressed_path, st, file_encoding
# Fallback to the uncompressed file
return file_path, file_path.stat(), None
async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
filepath = self._path
loop = asyncio.get_running_loop()
# Encoding comparisons should be case-insensitive
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
try:
file_path, st, file_encoding = await loop.run_in_executor(
None, self._get_file_path_stat_encoding, accept_encoding
)
except OSError:
# Most likely to be FileNotFoundError or OSError for circular
# symlinks in python >= 3.13, so respond with 404.
self.set_status(HTTPNotFound.status_code)
return await super().prepare(request)
gzip = False
if "gzip" in request.headers.get(hdrs.ACCEPT_ENCODING, ""):
gzip_path = filepath.with_name(filepath.name + ".gz")
if gzip_path.is_file():
filepath = gzip_path
gzip = True
loop = asyncio.get_event_loop()
st: os.stat_result = await loop.run_in_executor(None, filepath.stat)
# Forbid special files like sockets, pipes, devices, etc.
if not S_ISREG(st.st_mode):
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
last_modified = st.st_mtime
# https://tools.ietf.org/html/rfc7232#section-6
# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2
ifmatch = request.if_match
if ifmatch is not None and not self._strong_etag_match(etag_value, ifmatch):
if ifmatch is not None and not self._etag_match(
etag_value, ifmatch, weak=False
):
return await self._precondition_failed(request)
unmodsince = request.if_unmodified_since
@@ -157,8 +225,11 @@ class FileResponse(StreamResponse):
):
return await self._precondition_failed(request)
# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2
ifnonematch = request.if_none_match
if ifnonematch is not None and self._strong_etag_match(etag_value, ifnonematch):
if ifnonematch is not None and self._etag_match(
etag_value, ifnonematch, weak=True
):
return await self._not_modified(request, etag_value, last_modified)
modsince = request.if_modified_since
@@ -169,15 +240,6 @@ class FileResponse(StreamResponse):
):
return await self._not_modified(request, etag_value, last_modified)
if hdrs.CONTENT_TYPE not in self.headers:
ct, encoding = mimetypes.guess_type(str(filepath))
if not ct:
ct = "application/octet-stream"
should_set_ct = True
else:
encoding = "gzip" if gzip else None
should_set_ct = False
status = self._status
file_size = st.st_size
count = file_size
@@ -252,12 +314,21 @@ class FileResponse(StreamResponse):
# return a HTTP 206 for a Range request.
self.set_status(status)
if should_set_ct:
self.content_type = ct # type: ignore[assignment]
if encoding:
self.headers[hdrs.CONTENT_ENCODING] = encoding
if gzip:
# If the Content-Type header is not already set, guess it based on the
# extension of the request path. The encoding returned by guess_type
# can be ignored since the map was cleared above.
if hdrs.CONTENT_TYPE not in self.headers:
self.content_type = (
CONTENT_TYPES.guess_type(self._path)[0] or FALLBACK_CONTENT_TYPE
)
if file_encoding:
self.headers[hdrs.CONTENT_ENCODING] = file_encoding
self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
# Disable compression if we are already sending
# a compressed file since we don't want to double
# compress.
self._compression = False
self.etag = etag_value # type: ignore[assignment]
self.last_modified = st.st_mtime # type: ignore[assignment]
@@ -273,10 +344,15 @@ class FileResponse(StreamResponse):
)
# If we are sending 0 bytes calling sendfile() will throw a ValueError
if count == 0 or request.method == hdrs.METH_HEAD or self.status in [204, 304]:
if count == 0 or must_be_empty_body(request.method, self.status):
return await super().prepare(request)
try:
fobj = await loop.run_in_executor(None, file_path.open, "rb")
except PermissionError:
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
fobj = await loop.run_in_executor(None, filepath.open, "rb")
if start: # be aware that start could be None or int=0 here.
offset = start
else:
@@ -285,4 +361,4 @@ class FileResponse(StreamResponse):
try:
return await self._sendfile(request, fobj, offset, count)
finally:
await loop.run_in_executor(None, fobj.close)
await asyncio.shield(loop.run_in_executor(None, fobj.close))

View File

@@ -3,6 +3,7 @@ import functools
import logging
import os
import re
import time as time_mod
from collections import namedtuple
from typing import Any, Callable, Dict, Iterable, List, Tuple # noqa
@@ -142,9 +143,10 @@ class AccessLogger(AbstractAccessLogger):
@staticmethod
def _format_t(request: BaseRequest, response: StreamResponse, time: float) -> str:
now = datetime.datetime.utcnow()
tz = datetime.timezone(datetime.timedelta(seconds=-time_mod.timezone))
now = datetime.datetime.now(tz)
start_time = now - datetime.timedelta(seconds=time)
return start_time.strftime("[%d/%b/%Y:%H:%M:%S +0000]")
return start_time.strftime("[%d/%b/%Y:%H:%M:%S %z]")
@staticmethod
def _format_P(request: BaseRequest, response: StreamResponse, time: float) -> str:
@@ -187,6 +189,9 @@ class AccessLogger(AbstractAccessLogger):
return [(key, method(request, response, time)) for key, method in self._methods]
def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None:
if not self.logger.isEnabledFor(logging.INFO):
# Avoid formatting the log line if it will not be emitted.
return
try:
fmt_info = self._format_line(request, response, time)

View File

@@ -1,8 +1,8 @@
import re
from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, Type, TypeVar
from typing import TYPE_CHECKING, Tuple, Type, TypeVar
from .typedefs import Handler
from .web_exceptions import HTTPPermanentRedirect, _HTTPMove
from .typedefs import Handler, Middleware
from .web_exceptions import HTTPMove, HTTPPermanentRedirect
from .web_request import Request
from .web_response import StreamResponse
from .web_urldispatcher import SystemRoute
@@ -12,7 +12,7 @@ __all__ = (
"normalize_path_middleware",
)
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from .web_app import Application
_Func = TypeVar("_Func")
@@ -35,16 +35,13 @@ def middleware(f: _Func) -> _Func:
return f
_Middleware = Callable[[Request, Handler], Awaitable[StreamResponse]]
def normalize_path_middleware(
*,
append_slash: bool = True,
remove_slash: bool = False,
merge_slashes: bool = True,
redirect_class: Type[_HTTPMove] = HTTPPermanentRedirect,
) -> _Middleware:
redirect_class: Type[HTTPMove] = HTTPPermanentRedirect,
) -> Middleware:
"""Factory for producing a middleware that normalizes the path of a request.
Normalizing means:
@@ -110,10 +107,15 @@ def normalize_path_middleware(
return impl
def _fix_request_current_app(app: "Application") -> _Middleware:
def _fix_request_current_app(app: "Application") -> Middleware:
@middleware
async def impl(request: Request, handler: Handler) -> StreamResponse:
with request.match_info.set_current_app(app):
match_info = request.match_info
prev = match_info.current_app
match_info.current_app = app
try:
return await handler(request)
finally:
match_info.current_app = prev
return impl

View File

@@ -1,5 +1,6 @@
import asyncio
import asyncio.streams
import sys
import traceback
import warnings
from collections import deque
@@ -37,14 +38,14 @@ from .http import (
from .log import access_logger, server_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .tcp_helpers import tcp_keepalive
from .web_exceptions import HTTPException
from .web_exceptions import HTTPException, HTTPInternalServerError
from .web_log import AccessLogger
from .web_request import BaseRequest
from .web_response import Response, StreamResponse
__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError")
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from .web_server import Server
@@ -83,6 +84,9 @@ class PayloadAccessError(Exception):
"""Payload was accessed after response was sent."""
_PAYLOAD_ACCESS_ERROR = PayloadAccessError()
@attr.s(auto_attribs=True, frozen=True, slots=True)
class _ErrInfo:
status: int
@@ -127,9 +131,11 @@ class RequestHandler(BaseProtocol):
max_headers -- Optional maximum header size
"""
timeout_ceil_threshold -- Optional value to specify
threshold to ceil() timeout
values
KEEPALIVE_RESCHEDULE_DELAY = 1
"""
__slots__ = (
"_request_count",
@@ -138,12 +144,13 @@ class RequestHandler(BaseProtocol):
"_request_handler",
"_request_factory",
"_tcp_keepalive",
"_keepalive_time",
"_next_keepalive_close_time",
"_keepalive_handle",
"_keepalive_timeout",
"_lingering_time",
"_messages",
"_message_tail",
"_handler_waiter",
"_waiter",
"_task_handler",
"_upgrade",
@@ -157,6 +164,8 @@ class RequestHandler(BaseProtocol):
"_close",
"_force_close",
"_current_request",
"_timeout_ceil_threshold",
"_request_in_progress",
)
def __init__(
@@ -177,6 +186,7 @@ class RequestHandler(BaseProtocol):
lingering_time: float = 10.0,
read_bufsize: int = 2**16,
auto_decompress: bool = True,
timeout_ceil_threshold: float = 5,
):
super().__init__(loop)
@@ -189,7 +199,7 @@ class RequestHandler(BaseProtocol):
self._tcp_keepalive = tcp_keepalive
# placeholder to be replaced on keepalive timeout setup
self._keepalive_time = 0.0
self._next_keepalive_close_time = 0.0
self._keepalive_handle: Optional[asyncio.Handle] = None
self._keepalive_timeout = keepalive_timeout
self._lingering_time = float(lingering_time)
@@ -198,6 +208,7 @@ class RequestHandler(BaseProtocol):
self._message_tail = b""
self._waiter: Optional[asyncio.Future[None]] = None
self._handler_waiter: Optional[asyncio.Future[None]] = None
self._task_handler: Optional[asyncio.Task[None]] = None
self._upgrade = False
@@ -213,6 +224,12 @@ class RequestHandler(BaseProtocol):
auto_decompress=auto_decompress,
)
self._timeout_ceil_threshold: float = 5
try:
self._timeout_ceil_threshold = float(timeout_ceil_threshold)
except (TypeError, ValueError):
pass
self.logger = logger
self.debug = debug
self.access_log = access_log
@@ -225,6 +242,7 @@ class RequestHandler(BaseProtocol):
self._close = False
self._force_close = False
self._request_in_progress = False
def __repr__(self) -> str:
return "<{} {}>".format(
@@ -247,25 +265,44 @@ class RequestHandler(BaseProtocol):
if self._keepalive_handle is not None:
self._keepalive_handle.cancel()
if self._waiter:
self._waiter.cancel()
# wait for handlers
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
# Wait for graceful handler completion
if self._request_in_progress:
# The future is only created when we are shutting
# down while the handler is still processing a request
# to avoid creating a future for every request.
self._handler_waiter = self._loop.create_future()
try:
async with ceil_timeout(timeout):
await self._handler_waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._handler_waiter = None
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
# Then cancel handler and wait
try:
async with ceil_timeout(timeout):
if self._current_request is not None:
self._current_request._cancel(asyncio.CancelledError())
if self._task_handler is not None and not self._task_handler.done():
await self._task_handler
await asyncio.shield(self._task_handler)
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
# force-close non-idle handler
if self._task_handler is not None:
self._task_handler.cancel()
if self.transport is not None:
self.transport.close()
self.transport = None
self.force_close()
def connection_made(self, transport: asyncio.BaseTransport) -> None:
super().connection_made(transport)
@@ -274,19 +311,27 @@ class RequestHandler(BaseProtocol):
if self._tcp_keepalive:
tcp_keepalive(real_transport)
self._task_handler = self._loop.create_task(self.start())
assert self._manager is not None
self._manager.connection_made(self, real_transport)
loop = self._loop
if sys.version_info >= (3, 12):
task = asyncio.Task(self.start(), loop=loop, eager_start=True)
else:
task = loop.create_task(self.start())
self._task_handler = task
def connection_lost(self, exc: Optional[BaseException]) -> None:
if self._manager is None:
return
self._manager.connection_lost(self, exc)
super().connection_lost(exc)
# Grab value before setting _manager to None.
handler_cancellation = self._manager.handler_cancellation
self.force_close()
super().connection_lost(exc)
self._manager = None
self._force_close = True
self._request_factory = None
self._request_handler = None
self._request_parser = None
@@ -299,8 +344,8 @@ class RequestHandler(BaseProtocol):
exc = ConnectionResetError("Connection lost")
self._current_request._cancel(exc)
if self._waiter is not None:
self._waiter.cancel()
if handler_cancellation and self._task_handler is not None:
self._task_handler.cancel()
self._task_handler = None
@@ -403,22 +448,21 @@ class RequestHandler(BaseProtocol):
self.logger.exception(*args, **kw)
def _process_keepalive(self) -> None:
self._keepalive_handle = None
if self._force_close or not self._keepalive:
return
next = self._keepalive_time + self._keepalive_timeout
loop = self._loop
now = loop.time()
close_time = self._next_keepalive_close_time
if now <= close_time:
# Keep alive close check fired too early, reschedule
self._keepalive_handle = loop.call_at(close_time, self._process_keepalive)
return
# handler in idle state
if self._waiter:
if self._loop.time() > next:
self.force_close()
return
# not all request handlers are done,
# reschedule itself to next second
self._keepalive_handle = self._loop.call_later(
self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive
)
if self._waiter and not self._waiter.done():
self.force_close()
async def _handle_request(
self,
@@ -426,7 +470,7 @@ class RequestHandler(BaseProtocol):
start_time: float,
request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]],
) -> Tuple[StreamResponse, bool]:
assert self._request_handler is not None
self._request_in_progress = True
try:
try:
self._current_request = request
@@ -435,16 +479,16 @@ class RequestHandler(BaseProtocol):
self._current_request = None
except HTTPException as exc:
resp = exc
reset = await self.finish_response(request, resp, start_time)
resp, reset = await self.finish_response(request, resp, start_time)
except asyncio.CancelledError:
raise
except asyncio.TimeoutError as exc:
self.log_debug("Request handler timed out.", exc_info=exc)
resp = self.handle_error(request, 504)
reset = await self.finish_response(request, resp, start_time)
resp, reset = await self.finish_response(request, resp, start_time)
except Exception as exc:
resp = self.handle_error(request, 500, exc)
reset = await self.finish_response(request, resp, start_time)
resp, reset = await self.finish_response(request, resp, start_time)
else:
# Deprecation warning (See #2415)
if getattr(resp, "__http_exception__", False):
@@ -455,7 +499,11 @@ class RequestHandler(BaseProtocol):
DeprecationWarning,
)
reset = await self.finish_response(request, resp, start_time)
resp, reset = await self.finish_response(request, resp, start_time)
finally:
self._request_in_progress = False
if self._handler_waiter is not None:
self._handler_waiter.set_result(None)
return resp, reset
@@ -469,7 +517,7 @@ class RequestHandler(BaseProtocol):
keep_alive(True) specified.
"""
loop = self._loop
handler = self._task_handler
handler = asyncio.current_task(loop)
assert handler is not None
manager = self._manager
assert manager is not None
@@ -484,8 +532,6 @@ class RequestHandler(BaseProtocol):
# wait for next request
self._waiter = loop.create_future()
await self._waiter
except asyncio.CancelledError:
break
finally:
self._waiter = None
@@ -505,12 +551,14 @@ class RequestHandler(BaseProtocol):
request = self._request_factory(message, payload, self, writer, handler)
try:
# a new task is used for copy context vars (#3406)
task = self._loop.create_task(
self._handle_request(request, start, request_handler)
)
coro = self._handle_request(request, start, request_handler)
if sys.version_info >= (3, 12):
task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
task = loop.create_task(coro)
try:
resp, reset = await task
except (asyncio.CancelledError, ConnectionError):
except ConnectionError:
self.log_debug("Ignored premature client disconnection")
break
@@ -534,27 +582,30 @@ class RequestHandler(BaseProtocol):
now = loop.time()
end_t = now + lingering_time
with suppress(asyncio.TimeoutError, asyncio.CancelledError):
try:
while not payload.is_eof() and now < end_t:
async with ceil_timeout(end_t - now):
# read and ignore
await payload.readany()
now = loop.time()
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task())
and t.cancelling()
):
raise
# if payload still uncompleted
if not payload.is_eof() and not self._force_close:
self.log_debug("Uncompleted request.")
self.close()
payload.set_exception(PayloadAccessError())
payload.set_exception(_PAYLOAD_ACCESS_ERROR)
except asyncio.CancelledError:
self.log_debug("Ignored premature client disconnection ")
break
except RuntimeError as exc:
if self.debug:
self.log_exception("Unhandled runtime exception", exc_info=exc)
self.force_close()
self.log_debug("Ignored premature client disconnection")
raise
except Exception as exc:
self.log_exception("Unhandled exception", exc_info=exc)
self.force_close()
@@ -565,11 +616,12 @@ class RequestHandler(BaseProtocol):
if self._keepalive and not self._close:
# start keep-alive timer
if keepalive_timeout is not None:
now = self._loop.time()
self._keepalive_time = now
now = loop.time()
close_time = now + keepalive_timeout
self._next_keepalive_close_time = close_time
if self._keepalive_handle is None:
self._keepalive_handle = loop.call_at(
now + keepalive_timeout, self._process_keepalive
close_time, self._process_keepalive
)
else:
break
@@ -582,7 +634,7 @@ class RequestHandler(BaseProtocol):
async def finish_response(
self, request: BaseRequest, resp: StreamResponse, start_time: float
) -> bool:
) -> Tuple[StreamResponse, bool]:
"""Prepare the response and write_eof, then log access.
This has to
@@ -590,6 +642,7 @@ class RequestHandler(BaseProtocol):
can get exception information. Returns True if the client disconnects
prematurely.
"""
request._finish()
if self._request_parser is not None:
self._request_parser.set_upgraded(False)
self._upgrade = False
@@ -600,22 +653,26 @@ class RequestHandler(BaseProtocol):
prepare_meth = resp.prepare
except AttributeError:
if resp is None:
raise RuntimeError("Missing return " "statement on request handler")
self.log_exception("Missing return statement on request handler")
else:
raise RuntimeError(
"Web-handler should return "
"a response instance, "
self.log_exception(
"Web-handler should return a response instance, "
"got {!r}".format(resp)
)
exc = HTTPInternalServerError()
resp = Response(
status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
)
prepare_meth = resp.prepare
try:
await prepare_meth(request)
await resp.write_eof()
except ConnectionError:
self.log_access(request, resp, start_time)
return True
else:
self.log_access(request, resp, start_time)
return False
return resp, True
self.log_access(request, resp, start_time)
return resp, False
def handle_error(
self,

View File

@@ -13,6 +13,7 @@ from typing import (
TYPE_CHECKING,
Any,
Dict,
Final,
Iterator,
Mapping,
MutableMapping,
@@ -25,12 +26,19 @@ from typing import (
from urllib.parse import parse_qsl
import attr
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
from multidict import (
CIMultiDict,
CIMultiDictProxy,
MultiDict,
MultiDictProxy,
MultiMapping,
)
from yarl import URL
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import (
_SENTINEL,
DEBUG,
ETAG_ANY,
LIST_QUOTED_ETAG_RE,
@@ -40,6 +48,7 @@ from .helpers import (
parse_http_date,
reify,
sentinel,
set_exception,
)
from .http_parser import RawRequestMessage
from .http_writer import HttpVersion
@@ -47,7 +56,6 @@ from .multipart import BodyPartReader, MultipartReader
from .streams import EmptyStreamReader, StreamReader
from .typedefs import (
DEFAULT_JSON_DECODER,
Final,
JSONDecoder,
LooseHeaders,
RawHeaders,
@@ -59,7 +67,7 @@ from .web_response import StreamResponse
__all__ = ("BaseRequest", "FileField", "Request")
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from .web_app import Application
from .web_protocol import RequestHandler
from .web_urldispatcher import UrlMappingMatchInfo
@@ -71,7 +79,7 @@ class FileField:
filename: str
file: io.BufferedReader
content_type: str
headers: "CIMultiDictProxy[str]"
headers: CIMultiDictProxy[str]
_TCHAR: Final[str] = string.digits + string.ascii_letters + r"!#$%&'*+.^_`|~-"
@@ -91,10 +99,10 @@ _QUOTED_STRING: Final[str] = r'"(?:{quoted_pair}|{qdtext})*"'.format(
qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR
)
_FORWARDED_PAIR: Final[
str
] = r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format(
token=_TOKEN, quoted_string=_QUOTED_STRING
_FORWARDED_PAIR: Final[str] = (
r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format(
token=_TOKEN, quoted_string=_QUOTED_STRING
)
)
_QUOTED_PAIR_REPLACE_RE: Final[Pattern[str]] = re.compile(r"\\([\t !-~])")
@@ -161,12 +169,16 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
self._payload_writer = payload_writer
self._payload = payload
self._headers = message.headers
self._headers: CIMultiDictProxy[str] = message.headers
self._method = message.method
self._version = message.version
self._cache: Dict[str, Any] = {}
url = message.url
if url.is_absolute():
if url.absolute:
if scheme is not None:
url = url.with_scheme(scheme)
if host is not None:
url = url.with_host(host)
# absolute URL is given,
# override auto-calculating url, host, and scheme
# all other properties should be good
@@ -176,6 +188,10 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
self._rel_url = url.relative()
else:
self._rel_url = message.url
if scheme is not None:
self._cache["scheme"] = scheme
if host is not None:
self._cache["host"] = host
self._post: Optional[MultiDictProxy[Union[str, bytes, FileField]]] = None
self._read_bytes: Optional[bytes] = None
@@ -189,22 +205,19 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
self._transport_sslcontext = transport.get_extra_info("sslcontext")
self._transport_peername = transport.get_extra_info("peername")
if scheme is not None:
self._cache["scheme"] = scheme
if host is not None:
self._cache["host"] = host
if remote is not None:
self._cache["remote"] = remote
def clone(
self,
*,
method: str = sentinel,
rel_url: StrOrURL = sentinel,
headers: LooseHeaders = sentinel,
scheme: str = sentinel,
host: str = sentinel,
remote: str = sentinel,
method: Union[str, _SENTINEL] = sentinel,
rel_url: Union[StrOrURL, _SENTINEL] = sentinel,
headers: Union[LooseHeaders, _SENTINEL] = sentinel,
scheme: Union[str, _SENTINEL] = sentinel,
host: Union[str, _SENTINEL] = sentinel,
remote: Union[str, _SENTINEL] = sentinel,
client_max_size: Union[int, _SENTINEL] = sentinel,
) -> "BaseRequest":
"""Clone itself with replacement some attributes.
@@ -219,14 +232,15 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
if method is not sentinel:
dct["method"] = method
if rel_url is not sentinel:
new_url = URL(rel_url)
new_url: URL = URL(rel_url)
dct["url"] = new_url
dct["path"] = str(new_url)
if headers is not sentinel:
# a copy semantic
dct["headers"] = CIMultiDictProxy(CIMultiDict(headers))
dct["raw_headers"] = tuple(
(k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
(k.encode("utf-8"), v.encode("utf-8"))
for k, v in dct["headers"].items()
)
message = self._message._replace(**dct)
@@ -238,6 +252,8 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
kwargs["host"] = host
if remote is not sentinel:
kwargs["remote"] = remote
if client_max_size is sentinel:
client_max_size = self._client_max_size
return self.__class__(
message,
@@ -246,7 +262,7 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
self._payload_writer,
self._task,
self._loop,
client_max_size=self._client_max_size,
client_max_size=client_max_size,
state=self._state.copy(),
**kwargs,
)
@@ -269,6 +285,10 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
def writer(self) -> AbstractStreamWriter:
return self._payload_writer
@property
def client_max_size(self) -> int:
return self._client_max_size
@reify
def message(self) -> RawRequestMessage:
warnings.warn("Request.message is deprecated", DeprecationWarning, stacklevel=3)
@@ -411,6 +431,10 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
- overridden value by .clone(host=new_host) call.
- HOST HTTP header
- socket.getfqdn() value
For example, 'example.com' or 'localhost:8080'.
For historical reasons, the port number may be included.
"""
host = self._message.headers.get(hdrs.HOST)
if host is not None:
@@ -434,8 +458,10 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
@reify
def url(self) -> URL:
url = URL.build(scheme=self.scheme, host=self.host)
return url.join(self._rel_url)
"""The full URL of the request."""
# authority is used here because it may include the port number
# and we want yarl to parse it correctly
return URL.build(scheme=self.scheme, authority=self.host).join(self._rel_url)
@reify
def path(self) -> str:
@@ -464,9 +490,9 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
return self._message.path
@reify
def query(self) -> "MultiDictProxy[str]":
def query(self) -> "MultiMapping[str]":
"""A multidict with all the variables in the query string."""
return MultiDictProxy(self._rel_url.query)
return self._rel_url.query
@reify
def query_string(self) -> str:
@@ -477,7 +503,7 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
return self._rel_url.query_string
@reify
def headers(self) -> "CIMultiDictProxy[str]":
def headers(self) -> CIMultiDictProxy[str]:
"""A case-insensitive multidict proxy with all headers."""
return self._headers
@@ -568,7 +594,7 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
A read-only dictionary-like object.
"""
raw = self.headers.get(hdrs.COOKIE, "")
parsed: SimpleCookie[str] = SimpleCookie(raw)
parsed = SimpleCookie(raw)
return MappingProxyType({key: val.value for key, val in parsed.items()})
@reify
@@ -711,19 +737,21 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
# https://tools.ietf.org/html/rfc7578#section-4.4
if field.filename:
# store file in temp file
tmp = tempfile.TemporaryFile()
tmp = await self._loop.run_in_executor(
None, tempfile.TemporaryFile
)
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
tmp.write(chunk)
await self._loop.run_in_executor(None, tmp.write, chunk)
size += len(chunk)
if 0 < max_size < size:
tmp.close()
await self._loop.run_in_executor(None, tmp.close)
raise HTTPRequestEntityTooLarge(
max_size=max_size, actual_size=size
)
chunk = await field.read_chunk(size=2**16)
tmp.seek(0)
await self._loop.run_in_executor(None, tmp.seek, 0)
if field_ct is None:
field_ct = "application/octet-stream"
@@ -800,7 +828,19 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
return
def _cancel(self, exc: BaseException) -> None:
self._payload.set_exception(exc)
set_exception(self._payload, exc)
def _finish(self) -> None:
if self._post is None or self.content_type != "multipart/form-data":
return
# NOTE: Release file descriptors for the
# NOTE: `tempfile.Temporaryfile`-created `_io.BufferedRandom`
# NOTE: instances of files sent within multipart request body
# NOTE: via HTTP POST request.
for file_name, file_field_object in self._post.items():
if isinstance(file_field_object, FileField):
file_field_object.file.close()
class Request(BaseRequest):
@@ -831,12 +871,13 @@ class Request(BaseRequest):
def clone(
self,
*,
method: str = sentinel,
rel_url: StrOrURL = sentinel,
headers: LooseHeaders = sentinel,
scheme: str = sentinel,
host: str = sentinel,
remote: str = sentinel,
method: Union[str, _SENTINEL] = sentinel,
rel_url: Union[StrOrURL, _SENTINEL] = sentinel,
headers: Union[LooseHeaders, _SENTINEL] = sentinel,
scheme: Union[str, _SENTINEL] = sentinel,
host: Union[str, _SENTINEL] = sentinel,
remote: Union[str, _SENTINEL] = sentinel,
client_max_size: Union[int, _SENTINEL] = sentinel,
) -> "Request":
ret = super().clone(
method=method,
@@ -845,6 +886,7 @@ class Request(BaseRequest):
scheme=scheme,
host=host,
remote=remote,
client_max_size=client_max_size,
)
new_ret = cast(Request, ret)
new_ret._match_info = self._match_info
@@ -879,4 +921,5 @@ class Request(BaseRequest):
if match_info is None:
return
for app in match_info._apps:
await app.on_response_prepare.send(self, response)
if on_response_prepare := app.on_response_prepare:
await on_response_prepare.send(self, response)

View File

@@ -6,18 +6,16 @@ import json
import math
import time
import warnings
import zlib
from concurrent.futures import Executor
from http.cookies import Morsel, SimpleCookie
from http import HTTPStatus
from http.cookies import SimpleCookie
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
Mapping,
MutableMapping,
Optional,
Tuple,
Union,
cast,
)
@@ -26,25 +24,30 @@ from multidict import CIMultiDict, istr
from . import hdrs, payload
from .abc import AbstractStreamWriter
from .compression_utils import ZLibCompressor
from .helpers import (
ETAG_ANY,
PY_38,
QUOTED_ETAG_RE,
ETag,
HeadersMixin,
must_be_empty_body,
parse_http_date,
rfc822_formatted_time,
sentinel,
should_remove_content_length,
validate_etag_value,
)
from .http import RESPONSES, SERVER_SOFTWARE, HttpVersion10, HttpVersion11
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11
from .payload import Payload
from .typedefs import JSONEncoder, LooseHeaders
REASON_PHRASES = {http_status.value: http_status.phrase for http_status in HTTPStatus}
LARGE_BODY_SIZE = 1024**2
__all__ = ("ContentCoding", "StreamResponse", "Response", "json_response")
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from .web_request import BaseRequest
BaseClass = MutableMapping[str, Any]
@@ -52,12 +55,7 @@ else:
BaseClass = collections.abc.MutableMapping
if not PY_38:
# allow samesite to be used in python < 3.8
# already permitted in python 3.8, see https://bugs.python.org/issue29613
Morsel._reserved["samesite"] = "SameSite" # type: ignore[attr-defined]
# TODO(py311): Convert to StrEnum for wider use
class ContentCoding(enum.Enum):
# The content codings that we have support for.
#
@@ -68,6 +66,8 @@ class ContentCoding(enum.Enum):
identity = "identity"
CONTENT_CODINGS = {coding.value: coding for coding in ContentCoding}
############################################################
# HTTP Response classes
############################################################
@@ -77,6 +77,8 @@ class StreamResponse(BaseClass, HeadersMixin):
_length_check = True
_body: Union[None, bytes, bytearray, Payload]
def __init__(
self,
*,
@@ -89,11 +91,12 @@ class StreamResponse(BaseClass, HeadersMixin):
self._chunked = False
self._compression = False
self._compression_force: Optional[ContentCoding] = None
self._cookies: SimpleCookie[str] = SimpleCookie()
self._cookies = SimpleCookie()
self._req: Optional[BaseRequest] = None
self._payload_writer: Optional[AbstractStreamWriter] = None
self._eof_sent = False
self._must_be_empty_body: Optional[bool] = None
self._body_length = 0
self._state: Dict[str, Any] = {}
@@ -102,11 +105,11 @@ class StreamResponse(BaseClass, HeadersMixin):
else:
self._headers = CIMultiDict()
self.set_status(status, reason)
self._set_status(status, reason)
@property
def prepared(self) -> bool:
return self._payload_writer is not None
return self._eof_sent or self._payload_writer is not None
@property
def task(self) -> "Optional[asyncio.Task[None]]":
@@ -135,17 +138,18 @@ class StreamResponse(BaseClass, HeadersMixin):
self,
status: int,
reason: Optional[str] = None,
_RESPONSES: Mapping[int, Tuple[str, str]] = RESPONSES,
) -> None:
assert not self.prepared, (
"Cannot change the response status code after " "the headers have been sent"
)
assert (
not self.prepared
), "Cannot change the response status code after the headers have been sent"
self._set_status(status, reason)
def _set_status(self, status: int, reason: Optional[str]) -> None:
self._status = int(status)
if reason is None:
try:
reason = _RESPONSES[self._status][0]
except Exception:
reason = ""
reason = REASON_PHRASES.get(self._status, "")
elif "\n" in reason:
raise ValueError("Reason cannot contain \\n")
self._reason = reason
@property
@@ -181,7 +185,7 @@ class StreamResponse(BaseClass, HeadersMixin):
) -> None:
"""Enables response compression encoding."""
# Backwards compatibility for when force was a bool <0.17.
if type(force) == bool:
if isinstance(force, bool):
force = ContentCoding.deflate if force else ContentCoding.identity
warnings.warn(
"Using boolean for force is deprecated #3318", DeprecationWarning
@@ -199,7 +203,7 @@ class StreamResponse(BaseClass, HeadersMixin):
return self._headers
@property
def cookies(self) -> "SimpleCookie[str]":
def cookies(self) -> SimpleCookie:
return self._cookies
def set_cookie(
@@ -394,30 +398,33 @@ class StreamResponse(BaseClass, HeadersMixin):
self._headers[CONTENT_TYPE] = ctype
async def _do_start_compression(self, coding: ContentCoding) -> None:
if coding != ContentCoding.identity:
assert self._payload_writer is not None
self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._payload_writer.enable_compression(coding.value)
# Compressed payload may have different content length,
# remove the header
self._headers.popall(hdrs.CONTENT_LENGTH, None)
if coding is ContentCoding.identity:
return
assert self._payload_writer is not None
self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._payload_writer.enable_compression(coding.value)
# Compressed payload may have different content length,
# remove the header
self._headers.popall(hdrs.CONTENT_LENGTH, None)
async def _start_compression(self, request: "BaseRequest") -> None:
if self._compression_force:
await self._do_start_compression(self._compression_force)
else:
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
for coding in ContentCoding:
if coding.value in accept_encoding:
await self._do_start_compression(coding)
return
return
# Encoding comparisons should be case-insensitive
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
for value, coding in CONTENT_CODINGS.items():
if value in accept_encoding:
await self._do_start_compression(coding)
return
async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
if self._eof_sent:
return None
if self._payload_writer is not None:
return self._payload_writer
self._must_be_empty_body = must_be_empty_body(request.method, self.status)
return await self._start(request)
async def _start(self, request: "BaseRequest") -> AbstractStreamWriter:
@@ -443,9 +450,10 @@ class StreamResponse(BaseClass, HeadersMixin):
version = request.version
headers = self._headers
for cookie in self._cookies.values():
value = cookie.output(header="")[1:]
headers.add(hdrs.SET_COOKIE, value)
if self._cookies:
for cookie in self._cookies.values():
value = cookie.output(header="")[1:]
headers.add(hdrs.SET_COOKIE, value)
if self._compression:
await self._start_compression(request)
@@ -456,26 +464,34 @@ class StreamResponse(BaseClass, HeadersMixin):
"Using chunked encoding is forbidden "
"for HTTP/{0.major}.{0.minor}".format(request.version)
)
writer.enable_chunking()
headers[hdrs.TRANSFER_ENCODING] = "chunked"
if not self._must_be_empty_body:
writer.enable_chunking()
headers[hdrs.TRANSFER_ENCODING] = "chunked"
if hdrs.CONTENT_LENGTH in headers:
del headers[hdrs.CONTENT_LENGTH]
elif self._length_check:
elif self._length_check: # Disabled for WebSockets
writer.length = self.content_length
if writer.length is None:
if version >= HttpVersion11 and self.status != 204:
writer.enable_chunking()
headers[hdrs.TRANSFER_ENCODING] = "chunked"
if hdrs.CONTENT_LENGTH in headers:
del headers[hdrs.CONTENT_LENGTH]
else:
if version >= HttpVersion11:
if not self._must_be_empty_body:
writer.enable_chunking()
headers[hdrs.TRANSFER_ENCODING] = "chunked"
elif not self._must_be_empty_body:
keep_alive = False
# HTTP 1.1: https://tools.ietf.org/html/rfc7230#section-3.3.2
# HTTP 1.0: https://tools.ietf.org/html/rfc1945#section-10.4
elif version >= HttpVersion11 and self.status in (100, 101, 102, 103, 204):
del headers[hdrs.CONTENT_LENGTH]
if self.status not in (204, 304):
# HTTP 1.1: https://tools.ietf.org/html/rfc7230#section-3.3.2
# HTTP 1.0: https://tools.ietf.org/html/rfc1945#section-10.4
if self._must_be_empty_body:
if hdrs.CONTENT_LENGTH in headers and should_remove_content_length(
request.method, self.status
):
del headers[hdrs.CONTENT_LENGTH]
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-10
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-13
if hdrs.TRANSFER_ENCODING in headers:
del headers[hdrs.TRANSFER_ENCODING]
elif (writer.length if self._length_check else self.content_length) != 0:
# https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5
headers.setdefault(hdrs.CONTENT_TYPE, "application/octet-stream")
headers.setdefault(hdrs.DATE, rfc822_formatted_time())
headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE)
@@ -485,9 +501,8 @@ class StreamResponse(BaseClass, HeadersMixin):
if keep_alive:
if version == HttpVersion10:
headers[hdrs.CONNECTION] = "keep-alive"
else:
if version == HttpVersion11:
headers[hdrs.CONNECTION] = "close"
elif version == HttpVersion11:
headers[hdrs.CONNECTION] = "close"
async def _write_headers(self) -> None:
request = self._req
@@ -496,9 +511,7 @@ class StreamResponse(BaseClass, HeadersMixin):
assert writer is not None
# status line
version = request.version
status_line = "HTTP/{}.{} {} {}".format(
version[0], version[1], self._status, self._reason
)
status_line = f"HTTP/{version[0]}.{version[1]} {self._status} {self._reason}"
await writer.write_headers(status_line, self._headers)
async def write(self, data: bytes) -> None:
@@ -617,19 +630,17 @@ class Response(StreamResponse):
real_headers[hdrs.CONTENT_TYPE] = content_type + "; charset=" + charset
body = text.encode(charset)
text = None
else:
if hdrs.CONTENT_TYPE in real_headers:
if content_type is not None or charset is not None:
raise ValueError(
"passing both Content-Type header and "
"content_type or charset params "
"is forbidden"
)
else:
if content_type is not None:
if charset is not None:
content_type += "; charset=" + charset
real_headers[hdrs.CONTENT_TYPE] = content_type
elif hdrs.CONTENT_TYPE in real_headers:
if content_type is not None or charset is not None:
raise ValueError(
"passing both Content-Type header and "
"content_type or charset params "
"is forbidden"
)
elif content_type is not None:
if charset is not None:
content_type += "; charset=" + charset
real_headers[hdrs.CONTENT_TYPE] = content_type
super().__init__(status=status, reason=reason, headers=real_headers)
@@ -647,41 +658,26 @@ class Response(StreamResponse):
return self._body
@body.setter
def body(
self,
body: bytes,
CONTENT_TYPE: istr = hdrs.CONTENT_TYPE,
CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH,
) -> None:
def body(self, body: Any) -> None:
if body is None:
self._body: Optional[bytes] = None
self._body_payload: bool = False
self._body = None
elif isinstance(body, (bytes, bytearray)):
self._body = body
self._body_payload = False
else:
try:
self._body = body = payload.PAYLOAD_REGISTRY.get(body)
except payload.LookupError:
raise ValueError("Unsupported body type %r" % type(body))
self._body_payload = True
headers = self._headers
# set content-length header if needed
if not self._chunked and CONTENT_LENGTH not in headers:
size = body.size
if size is not None:
headers[CONTENT_LENGTH] = str(size)
# set content-type
if CONTENT_TYPE not in headers:
headers[CONTENT_TYPE] = body.content_type
if hdrs.CONTENT_TYPE not in headers:
headers[hdrs.CONTENT_TYPE] = body.content_type
# copy payload headers
if body.headers:
for (key, value) in body.headers.items():
for key, value in body.headers.items():
if key not in headers:
headers[key] = value
@@ -705,7 +701,6 @@ class Response(StreamResponse):
self.charset = "utf-8"
self._body = text.encode(self.charset)
self._body_payload = False
self._compressed_body = None
@property
@@ -714,12 +709,12 @@ class Response(StreamResponse):
return None
if hdrs.CONTENT_LENGTH in self._headers:
return super().content_length
return int(self._headers[hdrs.CONTENT_LENGTH])
if self._compressed_body is not None:
# Return length of the compressed body
return len(self._compressed_body)
elif self._body_payload:
elif isinstance(self._body, Payload):
# A payload without content length, or a compressed payload
return None
elif self._body is not None:
@@ -741,62 +736,57 @@ class Response(StreamResponse):
assert not data, f"data arg is not supported, got {data!r}"
assert self._req is not None
assert self._payload_writer is not None
if body is not None:
if self._req._method == hdrs.METH_HEAD or self._status in [204, 304]:
await super().write_eof()
elif self._body_payload:
payload = cast(Payload, body)
await payload.write(self._payload_writer)
await super().write_eof()
else:
await super().write_eof(cast(bytes, body))
else:
if body is None or self._must_be_empty_body:
await super().write_eof()
elif isinstance(self._body, Payload):
await self._body.write(self._payload_writer)
await super().write_eof()
else:
await super().write_eof(cast(bytes, body))
async def _start(self, request: "BaseRequest") -> AbstractStreamWriter:
if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers:
if not self._body_payload:
if self._body is not None:
self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body))
else:
self._headers[hdrs.CONTENT_LENGTH] = "0"
if hdrs.CONTENT_LENGTH in self._headers:
if should_remove_content_length(request.method, self.status):
del self._headers[hdrs.CONTENT_LENGTH]
elif not self._chunked:
if isinstance(self._body, Payload):
if self._body.size is not None:
self._headers[hdrs.CONTENT_LENGTH] = str(self._body.size)
else:
body_len = len(self._body) if self._body else "0"
# https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-7
if body_len != "0" or (
self.status != 304 and request.method.upper() != hdrs.METH_HEAD
):
self._headers[hdrs.CONTENT_LENGTH] = str(body_len)
return await super()._start(request)
def _compress_body(self, zlib_mode: int) -> None:
assert zlib_mode > 0
compressobj = zlib.compressobj(wbits=zlib_mode)
body_in = self._body
assert body_in is not None
self._compressed_body = compressobj.compress(body_in) + compressobj.flush()
async def _do_start_compression(self, coding: ContentCoding) -> None:
if self._body_payload or self._chunked:
if self._chunked or isinstance(self._body, Payload):
return await super()._do_start_compression(coding)
if coding != ContentCoding.identity:
# Instead of using _payload_writer.enable_compression,
# compress the whole body
zlib_mode = (
16 + zlib.MAX_WBITS if coding == ContentCoding.gzip else zlib.MAX_WBITS
if coding is ContentCoding.identity:
return
# Instead of using _payload_writer.enable_compression,
# compress the whole body
compressor = ZLibCompressor(
encoding=coding.value,
max_sync_chunk_size=self._zlib_executor_size,
executor=self._zlib_executor,
)
assert self._body is not None
if self._zlib_executor_size is None and len(self._body) > LARGE_BODY_SIZE:
warnings.warn(
"Synchronous compression of large response bodies "
f"({len(self._body)} bytes) might block the async event loop. "
"Consider providing a custom value to zlib_executor_size/"
"zlib_executor response properties or disabling compression on it."
)
body_in = self._body
assert body_in is not None
if (
self._zlib_executor_size is not None
and len(body_in) > self._zlib_executor_size
):
await asyncio.get_event_loop().run_in_executor(
self._zlib_executor, self._compress_body, zlib_mode
)
else:
self._compress_body(zlib_mode)
body_out = self._compressed_body
assert body_out is not None
self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._headers[hdrs.CONTENT_LENGTH] = str(len(body_out))
self._compressed_body = (
await compressor.compress(self._body) + compressor.flush()
)
self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._headers[hdrs.CONTENT_LENGTH] = str(len(self._compressed_body))
def json_response(

View File

@@ -20,7 +20,7 @@ from . import hdrs
from .abc import AbstractView
from .typedefs import Handler, PathLike
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from .web_request import Request
from .web_response import StreamResponse
from .web_urldispatcher import AbstractRoute, UrlDispatcher
@@ -162,12 +162,10 @@ class RouteTableDef(Sequence[AbstractRouteDef]):
return f"<RouteTableDef count={len(self._items)}>"
@overload
def __getitem__(self, index: int) -> AbstractRouteDef:
...
def __getitem__(self, index: int) -> AbstractRouteDef: ...
@overload
def __getitem__(self, index: slice) -> List[AbstractRouteDef]:
...
def __getitem__(self, index: slice) -> List[AbstractRouteDef]: ...
def __getitem__(self, index): # type: ignore[no-untyped-def]
return self._items[index]

View File

@@ -1,11 +1,13 @@
import asyncio
import signal
import socket
import warnings
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Set
from yarl import URL
from .typedefs import PathLike
from .web_app import Application
from .web_server import Server
@@ -37,7 +39,7 @@ def _raise_graceful_exit() -> None:
class BaseSite(ABC):
__slots__ = ("_runner", "_shutdown_timeout", "_ssl_context", "_backlog", "_server")
__slots__ = ("_runner", "_ssl_context", "_backlog", "_server")
def __init__(
self,
@@ -49,8 +51,11 @@ class BaseSite(ABC):
) -> None:
if runner.server is None:
raise RuntimeError("Call runner.setup() before making a site")
if shutdown_timeout != 60.0:
msg = "shutdown_timeout should be set on BaseRunner"
warnings.warn(msg, DeprecationWarning, stacklevel=2)
runner._shutdown_timeout = shutdown_timeout
self._runner = runner
self._shutdown_timeout = shutdown_timeout
self._ssl_context = ssl_context
self._backlog = backlog
self._server: Optional[asyncio.AbstractServer] = None
@@ -66,16 +71,9 @@ class BaseSite(ABC):
async def stop(self) -> None:
self._runner._check_site(self)
if self._server is None:
self._runner._unreg_site(self)
return # not started yet
self._server.close()
# named pipes do not have wait_closed property
if hasattr(self._server, "wait_closed"):
await self._server.wait_closed()
await self._runner.shutdown()
assert self._runner.server
await self._runner.server.shutdown(self._shutdown_timeout)
if self._server is not None: # Maybe not started yet
self._server.close()
self._runner._unreg_site(self)
@@ -110,7 +108,7 @@ class TCPSite(BaseSite):
@property
def name(self) -> str:
scheme = "https" if self._ssl_context else "http"
host = "0.0.0.0" if self._host is None else self._host
host = "0.0.0.0" if not self._host else self._host
return str(URL.build(scheme=scheme, host=host, port=self._port))
async def start(self) -> None:
@@ -135,7 +133,7 @@ class UnixSite(BaseSite):
def __init__(
self,
runner: "BaseRunner",
path: str,
path: PathLike,
*,
shutdown_timeout: float = 60.0,
ssl_context: Optional[SSLContext] = None,
@@ -160,7 +158,10 @@ class UnixSite(BaseSite):
server = self._runner.server
assert server is not None
self._server = await loop.create_unix_server(
server, self._path, ssl=self._ssl_context, backlog=self._backlog
server,
self._path,
ssl=self._ssl_context,
backlog=self._backlog,
)
@@ -237,13 +238,20 @@ class SockSite(BaseSite):
class BaseRunner(ABC):
__slots__ = ("_handle_signals", "_kwargs", "_server", "_sites")
__slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout")
def __init__(self, *, handle_signals: bool = False, **kwargs: Any) -> None:
def __init__(
self,
*,
handle_signals: bool = False,
shutdown_timeout: float = 60.0,
**kwargs: Any,
) -> None:
self._handle_signals = handle_signals
self._kwargs = kwargs
self._server: Optional[Server] = None
self._sites: List[BaseSite] = []
self._shutdown_timeout = shutdown_timeout
@property
def server(self) -> Optional[Server]:
@@ -255,7 +263,7 @@ class BaseRunner(ABC):
for site in self._sites:
server = site._server
if server is not None:
sockets = server.sockets
sockets = server.sockets # type: ignore[attr-defined]
if sockets is not None:
for sock in sockets:
ret.append(sock.getsockname())
@@ -280,20 +288,28 @@ class BaseRunner(ABC):
@abstractmethod
async def shutdown(self) -> None:
pass # pragma: no cover
"""Call any shutdown hooks to help server close gracefully."""
async def cleanup(self) -> None:
loop = asyncio.get_event_loop()
# The loop over sites is intentional, an exception on gather()
# leaves self._sites in unpredictable state.
# The loop guaranties that a site is either deleted on success or
# still present on failure
for site in list(self._sites):
await site.stop()
if self._server: # If setup succeeded
# Yield to event loop to ensure incoming requests prior to stopping the sites
# have all started to be handled before we proceed to close idle connections.
await asyncio.sleep(0)
self._server.pre_shutdown()
await self.shutdown()
await self._server.shutdown(self._shutdown_timeout)
await self._cleanup_server()
self._server = None
if self._handle_signals:
loop = asyncio.get_running_loop()
try:
loop.remove_signal_handler(signal.SIGINT)
loop.remove_signal_handler(signal.SIGTERM)

View File

@@ -1,9 +1,9 @@
"""Low level HTTP server."""
import asyncio
from typing import Any, Awaitable, Callable, Dict, List, Optional # noqa
from .abc import AbstractStreamWriter
from .helpers import get_running_loop
from .http_parser import RawRequestMessage
from .streams import StreamReader
from .web_protocol import RequestHandler, _RequestFactory, _RequestHandler
@@ -18,15 +18,17 @@ class Server:
handler: _RequestHandler,
*,
request_factory: Optional[_RequestFactory] = None,
handler_cancellation: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs: Any
) -> None:
self._loop = get_running_loop(loop)
self._loop = loop or asyncio.get_running_loop()
self._connections: Dict[RequestHandler, asyncio.Transport] = {}
self._kwargs = kwargs
self.requests_count = 0
self.request_handler = handler
self.request_factory = request_factory or self._make_request
self.handler_cancellation = handler_cancellation
@property
def connections(self) -> List[RequestHandler]:
@@ -41,7 +43,12 @@ class Server:
self, handler: RequestHandler, exc: Optional[BaseException] = None
) -> None:
if handler in self._connections:
del self._connections[handler]
if handler._task_handler:
handler._task_handler.add_done_callback(
lambda f: self._connections.pop(handler, None)
)
else:
del self._connections[handler]
def _make_request(
self,
@@ -53,10 +60,23 @@ class Server:
) -> BaseRequest:
return BaseRequest(message, payload, protocol, writer, task, self._loop)
def pre_shutdown(self) -> None:
for conn in self._connections:
conn.close()
async def shutdown(self, timeout: Optional[float] = None) -> None:
coros = [conn.shutdown(timeout) for conn in self._connections]
coros = (conn.shutdown(timeout) for conn in self._connections)
await asyncio.gather(*coros)
self._connections.clear()
def __call__(self) -> RequestHandler:
return RequestHandler(self, loop=self._loop, **self._kwargs)
try:
return RequestHandler(self, loop=self._loop, **self._kwargs)
except TypeError:
# Failsafe creation: remove all custom handler_args
kwargs = {
k: v
for k, v in self._kwargs.items()
if k in ["debug", "access_log_class"]
}
return RequestHandler(self, loop=self._loop, **kwargs)

View File

@@ -1,13 +1,15 @@
import abc
import asyncio
import base64
import functools
import hashlib
import html
import inspect
import keyword
import os
import re
import sys
import warnings
from contextlib import contextmanager
from functools import wraps
from pathlib import Path
from types import MappingProxyType
@@ -18,28 +20,31 @@ from typing import (
Callable,
Container,
Dict,
Final,
Generator,
Iterable,
Iterator,
List,
Mapping,
NoReturn,
Optional,
Pattern,
Set,
Sized,
Tuple,
Type,
TypedDict,
Union,
cast,
)
from yarl import URL, __version__ as yarl_version # type: ignore[attr-defined]
from yarl import URL, __version__ as yarl_version
from . import hdrs
from .abc import AbstractMatchInfo, AbstractRouter, AbstractView
from .helpers import DEBUG
from .http import HttpVersion11
from .typedefs import Final, Handler, PathLike, TypedDict
from .typedefs import Handler, PathLike
from .web_exceptions import (
HTTPException,
HTTPExpectationFailed,
@@ -66,13 +71,19 @@ __all__ = (
)
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from .web_app import Application
BaseDict = Dict[str, str]
else:
BaseDict = dict
CIRCULAR_SYMLINK_ERROR = (
(OSError,)
if sys.version_info < (3, 10) and sys.platform.startswith("win32")
else (RuntimeError,) if sys.version_info < (3, 13) else ()
)
YARL_VERSION: Final[Tuple[int, ...]] = tuple(map(int, yarl_version.split(".")[:2]))
HTTP_METHOD_RE: Final[Pattern[str]] = re.compile(
@@ -84,9 +95,11 @@ ROUTE_RE: Final[Pattern[str]] = re.compile(
PATH_SEP: Final[str] = re.escape("/")
_ExpectHandler = Callable[[Request], Awaitable[None]]
_ExpectHandler = Callable[[Request], Awaitable[Optional[StreamResponse]]]
_Resolve = Tuple[Optional["UrlMappingMatchInfo"], Set[str]]
html_escape = functools.partial(html.escape, quote=True)
class _InfoDict(TypedDict, total=False):
path: str
@@ -192,10 +205,11 @@ class AbstractRoute(abc.ABC):
@wraps(handler)
async def handler_wrapper(request: Request) -> StreamResponse:
result = old_handler(request)
result = old_handler(request) # type: ignore[call-arg]
if asyncio.iscoroutine(result):
return await result
return result # type: ignore[return-value]
result = await result
assert isinstance(result, StreamResponse)
return result
old_handler = handler
handler = handler_wrapper
@@ -230,8 +244,8 @@ class AbstractRoute(abc.ABC):
def url_for(self, *args: str, **kwargs: str) -> URL:
"""Construct url for route with additional params."""
async def handle_expect_header(self, request: Request) -> None:
await self._expect_handler(request)
async def handle_expect_header(self, request: Request) -> Optional[StreamResponse]:
return await self._expect_handler(request)
class UrlMappingMatchInfo(BaseDict, AbstractMatchInfo):
@@ -278,8 +292,8 @@ class UrlMappingMatchInfo(BaseDict, AbstractMatchInfo):
assert app is not None
return app
@contextmanager
def set_current_app(self, app: "Application") -> Generator[None, None, None]:
@current_app.setter
def current_app(self, app: "Application") -> None:
if DEBUG: # pragma: no cover
if app not in self._apps:
raise RuntimeError(
@@ -287,12 +301,7 @@ class UrlMappingMatchInfo(BaseDict, AbstractMatchInfo):
self._apps, app
)
)
prev = self._current_app
self._current_app = app
try:
yield
finally:
self._current_app = prev
def freeze(self) -> None:
self._frozen = True
@@ -326,6 +335,8 @@ async def _default_expect_handler(request: Request) -> None:
if request.version == HttpVersion11:
if expect.lower() == "100-continue":
await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n")
# Reset output_size as we haven't started the main body yet.
request.writer.output_size = 0
else:
raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect)
@@ -364,7 +375,7 @@ class Resource(AbstractResource):
async def resolve(self, request: Request) -> _Resolve:
allowed_methods: Set[str] = set()
match_dict = self._match(request.rel_url.raw_path)
match_dict = self._match(request.rel_url.path_safe)
if match_dict is None:
return None, allowed_methods
@@ -384,7 +395,7 @@ class Resource(AbstractResource):
def __len__(self) -> int:
return len(self._routes)
def __iter__(self) -> Iterator[AbstractRoute]:
def __iter__(self) -> Iterator["ResourceRoute"]:
return iter(self._routes)
# TODO: implement all abstract methods
@@ -414,8 +425,7 @@ class PlainResource(Resource):
# string comparison is about 10 times faster than regexp matching
if self._path == path:
return {}
else:
return None
return None
def raw_match(self, path: str) -> bool:
return self._path == path
@@ -439,6 +449,7 @@ class DynamicResource(Resource):
def __init__(self, path: str, *, name: Optional[str] = None) -> None:
super().__init__(name=name)
self._orig_path = path
pattern = ""
formatter = ""
for part in ROUTE_RE.split(path):
@@ -485,13 +496,12 @@ class DynamicResource(Resource):
match = self._pattern.fullmatch(path)
if match is None:
return None
else:
return {
key: _unquote_path(value) for key, value in match.groupdict().items()
}
return {
key: _unquote_path_safe(value) for key, value in match.groupdict().items()
}
def raw_match(self, path: str) -> bool:
return self._formatter == path
return self._orig_path == path
def get_info(self) -> _InfoDict:
return {"formatter": self._formatter, "pattern": self._pattern}
@@ -549,14 +559,11 @@ class StaticResource(PrefixResource):
) -> None:
super().__init__(prefix, name=name)
try:
directory = Path(directory)
if str(directory).startswith("~"):
directory = Path(os.path.expanduser(str(directory)))
directory = directory.resolve()
if not directory.is_dir():
raise ValueError("Not a directory")
except (FileNotFoundError, ValueError) as error:
raise ValueError(f"No directory exists at '{directory}'") from error
directory = Path(directory).expanduser().resolve(strict=True)
except FileNotFoundError as error:
raise ValueError(f"'{directory}' does not exist") from error
if not directory.is_dir():
raise ValueError(f"'{directory}' is not a directory")
self._directory = directory
self._show_index = show_index
self._chunk_size = chunk_size
@@ -576,14 +583,12 @@ class StaticResource(PrefixResource):
def url_for( # type: ignore[override]
self,
*,
filename: Union[str, Path],
filename: PathLike,
append_version: Optional[bool] = None,
) -> URL:
if append_version is None:
append_version = self._append_version
if isinstance(filename, Path):
filename = str(filename)
filename = filename.lstrip("/")
filename = str(filename).lstrip("/")
url = URL.build(path=self._prefix, encoded=True)
# filename is not encoded
@@ -593,9 +598,14 @@ class StaticResource(PrefixResource):
url = url / filename
if append_version:
unresolved_path = self._directory.joinpath(filename)
try:
filepath = self._directory.joinpath(filename).resolve()
if not self._follow_symlinks:
if self._follow_symlinks:
normalized_path = Path(os.path.normpath(unresolved_path))
normalized_path.relative_to(self._directory)
filepath = normalized_path.resolve()
else:
filepath = unresolved_path.resolve()
filepath.relative_to(self._directory)
except (ValueError, FileNotFoundError):
# ValueError for case when path point to symlink
@@ -633,7 +643,7 @@ class StaticResource(PrefixResource):
)
async def resolve(self, request: Request) -> _Resolve:
path = request.rel_url.raw_path
path = request.rel_url.path_safe
method = request.method
allowed_methods = set(self._routes)
if not path.startswith(self._prefix2) and path != self._prefix:
@@ -642,7 +652,7 @@ class StaticResource(PrefixResource):
if method not in allowed_methods:
return None, allowed_methods
match_dict = {"filename": _unquote_path(path[len(self._prefix) + 1 :])}
match_dict = {"filename": _unquote_path_safe(path[len(self._prefix) + 1 :])}
return (UrlMappingMatchInfo(match_dict, self._routes[method]), allowed_methods)
def __len__(self) -> int:
@@ -653,58 +663,68 @@ class StaticResource(PrefixResource):
async def _handle(self, request: Request) -> StreamResponse:
rel_url = request.match_info["filename"]
filename = Path(rel_url)
if filename.anchor:
# rel_url is an absolute name like
# /static/\\machine_name\c$ or /static/D:\path
# where the static dir is totally different
raise HTTPForbidden()
unresolved_path = self._directory.joinpath(filename)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None, self._resolve_path_to_response, unresolved_path
)
def _resolve_path_to_response(self, unresolved_path: Path) -> StreamResponse:
"""Take the unresolved path and query the file system to form a response."""
# Check for access outside the root directory. For follow symlinks, URI
# cannot traverse out, but symlinks can. Otherwise, no access outside
# root is permitted.
try:
filename = Path(rel_url)
if filename.anchor:
# rel_url is an absolute name like
# /static/\\machine_name\c$ or /static/D:\path
# where the static dir is totally different
raise HTTPForbidden()
filepath = self._directory.joinpath(filename).resolve()
if not self._follow_symlinks:
filepath.relative_to(self._directory)
except (ValueError, FileNotFoundError) as error:
# relatively safe
raise HTTPNotFound() from error
except HTTPForbidden:
raise
except Exception as error:
# perm error or other kind!
request.app.logger.exception(error)
raise HTTPNotFound() from error
# on opening a dir, load its contents if allowed
if filepath.is_dir():
if self._show_index:
try:
return Response(
text=self._directory_as_html(filepath), content_type="text/html"
)
except PermissionError:
raise HTTPForbidden()
if self._follow_symlinks:
normalized_path = Path(os.path.normpath(unresolved_path))
normalized_path.relative_to(self._directory)
file_path = normalized_path.resolve()
else:
raise HTTPForbidden()
elif filepath.is_file():
return FileResponse(filepath, chunk_size=self._chunk_size)
else:
raise HTTPNotFound
file_path = unresolved_path.resolve()
file_path.relative_to(self._directory)
except (ValueError, *CIRCULAR_SYMLINK_ERROR) as error:
# ValueError is raised for the relative check. Circular symlinks
# raise here on resolving for python < 3.13.
raise HTTPNotFound() from error
def _directory_as_html(self, filepath: Path) -> str:
# returns directory's index as html
# if path is a directory, return the contents if permitted. Note the
# directory check will raise if a segment is not readable.
try:
if file_path.is_dir():
if self._show_index:
return Response(
text=self._directory_as_html(file_path),
content_type="text/html",
)
else:
raise HTTPForbidden()
except PermissionError as error:
raise HTTPForbidden() from error
# sanity check
assert filepath.is_dir()
# Return the file response, which handles all other checks.
return FileResponse(file_path, chunk_size=self._chunk_size)
relative_path_to_dir = filepath.relative_to(self._directory).as_posix()
index_of = f"Index of /{relative_path_to_dir}"
def _directory_as_html(self, dir_path: Path) -> str:
"""returns directory's index as html."""
assert dir_path.is_dir()
relative_path_to_dir = dir_path.relative_to(self._directory).as_posix()
index_of = f"Index of /{html_escape(relative_path_to_dir)}"
h1 = f"<h1>{index_of}</h1>"
index_list = []
dir_index = filepath.iterdir()
dir_index = dir_path.iterdir()
for _file in sorted(dir_index):
# show file url as relative to static path
rel_path = _file.relative_to(self._directory).as_posix()
file_url = self._prefix + "/" + rel_path
quoted_file_url = _quote_path(f"{self._prefix}/{rel_path}")
# if file is a directory, add '/' to the end of the name
if _file.is_dir():
@@ -713,9 +733,7 @@ class StaticResource(PrefixResource):
file_name = _file.name
index_list.append(
'<li><a href="{url}">{name}</a></li>'.format(
url=file_url, name=file_name
)
f'<li><a href="{quoted_file_url}">{html_escape(file_name)}</a></li>'
)
ul = "<ul>\n{}\n</ul>".format("\n".join(index_list))
body = f"<body>\n{h1}\n{ul}\n</body>"
@@ -736,13 +754,20 @@ class PrefixedSubAppResource(PrefixResource):
def __init__(self, prefix: str, app: "Application") -> None:
super().__init__(prefix)
self._app = app
for resource in app.router.resources():
resource.add_prefix(prefix)
self._add_prefix_to_resources(prefix)
def add_prefix(self, prefix: str) -> None:
super().add_prefix(prefix)
for resource in self._app.router.resources():
self._add_prefix_to_resources(prefix)
def _add_prefix_to_resources(self, prefix: str) -> None:
router = self._app.router
for resource in router.resources():
# Since the canonical path of a resource is about
# to change, we need to unindex it and then reindex
router.unindex_resource(resource)
resource.add_prefix(prefix)
router.index_resource(resource)
def url_for(self, *args: str, **kwargs: str) -> URL:
raise RuntimeError(".url_for() is not supported " "by sub-application root")
@@ -751,11 +776,6 @@ class PrefixedSubAppResource(PrefixResource):
return {"app": self._app, "prefix": self._prefix}
async def resolve(self, request: Request) -> _Resolve:
if (
not request.url.raw_path.startswith(self._prefix2)
and request.url.raw_path != self._prefix
):
return None, set()
match_info = await self._app.router.resolve(request)
match_info.add_app(self._app)
if isinstance(match_info.http_exception, HTTPMethodNotAllowed):
@@ -946,18 +966,18 @@ class View(AbstractView):
async def _iter(self) -> StreamResponse:
if self.request.method not in hdrs.METH_ALL:
self._raise_allowed_methods()
method: Callable[[], Awaitable[StreamResponse]] = getattr(
self, self.request.method.lower(), None
)
method: Optional[Callable[[], Awaitable[StreamResponse]]]
method = getattr(self, self.request.method.lower(), None)
if method is None:
self._raise_allowed_methods()
resp = await method()
return resp
ret = await method()
assert isinstance(ret, StreamResponse)
return ret
def __await__(self) -> Generator[Any, None, StreamResponse]:
return self._iter().__await__()
def _raise_allowed_methods(self) -> None:
def _raise_allowed_methods(self) -> NoReturn:
allowed_methods = {m for m in hdrs.METH_ALL if hasattr(self, m.lower())}
raise HTTPMethodNotAllowed(self.request.method, allowed_methods)
@@ -1001,12 +1021,39 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]):
super().__init__()
self._resources: List[AbstractResource] = []
self._named_resources: Dict[str, AbstractResource] = {}
self._resource_index: dict[str, list[AbstractResource]] = {}
self._matched_sub_app_resources: List[MatchedSubAppResource] = []
async def resolve(self, request: Request) -> UrlMappingMatchInfo:
method = request.method
resource_index = self._resource_index
allowed_methods: Set[str] = set()
for resource in self._resources:
# Walk the url parts looking for candidates. We walk the url backwards
# to ensure the most explicit match is found first. If there are multiple
# candidates for a given url part because there are multiple resources
# registered for the same canonical path, we resolve them in a linear
# fashion to ensure registration order is respected.
url_part = request.rel_url.path_safe
while url_part:
for candidate in resource_index.get(url_part, ()):
match_dict, allowed = await candidate.resolve(request)
if match_dict is not None:
return match_dict
else:
allowed_methods |= allowed
if url_part == "/":
break
url_part = url_part.rpartition("/")[0] or "/"
#
# We didn't find any candidates, so we'll try the matched sub-app
# resources which we have to walk in a linear fashion because they
# have regex/wildcard match rules and we cannot index them.
#
# For most cases we do not expect there to be many of these since
# currently they are only added by `add_domain`
#
for resource in self._matched_sub_app_resources:
match_dict, allowed = await resource.resolve(request)
if match_dict is not None:
return match_dict
@@ -1014,9 +1061,9 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]):
allowed_methods |= allowed
if allowed_methods:
return MatchInfoError(HTTPMethodNotAllowed(method, allowed_methods))
else:
return MatchInfoError(HTTPNotFound())
return MatchInfoError(HTTPMethodNotAllowed(request.method, allowed_methods))
return MatchInfoError(HTTPNotFound())
def __iter__(self) -> Iterator[str]:
return iter(self._named_resources)
@@ -1072,6 +1119,36 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]):
self._named_resources[name] = resource
self._resources.append(resource)
if isinstance(resource, MatchedSubAppResource):
# We cannot index match sub-app resources because they have match rules
self._matched_sub_app_resources.append(resource)
else:
self.index_resource(resource)
def _get_resource_index_key(self, resource: AbstractResource) -> str:
"""Return a key to index the resource in the resource index."""
if "{" in (index_key := resource.canonical):
# strip at the first { to allow for variables, and than
# rpartition at / to allow for variable parts in the path
# For example if the canonical path is `/core/locations{tail:.*}`
# the index key will be `/core` since index is based on the
# url parts split by `/`
index_key = index_key.partition("{")[0].rpartition("/")[0]
return index_key.rstrip("/") or "/"
def index_resource(self, resource: AbstractResource) -> None:
"""Add a resource to the resource index."""
resource_key = self._get_resource_index_key(resource)
# There may be multiple resources for a canonical path
# so we keep them in a list to ensure that registration
# order is respected.
self._resource_index.setdefault(resource_key, []).append(resource)
def unindex_resource(self, resource: AbstractResource) -> None:
"""Remove a resource from the resource index."""
resource_key = self._get_resource_index_key(resource)
self._resource_index[resource_key].remove(resource)
def add_resource(self, path: str, *, name: Optional[str] = None) -> Resource:
if path and not path.startswith("/"):
raise ValueError("path should be started with / or be empty")
@@ -1081,7 +1158,7 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]):
if resource.name == name and resource.raw_match(path):
return cast(Resource, resource)
if not ("{" in path or "}" in path or ROUTE_RE.search(path)):
resource = PlainResource(_requote_path(path), name=name)
resource = PlainResource(path, name=name)
self.register_resource(resource)
return resource
resource = DynamicResource(path, name=name)
@@ -1207,8 +1284,10 @@ def _quote_path(value: str) -> str:
return URL.build(path=value, encoded=False).raw_path
def _unquote_path(value: str) -> str:
return URL.build(path=value, encoded=True).path
def _unquote_path_safe(value: str) -> str:
if "%" not in value:
return value
return value.replace("%2F", "/").replace("%25", "%")
def _requote_path(value: str) -> str:

View File

@@ -3,15 +3,15 @@ import base64
import binascii
import hashlib
import json
from typing import Any, Iterable, Optional, Tuple, cast
import sys
from typing import Any, Final, Iterable, Optional, Tuple, cast
import async_timeout
import attr
from multidict import CIMultiDict
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import call_later, set_result
from .helpers import calculate_timeout_when, set_exception, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
@@ -27,11 +27,16 @@ from .http import (
)
from .log import ws_logger
from .streams import EofStream, FlowControlDataQueue
from .typedefs import Final, JSONDecoder, JSONEncoder
from .typedefs import JSONDecoder, JSONEncoder
from .web_exceptions import HTTPBadRequest, HTTPException
from .web_request import BaseRequest
from .web_response import StreamResponse
if sys.version_info >= (3, 11):
import asyncio as async_timeout
else:
import async_timeout
__all__ = (
"WebSocketResponse",
"WebSocketReady",
@@ -76,58 +81,119 @@ class WebSocketResponse(StreamResponse):
self._conn_lost = 0
self._close_code: Optional[int] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._waiting: Optional[asyncio.Future[bool]] = None
self._waiting: bool = False
self._close_wait: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._timeout = timeout
self._receive_timeout = receive_timeout
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_when = 0.0
self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._compress = compress
self._max_msg_size = max_msg_size
self._ping_task: Optional[asyncio.Task[None]] = None
def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None
def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
def _reset_heartbeat(self) -> None:
self._cancel_heartbeat()
if self._heartbeat is not None:
assert self._loop is not None
self._heartbeat_cb = call_later(
self._send_heartbeat, self._heartbeat, self._loop
)
if self._heartbeat is None:
return
self._cancel_pong_response_cb()
req = self._req
timeout_ceil_threshold = (
req._protocol._timeout_ceil_threshold if req is not None else 5
)
loop = self._loop
assert loop is not None
now = loop.time()
when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
self._heartbeat_when = when
if self._heartbeat_cb is None:
# We do not cancel the previous heartbeat_cb here because
# it generates a significant amount of TimerHandle churn
# which causes asyncio to rebuild the heap frequently.
# Instead _send_heartbeat() will reschedule the next
# heartbeat if it fires too early.
self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
def _send_heartbeat(self) -> None:
if self._heartbeat is not None and not self._closed:
assert self._loop is not None
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping()) # type: ignore[union-attr]
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = call_later(
self._pong_not_received, self._pong_heartbeat, self._loop
self._heartbeat_cb = None
loop = self._loop
assert loop is not None and self._writer is not None
now = loop.time()
if now < self._heartbeat_when:
# Heartbeat fired too early, reschedule
self._heartbeat_cb = loop.call_at(
self._heartbeat_when, self._send_heartbeat
)
return
req = self._req
timeout_ceil_threshold = (
req._protocol._timeout_ceil_threshold if req is not None else 5
)
when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
else:
ping_task = loop.create_task(self._writer.ping())
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)
else:
self._ping_task_done(ping_task)
def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
if not task.cancelled() and (exc := task.exception()):
self._handle_ping_pong_exception(exc)
self._ping_task = None
def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._closed = True
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = asyncio.TimeoutError()
self._req.transport.close()
self._handle_ping_pong_exception(asyncio.TimeoutError())
def _handle_ping_pong_exception(self, exc: BaseException) -> None:
"""Handle exceptions raised during ping/pong processing."""
if self._closed:
return
self._set_closed()
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = exc
if self._waiting and not self._closing and self._reader is not None:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None))
def _set_closed(self) -> None:
"""Set the connection to closed.
Cancel any heartbeat timers and set the closed flag.
"""
self._closed = True
self._cancel_heartbeat()
async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
# make pre-check to don't hide it by do_handshake() exceptions
@@ -286,6 +352,19 @@ class WebSocketResponse(StreamResponse):
def compress(self) -> bool:
return self._compress
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""Get optional transport information.
If no value associated with ``name`` is found, ``default`` is returned.
"""
writer = self._writer
if writer is None:
return default
transport = writer.transport
if transport is None:
return default
return transport.get_extra_info(name, default)
def exception(self) -> Optional[BaseException]:
return self._exception
@@ -300,14 +379,14 @@ class WebSocketResponse(StreamResponse):
raise RuntimeError("Call .prepare() first")
await self._writer.pong(message)
async def send_str(self, data: str, compress: Optional[bool] = None) -> None:
async def send_str(self, data: str, compress: Optional[int] = None) -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
if not isinstance(data, str):
raise TypeError("data argument must be str (%r)" % type(data))
await self._writer.send(data, binary=False, compress=compress)
async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None:
async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
if not isinstance(data, (bytes, bytearray, memoryview)):
@@ -317,7 +396,7 @@ class WebSocketResponse(StreamResponse):
async def send_json(
self,
data: Any,
compress: Optional[bool] = None,
compress: Optional[int] = None,
*,
dumps: JSONEncoder = json.dumps,
) -> None:
@@ -332,69 +411,84 @@ class WebSocketResponse(StreamResponse):
await self.close()
self._eof_sent = True
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
async def close(
self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True
) -> bool:
"""Close websocket connection."""
if self._writer is None:
raise RuntimeError("Call .prepare() first")
self._cancel_heartbeat()
if self._closed:
return False
self._set_closed()
try:
await self._writer.close(code, message)
writer = self._payload_writer
assert writer is not None
if drain:
await writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except Exception as exc:
self._exception = exc
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True
reader = self._reader
assert reader is not None
# we need to break `receive()` cycle before we can call
# `reader.read()` as `close()` may be called from different task
if self._waiting:
assert self._loop is not None
assert self._close_wait is None
self._close_wait = self._loop.create_future()
reader.feed_data(WS_CLOSING_MESSAGE)
await self._close_wait
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting is not None and not self._closed:
reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._waiting
if not self._closed:
self._closed = True
try:
await self._writer.close(code, message)
writer = self._payload_writer
assert writer is not None
await writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
return True
if self._closing:
return True
reader = self._reader
assert reader is not None
try:
async with async_timeout.timeout(self._timeout):
msg = await reader.read()
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
return True
if msg.type == WSMsgType.CLOSE:
self._close_code = msg.data
return True
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = asyncio.TimeoutError()
if self._closing:
self._close_transport()
return True
else:
return False
try:
async with async_timeout.timeout(self._timeout):
while True:
msg = await reader.read()
if msg.type is WSMsgType.CLOSE:
self._set_code_close_transport(msg.data)
return True
except asyncio.CancelledError:
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except Exception as exc:
self._exception = exc
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True
def _set_closing(self, code: WSCloseCode) -> None:
"""Set the close code and mark the connection as closing."""
self._closing = True
self._close_code = code
self._cancel_heartbeat()
def _set_code_close_transport(self, code: WSCloseCode) -> None:
"""Set the close code and close the transport."""
self._close_code = code
self._close_transport()
def _close_transport(self) -> None:
"""Close the transport."""
if self._req is not None and self._req.transport is not None:
self._req.transport.close()
async def receive(self, timeout: Optional[float] = None) -> WSMessage:
if self._reader is None:
raise RuntimeError("Call .prepare() first")
loop = self._loop
assert loop is not None
receive_timeout = timeout or self._receive_timeout
while True:
if self._waiting is not None:
if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")
if self._closed:
@@ -406,17 +500,23 @@ class WebSocketResponse(StreamResponse):
return WS_CLOSING_MESSAGE
try:
self._waiting = loop.create_future()
self._waiting = True
try:
async with async_timeout.timeout(timeout or self._receive_timeout):
if receive_timeout:
# Entering the context manager and creating
# Timeout() object can take almost 50% of the
# run time in this loop so we avoid it if
# there is no read timeout.
async with async_timeout.timeout(receive_timeout):
msg = await self._reader.read()
else:
msg = await self._reader.read()
self._reset_heartbeat()
finally:
waiter = self._waiting
set_result(waiter, True)
self._waiting = None
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._waiting = False
if self._close_wait:
set_result(self._close_wait, None)
except asyncio.TimeoutError:
raise
except EofStream:
self._close_code = WSCloseCode.OK
@@ -428,29 +528,32 @@ class WebSocketResponse(StreamResponse):
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._set_closing(WSCloseCode.ABNORMAL_CLOSURE)
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
if msg.type is WSMsgType.CLOSE:
self._set_closing(msg.data)
# Could be closed while awaiting reader.
if not self._closed and self._autoclose:
await self.close()
elif msg.type == WSMsgType.CLOSING:
self._closing = True
elif msg.type == WSMsgType.PING and self._autoping:
# The client is likely going to close the
# connection out from under us so we do not
# want to drain any pending writes as it will
# likely result writing to a broken pipe.
await self.close(drain=False)
elif msg.type is WSMsgType.CLOSING:
self._set_closing(WSCloseCode.OK)
elif msg.type is WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
elif msg.type == WSMsgType.PONG and self._autoping:
elif msg.type is WSMsgType.PONG and self._autoping:
continue
return msg
async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg = await self.receive(timeout)
if msg.type != WSMsgType.TEXT:
if msg.type is not WSMsgType.TEXT:
raise TypeError(
"Received message {}:{!r} is not WSMsgType.TEXT".format(
msg.type, msg.data
@@ -460,7 +563,7 @@ class WebSocketResponse(StreamResponse):
async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
if msg.type != WSMsgType.BINARY:
if msg.type is not WSMsgType.BINARY:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
return cast(bytes, msg.data)
@@ -483,5 +586,9 @@ class WebSocketResponse(StreamResponse):
return msg
def _cancel(self, exc: BaseException) -> None:
# web_protocol calls this from connection_lost
# or when the server is shutting down.
self._closing = True
self._cancel_heartbeat()
if self._reader is not None:
self._reader.set_exception(exc)
set_exception(self._reader, exc)

View File

@@ -26,7 +26,7 @@ except ImportError: # pragma: no cover
SSLContext = object # type: ignore[misc,assignment]
__all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker", "GunicornTokioWebWorker")
__all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker")
class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported]
@@ -89,6 +89,7 @@ class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported]
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format
),
shutdown_timeout=self.cfg.graceful_timeout / 100 * 95,
)
await runner.setup()
@@ -103,7 +104,6 @@ class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported]
runner,
sock,
ssl_context=ctx,
shutdown_timeout=self.cfg.graceful_timeout / 100 * 95,
)
await site.start()
@@ -114,7 +114,7 @@ class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported]
self.notify()
cnt = server.requests_count
if self.cfg.max_requests and cnt > self.cfg.max_requests:
if self.max_requests and cnt > self.max_requests:
self.alive = False
self.log.info("Max requests, shutting down: %s", self)
@@ -182,14 +182,8 @@ class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported]
signal.siginterrupt(signal.SIGUSR1, False)
# Reset signals so Gunicorn doesn't swallow subprocess return codes
# See: https://github.com/aio-libs/aiohttp/issues/6130
if sys.version_info < (3, 8):
# Starting from Python 3.8,
# the default child watcher is ThreadedChildWatcher.
# The watcher doesn't depend on SIGCHLD signal,
# there is no need to reset it.
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
def handle_quit(self, sig: int, frame: FrameType) -> None:
def handle_quit(self, sig: int, frame: Optional[FrameType]) -> None:
self.alive = False
# worker_int callback
@@ -198,7 +192,7 @@ class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported]
# wakeup closing process
self._notify_waiter_done()
def handle_abort(self, sig: int, frame: FrameType) -> None:
def handle_abort(self, sig: int, frame: Optional[FrameType]) -> None:
self.alive = False
self.exit_code = 1
self.cfg.worker_abort(self)
@@ -251,19 +245,3 @@ class GunicornUVLoopWebWorker(GunicornWebWorker):
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
super().init_process()
class GunicornTokioWebWorker(GunicornWebWorker):
def init_process(self) -> None: # pragma: no cover
import tokio
# Close any existing event loop before setting a
# new policy.
asyncio.get_event_loop().close()
# Setup tokio policy, so that every
# asyncio.get_event_loop() will create an instance
# of tokio event loop.
asyncio.set_event_loop_policy(tokio.EventLoopPolicy())
super().init_process()

View File

@@ -1,37 +1,41 @@
#!/usr/bin/env python3
"""Example of aiohttp.web.Application.on_startup signal handler"""
import asyncio
from typing import List
import aioredis
from aiohttp import web
redis_listener = web.AppKey("redis_listener", asyncio.Task[None])
websockets = web.AppKey("websockets", List[web.WebSocketResponse])
async def websocket_handler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
request.app["websockets"].append(ws)
request.app[websockets].append(ws)
try:
async for msg in ws:
print(msg)
await asyncio.sleep(1)
finally:
request.app["websockets"].remove(ws)
request.app[websockets].remove(ws)
return ws
async def on_shutdown(app):
for ws in app["websockets"]:
await ws.close(code=999, message="Server shutdown")
async def on_shutdown(app: web.Application) -> None:
for ws in app[websockets]:
await ws.close(code=999, message=b"Server shutdown")
async def listen_to_redis(app):
try:
sub = await aioredis.create_redis(("localhost", 6379), loop=app.loop)
sub = await aioredis.Redis(host="localhost", port=6379)
ch, *_ = await sub.subscribe("news")
async for msg in ch.iter(encoding="utf-8"):
# Forward message to all connected websockets:
for ws in app["websockets"]:
for ws in app[websockets]:
await ws.send_str(f"{ch.name}: {msg}")
print(f"message in {ch.name}: {msg}")
except asyncio.CancelledError:
@@ -44,18 +48,19 @@ async def listen_to_redis(app):
async def start_background_tasks(app: web.Application) -> None:
app["redis_listener"] = asyncio.create_task(listen_to_redis(app))
app[redis_listener] = asyncio.create_task(listen_to_redis(app))
async def cleanup_background_tasks(app):
print("cleanup background tasks...")
app["redis_listener"].cancel()
await app["redis_listener"]
app[redis_listener].cancel()
await app[redis_listener]
def init():
app = web.Application()
app["websockets"] = []
l: List[web.WebSocketResponse] = []
app[websockets] = l
app.router.add_get("/news", websocket_handler)
app.on_startup.append(start_background_tasks)
app.on_cleanup.append(cleanup_background_tasks)

View File

@@ -13,13 +13,15 @@ needs (i.e. ``-H``, ``-P`` & ``entry-func``) and passes on any additional
arguments to the `cli_app:init` function for processing.
"""
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from aiohttp import web
args_key = web.AppKey("args_key", Namespace)
def display_message(req):
args = req.app["args"]
async def display_message(req: web.Request) -> web.StreamResponse:
args = req.app[args_key]
text = "\n".join([args.message] * args.repeat)
return web.Response(text=text)
@@ -45,7 +47,7 @@ def init(argv):
args = arg_parser.parse_args(argv)
app = web.Application()
app["args"] = args
app[args_key] = args
app.router.add_get("/", display_message)
return app

View File

@@ -1,27 +1,18 @@
#!/usr/bin/env python3
"""websocket cmd client for wssrv.py example."""
"""websocket cmd client for web_ws.py example."""
import argparse
import asyncio
import signal
import sys
from contextlib import suppress
import aiohttp
async def start_client(loop, url):
async def start_client(url: str) -> None:
name = input("Please enter your name: ")
# input reader
def stdin_callback():
line = sys.stdin.buffer.readline().decode("utf-8")
if not line:
loop.stop()
else:
ws.send_str(name + ": " + line)
loop.add_reader(sys.stdin.fileno(), stdin_callback)
async def dispatch():
async def dispatch(ws: aiohttp.ClientWebSocketResponse) -> None:
while True:
msg = await ws.receive()
@@ -30,7 +21,7 @@ async def start_client(loop, url):
elif msg.type == aiohttp.WSMsgType.BINARY:
print("Binary: ", msg.data)
elif msg.type == aiohttp.WSMsgType.PING:
ws.pong()
await ws.pong()
elif msg.type == aiohttp.WSMsgType.PONG:
print("Pong received")
else:
@@ -43,10 +34,18 @@ async def start_client(loop, url):
break
# send request
async with aiohttp.ClientSession() as session:
async with session.ws_connect(url, autoclose=False, autoping=False) as ws:
await dispatch()
# send request
dispatch_task = asyncio.create_task(dispatch(ws))
# Exit with Ctrl+D
while line := await asyncio.to_thread(sys.stdin.readline):
await ws.send_str(name + ": " + line)
dispatch_task.cancel()
with suppress(asyncio.CancelledError):
await dispatch_task
ARGS = argparse.ArgumentParser(
@@ -67,8 +66,4 @@ if __name__ == "__main__":
url = f"http://{args.host}:{args.port}"
loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT, loop.stop)
loop.create_task(start_client(loop, url))
loop.run_forever()
asyncio.run(start_client(url))

View File

@@ -2,6 +2,7 @@
import argparse
import asyncio
import sys
import aiohttp
@@ -25,10 +26,11 @@ if __name__ == "__main__":
)
options = ARGS.parse_args()
if options.iocp:
if options.iocp and sys.platform == "win32":
from asyncio import events, windows_events
el = windows_events.ProactorEventLoop()
# https://github.com/python/mypy/issues/12286
el = windows_events.ProactorEventLoop() # type: ignore[attr-defined]
events.set_event_loop(el)
loop = asyncio.get_event_loop()

View File

@@ -3,10 +3,11 @@ import asyncio
import pathlib
import socket
import ssl
from typing import List
import aiohttp
from aiohttp import web
from aiohttp.abc import AbstractResolver
from aiohttp.abc import AbstractResolver, ResolveResult
from aiohttp.resolver import DefaultResolver
from aiohttp.test_utils import unused_port
@@ -19,7 +20,12 @@ class FakeResolver(AbstractResolver):
self._fakes = fakes
self._resolver = DefaultResolver(loop=loop)
async def resolve(self, host, port=0, family=socket.AF_INET):
async def resolve(
self,
host: str,
port: int = 0,
family: socket.AddressFamily = socket.AF_INET,
) -> List[ResolveResult]:
fake_port = self._fakes.get(host)
if fake_port is not None:
return [
@@ -36,7 +42,7 @@ class FakeResolver(AbstractResolver):
return await self._resolver.resolve(host, port, family)
async def close(self) -> None:
self._resolver.close()
await self._resolver.close()
class FakeFacebook:

View File

@@ -1,14 +1,20 @@
#!/usr/bin/env python3
"""Example for aiohttp.web websocket server."""
# The extra strict mypy settings are here to help test that `Application[AppKey()]`
# syntax is working correctly. A regression will cause mypy to raise an error.
# mypy: disallow-any-expr, disallow-any-unimported, disallow-subclassing-any
import os
from typing import List
from aiohttp import web
WS_FILE = os.path.join(os.path.dirname(__file__), "websocket.html")
sockets = web.AppKey("sockets", List[web.WebSocketResponse])
async def wshandler(request):
async def wshandler(request: web.Request) -> web.StreamResponse:
resp = web.WebSocketResponse()
available = resp.can_prepare(request)
if not available:
@@ -21,34 +27,35 @@ async def wshandler(request):
try:
print("Someone joined.")
for ws in request.app["sockets"]:
for ws in request.app[sockets]:
await ws.send_str("Someone joined")
request.app["sockets"].append(resp)
request.app[sockets].append(resp)
async for msg in resp:
if msg.type == web.WSMsgType.TEXT:
for ws in request.app["sockets"]:
async for msg in resp: # type: ignore[misc]
if msg.type == web.WSMsgType.TEXT: # type: ignore[misc]
for ws in request.app[sockets]:
if ws is not resp:
await ws.send_str(msg.data)
await ws.send_str(msg.data) # type: ignore[misc]
else:
return resp
return resp
finally:
request.app["sockets"].remove(resp)
request.app[sockets].remove(resp)
print("Someone disconnected.")
for ws in request.app["sockets"]:
for ws in request.app[sockets]:
await ws.send_str("Someone disconnected.")
async def on_shutdown(app):
for ws in app["sockets"]:
async def on_shutdown(app: web.Application) -> None:
for ws in app[sockets]:
await ws.close()
def init():
def init() -> web.Application:
app = web.Application()
app["sockets"] = []
l: List[web.WebSocketResponse] = []
app[sockets] = l
app.router.add_get("/", wshandler)
app.on_shutdown.append(on_shutdown)
return app

View File

@@ -2,79 +2,92 @@
<meta charset="utf-8" />
<html>
<head>
<script src="http://ajax.googleapis.com/ajax/libs/jquery/1.4.2/jquery.min.js">
</script>
<script language="javascript" type="text/javascript">
$(function() {
var conn = null;
function log(msg) {
var control = $('#log');
control.html(control.html() + msg + '<br/>');
control.scrollTop(control.scrollTop() + 1000);
}
function connect() {
var socket = null;
function log(msg) {
const logElem = document.getElementById("log");
const p = document.createElement("p");
p.textContent = msg;
logElem.appendChild(p);
logElem.scroll(0, logElem.scrollHeight);
}
function connect() {
disconnect();
var wsUri = (window.location.protocol=='https:'&&'wss://'||'ws://')+window.location.host;
conn = new WebSocket(wsUri);
log('Connecting...');
conn.onopen = function() {
log('Connected.');
update_ui();
};
conn.onmessage = function(e) {
log('Received: ' + e.data);
};
conn.onclose = function() {
log('Disconnected.');
conn = null;
update_ui();
};
}
function disconnect() {
if (conn != null) {
log('Disconnecting...');
conn.close();
conn = null;
update_ui();
socket = new WebSocket(document.getElementById("wsUri").value);
log("Connecting...");
socket.addEventListener("open", function() {
log("Connected.");
update_ui();
});
socket.addEventListener("message", function(e) {
log("Received: " + e.data);
});
socket.addEventListener("close", function() {
log("Disconnected.");
socket = null;
update_ui();
});
}
function disconnect() {
if (socket !== null) {
log("Disconnecting...");
socket.close();
socket = null;
update_ui();
}
}
function update_ui() {
if (conn == null) {
$('#status').text('disconnected');
$('#connect').html('Connect');
}
function update_ui() {
const status = document.getElementById("status");
const connect = document.getElementById("connect");
if (socket === null) {
status.innerText = "disconnected";
connect.innerText = "Connect";
} else {
$('#status').text('connected (' + conn.protocol + ')');
$('#connect').html('Disconnect');
status.innerText = "connected (" + socket.protocol + ")";
connect.innerText = "Disconnect";
}
}
$('#connect').click(function() {
if (conn == null) {
connect();
} else {
disconnect();
}
update_ui();
return false;
});
$('#send').click(function() {
var text = $('#text').val();
log('Sending: ' + text);
conn.send(text);
$('#text').val('').focus();
return false;
});
$('#text').keyup(function(e) {
if (e.keyCode === 13) {
$('#send').click();
return false;
}
});
}
window.addEventListener("DOMContentLoaded", function() {
const protocol = (window.location.protocol=="https:" && "wss://" || "ws://");
document.getElementById("wsUri").value = protocol + (window.location.host || "localhost:8080");
document.getElementById("connect").addEventListener("click", function() {
if (socket == null) {
connect();
} else {
disconnect();
}
update_ui();
return false;
});
document.getElementById("send").addEventListener("click", function() {
const text = document.getElementById("text");
log("Sending: " + text.value);
socket.send(text.value);
text.value = "";
text.focus();
return false;
});
document.getElementById("text").addEventListener("keyup", function(e) {
if (e.keyCode === 13) {
document.getElementById("send").click();
return false;
}
});
});
</script>
</head>
<body>
<h3>Chat!</h3>
<div>
<input id="wsUri" type="text" />
<button id="connect">Connect</button>&nbsp;|&nbsp;Status:
<span id="status">disconnected</span>
</div>

View File

@@ -5,15 +5,85 @@ requires = [
build-backend = "setuptools.build_meta"
[tool.towncrier]
package = "aiohttp"
filename = "CHANGES.rst"
directory = "CHANGES/"
title_format = "{version} ({project_date})"
template = "CHANGES/.TEMPLATE.rst"
issue_format = "`#{issue} <https://github.com/aio-libs/aiohttp/issues/{issue}>`_"
package = "aiohttp"
filename = "CHANGES.rst"
directory = "CHANGES/"
title_format = "{version} ({project_date})"
template = "CHANGES/.TEMPLATE.rst"
issue_format = "{issue}"
# NOTE: The types are declared because:
# NOTE: - there is no mechanism to override just the value of
# NOTE: `tool.towncrier.type.misc.showcontent`;
# NOTE: - and, we want to declare extra non-default types for
# NOTE: clarity and flexibility.
[[tool.towncrier.section]]
path = ""
[[tool.towncrier.type]]
# Something we deemed an improper undesired behavior that got corrected
# in the release to match pre-agreed expectations.
directory = "bugfix"
name = "Bug fixes"
showcontent = true
[[tool.towncrier.type]]
# New behaviors, public APIs. That sort of stuff.
directory = "feature"
name = "Features"
showcontent = true
[[tool.towncrier.type]]
# Declarations of future API removals and breaking changes in behavior.
directory = "deprecation"
name = "Deprecations (removal in next major release)"
showcontent = true
[[tool.towncrier.type]]
# When something public gets removed in a breaking way. Could be
# deprecated in an earlier release.
directory = "breaking"
name = "Removals and backward incompatible breaking changes"
showcontent = true
[[tool.towncrier.type]]
# Notable updates to the documentation structure or build process.
directory = "doc"
name = "Improved documentation"
showcontent = true
[[tool.towncrier.type]]
# Notes for downstreams about unobvious side effects and tooling. Changes
# in the test invocation considerations and runtime assumptions.
directory = "packaging"
name = "Packaging updates and notes for downstreams"
showcontent = true
[[tool.towncrier.type]]
# Stuff that affects the contributor experience. e.g. Running tests,
# building the docs, setting up the development environment.
directory = "contrib"
name = "Contributor-facing changes"
showcontent = true
[[tool.towncrier.type]]
# Changes that are hard to assign to any of the above categories.
directory = "misc"
name = "Miscellaneous internal changes"
showcontent = true
[tool.cibuildwheel]
test-command = ""
# don't build PyPy wheels, install from source instead
skip = "pp*"
[tool.codespell]
skip = '.git,*.pdf,*.svg,Makefile,CONTRIBUTORS.txt,venvs,_build'
ignore-words-list = 'te'
[tool.slotscheck]
# TODO(3.13): Remove aiohttp.helpers once https://github.com/python/cpython/pull/106771
# is available in all supported cpython versions
exclude-modules = "(^aiohttp\\.helpers)"

View File

@@ -0,0 +1 @@
7b50f4e32516f7a808dbe40b1c88ab367699d62151edae4eb989010c35da30e4 /home/runner/work/aiohttp/aiohttp/requirements/cython.txt

View File

@@ -0,0 +1,4 @@
-r runtime-deps.in
gunicorn
uvloop; platform_system != "Windows" and implementation_name == "cpython" # MagicStack/uvloop#14

View File

@@ -0,0 +1,42 @@
#
# This file is autogenerated by pip-compile with Python 3.8
# by the following command:
#
# pip-compile --allow-unsafe --output-file=requirements/base.txt --strip-extras requirements/base.in
#
aiodns==3.2.0 ; sys_platform == "linux" or sys_platform == "darwin"
# via -r requirements/runtime-deps.in
aiohappyeyeballs==2.3.4
# via -r requirements/runtime-deps.in
aiosignal==1.3.1
# via -r requirements/runtime-deps.in
async-timeout==4.0.3 ; python_version < "3.11"
# via -r requirements/runtime-deps.in
attrs==23.2.0
# via -r requirements/runtime-deps.in
brotli==1.1.0 ; platform_python_implementation == "CPython"
# via -r requirements/runtime-deps.in
cffi==1.17.0
# via pycares
frozenlist==1.4.1
# via
# -r requirements/runtime-deps.in
# aiosignal
gunicorn==22.0.0
# via -r requirements/base.in
idna==3.4
# via yarl
multidict==6.0.5
# via
# -r requirements/runtime-deps.in
# yarl
packaging==23.1
# via gunicorn
pycares==4.3.0
# via aiodns
pycparser==2.21
# via cffi
uvloop==0.21.0b1 ; platform_system != "Windows" and implementation_name == "cpython"
# via -r requirements/base.in
yarl==1.13.0
# via -r requirements/runtime-deps.in

View File

@@ -0,0 +1 @@
Pillow < 10 # https://github.com/blockdiag/sphinxcontrib-blockdiag/issues/26

View File

@@ -0,0 +1,4 @@
-r cython.in
-r dev.in
-r doc-spelling.in
-r lint.in

View File

@@ -0,0 +1,309 @@
#
# This file is autogenerated by pip-compile with Python 3.8
# by the following command:
#
# pip-compile --allow-unsafe --output-file=requirements/constraints.txt --strip-extras requirements/constraints.in
#
aiodns==3.2.0 ; sys_platform == "linux" or sys_platform == "darwin"
# via
# -r requirements/lint.in
# -r requirements/runtime-deps.in
aiohappyeyeballs==2.3.4
# via -r requirements/runtime-deps.in
aiohttp-theme==0.1.7
# via -r requirements/doc.in
aioredis==2.0.1
# via -r requirements/lint.in
aiosignal==1.3.1
# via -r requirements/runtime-deps.in
alabaster==0.7.12
# via sphinx
annotated-types==0.7.0
# via pydantic
async-timeout==4.0.3 ; python_version < "3.11"
# via
# -r requirements/runtime-deps.in
# aioredis
attrs==23.2.0
# via -r requirements/runtime-deps.in
babel==2.9.1
# via sphinx
backports-entry-points-selectable==1.1.1
# via virtualenv
blockdiag==2.0.1
# via sphinxcontrib-blockdiag
brotli==1.1.0 ; platform_python_implementation == "CPython"
# via -r requirements/runtime-deps.in
build==1.0.3
# via pip-tools
certifi==2023.7.22
# via requests
cffi==1.17.1
# via
# cryptography
# pycares
# pytest-codspeed
cfgv==3.3.1
# via pre-commit
charset-normalizer==3.2.0
# via requests
cherry-picker==2.2.0
# via -r requirements/dev.in
click==8.0.3
# via
# cherry-picker
# pip-tools
# slotscheck
# towncrier
# typer
# wait-for-it
coverage==7.6.0
# via
# -r requirements/test.in
# pytest-cov
cryptography==41.0.2
# via
# pyjwt
# trustme
cython==3.0.10
# via -r requirements/cython.in
distlib==0.3.3
# via virtualenv
docutils==0.20.1
# via sphinx
exceptiongroup==1.1.2
# via pytest
execnet==2.1.1
# via pytest-xdist
filelock==3.16.1
# via
# pytest-codspeed
# virtualenv
freezegun==1.5.1
# via
# -r requirements/lint.in
# -r requirements/test.in
frozenlist==1.4.1
# via
# -r requirements/runtime-deps.in
# aiosignal
funcparserlib==1.0.1
# via blockdiag
gidgethub==5.0.1
# via cherry-picker
gunicorn==22.0.0
# via -r requirements/base.in
identify==2.3.5
# via pre-commit
idna==3.3
# via
# requests
# trustme
# yarl
imagesize==1.3.0
# via sphinx
importlib-metadata==7.0.0
# via
# build
# sphinx
importlib-resources==6.1.1
# via towncrier
incremental==22.10.0
# via towncrier
iniconfig==1.1.1
# via pytest
jinja2==3.0.3
# via
# sphinx
# towncrier
markupsafe==2.0.1
# via jinja2
multidict==6.0.5
# via
# -r requirements/multidict.in
# -r requirements/runtime-deps.in
# yarl
mypy==1.11.1 ; implementation_name == "cpython"
# via
# -r requirements/lint.in
# -r requirements/test.in
mypy-extensions==1.0.0
# via mypy
nodeenv==1.6.0
# via pre-commit
packaging==21.2
# via
# build
# gunicorn
# pytest
# sphinx
pillow==9.5.0
# via
# -c requirements/broken-projects.in
# blockdiag
pip-tools==7.4.1
# via -r requirements/dev.in
platformdirs==2.4.0
# via virtualenv
pluggy==1.5.0
# via pytest
pre-commit==3.5.0
# via -r requirements/lint.in
proxy-py==2.4.4
# via -r requirements/test.in
pycares==4.3.0
# via aiodns
pycparser==2.21
# via cffi
pydantic==2.9.2
# via python-on-whales
pydantic-core==2.23.4
# via pydantic
pyenchant==3.2.2
# via sphinxcontrib-spelling
pygments==2.15.1
# via sphinx
pyjwt==2.3.0
# via
# gidgethub
# pyjwt
pyparsing==2.4.7
# via packaging
pyproject-hooks==1.0.0
# via
# build
# pip-tools
pytest==8.3.2
# via
# -r requirements/lint.in
# -r requirements/test.in
# pytest-codspeed
# pytest-cov
# pytest-mock
# pytest-xdist
pytest-codspeed==2.2.1
# via
# -r requirements/lint.in
# -r requirements/test.in
pytest-cov==5.0.0
# via -r requirements/test.in
pytest-mock==3.14.0
# via
# -r requirements/lint.in
# -r requirements/test.in
pytest-xdist==3.6.1
# via -r requirements/test.in
python-dateutil==2.8.2
# via freezegun
python-on-whales==0.72.0
# via
# -r requirements/lint.in
# -r requirements/test.in
pytz==2023.3.post1
# via babel
pyyaml==6.0.1
# via pre-commit
re-assert==1.1.0
# via -r requirements/test.in
regex==2024.9.11
# via re-assert
requests==2.31.0
# via
# cherry-picker
# python-on-whales
# sphinx
setuptools-git==1.2
# via -r requirements/test.in
six==1.16.0
# via
# python-dateutil
# virtualenv
slotscheck==0.19.0
# via -r requirements/lint.in
snowballstemmer==2.1.0
# via sphinx
sphinx==7.1.2
# via
# -r requirements/doc.in
# sphinxcontrib-blockdiag
# sphinxcontrib-spelling
# sphinxcontrib-towncrier
sphinxcontrib-applehelp==1.0.2
# via sphinx
sphinxcontrib-blockdiag==3.0.0
# via -r requirements/doc.in
sphinxcontrib-devhelp==1.0.2
# via sphinx
sphinxcontrib-htmlhelp==2.0.0
# via sphinx
sphinxcontrib-jsmath==1.0.1
# via sphinx
sphinxcontrib-qthelp==1.0.3
# via sphinx
sphinxcontrib-serializinghtml==1.1.5
# via sphinx
sphinxcontrib-spelling==8.0.0 ; platform_system != "Windows"
# via -r requirements/doc-spelling.in
sphinxcontrib-towncrier==0.4.0a0
# via -r requirements/doc.in
tomli==2.0.1
# via
# build
# cherry-picker
# coverage
# mypy
# pip-tools
# pyproject-hooks
# pytest
# slotscheck
# towncrier
towncrier==23.11.0
# via
# -r requirements/doc.in
# sphinxcontrib-towncrier
tqdm==4.62.3
# via python-on-whales
trustme==1.1.0 ; platform_machine != "i686"
# via
# -r requirements/lint.in
# -r requirements/test.in
typer==0.6.1
# via python-on-whales
typing-extensions==4.12.2
# via
# aioredis
# annotated-types
# mypy
# pydantic
# pydantic-core
# python-on-whales
uritemplate==4.1.1
# via gidgethub
urllib3==1.26.7
# via requests
uvloop==0.21.0b1 ; platform_system != "Windows"
# via
# -r requirements/base.in
# -r requirements/lint.in
virtualenv==20.10.0
# via pre-commit
wait-for-it==2.2.2
# via -r requirements/test.in
webcolors==1.11.1
# via blockdiag
wheel==0.37.0
# via pip-tools
yarl==1.13.0
# via -r requirements/runtime-deps.in
zipp==3.17.0
# via
# importlib-metadata
# importlib-resources
# The following packages are considered to be unsafe in a requirements file:
pip==23.2.1
# via pip-tools
setuptools==68.0.0
# via
# blockdiag
# pip-tools

View File

@@ -0,0 +1,3 @@
-r multidict.in
Cython

View File

@@ -0,0 +1,12 @@
#
# This file is autogenerated by pip-compile with python 3.8
# by the following command:
#
# pip-compile --allow-unsafe --output-file=requirements/cython.txt --resolver=backtracking --strip-extras requirements/cython.in
#
cython==3.0.10
# via -r requirements/cython.in
multidict==6.0.5
# via -r requirements/multidict.in
typing-extensions==4.12.2
# via -r requirements/typing-extensions.in

View File

@@ -0,0 +1,6 @@
-r lint.in
-r test.in
-r doc.in
cherry_picker
pip-tools

View File

@@ -0,0 +1,278 @@
#
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# pip-compile --allow-unsafe --output-file=requirements/dev.txt --resolver=backtracking --strip-extras requirements/dev.in
#
aiodns==3.2.0 ; sys_platform == "linux" or sys_platform == "darwin"
# via
# -r requirements/lint.in
# -r requirements/runtime-deps.in
aiohappyeyeballs==2.3.4
# via -r requirements/runtime-deps.in
aiohttp-theme==0.1.7
# via -r requirements/doc.in
aioredis==2.0.1
# via -r requirements/lint.in
aiosignal==1.3.1
# via -r requirements/runtime-deps.in
alabaster==0.7.13
# via sphinx
annotated-types==0.7.0
# via pydantic
async-timeout==4.0.3 ; python_version < "3.11"
# via
# -r requirements/runtime-deps.in
# aioredis
attrs==23.2.0
# via -r requirements/runtime-deps.in
babel==2.12.1
# via sphinx
blockdiag==3.0.0
# via sphinxcontrib-blockdiag
brotli==1.1.0 ; platform_python_implementation == "CPython"
# via -r requirements/runtime-deps.in
build==1.0.3
# via pip-tools
certifi==2023.7.22
# via requests
cffi==1.17.0
# via
# cryptography
# pycares
cfgv==3.3.1
# via pre-commit
charset-normalizer==3.2.0
# via requests
cherry-picker==2.2.0
# via -r requirements/dev.in
click==8.1.6
# via
# cherry-picker
# pip-tools
# slotscheck
# towncrier
# typer
# wait-for-it
coverage==7.6.0
# via
# -r requirements/test.in
# pytest-cov
cryptography==41.0.3
# via
# pyjwt
# trustme
distlib==0.3.7
# via virtualenv
docutils==0.20.1
# via sphinx
exceptiongroup==1.1.2
# via pytest
filelock==3.12.2
# via virtualenv
freezegun==1.5.1
# via -r requirements/test.in
frozenlist==1.4.1
# via
# -r requirements/runtime-deps.in
# aiosignal
funcparserlib==1.0.1
# via blockdiag
gidgethub==5.3.0
# via cherry-picker
gunicorn==22.0.0
# via -r requirements/base.in
identify==2.5.26
# via pre-commit
idna==3.4
# via
# requests
# trustme
# yarl
imagesize==1.4.1
# via sphinx
importlib-metadata==7.0.0
# via
# build
# sphinx
importlib-resources==6.1.1
# via towncrier
incremental==22.10.0
# via towncrier
iniconfig==2.0.0
# via pytest
jinja2==3.1.2
# via
# sphinx
# towncrier
markupsafe==2.1.3
# via jinja2
multidict==6.0.5
# via
# -r requirements/runtime-deps.in
# yarl
mypy==1.11.1 ; implementation_name == "cpython"
# via
# -r requirements/lint.in
# -r requirements/test.in
mypy-extensions==1.0.0
# via mypy
nodeenv==1.8.0
# via pre-commit
packaging==23.1
# via
# build
# gunicorn
# pytest
# sphinx
pillow==9.5.0
# via
# -c requirements/broken-projects.in
# blockdiag
pip-tools==7.4.1
# via -r requirements/dev.in
platformdirs==3.10.0
# via virtualenv
pluggy==1.5.0
# via pytest
pre-commit==3.5.0
# via -r requirements/lint.in
proxy-py==2.4.4
# via -r requirements/test.in
pycares==4.3.0
# via aiodns
pycparser==2.21
# via cffi
pydantic==2.9.2
# via python-on-whales
pydantic-core==2.23.4
# via pydantic
pygments==2.15.1
# via sphinx
pyjwt==2.8.0
# via
# gidgethub
# pyjwt
pyproject-hooks==1.0.0
# via
# build
# pip-tools
pytest==8.3.2
# via
# -r requirements/lint.in
# -r requirements/test.in
# pytest-cov
# pytest-mock
pytest-cov==5.0.0
# via -r requirements/test.in
pytest-mock==3.14.0
# via -r requirements/test.in
python-dateutil==2.8.2
# via freezegun
python-on-whales==0.72.0
# via
# -r requirements/lint.in
# -r requirements/test.in
pytz==2023.3.post1
# via babel
pyyaml==6.0.1
# via pre-commit
re-assert==1.1.0
# via -r requirements/test.in
regex==2024.9.11
# via re-assert
requests==2.31.0
# via
# cherry-picker
# python-on-whales
# sphinx
setuptools-git==1.2
# via -r requirements/test.in
six==1.16.0
# via python-dateutil
slotscheck==0.19.0
# via -r requirements/lint.in
snowballstemmer==2.2.0
# via sphinx
sphinx==7.1.2
# via
# -r requirements/doc.in
# sphinxcontrib-blockdiag
# sphinxcontrib-towncrier
sphinxcontrib-applehelp==1.0.4
# via sphinx
sphinxcontrib-blockdiag==3.0.0
# via -r requirements/doc.in
sphinxcontrib-devhelp==1.0.2
# via sphinx
sphinxcontrib-htmlhelp==2.0.1
# via sphinx
sphinxcontrib-jsmath==1.0.1
# via sphinx
sphinxcontrib-qthelp==1.0.3
# via sphinx
sphinxcontrib-serializinghtml==1.1.5
# via sphinx
sphinxcontrib-towncrier==0.4.0a0
# via -r requirements/doc.in
tomli==2.0.1
# via
# build
# cherry-picker
# coverage
# mypy
# pip-tools
# pyproject-hooks
# pytest
# slotscheck
# towncrier
towncrier==23.11.0
# via
# -r requirements/doc.in
# sphinxcontrib-towncrier
tqdm==4.65.0
# via python-on-whales
trustme==1.1.0 ; platform_machine != "i686"
# via -r requirements/test.in
typer==0.9.0
# via python-on-whales
typing-extensions==4.12.2
# via
# aioredis
# annotated-types
# mypy
# pydantic
# pydantic-core
# python-on-whales
# typer
uritemplate==4.1.1
# via gidgethub
urllib3==2.0.4
# via requests
uvloop==0.21.0b1 ; platform_system != "Windows" and implementation_name == "cpython"
# via
# -r requirements/base.in
# -r requirements/lint.in
virtualenv==20.24.2
# via pre-commit
wait-for-it==2.2.2
# via -r requirements/test.in
webcolors==1.13
# via blockdiag
wheel==0.41.0
# via pip-tools
yarl==1.13.0
# via -r requirements/runtime-deps.in
zipp==3.17.0
# via
# importlib-metadata
# importlib-resources
# The following packages are considered to be unsafe in a requirements file:
pip==23.2.1
# via pip-tools
setuptools==68.0.0
# via
# blockdiag
# nodeenv
# pip-tools

Some files were not shown because too many files have changed in this diff Show More