diff --git a/.cherry_picker.toml b/.cherry_picker.toml new file mode 100644 index 00000000000..d5fc3e7bcc5 --- /dev/null +++ b/.cherry_picker.toml @@ -0,0 +1,4 @@ +team = "aio-libs" +repo = "aiohttp" +check_sha = "f382b5ffc445e45a110734f5396728da7914aeb6" +fix_commit_msg = false diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 00000000000..a12881ed542 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,38 @@ +codecov: + branch: 3.7 + +coverage: + range: "95..100" + + status: + project: no + +flags: + library: + paths: + - aiohttp/ + configs: + paths: + - requirements/ + - ".git*" + - "*.toml" + - "*.yml" + changelog: + paths: + - CHANGES/ + - CHANGES.rst + docs: + paths: + - docs/ + - "*.md" + - "*.rst" + - "*.txt" + tests: + paths: + - tests/ + tools: + paths: + - tools/ + third-party: + paths: + - vendor/ diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index cd926cd0e56..00000000000 --- a/.coveragerc +++ /dev/null @@ -1,7 +0,0 @@ -[run] -branch = True -source = aiohttp, tests -omit = site-packages - -[html] -directory = coverage \ No newline at end of file diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000000..ae54f90b0b4 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,19 @@ +# EditorConfig is awesome: http://EditorConfig.org + +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true +indent_style = space +indent_size = 4 +trim_trailing_whitespace = true +charset = utf-8 + +[Makefile] +indent_style = tab + +[*.{yml,yaml}] +indent_size = 2 diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000000..1fdd659bbc9 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +tests/data.unknown_mime_type binary +tests/hello.txt.gz binary +tests/sample.* binary diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000000..9cf15d04afc --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,22 @@ +* @asvetlov +/.github/* @webknjaz @asvetlov +/.circleci/* @webknjaz @asvetlov +/CHANGES/* @asvetlov +/docs/* @asvetlov +/examples/* @asvetlov +/requirements/* @webknjaz @asvetlov +/tests/* @asvetlov +/tools/* @webknjaz @asvetlov +/vendor/* @webknjaz @asvetlov +*.ini @webknjaz @asvetlov +*.md @webknjaz @asvetlov +*.rst @webknjaz @asvetlov +*.toml @webknjaz @asvetlov +*.txt @webknjaz @asvetlov +*.yml @webknjaz @asvetlov +*.yaml @webknjaz @asvetlov +.editorconfig @webknjaz @asvetlov +.git* @webknjaz +Makefile @webknjaz @asvetlov +setup.py @webknjaz @asvetlov +setup.cfg @webknjaz @asvetlov diff --git a/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md similarity index 64% rename from ISSUE_TEMPLATE.md rename to .github/ISSUE_TEMPLATE.md index 319163f92a5..4c93982fd0d 100644 --- a/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -19,5 +19,10 @@ ## Your environment + This includes aiohttp version, OS, proxy server and other bits that + are related to your case. + + IMPORTANT: aiohttp is both server framework and client library. + For getting rid of confusing please put 'server', 'client' or 'both' + word here. + --> diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000000..076abc3b3a9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,22 @@ +--- +name: 🚀 Feature request +about: Suggest an idea for this project +labels: enhancement +assignees: aio-libs/triagers + +--- + +🐣 **Is your feature request related to a problem? Please describe.** + + + +💡 **Describe the solution you'd like** + + + +❓ **Describe alternatives you've considered** + + + +📋 **Additional context** + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000000..237c61a659f --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,32 @@ + + +## What do these changes do? + + + +## Are there changes in behavior for the user? + + + +## Related issue number + + + +## Checklist + +- [ ] I think the code is well written +- [ ] Unit tests for the changes exist +- [ ] Documentation reflects the changes +- [ ] If you provide code modification, please add yourself to `CONTRIBUTORS.txt` + * The format is <Name> <Surname>. + * Please keep alphabetical order, the file is sorted by names. +- [ ] Add a new news fragment into the `CHANGES` folder + * name it `.` for example (588.bugfix) + * if you don't have an `issue_id` change it to the pr id after creating the pr + * ensure type is one of the following: + * `.feature`: Signifying a new feature. + * `.bugfix`: Signifying a bug fix. + * `.doc`: Signifying a documentation improvement. + * `.removal`: Signifying a deprecation or removal of public API. + * `.misc`: A ticket has been closed, but it is not of interest to users. + * Make sure to use full sentences with correct case and punctuation, for example: "Fix issue with non-ascii contents in doctest text files." diff --git a/.github/config.yml b/.github/config.yml new file mode 100644 index 00000000000..ce72ac53c34 --- /dev/null +++ b/.github/config.yml @@ -0,0 +1,7 @@ +chronographer: + exclude: + bots: + - dependabot-preview + - dependabot + humans: + - pyup-bot diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000000..d09e636005b --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,40 @@ +version: 2 +updates: + + # Maintain dependencies for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + labels: + - dependencies + - autosquash + schedule: + interval: "daily" + + # Maintain dependencies for Python + - package-ecosystem: "pip" + directory: "/" + labels: + - dependencies + - autosquash + schedule: + interval: "daily" + + # Maintain dependencies for GitHub Actions aiohttp 3.7 + - package-ecosystem: "github-actions" + directory: "/" + labels: + - dependencies + - autosquash + target-branch: "3.7" + schedule: + interval: "daily" + + # Maintain dependencies for Python aiohttp 3.7 + - package-ecosystem: "pip" + directory: "/" + labels: + - dependencies + - autosquash + target-branch: "3.7" + schedule: + interval: "daily" diff --git a/.github/workflows/autosquash.yml b/.github/workflows/autosquash.yml new file mode 100644 index 00000000000..63d6868daf6 --- /dev/null +++ b/.github/workflows/autosquash.yml @@ -0,0 +1,39 @@ +name: Autosquash +on: + check_run: + types: + # Check runs completing successfully can unblock the + # corresponding pull requests and make them mergeable. + - completed + pull_request: + types: + # A closed pull request makes the checks on the other + # pull request on the same base outdated. + - closed + # Adding the autosquash label to a pull request can + # trigger an update or a merge. + - labeled + pull_request_review: + types: + # Review approvals can unblock the pull request and + # make it mergeable. + - submitted + # Success statuses can unblock the corresponding + # pull requests and make them mergeable. + status: {} + +jobs: + autosquash: + name: Autosquash + runs-on: ubuntu-latest + # not awailable for forks, skip the workflow + if: ${{ github.event.pull_request.head.repo.full_name == 'aio-libs/aiohttp' }} + steps: + - id: generate_token + uses: tibdex/github-app-token@v1 + with: + app_id: ${{ secrets.BOT_APP_ID }} + private_key: ${{ secrets.BOT_PRIVATE_KEY }} + - uses: tibdex/autosquash@v2 + with: + github_token: ${{ steps.generate_token.outputs.token }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000000..2537929c09a --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,281 @@ +name: CI + +on: + push: + branches: + - 'master' + - '[0-9].[0-9]+' # matches to backport branches, e.g. 3.6 + tags: [ 'v*' ] + pull_request: + branches: + - 'master' + - '[0-9].[0-9]+' + schedule: + - cron: '0 6 * * *' # Daily 6AM UTC build + + +jobs: + + lint: + name: Linter + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + submodules: true + - name: Setup Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Cache PyPI + uses: actions/cache@v2 + with: + key: pip-lint-${{ hashFiles('requirements/*.txt') }} + path: ~/.cache/pip + restore-keys: | + pip-lint- + - name: Install dependencies + uses: py-actions/py-dependency-install@v2 + with: + path: requirements/lint.txt + - name: Pre-Commit hooks + uses: pre-commit/action@v2.0.0 + - name: Install itself + run: | + python setup.py install + env: + AIOHTTP_NO_EXTENSIONS: 1 + - name: Run linters + run: | + make mypy + - name: Install spell checker + run: | + sudo apt install libenchant-dev + pip install -r requirements/doc-spelling.txt + - name: Run docs spelling + run: | + # towncrier --yes # uncomment me after publishing a release + make doc-spelling + - name: Prepare twine checker + run: | + pip install -U twine wheel + python setup.py sdist bdist_wheel + env: + AIOHTTP_NO_EXTENSIONS: 1 + - name: Run twine checker + run: | + twine check dist/* + - name: Making sure that CONTRIBUTORS.txt remains sorted + run: | + LC_ALL=C sort -c CONTRIBUTORS.txt + + test: + name: Test + needs: lint + strategy: + matrix: + pyver: [3.6, 3.7, 3.8, 3.9] + no-extensions: ['', 'Y'] + os: [ubuntu, macos, windows] + exclude: + - os: macos + no-extensions: 'Y' + - os: macos + pyver: 3.7 + - os: macos + pyver: 3.8 + - os: windows + no-extensions: 'Y' + include: + - pyver: pypy3 + no-extensions: 'Y' + os: ubuntu + fail-fast: true + runs-on: ${{ matrix.os }}-latest + timeout-minutes: 15 + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + submodules: true + - name: Setup Python ${{ matrix.pyver }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.pyver }} + - name: Get pip cache dir + id: pip-cache + run: | + echo "::set-output name=dir::$(pip cache dir)" # - name: Cache + - name: Cache PyPI + uses: actions/cache@v2 + with: + key: pip-ci-${{ runner.os }}-${{ matrix.pyver }}-{{ matrix.no-extensions }}-${{ hashFiles('requirements/*.txt') }} + path: ${{ steps.pip-cache.outputs.dir }} + restore-keys: | + pip-ci-${{ runner.os }}-${{ matrix.pyver }}-{{ matrix.no-extensions }}- + - name: Cythonize + if: ${{ matrix.no-extensions == '' }} + run: | + make cythonize + - name: Run unittests + env: + COLOR: 'yes' + AIOHTTP_NO_EXTENSIONS: ${{ matrix.no-extensions }} + run: | + make vvtest + python -m coverage xml + - name: Upload coverage + uses: codecov/codecov-action@v1 + with: + file: ./coverage.xml + flags: unit + fail_ci_if_error: false + + pre-deploy: + name: Pre-Deploy + runs-on: ubuntu-latest + needs: test + # Run only on pushing a tag + if: github.event_name == 'push' && contains(github.ref, 'refs/tags/') + steps: + - name: Dummy + run: | + echo "Predeploy step" + + build-tarball: + name: Tarball + runs-on: ubuntu-latest + needs: pre-deploy + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + submodules: true + - name: Setup Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Cythonize + run: | + make cythonize + - name: Make sdist + run: + python setup.py sdist + - name: Upload artifacts + uses: actions/upload-artifact@v2 + with: + name: dist + path: dist + + build-linux: + name: Linux + strategy: + matrix: + pyver: [cp36-cp36m, cp37-cp37m, cp38-cp38, cp39-cp39] + arch: [x86_64, aarch64, i686, ppc64le, s390x] + fail-fast: false + runs-on: ubuntu-latest + env: + py: /opt/python/${{ matrix.pyver }}/bin/python + img: quay.io/pypa/manylinux2014_${{ matrix.arch }} + needs: pre-deploy + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + submodules: true + - name: Set up QEMU + id: qemu + uses: docker/setup-qemu-action@v1 + - name: Available platforms + run: echo ${{ steps.qemu.outputs.platforms }} + - name: Setup Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Cythonize + if: ${{ matrix.no-extensions == '' }} + run: | + make cythonize + - name: Install tools + run: | + docker run --rm -v ${{ github.workspace }}:/ws:rw --workdir=/ws \ + ${{ env.img }} ${{ env.py }} -m pip install -U setuptools wheel + - name: Make wheel + run: | + docker run --rm -v ${{ github.workspace }}:/ws:rw --workdir=/ws \ + ${{ env.img }} ${{ env.py }} setup.py bdist_wheel + - name: Repair wheel wheel + run: | + docker run --rm -v ${{ github.workspace }}:/ws:rw --workdir=/ws \ + ${{ env.img }} auditwheel repair dist/*.whl --wheel-dir wheelhouse/ + - name: Upload artifacts + uses: actions/upload-artifact@v2 + with: + name: dist + path: wheelhouse/* + + build-binary: + name: Binary wheels + strategy: + matrix: + pyver: [3.6, 3.7, 3.8, 3.9] + os: [macos, windows] + arch: [x86, x64] + exclude: + - os: macos + arch: x86 + fail-fast: false + runs-on: ${{ matrix.os }}-latest + needs: pre-deploy + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + submodules: true + - name: Setup Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.pyver }} + architecture: ${{ matrix.arch }} + - name: Cythonize + if: ${{ matrix.no-extensions == '' }} + run: | + make cythonize + - name: Install dependencies + run: | + python -m pip install -U setuptools wheel + - name: Make wheel + run: + python setup.py bdist_wheel + - name: Upload artifacts + uses: actions/upload-artifact@v2 + with: + name: dist + path: dist + + deploy: + name: Deploy + needs: [build-linux, build-binary, build-tarball] + runs-on: ubuntu-latest + steps: + - name: Setup Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install twine + run: | + python -m pip install twine + - name: Download dists + uses: actions/download-artifact@v2 + with: + name: dist + path: dist + - name: PyPI upload + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + run: | + twine upload dist/* diff --git a/.gitignore b/.gitignore index 4754ff9eb9d..69f52e10d87 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,13 @@ -*.swp *.bak *.egg *.egg-info *.eggs +*.md5 *.pyc *.pyd *.pyo *.so +*.swp *.tar.gz *~ .DS_Store @@ -14,35 +15,53 @@ .cache .coverage .coverage.* +.develop +.direnv +.envrc +.flake +.gitconfig +.hash .idea +.install-cython +.install-deps .installed.cfg +.mypy_cache .noseids +.pytest_cache +.python-version +.test-results .tox .vimrc -aiohttp/_multidict.c -aiohttp/_multidict.html +.vscode +aiohttp/_find_header.c +aiohttp/_frozenlist.c +aiohttp/_frozenlist.html +aiohttp/_headers.html +aiohttp/_headers.pxi +aiohttp/_helpers.c +aiohttp/_helpers.html aiohttp/_http_parser.c +aiohttp/_http_parser.html +aiohttp/_http_writer.c +aiohttp/_http_writer.html +aiohttp/_websocket.c +aiohttp/_websocket.html bin build -cover -coverage +coverage.xml develop-eggs dist docs/_build/ eggs +htmlcov include/ lib/ man/ nosetests.xml parts +pip-wheel-metadata pyvenv sources var/* venv virtualenv.py -aiohttp/_websocket.c -aiohttp/_websocket.html -.install-deps -.develop -.gitconfig -.flake diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000000..c58945aa9fc --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "vendor/http-parser"] + path = vendor/http-parser + url = git://github.com/nodejs/http-parser.git + branch = 54f55a2 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..55826f603ee --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,71 @@ +repos: +- repo: local + hooks: + - id: check-changes + name: Check CHANGES + language: system + entry: ./tools/check_changes.py + pass_filenames: false +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: 'v3.3.0' + hooks: + - id: check-merge-conflict +- repo: https://github.com/asottile/yesqa + rev: v1.2.2 + hooks: + - id: yesqa +- repo: https://github.com/pre-commit/mirrors-isort + rev: 'v5.6.4' + hooks: + - id: isort +- repo: https://github.com/psf/black + rev: '20.8b1' + hooks: + - id: black + language_version: python3 # Should be a command that runs python3.6+ +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: 'v3.3.0' + hooks: + - id: end-of-file-fixer + exclude: >- + ^docs/[^/]*\.svg$ + - id: requirements-txt-fixer + - id: trailing-whitespace + - id: file-contents-sorter + files: | + CONTRIBUTORS.txt| + docs/spelling_wordlist.txt| + .gitignore| + .gitattributes + - id: check-case-conflict + - id: check-json + - id: check-xml + - id: check-executables-have-shebangs + - id: check-toml + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: check-added-large-files + - id: check-symlinks + - id: debug-statements + - id: detect-aws-credentials + args: ['--allow-missing-credentials'] + - id: detect-private-key + exclude: ^examples/ +- repo: https://github.com/asottile/pyupgrade + rev: 'v2.7.3' + hooks: + - id: pyupgrade + args: ['--py36-plus'] +- repo: https://gitlab.com/pycqa/flake8 + rev: '3.8.4' + hooks: + - id: flake8 + exclude: "^docs/" + +- repo: git://github.com/Lucas-C/pre-commit-hooks-markup + rev: v1.0.0 + hooks: + - id: rst-linter + files: >- + ^[^/]+[.]rst$ diff --git a/.pyup.yml b/.pyup.yml new file mode 100644 index 00000000000..75f971166aa --- /dev/null +++ b/.pyup.yml @@ -0,0 +1,4 @@ +# Label PRs with `deps-update` label +label_prs: deps-update + +schedule: every week diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 00000000000..e2e8d918392 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,5 @@ +build: + image: latest +python: + version: 3.6 + pip_install: false diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index a0f29174742..00000000000 --- a/.travis.yml +++ /dev/null @@ -1,68 +0,0 @@ -sudo: required -services: - - docker - -language: python - -python: - # python3.4.2 has bug in http.cookies module, aiohttp provides fix - - 3.4.2 - - 3.4.3 - - 3.5.2 - # - 3.5.3 - - 3.5 - - 3.6 - # - 'nightly' - -os: - - linux -# - osx # doesn't work on MacOSX -- the system has no Python installed - -matrix: - allow_failures: - - python: '3.6-dev' - - python: 'nightly' - - os: osx - -cache: - directories: - - $HOME/.cache/pip - -before_cache: - - rm -f $HOME/.cache/pip/log/debug.log - -install: - - sudo apt-get install enchant - - pip install --upgrade pip wheel - - pip install --upgrade setuptools - - pip install --upgrade setuptools-git - - pip install -r requirements-ci.txt - - pip install aiodns - - pip install codecov - - if python -c "import sys; sys.exit(sys.version_info < (3,5))"; then - pip install uvloop; - fi - - pip install sphinxcontrib-spelling - -script: - - make cov-dev-full - - - if python -c "import sys; sys.exit(sys.version_info < (3,5))"; then - make doc; - make doc-spelling; - fi - -after_success: - - codecov - - ./run_docker.sh - -deploy: - provider: pypi - user: andrew.svetlov - password: - secure: ZQKbdPT9BlNqP5CTbWRQyeyig7Bpf7wsnYVQIQPOZc9Ec74A+dsbagstR1sPkAO+d+5PN0pZMovvmU7OQhSVPAnJ74nsN90/fL4ux3kqYecMbevv0rJg20hMXSSkwMEIpjUsMdMjJvZAcaKytGWmKL0qAlOJHhixd1pBbWyuIUE= - distributions: "sdist" - on: - tags: true - all_branches: true - python: 3.5 diff --git a/CHANGES.rst b/CHANGES.rst index 2b33e0937de..e0f2b6da270 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,155 +1,714 @@ -Changes -======= +========= +Changelog +========= -2.0.7 (2017-04-12) ------------------- +.. + You should *NOT* be adding new change log entries to this file, this + file is managed by towncrier. You *may* edit previous change logs to + fix problems like typo corrections or such. + To add a new change log entry, please see + https://pip.pypa.io/en/latest/development/#adding-a-news-entry + we named the news folder "changes". -- Fix *pypi* distribution + WARNING: Don't drop the next directive! -- Fix exception description #1807 +.. towncrier release notes start -- Handle socket error in FileResponse #1773 +3.7.3 (2021-02-25) +================== -- Cancel websocket heartbeat on close #1793 +Bugfixes +-------- +- **(SECURITY BUG)** Started preventing open redirects in the + ``aiohttp.web.normalize_path_middleware`` middleware. For + more details, see + https://github.com/aio-libs/aiohttp/security/advisories/GHSA-v6wp-4m6f-gcjg. -2.0.6 (2017-04-06) ------------------- + Thanks to `Beast Glatisant `__ for + finding the first instance of this issue and `Jelmer Vernooij + `__ for reporting and tracking it down + in aiohttp. + `#5497 `_ +- Fix interpretation difference of the pure-Python and the Cython-based + HTTP parsers construct a ``yarl.URL`` object for HTTP request-target. -- Fix ``web.run_app`` not to bind to default host-port pair if only socket is - passed #1786 + Before this fix, the Python parser would turn the URI's absolute-path + for ``//some-path`` into ``/`` while the Cython code preserved it as + ``//some-path``. Now, both do the latter. + `#5498 `_ -- Keeping blank values for `request.post()` and `multipart.form()` #1765 -- TypeError in ResponseHandler.data_received #1770 +---- -2.0.5 (2017-03-29) ------------------- +3.7.3 (2020-11-18) +================== -- Memory leak with aiohttp.request #1756 +Features +-------- -- Disable cleanup closed ssl transports by default. +- Use Brotli instead of brotlipy + `#3803 `_ +- Made exceptions pickleable. Also changed the repr of some exceptions. + `#4077 `_ -- Exception in request handling if the server responds before the body is sent #1761 +Bugfixes +-------- -2.0.4 (2017-03-27) ------------------- +- Raise a ClientResponseError instead of an AssertionError for a blank + HTTP Reason Phrase. + `#3532 `_ +- Fix ``web_middlewares.normalize_path_middleware`` behavior for patch without slash. + `#3669 `_ +- Fix overshadowing of overlapped sub-applications prefixes. + `#3701 `_ +- Make `BaseConnector.close()` a coroutine and wait until the client closes all connections. Drop deprecated "with Connector():" syntax. + `#3736 `_ +- Reset the ``sock_read`` timeout each time data is received for a ``aiohttp.client`` response. + `#3808 `_ +- Fixed type annotation for add_view method of UrlDispatcher to accept any subclass of View + `#3880 `_ +- Fixed querying the address families from DNS that the current host supports. + `#5156 `_ +- Change return type of MultipartReader.__aiter__() and BodyPartReader.__aiter__() to AsyncIterator. + `#5163 `_ +- Provide x86 Windows wheels. + `#5230 `_ + + +Improved Documentation +---------------------- + +- Add documentation for ``aiohttp.web.FileResponse``. + `#3958 `_ +- Removed deprecation warning in tracing example docs + `#3964 `_ +- Fixed wrong "Usage" docstring of ``aiohttp.client.request``. + `#4603 `_ +- Add aiohttp-pydantic to third party libraries + `#5228 `_ + + +Misc +---- + +- `#4102 `_ + + +---- + + +3.7.2 (2020-10-27) +================== + +Bugfixes +-------- + +- Fixed static files handling for loops without ``.sendfile()`` support + `#5149 `_ + + +---- + + +3.7.1 (2020-10-25) +================== + +Bugfixes +-------- + +- Fixed a type error caused by the conditional import of `Protocol`. + `#5111 `_ +- Server doesn't send Content-Length for 1xx or 204 + `#4901 `_ +- Fix run_app typing + `#4957 `_ +- Always require ``typing_extensions`` library. + `#5107 `_ +- Fix a variable-shadowing bug causing `ThreadedResolver.resolve` to + return the resolved IP as the ``hostname`` in each record, which prevented + validation of HTTPS connections. + `#5110 `_ +- Added annotations to all public attributes. + `#5115 `_ +- Fix flaky test_when_timeout_smaller_second + `#5116 `_ +- Ensure sending a zero byte file does not throw an exception + `#5124 `_ +- Fix a bug in ``web.run_app()`` about Python version checking on Windows + `#5127 `_ + + +---- + + +3.7.0 (2020-10-24) +================== + +Features +-------- + +- Response headers are now prepared prior to running ``on_response_prepare`` hooks, directly before headers are sent to the client. + `#1958 `_ +- Add a ``quote_cookie`` option to ``CookieJar``, a way to skip quotation wrapping of cookies containing special characters. + `#2571 `_ +- Call ``AccessLogger.log`` with the current exception available from ``sys.exc_info()``. + `#3557 `_ +- `web.UrlDispatcher.add_routes` and `web.Application.add_routes` return a list + of registered `AbstractRoute` instances. `AbstractRouteDef.register` (and all + subclasses) return a list of registered resources registered resource. + `#3866 `_ +- Added properties of default ClientSession params to ClientSession class so it is available for introspection + `#3882 `_ +- Don't cancel web handler on peer disconnection, raise `OSError` on reading/writing instead. + `#4080 `_ +- Implement BaseRequest.get_extra_info() to access a protocol transports' extra info. + `#4189 `_ +- Added `ClientSession.timeout` property. + `#4191 `_ +- allow use of SameSite in cookies. + `#4224 `_ +- Use ``loop.sendfile()`` instead of custom implementation if available. + `#4269 `_ +- Apply SO_REUSEADDR to test server's socket. + `#4393 `_ +- Use .raw_host instead of slower .host in client API + `#4402 `_ +- Allow configuring the buffer size of input stream by passing ``read_bufsize`` argument. + `#4453 `_ +- Pass tests on Python 3.8 for Windows. + `#4513 `_ +- Add `method` and `url` attributes to `TraceRequestChunkSentParams` and `TraceResponseChunkReceivedParams`. + `#4674 `_ +- Add ClientResponse.ok property for checking status code under 400. + `#4711 `_ +- Don't ceil timeouts that are smaller than 5 seconds. + `#4850 `_ +- TCPSite now listens by default on all interfaces instead of just IPv4 when `None` is passed in as the host. + `#4894 `_ +- Bump ``http_parser`` to 2.9.4 + `#5070 `_ + + +Bugfixes +-------- + +- Fix keepalive connections not being closed in time + `#3296 `_ +- Fix failed websocket handshake leaving connection hanging. + `#3380 `_ +- Fix tasks cancellation order on exit. The run_app task needs to be cancelled first for cleanup hooks to run with all tasks intact. + `#3805 `_ +- Don't start heartbeat until _writer is set + `#4062 `_ +- Fix handling of multipart file uploads without a content type. + `#4089 `_ +- Preserve view handler function attributes across middlewares + `#4174 `_ +- Fix the string representation of ``ServerDisconnectedError``. + `#4175 `_ +- Raising RuntimeError when trying to get encoding from not read body + `#4214 `_ +- Remove warning messages from noop. + `#4282 `_ +- Raise ClientPayloadError if FormData re-processed. + `#4345 `_ +- Fix a warning about unfinished task in ``web_protocol.py`` + `#4408 `_ +- Fixed 'deflate' compression. According to RFC 2616 now. + `#4506 `_ +- Fixed OverflowError on platforms with 32-bit time_t + `#4515 `_ +- Fixed request.body_exists returns wrong value for methods without body. + `#4528 `_ +- Fix connecting to link-local IPv6 addresses. + `#4554 `_ +- Fix a problem with connection waiters that are never awaited. + `#4562 `_ +- Always make sure transport is not closing before reuse a connection. + + Reuse a protocol based on keepalive in headers is unreliable. + For example, uWSGI will not support keepalive even it serves a + HTTP 1.1 request, except explicitly configure uWSGI with a + ``--http-keepalive`` option. + + Servers designed like uWSGI could cause aiohttp intermittently + raise a ConnectionResetException when the protocol poll runs + out and some protocol is reused. + `#4587 `_ +- Handle the last CRLF correctly even if it is received via separate TCP segment. + `#4630 `_ +- Fix the register_resource function to validate route name before splitting it so that route name can include python keywords. + `#4691 `_ +- Improve typing annotations for ``web.Request``, ``aiohttp.ClientResponse`` and + ``multipart`` module. + `#4736 `_ +- Fix resolver task is not awaited when connector is cancelled + `#4795 `_ +- Fix a bug "Aiohttp doesn't return any error on invalid request methods" + `#4798 `_ +- Fix HEAD requests for static content. + `#4809 `_ +- Fix incorrect size calculation for memoryview + `#4890 `_ +- Add HTTPMove to _all__. + `#4897 `_ +- Fixed the type annotations in the ``tracing`` module. + `#4912 `_ +- Fix typing for multipart ``__aiter__``. + `#4931 `_ +- Fix for race condition on connections in BaseConnector that leads to exceeding the connection limit. + `#4936 `_ +- Add forced UTF-8 encoding for ``application/rdap+json`` responses. + `#4938 `_ +- Fix inconsistency between Python and C http request parsers in parsing pct-encoded URL. + `#4972 `_ +- Fix connection closing issue in HEAD request. + `#5012 `_ +- Fix type hint on BaseRunner.addresses (from ``List[str]`` to ``List[Any]``) + `#5086 `_ +- Make `web.run_app()` more responsive to Ctrl+C on Windows for Python < 3.8. It slightly + increases CPU load as a side effect. + `#5098 `_ + + +Improved Documentation +---------------------- + +- Fix example code in client quick-start + `#3376 `_ +- Updated the docs so there is no contradiction in ``ttl_dns_cache`` default value + `#3512 `_ +- Add 'Deploy with SSL' to docs. + `#4201 `_ +- Change typing of the secure argument on StreamResponse.set_cookie from ``Optional[str]`` to ``Optional[bool]`` + `#4204 `_ +- Changes ``ttl_dns_cache`` type from int to Optional[int]. + `#4270 `_ +- Simplify README hello word example and add a documentation page for people coming from requests. + `#4272 `_ +- Improve some code examples in the documentation involving websockets and starting a simple HTTP site with an AppRunner. + `#4285 `_ +- Fix typo in code example in Multipart docs + `#4312 `_ +- Fix code example in Multipart section. + `#4314 `_ +- Update contributing guide so new contributors read the most recent version of that guide. Update command used to create test coverage reporting. + `#4810 `_ +- Spelling: Change "canonize" to "canonicalize". + `#4986 `_ +- Add ``aiohttp-sse-client`` library to third party usage list. + `#5084 `_ + + +Misc +---- + +- `#2856 `_, `#4218 `_, `#4250 `_ + + +---- + + +3.6.3 (2020-10-12) +================== + +Bugfixes +-------- + +- Pin yarl to ``<1.6.0`` to avoid buggy behavior that will be fixed by the next aiohttp + release. + +3.6.2 (2019-10-09) +================== + +Features +-------- + +- Made exceptions pickleable. Also changed the repr of some exceptions. + `#4077 `_ +- Use ``Iterable`` type hint instead of ``Sequence`` for ``Application`` *middleware* + parameter. `#4125 `_ + + +Bugfixes +-------- + +- Reset the ``sock_read`` timeout each time data is received for a + ``aiohttp.ClientResponse``. `#3808 + `_ +- Fix handling of expired cookies so they are not stored in CookieJar. + `#4063 `_ +- Fix misleading message in the string representation of ``ClientConnectorError``; + ``self.ssl == None`` means default SSL context, not SSL disabled `#4097 + `_ +- Don't clobber HTTP status when using FileResponse. + `#4106 `_ + + +Improved Documentation +---------------------- + +- Added minimal required logging configuration to logging documentation. + `#2469 `_ +- Update docs to reflect proxy support. + `#4100 `_ +- Fix typo in code example in testing docs. + `#4108 `_ + + +Misc +---- + +- `#4102 `_ + + +---- + + +3.6.1 (2019-09-19) +================== + +Features +-------- + +- Compatibility with Python 3.8. + `#4056 `_ + + +Bugfixes +-------- + +- correct some exception string format + `#4068 `_ +- Emit a warning when ``ssl.OP_NO_COMPRESSION`` is + unavailable because the runtime is built against + an outdated OpenSSL. + `#4052 `_ +- Update multidict requirement to >= 4.5 + `#4057 `_ + + +Improved Documentation +---------------------- + +- Provide pytest-aiohttp namespace for pytest fixtures in docs. + `#3723 `_ + + +---- -- Memory leak with aiohttp.request #1756 -- Encoding is always UTF-8 in POST data #1750 - -- Do not add "Content-Disposition" header by default #1755 - - -2.0.3 (2017-03-24) ------------------- - -- Call https website through proxy will cause error #1745 - -- Fix exception on multipart/form-data post if content-type is not set #1743 - - -2.0.2 (2017-03-21) ------------------- - -- Fixed Application.on_loop_available signal #1739 - -- Remove debug code - - -2.0.1 (2017-03-21) ------------------- - -- Fix allow-head to include name on route #1737 - -- Fixed AttributeError in WebSocketResponse.can_prepare #1736 - - -2.0.0 (2017-03-20) ------------------- - -- Added `json` to `ClientSession.request()` method #1726 - -- Added session's `raise_for_status` parameter, automatically calls raise_for_status() on any request. #1724 - -- `response.json()` raises `ClientReponseError` exception if response's - content type does not match #1723 - -- Cleanup timer and loop handle on any client exception. - -- Deprecate `loop` parameter for Application's constructor - - -`2.0.0rc1` (2017-03-15) ------------------------ - -- Properly handle payload errors #1710 - -- Added `ClientWebSocketResponse.get_extra_info()` #1717 - -- It is not possible to combine Transfer-Encoding and chunked parameter, - same for compress and Content-Encoding #1655 - -- Connector's `limit` parameter indicates total concurrent connections. - New `limit_per_host` added, indicates total connections per endpoint. #1601 - -- Use url's `raw_host` for name resolution #1685 - -- Change `ClientResponse.url` to `yarl.URL` instance #1654 - -- Add max_size parameter to web.Request reading methods #1133 - -- Web Request.post() stores data in temp files #1469 - -- Add the `allow_head=True` keyword argument for `add_get` #1618 - -- `run_app` and the Command Line Interface now support serving over - Unix domain sockets for faster inter-process communication. - -- `run_app` now supports passing a preexisting socket object. This can be useful - e.g. for socket-based activated applications, when binding of a socket is - done by the parent process. - -- Implementation for Trailer headers parser is broken #1619 - -- Fix FileResponse to not fall on bad request (range out of file size) - -- Fix FileResponse to correct stream video to Chromes - -- Deprecate public low-level api #1657 - -- Deprecate `encoding` parameter for ClientSession.request() method - -- Dropped aiohttp.wsgi #1108 - -- Dropped `version` from ClientSession.request() method - -- Dropped websocket version 76 support #1160 - -- Dropped: `aiohttp.protocol.HttpPrefixParser` #1590 - -- Dropped: Servers response's `.started`, `.start()` and `.can_start()` method #1591 - -- Dropped: Adding `sub app` via `app.router.add_subapp()` is deprecated - use `app.add_subapp()` instead #1592 - -- Dropped: `Application.finish()` and `Application.register_on_finish()` #1602 - -- Dropped: `web.Request.GET` and `web.Request.POST` - -- Dropped: aiohttp.get(), aiohttp.options(), aiohttp.head(), - aiohttp.post(), aiohttp.put(), aiohttp.patch(), aiohttp.delete(), and - aiohttp.ws_connect() #1593 - -- Dropped: `aiohttp.web.WebSocketResponse.receive_msg()` #1605 - -- Dropped: `ServerHttpProtocol.keep_alive_timeout` attribute and - `keep-alive`, `keep_alive_on`, `timeout`, `log` constructor parameters #1606 - -- Dropped: `TCPConnector's`` `.resolve`, `.resolved_hosts`, `.clear_resolved_hosts()` - attributes and `resolve` constructor parameter #1607 - -- Dropped `ProxyConnector` #1609 +3.6.0 (2019-09-06) +================== + +Features +-------- + +- Add support for Named Pipes (Site and Connector) under Windows. This feature requires + Proactor event loop to work. `#3629 + `_ +- Removed ``Transfer-Encoding: chunked`` header from websocket responses to be + compatible with more http proxy servers. `#3798 + `_ +- Accept non-GET request for starting websocket handshake on server side. + `#3980 `_ + + +Bugfixes +-------- + +- Raise a ClientResponseError instead of an AssertionError for a blank + HTTP Reason Phrase. + `#3532 `_ +- Fix an issue where cookies would sometimes not be set during a redirect. + `#3576 `_ +- Change normalize_path_middleware to use 308 redirect instead of 301. + + This behavior should prevent clients from being unable to use PUT/POST + methods on endpoints that are redirected because of a trailing slash. + `#3579 `_ +- Drop the processed task from ``all_tasks()`` list early. It prevents logging about a + task with unhandled exception when the server is used in conjunction with + ``asyncio.run()``. `#3587 `_ +- ``Signal`` type annotation changed from ``Signal[Callable[['TraceConfig'], + Awaitable[None]]]`` to ``Signal[Callable[ClientSession, SimpleNamespace, ...]``. + `#3595 `_ +- Use sanitized URL as Location header in redirects + `#3614 `_ +- Improve typing annotations for multipart.py along with changes required + by mypy in files that references multipart.py. + `#3621 `_ +- Close session created inside ``aiohttp.request`` when unhandled exception occurs + `#3628 `_ +- Cleanup per-chunk data in generic data read. Memory leak fixed. + `#3631 `_ +- Use correct type for add_view and family + `#3633 `_ +- Fix _keepalive field in __slots__ of ``RequestHandler``. + `#3644 `_ +- Properly handle ConnectionResetError, to silence the "Cannot write to closing + transport" exception when clients disconnect uncleanly. + `#3648 `_ +- Suppress pytest warnings due to ``test_utils`` classes + `#3660 `_ +- Fix overshadowing of overlapped sub-application prefixes. + `#3701 `_ +- Fixed return type annotation for WSMessage.json() + `#3720 `_ +- Properly expose TooManyRedirects publicly as documented. + `#3818 `_ +- Fix missing brackets for IPv6 in proxy CONNECT request + `#3841 `_ +- Make the signature of ``aiohttp.test_utils.TestClient.request`` match + ``asyncio.ClientSession.request`` according to the docs `#3852 + `_ +- Use correct style for re-exported imports, makes mypy ``--strict`` mode happy. + `#3868 `_ +- Fixed type annotation for add_view method of UrlDispatcher to accept any subclass of + View `#3880 `_ +- Made cython HTTP parser set Reason-Phrase of the response to an empty string if it is + missing. `#3906 `_ +- Add URL to the string representation of ClientResponseError. + `#3959 `_ +- Accept ``istr`` keys in ``LooseHeaders`` type hints. + `#3976 `_ +- Fixed race conditions in _resolve_host caching and throttling when tracing is enabled. + `#4013 `_ +- For URLs like "unix://localhost/..." set Host HTTP header to "localhost" instead of + "localhost:None". `#4039 `_ + + +Improved Documentation +---------------------- + +- Modify documentation for Background Tasks to remove deprecated usage of event loop. + `#3526 `_ +- use ``if __name__ == '__main__':`` in server examples. + `#3775 `_ +- Update documentation reference to the default access logger. + `#3783 `_ +- Improve documentation for ``web.BaseRequest.path`` and ``web.BaseRequest.raw_path``. + `#3791 `_ +- Removed deprecation warning in tracing example docs + `#3964 `_ + + +---- + + +3.5.4 (2019-01-12) +================== + +Bugfixes +-------- + +- Fix stream ``.read()`` / ``.readany()`` / ``.iter_any()`` which used to return a + partial content only in case of compressed content + `#3525 `_ + + +3.5.3 (2019-01-10) +================== + +Bugfixes +-------- + +- Fix type stubs for ``aiohttp.web.run_app(access_log=True)`` and fix edge case of + ``access_log=True`` and the event loop being in debug mode. `#3504 + `_ +- Fix ``aiohttp.ClientTimeout`` type annotations to accept ``None`` for fields + `#3511 `_ +- Send custom per-request cookies even if session jar is empty + `#3515 `_ +- Restore Linux binary wheels publishing on PyPI + +---- + + +3.5.2 (2019-01-08) +================== + +Features +-------- + +- ``FileResponse`` from ``web_fileresponse.py`` uses a ``ThreadPoolExecutor`` to work + with files asynchronously. I/O based payloads from ``payload.py`` uses a + ``ThreadPoolExecutor`` to work with I/O objects asynchronously. `#3313 + `_ +- Internal Server Errors in plain text if the browser does not support HTML. + `#3483 `_ + + +Bugfixes +-------- + +- Preserve MultipartWriter parts headers on write. Refactor the way how + ``Payload.headers`` are handled. Payload instances now always have headers and + Content-Type defined. Fix Payload Content-Disposition header reset after initial + creation. `#3035 `_ +- Log suppressed exceptions in ``GunicornWebWorker``. + `#3464 `_ +- Remove wildcard imports. + `#3468 `_ +- Use the same task for app initialization and web server handling in gunicorn workers. + It allows to use Python3.7 context vars smoothly. + `#3471 `_ +- Fix handling of chunked+gzipped response when first chunk does not give uncompressed + data `#3477 `_ +- Replace ``collections.MutableMapping`` with ``collections.abc.MutableMapping`` to + avoid a deprecation warning. `#3480 + `_ +- ``Payload.size`` type annotation changed from ``Optional[float]`` to + ``Optional[int]``. `#3484 `_ +- Ignore done tasks when cancels pending activities on ``web.run_app`` finalization. + `#3497 `_ + + +Improved Documentation +---------------------- + +- Add documentation for ``aiohttp.web.HTTPException``. + `#3490 `_ + + +Misc +---- + +- `#3487 `_ + + +---- + + +3.5.1 (2018-12-24) +==================== + +- Fix a regression about ``ClientSession._requote_redirect_url`` modification in debug + mode. + +3.5.0 (2018-12-22) +==================== + +Features +-------- + +- The library type annotations are checked in strict mode now. +- Add support for setting cookies for individual request (`#2387 + `_) +- Application.add_domain implementation (`#2809 + `_) +- The default ``app`` in the request returned by ``test_utils.make_mocked_request`` can + now have objects assigned to it and retrieved using the ``[]`` operator. (`#3174 + `_) +- Make ``request.url`` accessible when transport is closed. (`#3177 + `_) +- Add ``zlib_executor_size`` argument to ``Response`` constructor to allow compression + to run in a background executor to avoid blocking the main thread and potentially + triggering health check failures. (`#3205 + `_) +- Enable users to set ``ClientTimeout`` in ``aiohttp.request`` (`#3213 + `_) +- Don't raise a warning if ``NETRC`` environment variable is not set and ``~/.netrc`` + file doesn't exist. (`#3267 `_) +- Add default logging handler to web.run_app If the ``Application.debug``` flag is set + and the default logger ``aiohttp.access`` is used, access logs will now be output + using a *stderr* ``StreamHandler`` if no handlers are attached. Furthermore, if the + default logger has no log level set, the log level will be set to ``DEBUG``. (`#3324 + `_) +- Add method argument to ``session.ws_connect()``. Sometimes server API requires a + different HTTP method for WebSocket connection establishment. For example, ``Docker + exec`` needs POST. (`#3378 `_) +- Create a task per request handling. (`#3406 + `_) + + +Bugfixes +-------- + +- Enable passing ``access_log_class`` via ``handler_args`` (`#3158 + `_) +- Return empty bytes with end-of-chunk marker in empty stream reader. (`#3186 + `_) +- Accept ``CIMultiDictProxy`` instances for ``headers`` argument in ``web.Response`` + constructor. (`#3207 `_) +- Don't uppercase HTTP method in parser (`#3233 + `_) +- Make method match regexp RFC-7230 compliant (`#3235 + `_) +- Add ``app.pre_frozen`` state to properly handle startup signals in + sub-applications. (`#3237 `_) +- Enhanced parsing and validation of helpers.BasicAuth.decode. (`#3239 + `_) +- Change imports from collections module in preparation for 3.8. (`#3258 + `_) +- Ensure Host header is added first to ClientRequest to better replicate browser (`#3265 + `_) +- Fix forward compatibility with Python 3.8: importing ABCs directly from the + collections module will not be supported anymore. (`#3273 + `_) +- Keep the query string by ``normalize_path_middleware``. (`#3278 + `_) +- Fix missing parameter ``raise_for_status`` for aiohttp.request() (`#3290 + `_) +- Bracket IPv6 addresses in the HOST header (`#3304 + `_) +- Fix default message for server ping and pong frames. (`#3308 + `_) +- Fix tests/test_connector.py typo and tests/autobahn/server.py duplicate loop + def. (`#3337 `_) +- Fix false-negative indicator end_of_HTTP_chunk in StreamReader.readchunk function + (`#3361 `_) +- Release HTTP response before raising status exception (`#3364 + `_) +- Fix task cancellation when ``sendfile()`` syscall is used by static file + handling. (`#3383 `_) +- Fix stack trace for ``asyncio.TimeoutError`` which was not logged, when it is caught + in the handler. (`#3414 `_) + + +Improved Documentation +---------------------- + +- Improve documentation of ``Application.make_handler`` parameters. (`#3152 + `_) +- Fix BaseRequest.raw_headers doc. (`#3215 + `_) +- Fix typo in TypeError exception reason in ``web.Application._handle`` (`#3229 + `_) +- Make server access log format placeholder %b documentation reflect + behavior and docstring. (`#3307 `_) + + +Deprecations and Removals +------------------------- + +- Deprecate modification of ``session.requote_redirect_url`` (`#2278 + `_) +- Deprecate ``stream.unread_data()`` (`#3260 + `_) +- Deprecated use of boolean in ``resp.enable_compression()`` (`#3318 + `_) +- Encourage creation of aiohttp public objects inside a coroutine (`#3331 + `_) +- Drop dead ``Connection.detach()`` and ``Connection.writer``. Both methods were broken + for more than 2 years. (`#3358 `_) +- Deprecate ``app.loop``, ``request.loop``, ``client.loop`` and ``connector.loop`` + properties. (`#3374 `_) +- Deprecate explicit debug argument. Use asyncio debug mode instead. (`#3381 + `_) +- Deprecate body parameter in HTTPException (and derived classes) constructor. (`#3385 + `_) +- Deprecate bare connector close, use ``async with connector:`` and ``await + connector.close()`` instead. (`#3417 + `_) +- Deprecate obsolete ``read_timeout`` and ``conn_timeout`` in ``ClientSession`` + constructor. (`#3438 `_) + + +Misc +---- + +- #3341, #3351 diff --git a/CHANGES/.TEMPLATE.rst b/CHANGES/.TEMPLATE.rst new file mode 100644 index 00000000000..bc6016baf5c --- /dev/null +++ b/CHANGES/.TEMPLATE.rst @@ -0,0 +1,36 @@ +{# TOWNCRIER TEMPLATE #} +{% for section, _ in sections.items() %} +{% set underline = underlines[0] %}{% if section %}{{section}} +{{ underline * section|length }}{% set underline = underlines[1] %} + +{% endif %} + +{% if sections[section] %} +{% for category, val in definitions.items() if category in sections[section]%} +{{ definitions[category]['name'] }} +{{ underline * definitions[category]['name']|length }} + +{% if definitions[category]['showcontent'] %} +{% for text, values in sections[section][category].items() %} +- {{ text }} + {{ values|join(',\n ') }} +{% endfor %} + +{% else %} +- {{ sections[section][category]['']|join(', ') }} + +{% endif %} +{% if sections[section][category]|length == 0 %} +No significant changes. + +{% else %} +{% endif %} + +{% endfor %} +{% else %} +No significant changes. + + +{% endif %} +{% endfor %} +---- diff --git a/CHANGES/.gitignore b/CHANGES/.gitignore new file mode 100644 index 00000000000..f935021a8f8 --- /dev/null +++ b/CHANGES/.gitignore @@ -0,0 +1 @@ +!.gitignore diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000000..70e1010d426 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,46 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at andrew.svetlov@gmail.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 137b6649091..cb5ce3431b5 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -12,158 +12,25 @@ I hope everybody knows how to work with git and github nowadays :) Workflow is pretty straightforward: - 1. Clone the GitHub_ repo + 1. Clone the GitHub_ repo using the ``--recurse-submodules`` argument - 2. Make a change + 2. Setup your machine with the required dev environment - 3. Make sure all tests passed + 3. Make a change - 4. Commit changes to own aiohttp clone + 4. Make sure all tests passed - 5. Make pull request from github page for your clone against master branch + 5. Add a file into the ``CHANGES`` folder, named after the ticket or PR number - .. note:: - If your PR has long history or many commits - please rebase it from main repo before creating PR. + 6. Commit changes to your own aiohttp clone -Preconditions for running aiohttp test suite --------------------------------------------- + 7. Make a pull request from the github page of your clone against the master branch -We expect you to use a python virtual environment to run our tests. + 8. Optionally make backport Pull Request(s) for landing a bug fix into released aiohttp versions. -There are several ways to make a virtual environment. - -If you like to use *virtualenv* please run: - -.. code-block:: shell - - $ cd aiohttp - $ virtualenv --python=`which python3` venv - $ . venv/bin/activate - -For standard python *venv*: - -.. code-block:: shell - - $ cd aiohttp - $ python3 -m venv venv - $ . venv/bin/activate - -For *virtualenvwrapper* (my choice): - -.. code-block:: shell - - $ cd aiohttp - $ mkvirtualenv --python=`which python3` aiohttp - -There are other tools like *pyvenv* but you know the rule of thumb -now: create a python3 virtual environment and activate it. - -After that please install libraries required for development: - -.. code-block:: shell - - $ pip install -r requirements-dev.txt - -We also recommend to install ipdb_ but it's on your own: -: -.. code-block:: shell - - $ pip install ipdb - -.. note:: - If you plan to use ``ipdb`` within the test suite, execute: - -.. code-block:: shell - - $ py.test tests -s -p no:timeout - - command to run the tests with disabled timeout guard and output - capturing. - -Congratulations, you are ready to run the test suite - - -Run aiohttp test suite ----------------------- - -After all the preconditions are met you can run tests typing the next -command: - -.. code-block:: shell - - $ make test - -The command at first will run the *flake8* tool (sorry, we don't accept -pull requests with pep8 or pyflakes errors). - -On *flake8* success the tests will be run. - -Please take a look on the produced output. - -Any extra texts (print statements and so on) should be removed. - - -Tests coverage --------------- - -We are trying hard to have good test coverage; please don't make it worse. - -Use: - -.. code-block:: shell - - $ make cov - -to run test suite and collect coverage information. Once the command -has finished check your coverage at the file that appears in the last -line of the output: -``open file:///.../aiohttp/coverage/index.html`` - -Please go to the link and make sure that your code change is covered. - - -Documentation -------------- - -We encourage documentation improvements. - -Please before making a Pull Request about documentation changes run: - -.. code-block:: shell - - $ make doc - -Once it finishes it will output the index html page -``open file:///.../aiohttp/docs/_build/html/index.html``. - -Go to the link and make sure your doc changes looks good. - -Spell checking --------------- - -We use ``pyenchant`` and ``sphinxcontrib-spelling`` for running spell -checker for documentation: - -.. code-block:: shell - - $ make doc-spelling - -Unfortunately there are problems with running spell checker on MacOS X. - -To run spell checker on Linux box you should install it first: - -.. code-block:: shell - - $ sudo apt-get install enchant - $ pip install sphinxcontrib-spelling - -The End -------- - -After finishing all steps make a GitHub_ Pull Request, thanks. +.. important:: + Please open the "`contributing `_" + documentation page to get detailed informations about all steps. .. _GitHub: https://github.com/aio-libs/aiohttp - -.. _ipdb: https://pypi.python.org/pypi/ipdb diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 773c8ebb9ca..ad63ce9e4de 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -1,12 +1,21 @@ -Contributors ------------- - +- Contributors - +---------------- A. Jesse Jiryu Davis +Adam Bannister +Adam Cooper Adam Mills +Adrian Krupa +Adrián Chaves +Alan Tse +Alec Hanefeld Alejandro Gómez Aleksandr Danshyn Aleksey Kutepov Alex Hayes +Alex Key +Alex Khomchenko +Alex Kuzmenko +Alex Lisovoy Alexander Bayandin Alexander Karpinsky Alexander Koshevoy @@ -17,153 +26,287 @@ Alexander Travov Alexandru Mihai Alexey Firsov Alexey Popravka -Alex Key -Alex Khomchenko -Alex Lisovoy +Alexey Stepanov +Amin Etesamian +Amit Tulshyan Amy Boyle +Anders Melchiorsen Andrei Ursulenko Andrej Antonov Andrew Leech +Andrew Lytvyn Andrew Svetlov +Andrew Zhou Andrii Soldatenko +Antoine Pietri Anton Kasyanov +Anton Zhdan-Pushkin +Arseny Timoniq +Artem Yushkovskiy Arthur Darcet Ben Bader +Ben Timby Benedikt Reinartz Boris Feld +Boyi Chen Brett Cannon Brian C. Lane Brian Muller +Bruce Merry +Bryan Kok +Bryce Drennan Carl George +Cecile Tonglet Chien-Wei Huang Chih-Yuan Chen Chris AtLee Chris Laws Chris Moore Christopher Schmitt +Claudiu Popa +Colin Dunklau +Cong Xu +Damien Nadé +Dan Xu Daniel García +Daniel Grossmann-Kavanagh Daniel Nelson Danny Song +David Bibb David Michael Brown +Denilson Amorim Denis Matiychuk +Dennis Kliban Dima Veselov Dimitar Dimitrov +Dmitriy Safonov Dmitry Doroshev +Dmitry Erlikh +Dmitry Lukashin +Dmitry Marakasov Dmitry Shamov Dmitry Trofimov +Dmytro Bohomiakov Dmytro Kuznetsov Dustin J. Mitchell Eduard Iskandarov +Eli Ribble Elizabeth Leddy Enrique Saez +Eric Sheng Erich Healy +Erik Peterson Eugene Chernyshov Eugene Naydenov +Eugene Nikolaiev +Eugene Tolmachev +Evan Kepner +Evert Lammerts +Felix Yan +Fernanda Guimarães +FichteFoll +Florian Scheffler Frederik Gladhorn Frederik Peter Aalund Gabriel Tremblay +Gary Wilson Jr. Gennady Andreyev Georges Dubus Greg Holt Gregory Haynes +Gus Goulart +Gustavo Carneiro Günther Jena +Hans Adema +Harmon Y. +Hrishikesh Paranjape Hu Bo +Hugh Young Hugo Herter +Hynek Schlawack +Igor Alexandrov +Igor Davydenko +Igor Mozharovsky Igor Pavlov +Illia Volochii +Ilya Chichak +Ilya Gruzinov Ingmar Steen Jacob Champion Jaesung Lee Jake Davis +Jakob Ackermann Jakub Wilk Jashandeep Sohi +Jens Steinhauser +Jeonghun Lee Jeongkyu Shin Jeroen van der Heijden Jesus Cea +Jian Zeng Jinkyu Yi Joel Watts +Jon Nabozny +Jonas Krüger Svensson +Jonas Obrist +Jonathan Wright +Jonny Tan Joongi Kim +Josep Cugat +Josh Junon +Joshu Coats +Julia Tsemusheva Julien Duponchelle +Jungkook Park Junjie Tao +Junyeong Jeong Justas Trimailovas +Justin Foo Justin Turner Arthur Kay Zheng +Kevin Samuel Kimmo Parviainen-Jalanko Kirill Klenov Kirill Malovitsa +Konstantin Valetov +Krzysztof Blazewicz Kyrylo Perevozchikov +Kyungmin Lee Lars P. Søndergaard +Liu Hua Louis-Philippe Huberdeau +Loïc Lajeanne Lu Gong Lubomir Gelo Ludovic Gasc +Luis Pedrosa Lukasz Marcin Dobrzanski Makc Belousow Manuel Miranda +Marat Sharafutdinov Marco Paolini Mariano Anaya +Martijn Pieters Martin Melka Martin Richard Mathias Fröjdman +Mathieu Dugré Matthieu Hauglustaine +Matthieu Rigal Michael Ihnatenko +Michał Górny +Mikhail Burshteyn Mikhail Kashkin Mikhail Lukyanchenko +Mikhail Nacharov Misha Behersky +Mitchell Ferree Morgan Delahaye-Prat Moss Collum Mun Gwan-gyeong +Navid Sheikhol Nicolas Braem +Nikolay Kim Nikolay Novik +Oisin Aylward Olaf Conradi Pahaz Blinov +Panagiotis Kolokotronis Pankaj Pandey -Pawel Miech Pau Freixes Paul Colomiets +Paulius Šileikis Paulus Schoutsen +Pavel Kamaev +Pavel Polyakov +Pawel Kowalski +Pawel Miech +Pepe Osca Philipp A. +Pieter van Beek Rafael Viotti +Raphael Bialon Raúl Cumplido Required Field Robert Lu +Robert Nikolich Roman Podoliaka Samuel Colvin +Sean Hunt +Sebastian Acuna Sebastian Hanula Sebastian Hüther -Sean Hunt +Sebastien Geffroy SeongSoo Cho Sergey Ninua Sergey Skripnick +Serhii Charykov Serhii Kostel +Serhiy Storchaka Simon Kennedy Sin-Woo Bang Stanislas Plum Stanislav Prokop +Stefan Tjarks +Stepan Pletnev +Stephan Jaensch Stephen Granade Steven Seguin +Sunghyun Hwang +Sunit Deshpande Sviatoslav Bulbakha +Sviatoslav Sydorenko Taha Jahangir Taras Voinarovskyi Terence Honles +Thanos Lefteris +Thijs Vermeir +Thomas Forbes Thomas Grainger Tolga Tezel +Tomasz Trebski +Toshiaki Tanaka +Trinh Hoang Nhu +Vadim Suharnikov Vaibhav Sagar Vamsi Krishna Avula +Vasiliy Faronov Vasyl Baran +Viacheslav Greshilov +Victor Collod +Victor Kovtun Vikas Kawadia +Viktor Danyliuk +Ville Skyttä +Vincent Maillol Vitalik Verhovodov Vitaly Haritonsky Vitaly Magerya +Vladimir Kamarzin Vladimir Kozlovski Vladimir Rutsky Vladimir Shulyak Vladimir Zakharov +Vladyslav Bohaichuk +Vladyslav Bondar +W. Trevor King +Wei Lin +Weiwei Wang +Will McGugan Willem de Groot +William Grzybowski +William S. Wilson Ong -W. Trevor King +Yang Zhou Yannick Koechlin +Yannick Péroux +Ye Cao +Yegor Roganov +Yifei Kong Young-Ho Cha +Yuriy Shatrov Yury Selivanov Yusuke Tsutsumi +Zlatan Sičanica Марк Коренберг Семён Марьясин diff --git a/HISTORY.rst b/HISTORY.rst index 6833e368b60..44c1484f917 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -1,11 +1,1045 @@ +3.4.4 (2018-09-05) +================== + +- Fix installation from sources when compiling toolkit is not available (`#3241 `_) + +3.4.3 (2018-09-04) +================== + +- Add ``app.pre_frozen`` state to properly handle startup signals in sub-applications. (`#3237 `_) + + +3.4.2 (2018-09-01) +================== + +- Fix ``iter_chunks`` type annotation (`#3230 `_) + +3.4.1 (2018-08-28) +================== + +- Fix empty header parsing regression. (`#3218 `_) +- Fix BaseRequest.raw_headers doc. (`#3215 `_) +- Fix documentation building on ReadTheDocs (`#3221 `_) + + +3.4.0 (2018-08-25) +================== + +Features +-------- + +- Add type hints (`#3049 `_) +- Add ``raise_for_status`` request parameter (`#3073 `_) +- Add type hints to HTTP client (`#3092 `_) +- Minor server optimizations (`#3095 `_) +- Preserve the cause when `HTTPException` is raised from another exception. (`#3096 `_) +- Add `close_boundary` option in `MultipartWriter.write` method. Support streaming (`#3104 `_) +- Added a ``remove_slash`` option to the ``normalize_path_middleware`` factory. (`#3173 `_) +- The class `AbstractRouteDef` is importable from `aiohttp.web`. (`#3183 `_) + + +Bugfixes +-------- + +- Prevent double closing when client connection is released before the + last ``data_received()`` callback. (`#3031 `_) +- Make redirect with `normalize_path_middleware` work when using url encoded paths. (`#3051 `_) +- Postpone web task creation to connection establishment. (`#3052 `_) +- Fix ``sock_read`` timeout. (`#3053 `_) +- When using a server-request body as the `data=` argument of a client request, iterate over the content with `readany` instead of `readline` to avoid `Line too long` errors. (`#3054 `_) +- fix `UrlDispatcher` has no attribute `add_options`, add `web.options` (`#3062 `_) +- correct filename in content-disposition with multipart body (`#3064 `_) +- Many HTTP proxies has buggy keepalive support. + Let's not reuse connection but close it after processing every response. (`#3070 `_) +- raise 413 "Payload Too Large" rather than raising ValueError in request.post() + Add helpful debug message to 413 responses (`#3087 `_) +- Fix `StreamResponse` equality, now that they are `MutableMapping` objects. (`#3100 `_) +- Fix server request objects comparison (`#3116 `_) +- Do not hang on `206 Partial Content` response with `Content-Encoding: gzip` (`#3123 `_) +- Fix timeout precondition checkers (`#3145 `_) + + +Improved Documentation +---------------------- + +- Add a new FAQ entry that clarifies that you should not reuse response + objects in middleware functions. (`#3020 `_) +- Add FAQ section "Why is creating a ClientSession outside of an event loop dangerous?" (`#3072 `_) +- Fix link to Rambler (`#3115 `_) +- Fix TCPSite documentation on the Server Reference page. (`#3146 `_) +- Fix documentation build configuration file for Windows. (`#3147 `_) +- Remove no longer existing lingering_timeout parameter of Application.make_handler from documentation. (`#3151 `_) +- Mention that ``app.make_handler`` is deprecated, recommend to use runners + API instead. (`#3157 `_) + + +Deprecations and Removals +------------------------- + +- Drop ``loop.current_task()`` from ``helpers.current_task()`` (`#2826 `_) +- Drop ``reader`` parameter from ``request.multipart()``. (`#3090 `_) + + +3.3.2 (2018-06-12) +================== + +- Many HTTP proxies has buggy keepalive support. Let's not reuse connection but + close it after processing every response. (`#3070 `_) + +- Provide vendor source files in tarball (`#3076 `_) + + +3.3.1 (2018-06-05) +================== + +- Fix ``sock_read`` timeout. (`#3053 `_) +- When using a server-request body as the ``data=`` argument of a client request, + iterate over the content with ``readany`` instead of ``readline`` to avoid ``Line + too long`` errors. (`#3054 `_) + + +3.3.0 (2018-06-01) +================== + +Features +-------- + +- Raise ``ConnectionResetError`` instead of ``CancelledError`` on trying to + write to a closed stream. (`#2499 `_) +- Implement ``ClientTimeout`` class and support socket read timeout. (`#2768 `_) +- Enable logging when ``aiohttp.web`` is used as a program (`#2956 `_) +- Add canonical property to resources (`#2968 `_) +- Forbid reading response BODY after release (`#2983 `_) +- Implement base protocol class to avoid a dependency from internal + ``asyncio.streams.FlowControlMixin`` (`#2986 `_) +- Cythonize ``@helpers.reify``, 5% boost on macro benchmark (`#2995 `_) +- Optimize HTTP parser (`#3015 `_) +- Implement ``runner.addresses`` property. (`#3036 `_) +- Use ``bytearray`` instead of a list of ``bytes`` in websocket reader. It + improves websocket message reading a little. (`#3039 `_) +- Remove heartbeat on closing connection on keepalive timeout. The used hack + violates HTTP protocol. (`#3041 `_) +- Limit websocket message size on reading to 4 MB by default. (`#3045 `_) + + +Bugfixes +-------- + +- Don't reuse a connection with the same URL but different proxy/TLS settings + (`#2981 `_) +- When parsing the Forwarded header, the optional port number is now preserved. + (`#3009 `_) + + +Improved Documentation +---------------------- + +- Make Change Log more visible in docs (`#3029 `_) +- Make style and grammar improvements on the FAQ page. (`#3030 `_) +- Document that signal handlers should be async functions since aiohttp 3.0 + (`#3032 `_) + + +Deprecations and Removals +------------------------- + +- Deprecate custom application's router. (`#3021 `_) + + +Misc +---- + +- #3008, #3011 + + +3.2.1 (2018-05-10) +================== + +- Don't reuse a connection with the same URL but different proxy/TLS settings + (`#2981 `_) + + +3.2.0 (2018-05-06) +================== + +Features +-------- + +- Raise ``TooManyRedirects`` exception when client gets redirected too many + times instead of returning last response. (`#2631 `_) +- Extract route definitions into separate ``web_routedef.py`` file (`#2876 `_) +- Raise an exception on request body reading after sending response. (`#2895 `_) +- ClientResponse and RequestInfo now have real_url property, which is request + url without fragment part being stripped (`#2925 `_) +- Speed up connector limiting (`#2937 `_) +- Added and links property for ClientResponse object (`#2948 `_) +- Add ``request.config_dict`` for exposing nested applications data. (`#2949 `_) +- Speed up HTTP headers serialization, server micro-benchmark runs 5% faster + now. (`#2957 `_) +- Apply assertions in debug mode only (`#2966 `_) + + +Bugfixes +-------- + +- expose property `app` for TestClient (`#2891 `_) +- Call on_chunk_sent when write_eof takes as a param the last chunk (`#2909 `_) +- A closing bracket was added to `__repr__` of resources (`#2935 `_) +- Fix compression of FileResponse (`#2942 `_) +- Fixes some bugs in the limit connection feature (`#2964 `_) + + +Improved Documentation +---------------------- + +- Drop ``async_timeout`` usage from documentation for client API in favor of + ``timeout`` parameter. (`#2865 `_) +- Improve Gunicorn logging documentation (`#2921 `_) +- Replace multipart writer `.serialize()` method with `.write()` in + documentation. (`#2965 `_) + + +Deprecations and Removals +------------------------- + +- Deprecate Application.make_handler() (`#2938 `_) + + +Misc +---- + +- #2958 + + +3.1.3 (2018-04-12) +================== + +- Fix cancellation broadcast during DNS resolve (`#2910 `_) + + +3.1.2 (2018-04-05) +================== + +- Make ``LineTooLong`` exception more detailed about actual data size (`#2863 `_) +- Call ``on_chunk_sent`` when write_eof takes as a param the last chunk (`#2909 `_) + + +3.1.1 (2018-03-27) +================== + +- Support *asynchronous iterators* (and *asynchronous generators* as + well) in both client and server API as request / response BODY + payloads. (`#2802 `_) + + +3.1.0 (2018-03-21) +================== + +Welcome to aiohttp 3.1 release. + +This is an *incremental* release, fully backward compatible with *aiohttp 3.0*. + +But we have added several new features. + +The most visible one is ``app.add_routes()`` (an alias for existing +``app.router.add_routes()``. The addition is very important because +all *aiohttp* docs now uses ``app.add_routes()`` call in code +snippets. All your existing code still do register routes / resource +without any warning but you've got the idea for a favorite way: noisy +``app.router.add_get()`` is replaced by ``app.add_routes()``. + +The library does not make a preference between decorators:: + + routes = web.RouteTableDef() + + @routes.get('/') + async def hello(request): + return web.Response(text="Hello, world") + + app.add_routes(routes) + +and route tables as a list:: + + async def hello(request): + return web.Response(text="Hello, world") + + app.add_routes([web.get('/', hello)]) + +Both ways are equal, user may decide basing on own code taste. + +Also we have a lot of minor features, bug fixes and documentation +updates, see below. + +Features +-------- + +- Relax JSON content-type checking in the ``ClientResponse.json()`` to allow + "application/xxx+json" instead of strict "application/json". (`#2206 `_) +- Bump C HTTP parser to version 2.8 (`#2730 `_) +- Accept a coroutine as an application factory in ``web.run_app`` and gunicorn + worker. (`#2739 `_) +- Implement application cleanup context (``app.cleanup_ctx`` property). (`#2747 `_) +- Make ``writer.write_headers`` a coroutine. (`#2762 `_) +- Add tracking signals for getting request/response bodies. (`#2767 `_) +- Deprecate ClientResponseError.code in favor of .status to keep similarity + with response classes. (`#2781 `_) +- Implement ``app.add_routes()`` method. (`#2787 `_) +- Implement ``web.static()`` and ``RouteTableDef.static()`` API. (`#2795 `_) +- Install a test event loop as default by ``asyncio.set_event_loop()``. The + change affects aiohttp test utils but backward compatibility is not broken + for 99.99% of use cases. (`#2804 `_) +- Refactor ``ClientResponse`` constructor: make logically required constructor + arguments mandatory, drop ``_post_init()`` method. (`#2820 `_) +- Use ``app.add_routes()`` in server docs everywhere (`#2830 `_) +- Websockets refactoring, all websocket writer methods are converted into + coroutines. (`#2836 `_) +- Provide ``Content-Range`` header for ``Range`` requests (`#2844 `_) + + +Bugfixes +-------- + +- Fix websocket client return EofStream. (`#2784 `_) +- Fix websocket demo. (`#2789 `_) +- Property ``BaseRequest.http_range`` now returns a python-like slice when + requesting the tail of the range. It's now indicated by a negative value in + ``range.start`` rather then in ``range.stop`` (`#2805 `_) +- Close a connection if an unexpected exception occurs while sending a request + (`#2827 `_) +- Fix firing DNS tracing events. (`#2841 `_) + + +Improved Documentation +---------------------- + +- Document behavior when cchardet detects encodings that are unknown to Python. + (`#2732 `_) +- Add diagrams for tracing request life style. (`#2748 `_) +- Drop removed functionality for passing ``StreamReader`` as data at client + side. (`#2793 `_) + +3.0.9 (2018-03-14) +================== + +- Close a connection if an unexpected exception occurs while sending a request + (`#2827 `_) + + +3.0.8 (2018-03-12) +================== + +- Use ``asyncio.current_task()`` on Python 3.7 (`#2825 `_) + +3.0.7 (2018-03-08) +================== + +- Fix SSL proxy support by client. (`#2810 `_) +- Restore an imperative check in ``setup.py`` for python version. The check + works in parallel to environment marker. As effect an error about unsupported + Python versions is raised even on outdated systems with very old + ``setuptools`` version installed. (`#2813 `_) + + +3.0.6 (2018-03-05) +================== + +- Add ``_reuse_address`` and ``_reuse_port`` to + ``web_runner.TCPSite.__slots__``. (`#2792 `_) + +3.0.5 (2018-02-27) +================== + +- Fix ``InvalidStateError`` on processing a sequence of two + ``RequestHandler.data_received`` calls on web server. (`#2773 `_) + +3.0.4 (2018-02-26) +================== + +- Fix ``IndexError`` in HTTP request handling by server. (`#2752 `_) +- Fix MultipartWriter.append* no longer returning part/payload. (`#2759 `_) + + +3.0.3 (2018-02-25) +================== + +- Relax ``attrs`` dependency to minimal actually supported version + 17.0.3 The change allows to avoid version conflicts with currently + existing test tools. + +3.0.2 (2018-02-23) +================== + +Security Fix +------------ + +- Prevent Windows absolute URLs in static files. Paths like + ``/static/D:\path`` and ``/static/\\hostname\drive\path`` are + forbidden. + +3.0.1 +===== + +- Technical release for fixing distribution problems. + +3.0.0 (2018-02-12) +================== + +Features +-------- + +- Speed up the `PayloadWriter.write` method for large request bodies. (`#2126 `_) +- StreamResponse and Response are now MutableMappings. (`#2246 `_) +- ClientSession publishes a set of signals to track the HTTP request execution. + (`#2313 `_) +- Content-Disposition fast access in ClientResponse (`#2455 `_) +- Added support to Flask-style decorators with class-based Views. (`#2472 `_) +- Signal handlers (registered callbacks) should be coroutines. (`#2480 `_) +- Support ``async with test_client.ws_connect(...)`` (`#2525 `_) +- Introduce *site* and *application runner* as underlying API for `web.run_app` + implementation. (`#2530 `_) +- Only quote multipart boundary when necessary and sanitize input (`#2544 `_) +- Make the `aiohttp.ClientResponse.get_encoding` method public with the + processing of invalid charset while detecting content encoding. (`#2549 `_) +- Add optional configurable per message compression for + `ClientWebSocketResponse` and `WebSocketResponse`. (`#2551 `_) +- Add hysteresis to `StreamReader` to prevent flipping between paused and + resumed states too often. (`#2555 `_) +- Support `.netrc` by `trust_env` (`#2581 `_) +- Avoid to create a new resource when adding a route with the same name and + path of the last added resource (`#2586 `_) +- `MultipartWriter.boundary` is `str` now. (`#2589 `_) +- Allow a custom port to be used by `TestServer` (and associated pytest + fixtures) (`#2613 `_) +- Add param access_log_class to web.run_app function (`#2615 `_) +- Add ``ssl`` parameter to client API (`#2626 `_) +- Fixes performance issue introduced by #2577. When there are no middlewares + installed by the user, no additional and useless code is executed. (`#2629 `_) +- Rename PayloadWriter to StreamWriter (`#2654 `_) +- New options *reuse_port*, *reuse_address* are added to `run_app` and + `TCPSite`. (`#2679 `_) +- Use custom classes to pass client signals parameters (`#2686 `_) +- Use ``attrs`` library for data classes, replace `namedtuple`. (`#2690 `_) +- Pytest fixtures renaming, add ``aiohttp_`` prefix (`#2578 `_) +- Add ``aiohttp-`` prefix for ``pytest-aiohttp`` command line + parameters (`#2578 `_) + +Bugfixes +-------- + +- Correctly process upgrade request from server to HTTP2. ``aiohttp`` does not + support HTTP2 yet, the protocol is not upgraded but response is handled + correctly. (`#2277 `_) +- Fix ClientConnectorSSLError and ClientProxyConnectionError for proxy + connector (`#2408 `_) +- Fix connector convert OSError to ClientConnectorError (`#2423 `_) +- Fix connection attempts for multiple dns hosts (`#2424 `_) +- Fix writing to closed transport by raising `asyncio.CancelledError` (`#2499 `_) +- Fix warning in `ClientSession.__del__` by stopping to try to close it. + (`#2523 `_) +- Fixed race-condition for iterating addresses from the DNSCache. (`#2620 `_) +- Fix default value of `access_log_format` argument in `web.run_app` (`#2649 `_) +- Freeze sub-application on adding to parent app (`#2656 `_) +- Do percent encoding for `.url_for()` parameters (`#2668 `_) +- Correctly process request start time and multiple request/response + headers in access log extra (`#2641 `_) + +Improved Documentation +---------------------- + +- Improve tutorial docs, using `literalinclude` to link to the actual files. + (`#2396 `_) +- Small improvement docs: better example for file uploads. (`#2401 `_) +- Rename `from_env` to `trust_env` in client reference. (`#2451 `_) +- Fixed mistype in `Proxy Support` section where `trust_env` parameter was + used in `session.get("http://python.org", trust_env=True)` method instead of + aiohttp.ClientSession constructor as follows: + `aiohttp.ClientSession(trust_env=True)`. (`#2688 `_) +- Fix issue with unittest example not compiling in testing docs. (`#2717 `_) + +Deprecations and Removals +------------------------- + +- Simplify HTTP pipelining implementation (`#2109 `_) +- Drop `StreamReaderPayload` and `DataQueuePayload`. (`#2257 `_) +- Drop `md5` and `sha1` finger-prints (`#2267 `_) +- Drop WSMessage.tp (`#2321 `_) +- Drop Python 3.4 and Python 3.5.0, 3.5.1, 3.5.2. Minimal supported Python + versions are 3.5.3 and 3.6.0. `yield from` is gone, use `async/await` syntax. + (`#2343 `_) +- Drop `aiohttp.Timeout` and use `async_timeout.timeout` instead. (`#2348 `_) +- Drop `resolve` param from TCPConnector. (`#2377 `_) +- Add DeprecationWarning for returning HTTPException (`#2415 `_) +- `send_str()`, `send_bytes()`, `send_json()`, `ping()` and `pong()` are + genuine async functions now. (`#2475 `_) +- Drop undocumented `app.on_pre_signal` and `app.on_post_signal`. Signal + handlers should be coroutines, support for regular functions is dropped. + (`#2480 `_) +- `StreamResponse.drain()` is not a part of public API anymore, just use `await + StreamResponse.write()`. `StreamResponse.write` is converted to async + function. (`#2483 `_) +- Drop deprecated `slow_request_timeout` param and `**kwargs`` from + `RequestHandler`. (`#2500 `_) +- Drop deprecated `resource.url()`. (`#2501 `_) +- Remove `%u` and `%l` format specifiers from access log format. (`#2506 `_) +- Drop deprecated `request.GET` property. (`#2547 `_) +- Simplify stream classes: drop `ChunksQueue` and `FlowControlChunksQueue`, + merge `FlowControlStreamReader` functionality into `StreamReader`, drop + `FlowControlStreamReader` name. (`#2555 `_) +- Do not create a new resource on `router.add_get(..., allow_head=True)` + (`#2585 `_) +- Drop access to TCP tuning options from PayloadWriter and Response classes + (`#2604 `_) +- Drop deprecated `encoding` parameter from client API (`#2606 `_) +- Deprecate ``verify_ssl``, ``ssl_context`` and ``fingerprint`` parameters in + client API (`#2626 `_) +- Get rid of the legacy class StreamWriter. (`#2651 `_) +- Forbid non-strings in `resource.url_for()` parameters. (`#2668 `_) +- Deprecate inheritance from ``ClientSession`` and ``web.Application`` and + custom user attributes for ``ClientSession``, ``web.Request`` and + ``web.Application`` (`#2691 `_) +- Drop `resp = await aiohttp.request(...)` syntax for sake of `async with + aiohttp.request(...) as resp:`. (`#2540 `_) +- Forbid synchronous context managers for `ClientSession` and test + server/client. (`#2362 `_) + + +Misc +---- + +- #2552 + + +2.3.10 (2018-02-02) +=================== + +- Fix 100% CPU usage on HTTP GET and websocket connection just after it (`#1955 `_) + +- Patch broken `ssl.match_hostname()` on Python<3.7 (`#2674 `_) + +2.3.9 (2018-01-16) +================== + +- Fix colon handing in path for dynamic resources (`#2670 `_) + +2.3.8 (2018-01-15) +================== + +- Do not use `yarl.unquote` internal function in aiohttp. Fix + incorrectly unquoted path part in URL dispatcher (`#2662 `_) + +- Fix compatibility with `yarl==1.0.0` (`#2662 `_) + +2.3.7 (2017-12-27) +================== + +- Fixed race-condition for iterating addresses from the DNSCache. (`#2620 `_) +- Fix docstring for request.host (`#2591 `_) +- Fix docstring for request.remote (`#2592 `_) + + +2.3.6 (2017-12-04) +================== + +- Correct `request.app` context (for handlers not just middlewares). (`#2577 `_) + + +2.3.5 (2017-11-30) +================== + +- Fix compatibility with `pytest` 3.3+ (`#2565 `_) + + +2.3.4 (2017-11-29) +================== + +- Make `request.app` point to proper application instance when using nested + applications (with middlewares). (`#2550 `_) +- Change base class of ClientConnectorSSLError to ClientSSLError from + ClientConnectorError. (`#2563 `_) +- Return client connection back to free pool on error in `connector.connect()`. + (`#2567 `_) + + +2.3.3 (2017-11-17) +================== + +- Having a `;` in Response content type does not assume it contains a charset + anymore. (`#2197 `_) +- Use `getattr(asyncio, 'async')` for keeping compatibility with Python 3.7. + (`#2476 `_) +- Ignore `NotImplementedError` raised by `set_child_watcher` from `uvloop`. + (`#2491 `_) +- Fix warning in `ClientSession.__del__` by stopping to try to close it. + (`#2523 `_) +- Fixed typo's in Third-party libraries page. And added async-v20 to the list + (`#2510 `_) + + +2.3.2 (2017-11-01) +================== + +- Fix passing client max size on cloning request obj. (`#2385 `_) +- Fix ClientConnectorSSLError and ClientProxyConnectionError for proxy + connector. (`#2408 `_) +- Drop generated `_http_parser` shared object from tarball distribution. (`#2414 `_) +- Fix connector convert OSError to ClientConnectorError. (`#2423 `_) +- Fix connection attempts for multiple dns hosts. (`#2424 `_) +- Fix ValueError for AF_INET6 sockets if a preexisting INET6 socket to the + `aiohttp.web.run_app` function. (`#2431 `_) +- `_SessionRequestContextManager` closes the session properly now. (`#2441 `_) +- Rename `from_env` to `trust_env` in client reference. (`#2451 `_) + + +2.3.1 (2017-10-18) +================== + +- Relax attribute lookup in warning about old-styled middleware (`#2340 `_) + + +2.3.0 (2017-10-18) +================== + +Features +-------- + +- Add SSL related params to `ClientSession.request` (`#1128 `_) +- Make enable_compression work on HTTP/1.0 (`#1828 `_) +- Deprecate registering synchronous web handlers (`#1993 `_) +- Switch to `multidict 3.0`. All HTTP headers preserve casing now but compared + in case-insensitive way. (`#1994 `_) +- Improvement for `normalize_path_middleware`. Added possibility to handle URLs + with query string. (`#1995 `_) +- Use towncrier for CHANGES.txt build (`#1997 `_) +- Implement `trust_env=True` param in `ClientSession`. (`#1998 `_) +- Added variable to customize proxy headers (`#2001 `_) +- Implement `router.add_routes` and router decorators. (`#2004 `_) +- Deprecated `BaseRequest.has_body` in favor of + `BaseRequest.can_read_body` Added `BaseRequest.body_exists` + attribute that stays static for the lifetime of the request (`#2005 `_) +- Provide `BaseRequest.loop` attribute (`#2024 `_) +- Make `_CoroGuard` awaitable and fix `ClientSession.close` warning message + (`#2026 `_) +- Responses to redirects without Location header are returned instead of + raising a RuntimeError (`#2030 `_) +- Added `get_client`, `get_server`, `setUpAsync` and `tearDownAsync` methods to + AioHTTPTestCase (`#2032 `_) +- Add automatically a SafeChildWatcher to the test loop (`#2058 `_) +- add ability to disable automatic response decompression (`#2110 `_) +- Add support for throttling DNS request, avoiding the requests saturation when + there is a miss in the DNS cache and many requests getting into the connector + at the same time. (`#2111 `_) +- Use request for getting access log information instead of message/transport + pair. Add `RequestBase.remote` property for accessing to IP of client + initiated HTTP request. (`#2123 `_) +- json() raises a ContentTypeError exception if the content-type does not meet + the requirements instead of raising a generic ClientResponseError. (`#2136 `_) +- Make the HTTP client able to return HTTP chunks when chunked transfer + encoding is used. (`#2150 `_) +- add `append_version` arg into `StaticResource.url` and + `StaticResource.url_for` methods for getting an url with hash (version) of + the file. (`#2157 `_) +- Fix parsing the Forwarded header. * commas and semicolons are allowed inside + quoted-strings; * empty forwarded-pairs (as in for=_1;;by=_2) are allowed; * + non-standard parameters are allowed (although this alone could be easily done + in the previous parser). (`#2173 `_) +- Don't require ssl module to run. aiohttp does not require SSL to function. + The code paths involved with SSL will only be hit upon SSL usage. Raise + `RuntimeError` if HTTPS protocol is required but ssl module is not present. + (`#2221 `_) +- Accept coroutine fixtures in pytest plugin (`#2223 `_) +- Call `shutdown_asyncgens` before event loop closing on Python 3.6. (`#2227 `_) +- Speed up Signals when there are no receivers (`#2229 `_) +- Raise `InvalidURL` instead of `ValueError` on fetches with invalid URL. + (`#2241 `_) +- Move `DummyCookieJar` into `cookiejar.py` (`#2242 `_) +- `run_app`: Make `print=None` disable printing (`#2260 `_) +- Support `brotli` encoding (generic-purpose lossless compression algorithm) + (`#2270 `_) +- Add server support for WebSockets Per-Message Deflate. Add client option to + add deflate compress header in WebSockets request header. If calling + ClientSession.ws_connect() with `compress=15` the client will support deflate + compress negotiation. (`#2273 `_) +- Support `verify_ssl`, `fingerprint`, `ssl_context` and `proxy_headers` by + `client.ws_connect`. (`#2292 `_) +- Added `aiohttp.ClientConnectorSSLError` when connection fails due + `ssl.SSLError` (`#2294 `_) +- `aiohttp.web.Application.make_handler` support `access_log_class` (`#2315 `_) +- Build HTTP parser extension in non-strict mode by default. (`#2332 `_) + + +Bugfixes +-------- + +- Clear auth information on redirecting to other domain (`#1699 `_) +- Fix missing app.loop on startup hooks during tests (`#2060 `_) +- Fix issue with synchronous session closing when using `ClientSession` as an + asynchronous context manager. (`#2063 `_) +- Fix issue with `CookieJar` incorrectly expiring cookies in some edge cases. + (`#2084 `_) +- Force use of IPv4 during test, this will make tests run in a Docker container + (`#2104 `_) +- Warnings about unawaited coroutines now correctly point to the user's code. + (`#2106 `_) +- Fix issue with `IndexError` being raised by the `StreamReader.iter_chunks()` + generator. (`#2112 `_) +- Support HTTP 308 Permanent redirect in client class. (`#2114 `_) +- Fix `FileResponse` sending empty chunked body on 304. (`#2143 `_) +- Do not add `Content-Length: 0` to GET/HEAD/TRACE/OPTIONS requests by default. + (`#2167 `_) +- Fix parsing the Forwarded header according to RFC 7239. (`#2170 `_) +- Securely determining remote/scheme/host #2171 (`#2171 `_) +- Fix header name parsing, if name is split into multiple lines (`#2183 `_) +- Handle session close during connection, `KeyError: + ` (`#2193 `_) +- Fixes uncaught `TypeError` in `helpers.guess_filename` if `name` is not a + string (`#2201 `_) +- Raise OSError on async DNS lookup if resolved domain is an alias for another + one, which does not have an A or CNAME record. (`#2231 `_) +- Fix incorrect warning in `StreamReader`. (`#2251 `_) +- Properly clone state of web request (`#2284 `_) +- Fix C HTTP parser for cases when status line is split into different TCP + packets. (`#2311 `_) +- Fix `web.FileResponse` overriding user supplied Content-Type (`#2317 `_) + + +Improved Documentation +---------------------- + +- Add a note about possible performance degradation in `await resp.text()` if + charset was not provided by `Content-Type` HTTP header. Pass explicit + encoding to solve it. (`#1811 `_) +- Drop `disqus` widget from documentation pages. (`#2018 `_) +- Add a graceful shutdown section to the client usage documentation. (`#2039 `_) +- Document `connector_owner` parameter. (`#2072 `_) +- Update the doc of web.Application (`#2081 `_) +- Fix mistake about access log disabling. (`#2085 `_) +- Add example usage of on_startup and on_shutdown signals by creating and + disposing an aiopg connection engine. (`#2131 `_) +- Document `encoded=True` for `yarl.URL`, it disables all yarl transformations. + (`#2198 `_) +- Document that all app's middleware factories are run for every request. + (`#2225 `_) +- Reflect the fact that default resolver is threaded one starting from aiohttp + 1.1 (`#2228 `_) + + +Deprecations and Removals +------------------------- + +- Drop deprecated `Server.finish_connections` (`#2006 `_) +- Drop %O format from logging, use %b instead. Drop %e format from logging, + environment variables are not supported anymore. (`#2123 `_) +- Drop deprecated secure_proxy_ssl_header support (`#2171 `_) +- Removed TimeService in favor of simple caching. TimeService also had a bug + where it lost about 0.5 seconds per second. (`#2176 `_) +- Drop unused response_factory from static files API (`#2290 `_) + + +Misc +---- + +- #2013, #2014, #2048, #2094, #2149, #2187, #2214, #2225, #2243, #2248 + + +2.2.5 (2017-08-03) +================== + +- Don't raise deprecation warning on + `loop.run_until_complete(client.close())` (`#2065 `_) + +2.2.4 (2017-08-02) +================== + +- Fix issue with synchronous session closing when using ClientSession + as an asynchronous context manager. (`#2063 `_) + +2.2.3 (2017-07-04) +================== + +- Fix `_CoroGuard` for python 3.4 + +2.2.2 (2017-07-03) +================== + +- Allow `await session.close()` along with `yield from session.close()` + + +2.2.1 (2017-07-02) +================== + +- Relax `yarl` requirement to 0.11+ + +- Backport #2026: `session.close` *is* a coroutine (`#2029 `_) + + +2.2.0 (2017-06-20) +================== + +- Add doc for add_head, update doc for add_get. (`#1944 `_) + +- Fixed consecutive calls for `Response.write_eof`. + +- Retain method attributes (e.g. :code:`__doc__`) when registering synchronous + handlers for resources. (`#1953 `_) + +- Added signal TERM handling in `run_app` to gracefully exit (`#1932 `_) + +- Fix websocket issues caused by frame fragmentation. (`#1962 `_) + +- Raise RuntimeError is you try to set the Content Length and enable + chunked encoding at the same time (`#1941 `_) + +- Small update for `unittest_run_loop` + +- Use CIMultiDict for ClientRequest.skip_auto_headers (`#1970 `_) + +- Fix wrong startup sequence: test server and `run_app()` are not raise + `DeprecationWarning` now (`#1947 `_) + +- Make sure cleanup signal is sent if startup signal has been sent (`#1959 `_) + +- Fixed server keep-alive handler, could cause 100% cpu utilization (`#1955 `_) + +- Connection can be destroyed before response get processed if + `await aiohttp.request(..)` is used (`#1981 `_) + +- MultipartReader does not work with -OO (`#1969 `_) + +- Fixed `ClientPayloadError` with blank `Content-Encoding` header (`#1931 `_) + +- Support `deflate` encoding implemented in `httpbin.org/deflate` (`#1918 `_) + +- Fix BadStatusLine caused by extra `CRLF` after `POST` data (`#1792 `_) + +- Keep a reference to `ClientSession` in response object (`#1985 `_) + +- Deprecate undocumented `app.on_loop_available` signal (`#1978 `_) + + + +2.1.0 (2017-05-26) +================== + +- Added support for experimental `async-tokio` event loop written in Rust + https://github.com/PyO3/tokio + +- Write to transport ``\r\n`` before closing after keepalive timeout, + otherwise client can not detect socket disconnection. (`#1883 `_) + +- Only call `loop.close` in `run_app` if the user did *not* supply a loop. + Useful for allowing clients to specify their own cleanup before closing the + asyncio loop if they wish to tightly control loop behavior + +- Content disposition with semicolon in filename (`#917 `_) + +- Added `request_info` to response object and `ClientResponseError`. (`#1733 `_) + +- Added `history` to `ClientResponseError`. (`#1741 `_) + +- Allow to disable redirect url re-quoting (`#1474 `_) + +- Handle RuntimeError from transport (`#1790 `_) + +- Dropped "%O" in access logger (`#1673 `_) + +- Added `args` and `kwargs` to `unittest_run_loop`. Useful with other + decorators, for example `@patch`. (`#1803 `_) + +- Added `iter_chunks` to response.content object. (`#1805 `_) + +- Avoid creating TimerContext when there is no timeout to allow + compatibility with Tornado. (`#1817 `_) (`#1180 `_) + +- Add `proxy_from_env` to `ClientRequest` to read from environment + variables. (`#1791 `_) + +- Add DummyCookieJar helper. (`#1830 `_) + +- Fix assertion errors in Python 3.4 from noop helper. (`#1847 `_) + +- Do not unquote `+` in match_info values (`#1816 `_) + +- Use Forwarded, X-Forwarded-Scheme and X-Forwarded-Host for better scheme and + host resolution. (`#1134 `_) + +- Fix sub-application middlewares resolution order (`#1853 `_) + +- Fix applications comparison (`#1866 `_) + +- Fix static location in index when prefix is used (`#1662 `_) + +- Make test server more reliable (`#1896 `_) + +- Extend list of web exceptions, add HTTPUnprocessableEntity, + HTTPFailedDependency, HTTPInsufficientStorage status codes (`#1920 `_) + + +2.0.7 (2017-04-12) +================== + +- Fix *pypi* distribution + +- Fix exception description (`#1807 `_) + +- Handle socket error in FileResponse (`#1773 `_) + +- Cancel websocket heartbeat on close (`#1793 `_) + + +2.0.6 (2017-04-04) +================== + +- Keeping blank values for `request.post()` and `multipart.form()` (`#1765 `_) + +- TypeError in data_received of ResponseHandler (`#1770 `_) + +- Fix ``web.run_app`` not to bind to default host-port pair if only socket is + passed (`#1786 `_) + + +2.0.5 (2017-03-29) +================== + +- Memory leak with aiohttp.request (`#1756 `_) + +- Disable cleanup closed ssl transports by default. + +- Exception in request handling if the server responds before the body + is sent (`#1761 `_) + + +2.0.4 (2017-03-27) +================== + +- Memory leak with aiohttp.request (`#1756 `_) + +- Encoding is always UTF-8 in POST data (`#1750 `_) + +- Do not add "Content-Disposition" header by default (`#1755 `_) + + +2.0.3 (2017-03-24) +================== + +- Call https website through proxy will cause error (`#1745 `_) + +- Fix exception on multipart/form-data post if content-type is not set (`#1743 `_) + + +2.0.2 (2017-03-21) +================== + +- Fixed Application.on_loop_available signal (`#1739 `_) + +- Remove debug code + + +2.0.1 (2017-03-21) +================== + +- Fix allow-head to include name on route (`#1737 `_) + +- Fixed AttributeError in WebSocketResponse.can_prepare (`#1736 `_) + + +2.0.0 (2017-03-20) +================== + +- Added `json` to `ClientSession.request()` method (`#1726 `_) + +- Added session's `raise_for_status` parameter, automatically calls + raise_for_status() on any request. (`#1724 `_) + +- `response.json()` raises `ClientReponseError` exception if response's + content type does not match (`#1723 `_) + + - Cleanup timer and loop handle on any client exception. + +- Deprecate `loop` parameter for Application's constructor + + +`2.0.0rc1` (2017-03-15) +======================= + +- Properly handle payload errors (`#1710 `_) + +- Added `ClientWebSocketResponse.get_extra_info()` (`#1717 `_) + +- It is not possible to combine Transfer-Encoding and chunked parameter, + same for compress and Content-Encoding (`#1655 `_) + +- Connector's `limit` parameter indicates total concurrent connections. + New `limit_per_host` added, indicates total connections per endpoint. (`#1601 `_) + +- Use url's `raw_host` for name resolution (`#1685 `_) + +- Change `ClientResponse.url` to `yarl.URL` instance (`#1654 `_) + +- Add max_size parameter to web.Request reading methods (`#1133 `_) + +- Web Request.post() stores data in temp files (`#1469 `_) + +- Add the `allow_head=True` keyword argument for `add_get` (`#1618 `_) + +- `run_app` and the Command Line Interface now support serving over + Unix domain sockets for faster inter-process communication. + +- `run_app` now supports passing a preexisting socket object. This can be useful + e.g. for socket-based activated applications, when binding of a socket is + done by the parent process. + +- Implementation for Trailer headers parser is broken (`#1619 `_) + +- Fix FileResponse to not fall on bad request (range out of file size) + +- Fix FileResponse to correct stream video to Chromes + +- Deprecate public low-level api (`#1657 `_) + +- Deprecate `encoding` parameter for ClientSession.request() method + +- Dropped aiohttp.wsgi (`#1108 `_) + +- Dropped `version` from ClientSession.request() method + +- Dropped websocket version 76 support (`#1160 `_) + +- Dropped: `aiohttp.protocol.HttpPrefixParser` (`#1590 `_) + +- Dropped: Servers response's `.started`, `.start()` and + `.can_start()` method (`#1591 `_) + +- Dropped: Adding `sub app` via `app.router.add_subapp()` is deprecated + use `app.add_subapp()` instead (`#1592 `_) + +- Dropped: `Application.finish()` and `Application.register_on_finish()` (`#1602 `_) + +- Dropped: `web.Request.GET` and `web.Request.POST` + +- Dropped: aiohttp.get(), aiohttp.options(), aiohttp.head(), + aiohttp.post(), aiohttp.put(), aiohttp.patch(), aiohttp.delete(), and + aiohttp.ws_connect() (`#1593 `_) + +- Dropped: `aiohttp.web.WebSocketResponse.receive_msg()` (`#1605 `_) + +- Dropped: `ServerHttpProtocol.keep_alive_timeout` attribute and + `keep-alive`, `keep_alive_on`, `timeout`, `log` constructor parameters (`#1606 `_) + +- Dropped: `TCPConnector's`` `.resolve`, `.resolved_hosts`, + `.clear_resolved_hosts()` attributes and `resolve` constructor + parameter (`#1607 `_) + +- Dropped `ProxyConnector` (`#1609 `_) + + 1.3.5 (2017-03-16) ------------------- +================== -- Fixed None timeout support #1720 +- Fixed None timeout support (`#1720 `_) 1.3.4 (2017-03-14) ------------------- +================== - Revert timeout handling in client request @@ -15,92 +1049,96 @@ - Fix file_sender to correct stream video to Chromes -- Fix NotImplementedError server exception #1703 +- Fix NotImplementedError server exception (`#1703 `_) -- Clearer error message for URL without a host name. #1691 +- Clearer error message for URL without a host name. (`#1691 `_) -- Silence deprecation warning in __repr__ #1690 +- Silence deprecation warning in __repr__ (`#1690 `_) -- IDN + HTTPS = `ssl.CertificateError` #1685 +- IDN + HTTPS = `ssl.CertificateError` (`#1685 `_) 1.3.3 (2017-02-19) ------------------- +================== -- Fixed memory leak in time service #1656 +- Fixed memory leak in time service (`#1656 `_) 1.3.2 (2017-02-16) ------------------- +================== -- Awaiting on WebSocketResponse.send_* does not work #1645 +- Awaiting on WebSocketResponse.send_* does not work (`#1645 `_) -- Fix multiple calls to client ws_connect when using a shared header dict #1643 +- Fix multiple calls to client ws_connect when using a shared header + dict (`#1643 `_) -- Make CookieJar.filter_cookies() accept plain string parameter. #1636 +- Make CookieJar.filter_cookies() accept plain string parameter. (`#1636 `_) 1.3.1 (2017-02-09) ------------------- +================== - Handle CLOSING in WebSocketResponse.__anext__ -- Fixed AttributeError 'drain' for server websocket handler #1613 +- Fixed AttributeError 'drain' for server websocket handler (`#1613 `_) 1.3.0 (2017-02-08) ------------------- +================== -- Multipart writer validates the data on append instead of on a request send #920 +- Multipart writer validates the data on append instead of on a + request send (`#920 `_) - Multipart reader accepts multipart messages with or without their epilogue - to consistently handle valid and legacy behaviors #1526 #1581 + to consistently handle valid and legacy behaviors (`#1526 `_) (`#1581 `_) - Separate read + connect + request timeouts # 1523 -- Do not swallow Upgrade header #1587 +- Do not swallow Upgrade header (`#1587 `_) -- Fix polls demo run application #1487 +- Fix polls demo run application (`#1487 `_) -- Ignore unknown 1XX status codes in client #1353 +- Ignore unknown 1XX status codes in client (`#1353 `_) -- Fix sub-Multipart messages missing their headers on serialization #1525 +- Fix sub-Multipart messages missing their headers on serialization (`#1525 `_) - Do not use readline when reading the content of a part - in the multipart reader #1535 + in the multipart reader (`#1535 `_) -- Add optional flag for quoting `FormData` fields #916 +- Add optional flag for quoting `FormData` fields (`#916 `_) -- 416 Range Not Satisfiable if requested range end > file size #1588 +- 416 Range Not Satisfiable if requested range end > file size (`#1588 `_) -- Having a `:` or `@` in a route does not work #1552 +- Having a `:` or `@` in a route does not work (`#1552 `_) -- Added `receive_timeout` timeout for websocket to receive complete message. #1325 +- Added `receive_timeout` timeout for websocket to receive complete + message. (`#1325 `_) -- Added `heartbeat` parameter for websocket to automatically send `ping` message. #1024 #777 +- Added `heartbeat` parameter for websocket to automatically send + `ping` message. (`#1024 `_) (`#777 `_) -- Remove `web.Application` dependency from `web.UrlDispatcher` #1510 +- Remove `web.Application` dependency from `web.UrlDispatcher` (`#1510 `_) -- Accepting back-pressure from slow websocket clients #1367 +- Accepting back-pressure from slow websocket clients (`#1367 `_) -- Do not pause transport during set_parser stage #1211 +- Do not pause transport during set_parser stage (`#1211 `_) -- Lingering close doesn't terminate before timeout #1559 +- Lingering close does not terminate before timeout (`#1559 `_) -- `setsockopt` may raise `OSError` exception if socket is closed already #1595 +- `setsockopt` may raise `OSError` exception if socket is closed already (`#1595 `_) -- Lots of CancelledError when requests are interrupted #1565 +- Lots of CancelledError when requests are interrupted (`#1565 `_) - Allow users to specify what should happen to decoding errors - when calling a responses `text()` method #1542 + when calling a responses `text()` method (`#1542 `_) -- Back port std module `http.cookies` for python3.4.2 #1566 +- Back port std module `http.cookies` for python3.4.2 (`#1566 `_) -- Maintain url's fragment in client response #1314 +- Maintain url's fragment in client response (`#1314 `_) -- Allow concurrently close WebSocket connection #754 +- Allow concurrently close WebSocket connection (`#754 `_) -- Gzipped responses with empty body raises ContentEncodingError #609 +- Gzipped responses with empty body raises ContentEncodingError (`#609 `_) - Return 504 if request handle raises TimeoutError. @@ -109,39 +1147,40 @@ - Close response connection if we can not consume whole http message during client response release -- Abort closed ssl client transports, broken servers can keep socket open un-limit time #1568 +- Abort closed ssl client transports, broken servers can keep socket + open un-limit time (`#1568 `_) - Log warning instead of `RuntimeError` is websocket connection is closed. - Deprecated: `aiohttp.protocol.HttpPrefixParser` - will be removed in 1.4 #1590 + will be removed in 1.4 (`#1590 `_) -- Deprecated: Servers response's `.started`, `.start()` and `.can_start()` method - will be removed in 1.4 #1591 +- Deprecated: Servers response's `.started`, `.start()` and + `.can_start()` method will be removed in 1.4 (`#1591 `_) - Deprecated: Adding `sub app` via `app.router.add_subapp()` is deprecated - use `app.add_subapp()` instead, will be removed in 1.4 #1592 + use `app.add_subapp()` instead, will be removed in 1.4 (`#1592 `_) - Deprecated: aiohttp.get(), aiohttp.options(), aiohttp.head(), aiohttp.post(), aiohttp.put(), aiohttp.patch(), aiohttp.delete(), and aiohttp.ws_connect() - will be removed in 1.4 #1593 + will be removed in 1.4 (`#1593 `_) - Deprecated: `Application.finish()` and `Application.register_on_finish()` - will be removed in 1.4 #1602 + will be removed in 1.4 (`#1602 `_) 1.2.0 (2016-12-17) ------------------- +================== - Extract `BaseRequest` from `web.Request`, introduce `web.Server` (former `RequestHandlerFactory`), introduce new low-level web server - which is not coupled with `web.Application` and routing #1362 + which is not coupled with `web.Application` and routing (`#1362 `_) -- Make `TestServer.make_url` compatible with `yarl.URL` #1389 +- Make `TestServer.make_url` compatible with `yarl.URL` (`#1389 `_) -- Implement range requests for static files #1382 +- Implement range requests for static files (`#1382 `_) -- Support task attribute for StreamResponse #1410 +- Support task attribute for StreamResponse (`#1410 `_) - Drop `TestClient.app` property, use `TestClient.server.app` instead (BACKWARD INCOMPATIBLE) @@ -152,108 +1191,108 @@ - `TestClient.server` property returns a test server instance, was `asyncio.AbstractServer` (BACKWARD INCOMPATIBLE) -- Follow gunicorn's signal semantics in `Gunicorn[UVLoop]WebWorker` #1201 +- Follow gunicorn's signal semantics in `Gunicorn[UVLoop]WebWorker` (`#1201 `_) - Call worker_int and worker_abort callbacks in - `Gunicorn[UVLoop]WebWorker` #1202 + `Gunicorn[UVLoop]WebWorker` (`#1202 `_) -- Has functional tests for client proxy #1218 +- Has functional tests for client proxy (`#1218 `_) -- Fix bugs with client proxy target path and proxy host with port #1413 +- Fix bugs with client proxy target path and proxy host with port (`#1413 `_) -- Fix bugs related to the use of unicode hostnames #1444 +- Fix bugs related to the use of unicode hostnames (`#1444 `_) -- Preserve cookie quoting/escaping #1453 +- Preserve cookie quoting/escaping (`#1453 `_) -- FileSender will send gzipped response if gzip version available #1426 +- FileSender will send gzipped response if gzip version available (`#1426 `_) - Don't override `Content-Length` header in `web.Response` if no body - was set #1400 + was set (`#1400 `_) -- Introduce `router.post_init()` for solving #1373 +- Introduce `router.post_init()` for solving (`#1373 `_) - Fix raise error in case of multiple calls of `TimeServive.stop()` -- Allow to raise web exceptions on router resolving stage #1460 +- Allow to raise web exceptions on router resolving stage (`#1460 `_) -- Add a warning for session creation outside of coroutine #1468 +- Add a warning for session creation outside of coroutine (`#1468 `_) - Avoid a race when application might start accepting incoming requests but startup signals are not processed yet e98e8c6 - Raise a `RuntimeError` when trying to change the status of the HTTP response - after the headers have been sent #1480 + after the headers have been sent (`#1480 `_) -- Fix bug with https proxy acquired cleanup #1340 +- Fix bug with https proxy acquired cleanup (`#1340 `_) -- Use UTF-8 as the default encoding for multipart text parts #1484 +- Use UTF-8 as the default encoding for multipart text parts (`#1484 `_) 1.1.6 (2016-11-28) ------------------- +================== - Fix `BodyPartReader.read_chunk` bug about returns zero bytes before - `EOF` #1428 + `EOF` (`#1428 `_) 1.1.5 (2016-11-16) ------------------- +================== -- Fix static file serving in fallback mode #1401 +- Fix static file serving in fallback mode (`#1401 `_) 1.1.4 (2016-11-14) ------------------- +================== -- Make `TestServer.make_url` compatible with `yarl.URL` #1389 +- Make `TestServer.make_url` compatible with `yarl.URL` (`#1389 `_) - Generate informative exception on redirects from server which - doesn't provide redirection headers #1396 + does not provide redirection headers (`#1396 `_) 1.1.3 (2016-11-10) ------------------- +================== -- Support *root* resources for sub-applications #1379 +- Support *root* resources for sub-applications (`#1379 `_) 1.1.2 (2016-11-08) ------------------- +================== -- Allow starting variables with an underscore #1379 +- Allow starting variables with an underscore (`#1379 `_) -- Properly process UNIX sockets by gunicorn worker #1375 +- Properly process UNIX sockets by gunicorn worker (`#1375 `_) - Fix ordering for `FrozenList` -- Don't propagate pre and post signals to sub-application #1377 +- Don't propagate pre and post signals to sub-application (`#1377 `_) 1.1.1 (2016-11-04) ------------------- +================== -- Fix documentation generation #1120 +- Fix documentation generation (`#1120 `_) 1.1.0 (2016-11-03) ------------------- +================== - Drop deprecated `WSClientDisconnectedError` (BACKWARD INCOMPATIBLE) - Use `yarl.URL` in client API. The change is 99% backward compatible - but `ClientResponse.url` is an `yarl.URL` instance now. #1217 + but `ClientResponse.url` is an `yarl.URL` instance now. (`#1217 `_) -- Close idle keep-alive connections on shutdown #1222 +- Close idle keep-alive connections on shutdown (`#1222 `_) -- Modify regex in AccessLogger to accept underscore and numbers #1225 +- Modify regex in AccessLogger to accept underscore and numbers (`#1225 `_) - Use `yarl.URL` in web server API. `web.Request.rel_url` and `web.Request.url` are added. URLs and templates are percent-encoded - now. #1224 + now. (`#1224 `_) -- Accept `yarl.URL` by server redirections #1278 +- Accept `yarl.URL` by server redirections (`#1278 `_) -- Return `yarl.URL` by `.make_url()` testing utility #1279 +- Return `yarl.URL` by `.make_url()` testing utility (`#1279 `_) -- Properly format IPv6 addresses by `aiohttp.web.run_app` #1139 +- Properly format IPv6 addresses by `aiohttp.web.run_app` (`#1139 `_) -- Use `yarl.URL` by server API #1288 +- Use `yarl.URL` by server API (`#1288 `_) * Introduce `resource.url_for()`, deprecate `resource.url()`. @@ -264,38 +1303,38 @@ * Drop old-style routes: `Route`, `PlainRoute`, `DynamicRoute`, `StaticRoute`, `ResourceAdapter`. -- Revert `resp.url` back to `str`, introduce `resp.url_obj` #1292 +- Revert `resp.url` back to `str`, introduce `resp.url_obj` (`#1292 `_) -- Raise ValueError if BasicAuth login has a ":" character #1307 +- Raise ValueError if BasicAuth login has a ":" character (`#1307 `_) - Fix bug when ClientRequest send payload file with opened as - open('filename', 'r+b') #1306 + open('filename', 'r+b') (`#1306 `_) -- Enhancement to AccessLogger (pass *extra* dict) #1303 +- Enhancement to AccessLogger (pass *extra* dict) (`#1303 `_) -- Show more verbose message on import errors #1319 +- Show more verbose message on import errors (`#1319 `_) -- Added save and load functionality for `CookieJar` #1219 +- Added save and load functionality for `CookieJar` (`#1219 `_) -- Added option on `StaticRoute` to follow symlinks #1299 +- Added option on `StaticRoute` to follow symlinks (`#1299 `_) -- Force encoding of `application/json` content type to utf-8 #1339 +- Force encoding of `application/json` content type to utf-8 (`#1339 `_) -- Fix invalid invocations of `errors.LineTooLong` #1335 +- Fix invalid invocations of `errors.LineTooLong` (`#1335 `_) -- Websockets: Stop `async for` iteration when connection is closed #1144 +- Websockets: Stop `async for` iteration when connection is closed (`#1144 `_) -- Ensure TestClient HTTP methods return a context manager #1318 +- Ensure TestClient HTTP methods return a context manager (`#1318 `_) - Raise `ClientDisconnectedError` to `FlowControlStreamReader` read function - if `ClientSession` object is closed by client when reading data. #1323 + if `ClientSession` object is closed by client when reading data. (`#1323 `_) -- Document deployment without `Gunicorn` #1120 +- Document deployment without `Gunicorn` (`#1120 `_) - Add deprecation warning for MD5 and SHA1 digests when used for fingerprint - of site certs in TCPConnector. #1186 + of site certs in TCPConnector. (`#1186 `_) -- Implement sub-applications #1301 +- Implement sub-applications (`#1301 `_) - Don't inherit `web.Request` from `dict` but implement `MutableMapping` protocol. @@ -320,55 +1359,55 @@ boost of your application -- a couple DB requests and business logic is still the main bottleneck. -- Boost performance by adding a custom time service #1350 +- Boost performance by adding a custom time service (`#1350 `_) - Extend `ClientResponse` with `content_type` and `charset` - properties like in `web.Request`. #1349 + properties like in `web.Request`. (`#1349 `_) -- Disable aiodns by default #559 +- Disable aiodns by default (`#559 `_) - Don't flap `tcp_cork` in client code, use TCP_NODELAY mode by default. -- Implement `web.Request.clone()` #1361 +- Implement `web.Request.clone()` (`#1361 `_) 1.0.5 (2016-10-11) ------------------- +================== - Fix StreamReader._read_nowait to return all available - data up to the requested amount #1297 + data up to the requested amount (`#1297 `_) 1.0.4 (2016-09-22) ------------------- +================== - Fix FlowControlStreamReader.read_nowait so that it checks - whether the transport is paused #1206 + whether the transport is paused (`#1206 `_) 1.0.2 (2016-09-22) ------------------- +================== -- Make CookieJar compatible with 32-bit systems #1188 +- Make CookieJar compatible with 32-bit systems (`#1188 `_) -- Add missing `WSMsgType` to `web_ws.__all__`, see #1200 +- Add missing `WSMsgType` to `web_ws.__all__`, see (`#1200 `_) -- Fix `CookieJar` ctor when called with `loop=None` #1203 +- Fix `CookieJar` ctor when called with `loop=None` (`#1203 `_) -- Fix broken upper-casing in wsgi support #1197 +- Fix broken upper-casing in wsgi support (`#1197 `_) 1.0.1 (2016-09-16) ------------------- +================== - Restore `aiohttp.web.MsgType` alias for `aiohttp.WSMsgType` for sake - of backward compatibility #1178 + of backward compatibility (`#1178 `_) - Tune alabaster schema. - Use `text/html` content type for displaying index pages by static file handler. -- Fix `AssertionError` in static file handling #1177 +- Fix `AssertionError` in static file handling (`#1177 `_) - Fix access log formats `%O` and `%b` for static file handling @@ -377,12 +1416,12 @@ 1.0.0 (2016-09-16) -------------------- +================== - Change default size for client session's connection pool from - unlimited to 20 #977 + unlimited to 20 (`#977 `_) -- Add IE support for cookie deletion. #994 +- Add IE support for cookie deletion. (`#994 `_) - Remove deprecated `WebSocketResponse.wait_closed` method (BACKWARD INCOMPATIBLE) @@ -391,26 +1430,26 @@ method (BACKWARD INCOMPATIBLE) - Avoid using of mutable CIMultiDict kw param in make_mocked_request - #997 + (`#997 `_) - Make WebSocketResponse.close a little bit faster by avoiding new task creating just for timeout measurement - Add `proxy` and `proxy_auth` params to `client.get()` and family, - deprecate `ProxyConnector` #998 + deprecate `ProxyConnector` (`#998 `_) - Add support for websocket send_json and receive_json, synchronize - server and client API for websockets #984 + server and client API for websockets (`#984 `_) - Implement router shourtcuts for most useful HTTP methods, use `app.router.add_get()`, `app.router.add_post()` etc. instead of - `app.router.add_route()` #986 + `app.router.add_route()` (`#986 `_) -- Support SSL connections for gunicorn worker #1003 +- Support SSL connections for gunicorn worker (`#1003 `_) - Move obsolete examples to legacy folder -- Switch to multidict 2.0 and title-cased strings #1015 +- Switch to multidict 2.0 and title-cased strings (`#1015 `_) - `{FOO}e` logger format is case-sensitive now @@ -426,9 +1465,9 @@ - Remove deprecated decode param from resp.read(decode=True) -- Use 5min default client timeout #1028 +- Use 5min default client timeout (`#1028 `_) -- Relax HTTP method validation in UrlDispatcher #1037 +- Relax HTTP method validation in UrlDispatcher (`#1037 `_) - Pin minimal supported asyncio version to 3.4.2+ (`loop.is_close()` should be present) @@ -438,105 +1477,105 @@ - Link header for 451 status code is mandatory -- Fix test_client fixture to allow multiple clients per test #1072 +- Fix test_client fixture to allow multiple clients per test (`#1072 `_) -- make_mocked_request now accepts dict as headers #1073 +- make_mocked_request now accepts dict as headers (`#1073 `_) - Add Python 3.5.2/3.6+ compatibility patch for async generator - protocol change #1082 + protocol change (`#1082 `_) -- Improvement test_client can accept instance object #1083 +- Improvement test_client can accept instance object (`#1083 `_) -- Simplify ServerHttpProtocol implementation #1060 +- Simplify ServerHttpProtocol implementation (`#1060 `_) - Add a flag for optional showing directory index for static file - handling #921 + handling (`#921 `_) -- Define `web.Application.on_startup()` signal handler #1103 +- Define `web.Application.on_startup()` signal handler (`#1103 `_) -- Drop ChunkedParser and LinesParser #1111 +- Drop ChunkedParser and LinesParser (`#1111 `_) -- Call `Application.startup` in GunicornWebWorker #1105 +- Call `Application.startup` in GunicornWebWorker (`#1105 `_) - Fix client handling hostnames with 63 bytes when a port is given in - the url #1044 + the url (`#1044 `_) -- Implement proxy support for ClientSession.ws_connect #1025 +- Implement proxy support for ClientSession.ws_connect (`#1025 `_) -- Return named tuple from WebSocketResponse.can_prepare #1016 +- Return named tuple from WebSocketResponse.can_prepare (`#1016 `_) -- Fix access_log_format in `GunicornWebWorker` #1117 +- Fix access_log_format in `GunicornWebWorker` (`#1117 `_) -- Setup Content-Type to application/octet-stream by default #1124 +- Setup Content-Type to application/octet-stream by default (`#1124 `_) - Deprecate debug parameter from app.make_handler(), use - `Application(debug=True)` instead #1121 + `Application(debug=True)` instead (`#1121 `_) -- Remove fragment string in request path #846 +- Remove fragment string in request path (`#846 `_) -- Use aiodns.DNSResolver.gethostbyname() if available #1136 +- Use aiodns.DNSResolver.gethostbyname() if available (`#1136 `_) -- Fix static file sending on uvloop when sendfile is available #1093 +- Fix static file sending on uvloop when sendfile is available (`#1093 `_) -- Make prettier urls if query is empty dict #1143 +- Make prettier urls if query is empty dict (`#1143 `_) -- Fix redirects for HEAD requests #1147 +- Fix redirects for HEAD requests (`#1147 `_) -- Default value for `StreamReader.read_nowait` is -1 from now #1150 +- Default value for `StreamReader.read_nowait` is -1 from now (`#1150 `_) - `aiohttp.StreamReader` is not inherited from `asyncio.StreamReader` from now - (BACKWARD INCOMPATIBLE) #1150 + (BACKWARD INCOMPATIBLE) (`#1150 `_) -- Streams documentation added #1150 +- Streams documentation added (`#1150 `_) -- Add `multipart` coroutine method for web Request object #1067 +- Add `multipart` coroutine method for web Request object (`#1067 `_) -- Publish ClientSession.loop property #1149 +- Publish ClientSession.loop property (`#1149 `_) -- Fix static file with spaces #1140 +- Fix static file with spaces (`#1140 `_) -- Fix piling up asyncio loop by cookie expiration callbacks #1061 +- Fix piling up asyncio loop by cookie expiration callbacks (`#1061 `_) - Drop `Timeout` class for sake of `async_timeout` external library. `aiohttp.Timeout` is an alias for `async_timeout.timeout` - `use_dns_cache` parameter of `aiohttp.TCPConnector` is `True` by - default (BACKWARD INCOMPATIBLE) #1152 + default (BACKWARD INCOMPATIBLE) (`#1152 `_) - `aiohttp.TCPConnector` uses asynchronous DNS resolver if available by - default (BACKWARD INCOMPATIBLE) #1152 + default (BACKWARD INCOMPATIBLE) (`#1152 `_) -- Conform to RFC3986 - do not include url fragments in client requests #1174 +- Conform to RFC3986 - do not include url fragments in client requests (`#1174 `_) -- Drop `ClientSession.cookies` (BACKWARD INCOMPATIBLE) #1173 +- Drop `ClientSession.cookies` (BACKWARD INCOMPATIBLE) (`#1173 `_) -- Refactor `AbstractCookieJar` public API (BACKWARD INCOMPATIBLE) #1173 +- Refactor `AbstractCookieJar` public API (BACKWARD INCOMPATIBLE) (`#1173 `_) - Fix clashing cookies with have the same name but belong to different - domains (BACKWARD INCOMPATIBLE) #1125 + domains (BACKWARD INCOMPATIBLE) (`#1125 `_) -- Support binary Content-Transfer-Encoding #1169 +- Support binary Content-Transfer-Encoding (`#1169 `_) 0.22.5 (08-02-2016) -------------------- +=================== - Pin miltidict version to >=1.2.2 0.22.3 (07-26-2016) -------------------- +=================== -- Do not filter cookies if unsafe flag provided #1005 +- Do not filter cookies if unsafe flag provided (`#1005 `_) 0.22.2 (07-23-2016) -------------------- +=================== -- Suppress CancelledError when Timeout raises TimeoutError #970 +- Suppress CancelledError when Timeout raises TimeoutError (`#970 `_) - Don't expose `aiohttp.__version__` -- Add unsafe parameter to CookieJar #968 +- Add unsafe parameter to CookieJar (`#968 `_) - Use unsafe cookie jar in test client tools @@ -544,91 +1583,91 @@ 0.22.1 (07-16-2016) -------------------- +=================== -- Large cookie expiration/max-age doesn't break an event loop from now - (fixes #967) +- Large cookie expiration/max-age does not break an event loop from now + (fixes (`#967 `_)) 0.22.0 (07-15-2016) -------------------- +=================== -- Fix bug in serving static directory #803 +- Fix bug in serving static directory (`#803 `_) -- Fix command line arg parsing #797 +- Fix command line arg parsing (`#797 `_) -- Fix a documentation chapter about cookie usage #790 +- Fix a documentation chapter about cookie usage (`#790 `_) -- Handle empty body with gzipped encoding #758 +- Handle empty body with gzipped encoding (`#758 `_) -- Support 451 Unavailable For Legal Reasons http status #697 +- Support 451 Unavailable For Legal Reasons http status (`#697 `_) -- Fix Cookie share example and few small typos in docs #817 +- Fix Cookie share example and few small typos in docs (`#817 `_) -- UrlDispatcher.add_route with partial coroutine handler #814 +- UrlDispatcher.add_route with partial coroutine handler (`#814 `_) -- Optional support for aiodns #728 +- Optional support for aiodns (`#728 `_) -- Add ServiceRestart and TryAgainLater websocket close codes #828 +- Add ServiceRestart and TryAgainLater websocket close codes (`#828 `_) -- Fix prompt message for `web.run_app` #832 +- Fix prompt message for `web.run_app` (`#832 `_) -- Allow to pass None as a timeout value to disable timeout logic #834 +- Allow to pass None as a timeout value to disable timeout logic (`#834 `_) -- Fix leak of connection slot during connection error #835 +- Fix leak of connection slot during connection error (`#835 `_) - Gunicorn worker with uvloop support - `aiohttp.worker.GunicornUVLoopWebWorker` #878 + `aiohttp.worker.GunicornUVLoopWebWorker` (`#878 `_) -- Don't send body in response to HEAD request #838 +- Don't send body in response to HEAD request (`#838 `_) -- Skip the preamble in MultipartReader #881 +- Skip the preamble in MultipartReader (`#881 `_) -- Implement BasicAuth decode classmethod. #744 +- Implement BasicAuth decode classmethod. (`#744 `_) -- Don't crash logger when transport is None #889 +- Don't crash logger when transport is None (`#889 `_) - Use a create_future compatibility wrapper instead of creating - Futures directly #896 + Futures directly (`#896 `_) -- Add test utilities to aiohttp #902 +- Add test utilities to aiohttp (`#902 `_) -- Improve Request.__repr__ #875 +- Improve Request.__repr__ (`#875 `_) -- Skip DNS resolving if provided host is already an ip address #874 +- Skip DNS resolving if provided host is already an ip address (`#874 `_) -- Add headers to ClientSession.ws_connect #785 +- Add headers to ClientSession.ws_connect (`#785 `_) -- Document that server can send pre-compressed data #906 +- Document that server can send pre-compressed data (`#906 `_) -- Don't add Content-Encoding and Transfer-Encoding if no body #891 +- Don't add Content-Encoding and Transfer-Encoding if no body (`#891 `_) -- Add json() convenience methods to websocket message objects #897 +- Add json() convenience methods to websocket message objects (`#897 `_) -- Add client_resp.raise_for_status() #908 +- Add client_resp.raise_for_status() (`#908 `_) -- Implement cookie filter #799 +- Implement cookie filter (`#799 `_) -- Include an example of middleware to handle error pages #909 +- Include an example of middleware to handle error pages (`#909 `_) -- Fix error handling in StaticFileMixin #856 +- Fix error handling in StaticFileMixin (`#856 `_) -- Add mocked request helper #900 +- Add mocked request helper (`#900 `_) -- Fix empty ALLOW Response header for cls based View #929 +- Fix empty ALLOW Response header for cls based View (`#929 `_) -- Respect CONNECT method to implement a proxy server #847 +- Respect CONNECT method to implement a proxy server (`#847 `_) -- Add pytest_plugin #914 +- Add pytest_plugin (`#914 `_) - Add tutorial - Add backlog option to support more than 128 (default value in - "create_server" function) concurrent connections #892 + "create_server" function) concurrent connections (`#892 `_) -- Allow configuration of header size limits #912 +- Allow configuration of header size limits (`#912 `_) -- Separate sending file logic from StaticRoute dispatcher #901 +- Separate sending file logic from StaticRoute dispatcher (`#901 `_) - Drop deprecated share_cookies connector option (BACKWARD INCOMPATIBLE) @@ -641,65 +1680,65 @@ - Drop all mentions about api changes in documentation for versions older than 0.16 -- Allow to override default cookie jar #963 +- Allow to override default cookie jar (`#963 `_) - Add manylinux wheel builds -- Dup a socket for sendfile usage #964 +- Dup a socket for sendfile usage (`#964 `_) 0.21.6 (05-05-2016) -------------------- +=================== -- Drop initial query parameters on redirects #853 +- Drop initial query parameters on redirects (`#853 `_) 0.21.5 (03-22-2016) -------------------- +=================== -- Fix command line arg parsing #797 +- Fix command line arg parsing (`#797 `_) 0.21.4 (03-12-2016) -------------------- +=================== - Fix ResourceAdapter: don't add method to allowed if resource is not - match #826 + match (`#826 `_) - Fix Resource: append found method to returned allowed methods 0.21.2 (02-16-2016) -------------------- +=================== - Fix a regression: support for handling ~/path in static file routes was - broken #782 + broken (`#782 `_) 0.21.1 (02-10-2016) -------------------- +=================== -- Make new resources classes public #767 +- Make new resources classes public (`#767 `_) - Add `router.resources()` view - Fix cmd-line parameter names in doc 0.21.0 (02-04-2016) --------------------- +=================== -- Introduce on_shutdown signal #722 +- Introduce on_shutdown signal (`#722 `_) -- Implement raw input headers #726 +- Implement raw input headers (`#726 `_) -- Implement web.run_app utility function #734 +- Implement web.run_app utility function (`#734 `_) - Introduce on_cleanup signal - Deprecate Application.finish() / Application.register_on_finish() in favor of on_cleanup. -- Get rid of bare aiohttp.request(), aiohttp.get() and family in docs #729 +- Get rid of bare aiohttp.request(), aiohttp.get() and family in docs (`#729 `_) -- Deprecate bare aiohttp.request(), aiohttp.get() and family #729 +- Deprecate bare aiohttp.request(), aiohttp.get() and family (`#729 `_) -- Refactor keep-alive support #737: +- Refactor keep-alive support (`#737 `_): - Enable keepalive for HTTP 1.0 by default @@ -718,18 +1757,18 @@ - don't send `Connection` header for HTTP 1.0 - Add version parameter to ClientSession constructor, - deprecate it for session.request() and family #736 + deprecate it for session.request() and family (`#736 `_) -- Enable access log by default #735 +- Enable access log by default (`#735 `_) - Deprecate app.router.register_route() (the method was not documented intentionally BTW). - Deprecate app.router.named_routes() in favor of app.router.named_resources() -- route.add_static accepts pathlib.Path now #743 +- route.add_static accepts pathlib.Path now (`#743 `_) -- Add command line support: `$ python -m aiohttp.web package.main` #740 +- Add command line support: `$ python -m aiohttp.web package.main` (`#740 `_) - FAQ section was added to docs. Enjoy and fill free to contribute new topics @@ -737,63 +1776,63 @@ - Document ClientResponse's host, method, url properties -- Use CORK/NODELAY in client API #748 +- Use CORK/NODELAY in client API (`#748 `_) - ClientSession.close and Connector.close are coroutines now - Close client connection on exception in ClientResponse.release() -- Allow to read multipart parts without content-length specified #750 +- Allow to read multipart parts without content-length specified (`#750 `_) -- Add support for unix domain sockets to gunicorn worker #470 +- Add support for unix domain sockets to gunicorn worker (`#470 `_) -- Add test for default Expect handler #601 +- Add test for default Expect handler (`#601 `_) - Add the first demo project -- Rename `loader` keyword argument in `web.Request.json` method. #646 +- Rename `loader` keyword argument in `web.Request.json` method. (`#646 `_) -- Add local socket binding for TCPConnector #678 +- Add local socket binding for TCPConnector (`#678 `_) 0.20.2 (01-07-2016) --------------------- +=================== -- Enable use of `await` for a class based view #717 +- Enable use of `await` for a class based view (`#717 `_) -- Check address family to fill wsgi env properly #718 +- Check address family to fill wsgi env properly (`#718 `_) -- Fix memory leak in headers processing (thanks to Marco Paolini) #723 +- Fix memory leak in headers processing (thanks to Marco Paolini) (`#723 `_) 0.20.1 (12-30-2015) -------------------- +=================== - Raise RuntimeError is Timeout context manager was used outside of task context. -- Add number of bytes to stream.read_nowait #700 +- Add number of bytes to stream.read_nowait (`#700 `_) - Use X-FORWARDED-PROTO for wsgi.url_scheme when available 0.20.0 (12-28-2015) -------------------- +=================== - Extend list of web exceptions, add HTTPMisdirectedRequest, HTTPUpgradeRequired, HTTPPreconditionRequired, HTTPTooManyRequests, HTTPRequestHeaderFieldsTooLarge, HTTPVariantAlsoNegotiates, - HTTPNotExtended, HTTPNetworkAuthenticationRequired status codes #644 + HTTPNotExtended, HTTPNetworkAuthenticationRequired status codes (`#644 `_) -- Do not remove AUTHORIZATION header by WSGI handler #649 +- Do not remove AUTHORIZATION header by WSGI handler (`#649 `_) -- Fix broken support for https proxies with authentication #617 +- Fix broken support for https proxies with authentication (`#617 `_) - Get REMOTE_* and SEVER_* http vars from headers when listening on - unix socket #654 + unix socket (`#654 `_) -- Add HTTP 308 support #663 +- Add HTTP 308 support (`#663 `_) - Add Tf format (time to serve request in seconds, %06f format) to - access log #669 + access log (`#669 `_) - Remove one and a half years long deprecated ClientResponse.read_and_close() method @@ -802,105 +1841,105 @@ on sending chunked encoded data - Use TCP_CORK and TCP_NODELAY to optimize network latency and - throughput #680 + throughput (`#680 `_) -- Websocket XOR performance improved #687 +- Websocket XOR performance improved (`#687 `_) -- Avoid sending cookie attributes in Cookie header #613 +- Avoid sending cookie attributes in Cookie header (`#613 `_) - Round server timeouts to seconds for grouping pending calls. That - leads to less amount of poller syscalls e.g. epoll.poll(). #702 + leads to less amount of poller syscalls e.g. epoll.poll(). (`#702 `_) -- Close connection on websocket handshake error #703 +- Close connection on websocket handshake error (`#703 `_) -- Implement class based views #684 +- Implement class based views (`#684 `_) -- Add *headers* parameter to ws_connect() #709 +- Add *headers* parameter to ws_connect() (`#709 `_) -- Drop unused function `parse_remote_addr()` #708 +- Drop unused function `parse_remote_addr()` (`#708 `_) -- Close session on exception #707 +- Close session on exception (`#707 `_) -- Store http code and headers in WSServerHandshakeError #706 +- Store http code and headers in WSServerHandshakeError (`#706 `_) -- Make some low-level message properties readonly #710 +- Make some low-level message properties readonly (`#710 `_) 0.19.0 (11-25-2015) -------------------- +=================== -- Memory leak in ParserBuffer #579 +- Memory leak in ParserBuffer (`#579 `_) - Support gunicorn's `max_requests` settings in gunicorn worker -- Fix wsgi environment building #573 +- Fix wsgi environment building (`#573 `_) -- Improve access logging #572 +- Improve access logging (`#572 `_) -- Drop unused host and port from low-level server #586 +- Drop unused host and port from low-level server (`#586 `_) -- Add Python 3.5 `async for` implementation to server websocket #543 +- Add Python 3.5 `async for` implementation to server websocket (`#543 `_) - Add Python 3.5 `async for` implementation to client websocket - Add Python 3.5 `async with` implementation to client websocket -- Add charset parameter to web.Response constructor #593 +- Add charset parameter to web.Response constructor (`#593 `_) - Forbid passing both Content-Type header and content_type or charset params into web.Response constructor -- Forbid duplicating of web.Application and web.Request #602 +- Forbid duplicating of web.Application and web.Request (`#602 `_) -- Add an option to pass Origin header in ws_connect #607 +- Add an option to pass Origin header in ws_connect (`#607 `_) -- Add json_response function #592 +- Add json_response function (`#592 `_) -- Make concurrent connections respect limits #581 +- Make concurrent connections respect limits (`#581 `_) -- Collect history of responses if redirects occur #614 +- Collect history of responses if redirects occur (`#614 `_) -- Enable passing pre-compressed data in requests #621 +- Enable passing pre-compressed data in requests (`#621 `_) -- Expose named routes via UrlDispatcher.named_routes() #622 +- Expose named routes via UrlDispatcher.named_routes() (`#622 `_) -- Allow disabling sendfile by environment variable AIOHTTP_NOSENDFILE #629 +- Allow disabling sendfile by environment variable AIOHTTP_NOSENDFILE (`#629 `_) - Use ensure_future if available -- Always quote params for Content-Disposition #641 +- Always quote params for Content-Disposition (`#641 `_) -- Support async for in multipart reader #640 +- Support async for in multipart reader (`#640 `_) -- Add Timeout context manager #611 +- Add Timeout context manager (`#611 `_) 0.18.4 (13-11-2015) -------------------- +=================== - Relax rule for router names again by adding dash to allowed characters: they may contain identifiers, dashes, dots and columns 0.18.3 (25-10-2015) -------------------- +=================== -- Fix formatting for _RequestContextManager helper #590 +- Fix formatting for _RequestContextManager helper (`#590 `_) 0.18.2 (22-10-2015) -------------------- +=================== -- Fix regression for OpenSSL < 1.0.0 #583 +- Fix regression for OpenSSL < 1.0.0 (`#583 `_) 0.18.1 (20-10-2015) -------------------- +=================== - Relax rule for router names: they may contain dots and columns starting from now 0.18.0 (19-10-2015) -------------------- +=================== - Use errors.HttpProcessingError.message as HTTP error reason and - message #459 + message (`#459 `_) - Optimize cythonized multidict a bit @@ -908,27 +1947,27 @@ - default headers in ClientSession are now case-insensitive -- Make '=' char and 'wss://' schema safe in urls #477 +- Make '=' char and 'wss://' schema safe in urls (`#477 `_) -- `ClientResponse.close()` forces connection closing by default from now #479 +- `ClientResponse.close()` forces connection closing by default from now (`#479 `_) N.B. Backward incompatible change: was `.close(force=False) Using `force` parameter for the method is deprecated: use `.release()` instead. -- Properly requote URL's path #480 +- Properly requote URL's path (`#480 `_) -- add `skip_auto_headers` parameter for client API #486 +- add `skip_auto_headers` parameter for client API (`#486 `_) -- Properly parse URL path in aiohttp.web.Request #489 +- Properly parse URL path in aiohttp.web.Request (`#489 `_) -- Raise RuntimeError when chunked enabled and HTTP is 1.0 #488 +- Raise RuntimeError when chunked enabled and HTTP is 1.0 (`#488 `_) -- Fix a bug with processing io.BytesIO as data parameter for client API #500 +- Fix a bug with processing io.BytesIO as data parameter for client API (`#500 `_) -- Skip auto-generation of Content-Type header #507 +- Skip auto-generation of Content-Type header (`#507 `_) -- Use sendfile facility for static file handling #503 +- Use sendfile facility for static file handling (`#503 `_) - Default `response_factory` in `app.router.add_static` now is `StreamResponse`, not `None`. The functionality is not changed if @@ -937,97 +1976,97 @@ - Drop `ClientResponse.message` attribute, it was always implementation detail. - Streams are optimized for speed and mostly memory in case of a big - HTTP message sizes #496 + HTTP message sizes (`#496 `_) - Fix a bug for server-side cookies for dropping cookie and setting it again without Max-Age parameter. -- Don't trim redirect URL in client API #499 +- Don't trim redirect URL in client API (`#499 `_) -- Extend precision of access log "D" to milliseconds #527 +- Extend precision of access log "D" to milliseconds (`#527 `_) - Deprecate `StreamResponse.start()` method in favor of - `StreamResponse.prepare()` coroutine #525 + `StreamResponse.prepare()` coroutine (`#525 `_) `.start()` is still supported but responses begun with `.start()` - doesn't call signal for response preparing to be sent. + does not call signal for response preparing to be sent. - Add `StreamReader.__repr__` - Drop Python 3.3 support, from now minimal required version is Python - 3.4.1 #541 + 3.4.1 (`#541 `_) -- Add `async with` support for `ClientSession.request()` and family #536 +- Add `async with` support for `ClientSession.request()` and family (`#536 `_) -- Ignore message body on 204 and 304 responses #505 +- Ignore message body on 204 and 304 responses (`#505 `_) -- `TCPConnector` processed both IPv4 and IPv6 by default #559 +- `TCPConnector` processed both IPv4 and IPv6 by default (`#559 `_) -- Add `.routes()` view for urldispatcher #519 +- Add `.routes()` view for urldispatcher (`#519 `_) -- Route name should be a valid identifier name from now #567 +- Route name should be a valid identifier name from now (`#567 `_) -- Implement server signals #562 +- Implement server signals (`#562 `_) - Drop a year-old deprecated *files* parameter from client API. -- Added `async for` support for aiohttp stream #542 +- Added `async for` support for aiohttp stream (`#542 `_) 0.17.4 (09-29-2015) -------------------- +=================== -- Properly parse URL path in aiohttp.web.Request #489 +- Properly parse URL path in aiohttp.web.Request (`#489 `_) - Add missing coroutine decorator, the client api is await-compatible now 0.17.3 (08-28-2015) ---------------------- +=================== -- Remove Content-Length header on compressed responses #450 +- Remove Content-Length header on compressed responses (`#450 `_) - Support Python 3.5 -- Improve performance of transport in-use list #472 +- Improve performance of transport in-use list (`#472 `_) -- Fix connection pooling #473 +- Fix connection pooling (`#473 `_) 0.17.2 (08-11-2015) ---------------------- +=================== -- Don't forget to pass `data` argument forward #462 +- Don't forget to pass `data` argument forward (`#462 `_) -- Fix multipart read bytes count #463 +- Fix multipart read bytes count (`#463 `_) 0.17.1 (08-10-2015) ---------------------- +=================== - Fix multidict comparison to arbitrary abc.Mapping 0.17.0 (08-04-2015) ---------------------- +=================== -- Make StaticRoute support Last-Modified and If-Modified-Since headers #386 +- Make StaticRoute support Last-Modified and If-Modified-Since headers (`#386 `_) - Add Request.if_modified_since and Stream.Response.last_modified properties -- Fix deflate compression when writing a chunked response #395 +- Fix deflate compression when writing a chunked response (`#395 `_) - Request`s content-length header is cleared now after redirect from - POST method #391 + POST method (`#391 `_) -- Return a 400 if server received a non HTTP content #405 +- Return a 400 if server received a non HTTP content (`#405 `_) -- Fix keep-alive support for aiohttp clients #406 +- Fix keep-alive support for aiohttp clients (`#406 `_) -- Allow gzip compression in high-level server response interface #403 +- Allow gzip compression in high-level server response interface (`#403 `_) -- Rename TCPConnector.resolve and family to dns_cache #415 +- Rename TCPConnector.resolve and family to dns_cache (`#415 `_) -- Make UrlDispatcher ignore quoted characters during url matching #414 +- Make UrlDispatcher ignore quoted characters during url matching (`#414 `_) Backward-compatibility warning: this may change the url matched by - your queries if they send quoted character (like %2F for /) #414 + your queries if they send quoted character (like %2F for /) (`#414 `_) -- Use optional cchardet accelerator if present #418 +- Use optional cchardet accelerator if present (`#418 `_) - Borrow loop from Connector in ClientSession if loop is not set @@ -1036,79 +2075,79 @@ - Add toplevel get(), post(), put(), head(), delete(), options(), patch() coroutines. -- Fix IPv6 support for client API #425 +- Fix IPv6 support for client API (`#425 `_) -- Pass SSL context through proxy connector #421 +- Pass SSL context through proxy connector (`#421 `_) - Make the rule: path for add_route should start with slash - Don't process request finishing by low-level server on closed event loop -- Don't override data if multiple files are uploaded with same key #433 +- Don't override data if multiple files are uploaded with same key (`#433 `_) - Ensure multipart.BodyPartReader.read_chunk read all the necessary data to avoid false assertions about malformed multipart payload -- Don't send body for 204, 205 and 304 http exceptions #442 +- Don't send body for 204, 205 and 304 http exceptions (`#442 `_) -- Correctly skip Cython compilation in MSVC not found #453 +- Correctly skip Cython compilation in MSVC not found (`#453 `_) -- Add response factory to StaticRoute #456 +- Add response factory to StaticRoute (`#456 `_) -- Don't append trailing CRLF for multipart.BodyPartReader #454 +- Don't append trailing CRLF for multipart.BodyPartReader (`#454 `_) 0.16.6 (07-15-2015) -------------------- +=================== -- Skip compilation on Windows if vcvarsall.bat cannot be found #438 +- Skip compilation on Windows if vcvarsall.bat cannot be found (`#438 `_) 0.16.5 (06-13-2015) -------------------- +=================== -- Get rid of all comprehensions and yielding in _multidict #410 +- Get rid of all comprehensions and yielding in _multidict (`#410 `_) 0.16.4 (06-13-2015) -------------------- +=================== - Don't clear current exception in multidict's `__repr__` (cythonized - versions) #410 + versions) (`#410 `_) 0.16.3 (05-30-2015) -------------------- +=================== -- Fix StaticRoute vulnerability to directory traversal attacks #380 +- Fix StaticRoute vulnerability to directory traversal attacks (`#380 `_) 0.16.2 (05-27-2015) -------------------- +=================== - Update python version required for `__del__` usage: it's actually 3.4.1 instead of 3.4.0 - Add check for presence of loop.is_closed() method before call the - former #378 + former (`#378 `_) 0.16.1 (05-27-2015) -------------------- +=================== -- Fix regression in static file handling #377 +- Fix regression in static file handling (`#377 `_) 0.16.0 (05-26-2015) -------------------- +=================== -- Unset waiter future after cancellation #363 +- Unset waiter future after cancellation (`#363 `_) -- Update request url with query parameters #372 +- Update request url with query parameters (`#372 `_) - Support new `fingerprint` param of TCPConnector to enable verifying - SSL certificates via MD5, SHA1, or SHA256 digest #366 + SSL certificates via MD5, SHA1, or SHA256 digest (`#366 `_) - Setup uploaded filename if field value is binary and transfer - encoding is not specified #349 + encoding is not specified (`#349 `_) - Implement `ClientSession.close()` method @@ -1123,34 +2162,34 @@ - Add `__del__` to client-side objects: sessions, connectors, connections, requests, responses. -- Refactor connections cleanup by connector #357 +- Refactor connections cleanup by connector (`#357 `_) -- Add `limit` parameter to connector constructor #358 +- Add `limit` parameter to connector constructor (`#358 `_) -- Add `request.has_body` property #364 +- Add `request.has_body` property (`#364 `_) -- Add `response_class` parameter to `ws_connect()` #367 +- Add `response_class` parameter to `ws_connect()` (`#367 `_) -- `ProxyConnector` doesn't support keep-alive requests by default - starting from now #368 +- `ProxyConnector` does not support keep-alive requests by default + starting from now (`#368 `_) - Add `connector.force_close` property -- Add ws_connect to ClientSession #374 +- Add ws_connect to ClientSession (`#374 `_) - Support optional `chunk_size` parameter in `router.add_static()` 0.15.3 (04-22-2015) -------------------- +=================== - Fix graceful shutdown handling -- Fix `Expect` header handling for not found and not allowed routes #340 +- Fix `Expect` header handling for not found and not allowed routes (`#340 `_) 0.15.2 (04-19-2015) -------------------- +=================== - Flow control subsystem refactoring @@ -1158,21 +2197,21 @@ - Allow to match any request method with `*` -- Explicitly call drain on transport #316 +- Explicitly call drain on transport (`#316 `_) -- Make chardet module dependency mandatory #318 +- Make chardet module dependency mandatory (`#318 `_) -- Support keep-alive for HTTP 1.0 #325 +- Support keep-alive for HTTP 1.0 (`#325 `_) -- Do not chunk single file during upload #327 +- Do not chunk single file during upload (`#327 `_) -- Add ClientSession object for cookie storage and default headers #328 +- Add ClientSession object for cookie storage and default headers (`#328 `_) - Add `keep_alive_on` argument for HTTP server handler. 0.15.1 (03-31-2015) -------------------- +=================== - Pass Autobahn Testsuite tests @@ -1188,17 +2227,17 @@ 0.15.0 (03-27-2015) -------------------- +=================== - Client WebSockets support -- New Multipart system #273 +- New Multipart system (`#273 `_) -- Support for "Except" header #287 #267 +- Support for "Except" header (`#287 `_) (`#267 `_) -- Set default Content-Type for post requests #184 +- Set default Content-Type for post requests (`#184 `_) -- Fix issue with construction dynamic route with regexps and trailing slash #266 +- Fix issue with construction dynamic route with regexps and trailing slash (`#266 `_) - Add repr to web.Request @@ -1208,37 +2247,37 @@ - Add repr for web.Application -- Add repr to UrlMappingMatchInfo #217 +- Add repr to UrlMappingMatchInfo (`#217 `_) - Gunicorn 19.2.x compatibility 0.14.4 (01-29-2015) -------------------- +=================== -- Fix issue with error during constructing of url with regex parts #264 +- Fix issue with error during constructing of url with regex parts (`#264 `_) 0.14.3 (01-28-2015) -------------------- +=================== -- Use path='/' by default for cookies #261 +- Use path='/' by default for cookies (`#261 `_) 0.14.2 (01-23-2015) -------------------- +=================== -- Connections leak in BaseConnector #253 +- Connections leak in BaseConnector (`#253 `_) -- Do not swallow websocket reader exceptions #255 +- Do not swallow websocket reader exceptions (`#255 `_) -- web.Request's read, text, json are memorized #250 +- web.Request's read, text, json are memorized (`#250 `_) 0.14.1 (01-15-2015) -------------------- +=================== -- HttpMessage._add_default_headers does not overwrite existing headers #216 +- HttpMessage._add_default_headers does not overwrite existing headers (`#216 `_) - Expose multidict classes at package level @@ -1277,56 +2316,56 @@ - Server has 75 seconds keepalive timeout now, was non-keepalive by default. -- Application doesn't accept `**kwargs` anymore (#243). +- Application does not accept `**kwargs` anymore ((`#243 `_)). - Request is inherited from dict now for making per-request storage to - middlewares (#242). + middlewares ((`#242 `_)). 0.13.1 (12-31-2014) --------------------- +=================== -- Add `aiohttp.web.StreamResponse.started` property #213 +- Add `aiohttp.web.StreamResponse.started` property (`#213 `_) - HTML escape traceback text in `ServerHttpProtocol.handle_error` - Mention handler and middlewares in `aiohttp.web.RequestHandler.handle_request` - on error (#218) + on error ((`#218 `_)) 0.13.0 (12-29-2014) -------------------- +=================== - `StreamResponse.charset` converts value to lower-case on assigning. - Chain exceptions when raise `ClientRequestError`. -- Support custom regexps in route variables #204 +- Support custom regexps in route variables (`#204 `_) - Fixed graceful shutdown, disable keep-alive on connection closing. - Decode HTTP message with `utf-8` encoding, some servers send headers - in utf-8 encoding #207 + in utf-8 encoding (`#207 `_) -- Support `aiohtt.web` middlewares #209 +- Support `aiohtt.web` middlewares (`#209 `_) -- Add ssl_context to TCPConnector #206 +- Add ssl_context to TCPConnector (`#206 `_) 0.12.0 (12-12-2014) -------------------- +=================== - Deep refactoring of `aiohttp.web` in backward-incompatible manner. Sorry, we have to do this. - Automatically force aiohttp.web handlers to coroutines in - `UrlDispatcher.add_route()` #186 + `UrlDispatcher.add_route()` (`#186 `_) - Rename `Request.POST()` function to `Request.post()` - Added POST attribute -- Response processing refactoring: constructor doesn't accept Request +- Response processing refactoring: constructor does not accept Request instance anymore. - Pass application instance to finish callback @@ -1347,21 +2386,21 @@ 0.11.0 (11-29-2014) -------------------- +=================== -- Support named routes in `aiohttp.web.UrlDispatcher` #179 +- Support named routes in `aiohttp.web.UrlDispatcher` (`#179 `_) -- Make websocket subprotocols conform to spec #181 +- Make websocket subprotocols conform to spec (`#181 `_) 0.10.2 (11-19-2014) -------------------- +=================== -- Don't unquote `environ['PATH_INFO']` in wsgi.py #177 +- Don't unquote `environ['PATH_INFO']` in wsgi.py (`#177 `_) 0.10.1 (11-17-2014) -------------------- +=================== - aiohttp.web.HTTPException and descendants now files response body with string like `404: NotFound` @@ -1371,7 +2410,7 @@ 0.10.0 (11-13-2014) -------------------- +=================== - Add aiohttp.web subpackage for highlevel HTTP server support. @@ -1383,64 +2422,64 @@ from 'Can not read status line' to explicit 'Connection closed by server' -- Drop closed connections from connector #173 +- Drop closed connections from connector (`#173 `_) -- Set server.transport to None on .closing() #172 +- Set server.transport to None on .closing() (`#172 `_) 0.9.3 (10-30-2014) ------------------- +================== -- Fix compatibility with asyncio 3.4.1+ #170 +- Fix compatibility with asyncio 3.4.1+ (`#170 `_) 0.9.2 (10-16-2014) ------------------- +================== -- Improve redirect handling #157 +- Improve redirect handling (`#157 `_) -- Send raw files as is #153 +- Send raw files as is (`#153 `_) -- Better websocket support #150 +- Better websocket support (`#150 `_) 0.9.1 (08-30-2014) ------------------- +================== -- Added MultiDict support for client request params and data #114. +- Added MultiDict support for client request params and data (`#114 `_). -- Fixed parameter type for IncompleteRead exception #118. +- Fixed parameter type for IncompleteRead exception (`#118 `_). -- Strictly require ASCII headers names and values #137 +- Strictly require ASCII headers names and values (`#137 `_) -- Keep port in ProxyConnector #128. +- Keep port in ProxyConnector (`#128 `_). -- Python 3.4.1 compatibility #131. +- Python 3.4.1 compatibility (`#131 `_). 0.9.0 (07-08-2014) ------------------- +================== -- Better client basic authentication support #112. +- Better client basic authentication support (`#112 `_). -- Fixed incorrect line splitting in HttpRequestParser #97. +- Fixed incorrect line splitting in HttpRequestParser (`#97 `_). - Support StreamReader and DataQueue as request data. -- Client files handling refactoring #20. +- Client files handling refactoring (`#20 `_). - Backward incompatible: Replace DataQueue with StreamReader for - request payload #87. + request payload (`#87 `_). 0.8.4 (07-04-2014) ------------------- +================== - Change ProxyConnector authorization parameters. 0.8.3 (07-03-2014) ------------------- +================== - Publish TCPConnector properties: verify_ssl, family, resolve, resolved_hosts. @@ -1450,7 +2489,7 @@ 0.8.2 (06-22-2014) ------------------- +================== - Make ProxyConnector.proxy immutable property. @@ -1462,7 +2501,7 @@ 0.8.1 (06-18-2014) ------------------- +================== - Use case insensitive multidict for server request/response headers. @@ -1476,7 +2515,7 @@ 0.8.0 (06-06-2014) ------------------- +================== - Add support for utf-8 values in HTTP headers @@ -1494,13 +2533,13 @@ 0.7.3 (05-20-2014) ------------------- +================== - Simple HTTP proxy support. 0.7.2 (05-14-2014) ------------------- +================== - Get rid of `__del__` methods @@ -1508,7 +2547,7 @@ 0.7.1 (04-28-2014) ------------------- +================== - Do not unquote client request urls. @@ -1521,7 +2560,7 @@ 0.7.0 (04-16-2014) ------------------- +================== - Connection flow control. @@ -1531,7 +2570,7 @@ 0.6.5 (03-29-2014) ------------------- +================== - Added client session reuse timeout. @@ -1543,21 +2582,21 @@ 0.6.4 (02-27-2014) ------------------- +================== - Log content-length missing warning only for put and post requests. 0.6.3 (02-27-2014) ------------------- +================== - Better support for server exit. -- Read response body until EOF if content-length is not defined #14 +- Read response body until EOF if content-length is not defined (`#14 `_) 0.6.2 (02-18-2014) ------------------- +================== - Fix trailing char in allowed_methods. @@ -1565,7 +2604,7 @@ 0.6.1 (02-17-2014) ------------------- +================== - Added utility method HttpResponse.read_and_close() @@ -1575,13 +2614,13 @@ 0.6.0 (02-12-2014) ------------------- +================== - Better handling for process exit. 0.5.0 (01-29-2014) ------------------- +================== - Allow to use custom HttpRequest client class. @@ -1593,20 +2632,20 @@ 0.4.4 (11-15-2013) ------------------- +================== - Resolve only AF_INET family, because it is not clear how to pass extra info to asyncio. 0.4.3 (11-15-2013) ------------------- +================== - Allow to wait completion of request with `HttpResponse.wait_for_close()` 0.4.2 (11-14-2013) ------------------- +================== - Handle exception in client request stream. @@ -1614,13 +2653,13 @@ 0.4.1 (11-12-2013) ------------------- +================== - Added client support for `expect: 100-continue` header. 0.4 (11-06-2013) ----------------- +================ - Added custom wsgi application close procedure @@ -1628,7 +2667,7 @@ 0.3 (11-04-2013) ----------------- +================ - Added PortMapperWorker @@ -1642,6 +2681,6 @@ 0.2 ---- +=== - Fix packaging diff --git a/LICENSE.txt b/LICENSE.txt index 378d07d4706..90c9d01bc5a 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -186,7 +186,7 @@ Apache License same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2013-2017 Nikolay Kim and Andrew Svetlov + Copyright 2013-2020 aiohttp maintainers Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/MANIFEST.in b/MANIFEST.in index 2bade871b0b..05084efddb9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -7,13 +7,14 @@ graft aiohttp graft docs graft examples graft tests +recursive-include vendor * +global-include aiohttp *.pyi global-exclude *.pyc -exclude aiohttp/_multidict.html -exclude aiohttp/_multidict.*.so -exclude aiohttp/_multidict.pyd -exclude aiohttp/_multidict.*.pyd -exclude aiohttp/_websocket.html -exclude aiohttp/_websocket.*.so -exclude aiohttp/_websocket.pyd -exclude aiohttp/_websocket.*.pyd +global-exclude *.pyd +global-exclude *.so +global-exclude *.lib +global-exclude *.dll +global-exclude *.a +global-exclude *.obj +exclude aiohttp/*.html prune docs/_build diff --git a/Makefile b/Makefile index 09048eeba76..13cd76487eb 100644 --- a/Makefile +++ b/Makefile @@ -1,63 +1,102 @@ # Some simple testing tasks (sorry, UNIX only). -.install-deps: requirements-dev.txt - @pip install -U -r requirements-dev.txt +to-hash-one = $(dir $1).hash/$(addsuffix .hash,$(notdir $1)) +to-hash = $(foreach fname,$1,$(call to-hash-one,$(fname))) + +CYS := $(wildcard aiohttp/*.pyx) $(wildcard aiohttp/*.pyi) $(wildcard aiohttp/*.pxd) +PYXS := $(wildcard aiohttp/*.pyx) +CS := $(wildcard aiohttp/*.c) +PYS := $(wildcard aiohttp/*.py) +REQS := $(wildcard requirements/*.txt) +ALLS := $(sort $(CYS) $(CS) $(PYS) $(REQS)) + +.PHONY: all +all: test + +tst: + @echo $(call to-hash,requirements/cython.txt) + @echo $(call to-hash,aiohttp/%.pyx) + + +# Recipe from https://www.cmcrossroads.com/article/rebuilding-when-files-checksum-changes +FORCE: + +# check_sum.py works perfectly fine but slow when called for every file from $(ALLS) +# (perhaps even several times for each file). +# That is why much less readable but faster solution exists +ifneq (, $(shell which sha256sum)) +%.hash: FORCE + $(eval $@_ABS := $(abspath $@)) + $(eval $@_NAME := $($@_ABS)) + $(eval $@_HASHDIR := $(dir $($@_ABS))) + $(eval $@_TMP := $($@_HASHDIR)../$(notdir $($@_ABS))) + $(eval $@_ORIG := $(subst /.hash/../,/,$(basename $($@_TMP)))) + @#echo ==== $($@_ABS) $($@_HASHDIR) $($@_NAME) $($@_TMP) $($@_ORIG) + @if ! (sha256sum --check $($@_ABS) 1>/dev/null 2>/dev/null); then \ + mkdir -p $($@_HASHDIR); \ + echo re-hash $($@_ORIG); \ + sha256sum $($@_ORIG) > $($@_ABS); \ + fi +else +%.hash: FORCE + @./tools/check_sum.py $@ # --debug +endif + +# Enumerate intermediate files to don't remove them automatically. +.SECONDARY: $(call to-hash,$(ALLS)) + + +.install-cython: $(call to-hash,requirements/cython.txt) + pip install -r requirements/cython.txt + @touch .install-cython + +aiohttp/_find_header.c: $(call to-hash,aiohttp/hdrs.py ./tools/gen.py) + ./tools/gen.py + +# _find_headers generator creates _headers.pyi as well +aiohttp/%.c: aiohttp/%.pyx $(call to-hash,$(CYS)) aiohttp/_find_header.c + cython -3 -o $@ $< -I aiohttp + + +.PHONY: cythonize +cythonize: .install-cython $(PYXS:.pyx=.c) + +.install-deps: .install-cython $(PYXS:.pyx=.c) $(call to-hash,$(CYS) $(REQS)) + pip install -r requirements/dev.txt @touch .install-deps -isort: - isort -rc aiohttp - isort -rc tests - isort -rc examples - isort -rc demos - -flake: .flake - -.flake: .install-deps $(shell find aiohttp -type f) \ - $(shell find tests -type f) \ - $(shell find benchmark -type f) \ - $(shell find examples -type f) \ - $(shell find demos -type f) - @flake8 aiohttp --exclude=aiohttp/backport_cookies.py - @if python -c "import sys; sys.exit(sys.version_info < (3,5))"; then \ - flake8 examples tests demos benchmark && \ - python setup.py check -rms; \ - fi - @if ! isort -c -rc aiohttp tests examples; then \ - echo "Import sort errors, run 'make isort' to fix them!!!"; \ - isort --diff -rc aiohttp tests benchmark examples; \ - false; \ - fi - @touch .flake +.PHONY: lint +lint: fmt mypy +.PHONY: fmt format +fmt format: + python -m pre_commit run --all-files --show-diff-on-failure -.develop: .install-deps $(shell find aiohttp -type f) .flake - @pip install -e . +.PHONY: mypy +mypy: + mypy aiohttp + +.develop: .install-deps $(call to-hash,$(PYS) $(CYS) $(CS)) + pip install -e . @touch .develop +.PHONY: test test: .develop - @py.test -q ./tests + @pytest -q +.PHONY: vtest vtest: .develop - @py.test -s -v ./tests - -cov cover coverage: - tox - -cov-dev: .develop - @py.test --cov=aiohttp --cov-report=term --cov-report=html tests - @echo "open file://`pwd`/coverage/index.html" + @pytest -s -v -cov-dev-full: .develop - @echo "Run without extensions" - @AIOHTTP_NO_EXTENSIONS=1 py.test --cov=aiohttp tests - @echo "Run in debug mode" - @PYTHONASYNCIODEBUG=1 py.test --cov=aiohttp --cov-append tests - @echo "Regular run" - @py.test --cov=aiohttp --cov-report=term --cov-report=html --cov-append tests - @echo "open file://`pwd`/coverage/index.html" +.PHONY: vvtest +vvtest: .develop + @pytest -vv +.PHONY: clean clean: @rm -rf `find . -name __pycache__` + @rm -rf `find . -name .hash` + @rm -rf `find . -name .md5` # old styling @rm -f `find . -type f -name '*.py[co]' ` @rm -f `find . -type f -name '*~' ` @rm -f `find . -type f -name '.*~' ` @@ -65,35 +104,41 @@ clean: @rm -f `find . -type f -name '#*#' ` @rm -f `find . -type f -name '*.orig' ` @rm -f `find . -type f -name '*.rej' ` + @rm -f `find . -type f -name '*.md5' ` # old styling @rm -f .coverage - @rm -rf coverage + @rm -rf htmlcov @rm -rf build @rm -rf cover @make -C docs clean @python setup.py clean - @rm -f aiohttp/_multidict.html - @rm -f aiohttp/_multidict.c - @rm -f aiohttp/_multidict.*.so - @rm -f aiohttp/_multidict.*.pyd - @rm -f aiohttp/_websocket.html + @rm -f aiohttp/*.so + @rm -f aiohttp/*.pyd + @rm -f aiohttp/*.html + @rm -f aiohttp/_frozenlist.c + @rm -f aiohttp/_find_header.c + @rm -f aiohttp/_http_parser.c + @rm -f aiohttp/_http_writer.c @rm -f aiohttp/_websocket.c - @rm -f aiohttp/_websocket.*.so - @rm -f aiohttp/_websocket.*.pyd - @rm -f aiohttp/_parser.html - @rm -f aiohttp/_parser.c - @rm -f aiohttp/_parser.*.so - @rm -f aiohttp/_parser.*.pyd @rm -rf .tox + @rm -f .develop + @rm -f .flake + @rm -rf aiohttp.egg-info + @rm -f .install-deps + @rm -f .install-cython +.PHONY: doc doc: @make -C docs html SPHINXOPTS="-W -E" @echo "open file://`pwd`/docs/_build/html/index.html" +.PHONY: doc-spelling doc-spelling: @make -C docs spelling SPHINXOPTS="-W -E" +.PHONY: install install: - @pip install -U pip - @pip install -Ur requirements-dev.txt + @pip install -U 'pip' + @pip install -Ur requirements/dev.txt -.PHONY: all build flake test vtest cov clean doc +.PHONY: install-dev +install-dev: .develop diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md deleted file mode 100644 index 084e7ad2a4e..00000000000 --- a/PULL_REQUEST_TEMPLATE.md +++ /dev/null @@ -1,33 +0,0 @@ - - -## What do these changes do? - - - -## Are there changes in behavior for the user? - - - -## Related issue number - - - -## Checklist - -- [ ] I think the code is well written -- [ ] Unit tests for the changes exist -- [ ] Documentation reflects the changes -- [ ] If you provide code modification, please add yourself to `CONTRIBUTORS.txt` - * The format is <Name> <Surname>. - * Please keep alphabetical order, the file is sorted by names. -- [ ] Add a new entry to `CHANGES.rst` - * Choose any open position to avoid merge conflicts with other PRs. - * Add a link to the issue you are fixing (if any) using `#issue_number` format at the end of changelog message. Use Pull Request number if there are no issues for PR or PR covers the issue only partially. diff --git a/README.rst b/README.rst index 31b4f95c06f..338adbcae24 100644 --- a/README.rst +++ b/README.rst @@ -1,85 +1,93 @@ +================================== Async http client/server framework ================================== .. image:: https://raw.githubusercontent.com/aio-libs/aiohttp/master/docs/_static/aiohttp-icon-128x128.png - :height: 64px - :width: 64px - :alt: aiohttp logo + :height: 64px + :width: 64px + :alt: aiohttp logo + +| -.. image:: https://travis-ci.org/aio-libs/aiohttp.svg?branch=master - :target: https://travis-ci.org/aio-libs/aiohttp - :align: right +.. image:: https://github.com/aio-libs/aiohttp/workflows/CI/badge.svg + :target: https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI + :alt: GitHub Actions status for master branch .. image:: https://codecov.io/gh/aio-libs/aiohttp/branch/master/graph/badge.svg - :target: https://codecov.io/gh/aio-libs/aiohttp + :target: https://codecov.io/gh/aio-libs/aiohttp + :alt: codecov.io status for master branch .. image:: https://badge.fury.io/py/aiohttp.svg - :target: https://badge.fury.io/py/aiohttp - - -aiohttp 2.0 release! --------------------- - -For this release we completely refactored low-level implementation of http handling. -Finally `uvloop` gives performance improvement. Overall performance improvement -should be around 70-90% compared to 1.x version. - -We took opportunity to refactor long standing api design problems across whole package. -Client exceptions handling has been cleaned up and now much more straight forward. Client payload -management simplified and allows to extend with any custom type. Client connection pool -implementation has been redesigned as well, now there is no need for actively releasing response objects, -aiohttp handles connection release automatically. + :target: https://pypi.org/project/aiohttp + :alt: Latest PyPI package version -Another major change, we moved aiohttp development to public organization https://github.com/aio-libs +.. image:: https://readthedocs.org/projects/aiohttp/badge/?version=latest + :target: https://docs.aiohttp.org/ + :alt: Latest Read The Docs -With this amount of api changes we had to make backward incompatible changes. Please check this migration document http://aiohttp.readthedocs.io/en/latest/migration.html +.. image:: https://img.shields.io/discourse/status?server=https%3A%2F%2Faio-libs.discourse.group + :target: https://aio-libs.discourse.group + :alt: Discourse status -Please report problems or annoyance with with api to https://github.com/aio-libs/aiohttp +.. image:: https://badges.gitter.im/Join%20Chat.svg + :target: https://gitter.im/aio-libs/Lobby + :alt: Chat on Gitter -Features --------- +Key Features +============ - Supports both client and server side of HTTP protocol. -- Supports both client and server Web-Sockets out-of-the-box. -- Web-server has middlewares and pluggable routing. +- Supports both client and server Web-Sockets out-of-the-box and avoids + Callback Hell. +- Provides Web-server with middlewares and plugable routing. Getting started ---------------- +=============== Client -^^^^^^ +------ -To retrieve something from the web: +To get something from the web: .. code-block:: python import aiohttp import asyncio - async def fetch(session, url): - with aiohttp.Timeout(10, loop=session.loop): - async with session.get(url) as response: - return await response.text() + async def main(): + + async with aiohttp.ClientSession() as session: + async with session.get('http://python.org') as response: + + print("Status:", response.status) + print("Content-type:", response.headers['content-type']) + + html = await response.text() + print("Body:", html[:15], "...") + + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) + +This prints: - async def main(loop): - async with aiohttp.ClientSession(loop=loop) as session: - html = await fetch(session, 'http://python.org') - print(html) +.. code-block:: - if __name__ == '__main__': - loop = asyncio.get_event_loop() - loop.run_until_complete(main(loop)) + Status: 200 + Content-type: text/html; charset=utf-8 + Body: ... +Coming from `requests `_ ? Read `why we need so many lines `_. Server -^^^^^^ +------ -This is simple usage example: +An example using a simple server: .. code-block:: python + # examples/server_simple.py from aiohttp import web async def handle(request): @@ -87,57 +95,72 @@ This is simple usage example: text = "Hello, " + name return web.Response(text=text) - async def wshandler(request): + async def wshandle(request): ws = web.WebSocketResponse() await ws.prepare(request) async for msg in ws: - if msg.type == web.MsgType.text: + if msg.type == web.WSMsgType.text: await ws.send_str("Hello, {}".format(msg.data)) - elif msg.type == web.MsgType.binary: + elif msg.type == web.WSMsgType.binary: await ws.send_bytes(msg.data) - elif msg.type == web.MsgType.close: + elif msg.type == web.WSMsgType.close: break return ws app = web.Application() - app.router.add_get('/echo', wshandler) - app.router.add_get('/', handle) - app.router.add_get('/{name}', handle) + app.add_routes([web.get('/', handle), + web.get('/echo', wshandle), + web.get('/{name}', handle)]) - web.run_app(app) + if __name__ == '__main__': + web.run_app(app) -Note: examples are written for Python 3.5+ and utilize PEP-492 aka -async/await. If you are using Python 3.4 please replace ``await`` with -``yield from`` and ``async def`` with ``@coroutine`` e.g.:: +Documentation +============= - async def coro(...): - ret = await f() +https://aiohttp.readthedocs.io/ -should be replaced by:: - @asyncio.coroutine - def coro(...): - ret = yield from f() +Demos +===== -Documentation -------------- +https://github.com/aio-libs/aiohttp-demos -https://aiohttp.readthedocs.io/ -Discussion list ---------------- +External links +============== + +* `Third party libraries + `_ +* `Built with aiohttp + `_ +* `Powered by aiohttp + `_ + +Feel free to make a Pull Request for adding your link to these pages! + + +Communication channels +====================== + +*aio-libs discourse group*: https://aio-libs.discourse.group + +*gitter chat* https://gitter.im/aio-libs/Lobby -*aio-libs* google group: https://groups.google.com/forum/#!forum/aio-libs +We support `Stack Overflow +`_. +Please add *aiohttp* tag to your question there. Requirements ------------- +============ -- Python >= 3.4.2 +- Python >= 3.6 - async-timeout_ +- attrs_ - chardet_ - multidict_ - yarl_ @@ -147,32 +170,35 @@ recommended for sake of speed). .. _chardet: https://pypi.python.org/pypi/chardet .. _aiodns: https://pypi.python.org/pypi/aiodns +.. _attrs: https://github.com/python-attrs/attrs .. _multidict: https://pypi.python.org/pypi/multidict .. _yarl: https://pypi.python.org/pypi/yarl .. _async-timeout: https://pypi.python.org/pypi/async_timeout .. _cChardet: https://pypi.python.org/pypi/cchardet License -------- +======= ``aiohttp`` is offered under the Apache 2 license. Keepsafe --------- +======== -The aiohttp community would like to thank Keepsafe (https://www.getkeepsafe.com) for it's support in the early days of the project. +The aiohttp community would like to thank Keepsafe +(https://www.getkeepsafe.com) for its support in the early days of +the project. Source code ------------- +=========== -The latest developer version is available in a github repository: +The latest developer version is available in a GitHub repository: https://github.com/aio-libs/aiohttp Benchmarks ----------- +========== -If you are interested in by efficiency, AsyncIO community maintains a +If you are interested in efficiency, the AsyncIO community maintains a list of benchmarks on the official wiki: https://github.com/python/asyncio/wiki/Benchmarks diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index 516270b3099..23cd5c9d6de 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -1,35 +1,217 @@ -__version__ = '2.0.7' +__version__ = "3.7.4" -# This relies on each of the submodules having an __all__ variable. +from typing import Tuple -from . import hdrs # noqa -from .client import * # noqa -from .formdata import * # noqa -from .helpers import * # noqa -from .http import (HttpVersion, HttpVersion10, HttpVersion11, # noqa - WSMsgType, WSCloseCode, WSMessage, WebSocketError) # noqa -from .streams import * # noqa -from .multipart import * # noqa -from .cookiejar import CookieJar # noqa -from .payload import * # noqa -from .payload_streamer import * # noqa -from .resolver import * # noqa +from . import hdrs as hdrs +from .client import ( + BaseConnector as BaseConnector, + ClientConnectionError as ClientConnectionError, + ClientConnectorCertificateError as ClientConnectorCertificateError, + ClientConnectorError as ClientConnectorError, + ClientConnectorSSLError as ClientConnectorSSLError, + ClientError as ClientError, + ClientHttpProxyError as ClientHttpProxyError, + ClientOSError as ClientOSError, + ClientPayloadError as ClientPayloadError, + ClientProxyConnectionError as ClientProxyConnectionError, + ClientRequest as ClientRequest, + ClientResponse as ClientResponse, + ClientResponseError as ClientResponseError, + ClientSession as ClientSession, + ClientSSLError as ClientSSLError, + ClientTimeout as ClientTimeout, + ClientWebSocketResponse as ClientWebSocketResponse, + ContentTypeError as ContentTypeError, + Fingerprint as Fingerprint, + InvalidURL as InvalidURL, + NamedPipeConnector as NamedPipeConnector, + RequestInfo as RequestInfo, + ServerConnectionError as ServerConnectionError, + ServerDisconnectedError as ServerDisconnectedError, + ServerFingerprintMismatch as ServerFingerprintMismatch, + ServerTimeoutError as ServerTimeoutError, + TCPConnector as TCPConnector, + TooManyRedirects as TooManyRedirects, + UnixConnector as UnixConnector, + WSServerHandshakeError as WSServerHandshakeError, + request as request, +) +from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar +from .formdata import FormData as FormData +from .helpers import BasicAuth as BasicAuth, ChainMapProxy as ChainMapProxy +from .http import ( + HttpVersion as HttpVersion, + HttpVersion10 as HttpVersion10, + HttpVersion11 as HttpVersion11, + WebSocketError as WebSocketError, + WSCloseCode as WSCloseCode, + WSMessage as WSMessage, + WSMsgType as WSMsgType, +) +from .multipart import ( + BadContentDispositionHeader as BadContentDispositionHeader, + BadContentDispositionParam as BadContentDispositionParam, + BodyPartReader as BodyPartReader, + MultipartReader as MultipartReader, + MultipartWriter as MultipartWriter, + content_disposition_filename as content_disposition_filename, + parse_content_disposition as parse_content_disposition, +) +from .payload import ( + PAYLOAD_REGISTRY as PAYLOAD_REGISTRY, + AsyncIterablePayload as AsyncIterablePayload, + BufferedReaderPayload as BufferedReaderPayload, + BytesIOPayload as BytesIOPayload, + BytesPayload as BytesPayload, + IOBasePayload as IOBasePayload, + JsonPayload as JsonPayload, + Payload as Payload, + StringIOPayload as StringIOPayload, + StringPayload as StringPayload, + TextIOPayload as TextIOPayload, + get_payload as get_payload, + payload_type as payload_type, +) +from .payload_streamer import streamer as streamer +from .resolver import ( + AsyncResolver as AsyncResolver, + DefaultResolver as DefaultResolver, + ThreadedResolver as ThreadedResolver, +) +from .signals import Signal as Signal +from .streams import ( + EMPTY_PAYLOAD as EMPTY_PAYLOAD, + DataQueue as DataQueue, + EofStream as EofStream, + FlowControlDataQueue as FlowControlDataQueue, + StreamReader as StreamReader, +) +from .tracing import ( + TraceConfig as TraceConfig, + TraceConnectionCreateEndParams as TraceConnectionCreateEndParams, + TraceConnectionCreateStartParams as TraceConnectionCreateStartParams, + TraceConnectionQueuedEndParams as TraceConnectionQueuedEndParams, + TraceConnectionQueuedStartParams as TraceConnectionQueuedStartParams, + TraceConnectionReuseconnParams as TraceConnectionReuseconnParams, + TraceDnsCacheHitParams as TraceDnsCacheHitParams, + TraceDnsCacheMissParams as TraceDnsCacheMissParams, + TraceDnsResolveHostEndParams as TraceDnsResolveHostEndParams, + TraceDnsResolveHostStartParams as TraceDnsResolveHostStartParams, + TraceRequestChunkSentParams as TraceRequestChunkSentParams, + TraceRequestEndParams as TraceRequestEndParams, + TraceRequestExceptionParams as TraceRequestExceptionParams, + TraceRequestRedirectParams as TraceRequestRedirectParams, + TraceRequestStartParams as TraceRequestStartParams, + TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams, +) -try: - from .worker import GunicornWebWorker, GunicornUVLoopWebWorker # noqa - workers = ('GunicornWebWorker', 'GunicornUVLoopWebWorker') -except ImportError: - workers = () +__all__: Tuple[str, ...] = ( + "hdrs", + # client + "BaseConnector", + "ClientConnectionError", + "ClientConnectorCertificateError", + "ClientConnectorError", + "ClientConnectorSSLError", + "ClientError", + "ClientHttpProxyError", + "ClientOSError", + "ClientPayloadError", + "ClientProxyConnectionError", + "ClientResponse", + "ClientRequest", + "ClientResponseError", + "ClientSSLError", + "ClientSession", + "ClientTimeout", + "ClientWebSocketResponse", + "ContentTypeError", + "Fingerprint", + "InvalidURL", + "RequestInfo", + "ServerConnectionError", + "ServerDisconnectedError", + "ServerFingerprintMismatch", + "ServerTimeoutError", + "TCPConnector", + "TooManyRedirects", + "UnixConnector", + "NamedPipeConnector", + "WSServerHandshakeError", + "request", + # cookiejar + "CookieJar", + "DummyCookieJar", + # formdata + "FormData", + # helpers + "BasicAuth", + "ChainMapProxy", + # http + "HttpVersion", + "HttpVersion10", + "HttpVersion11", + "WSMsgType", + "WSCloseCode", + "WSMessage", + "WebSocketError", + # multipart + "BadContentDispositionHeader", + "BadContentDispositionParam", + "BodyPartReader", + "MultipartReader", + "MultipartWriter", + "content_disposition_filename", + "parse_content_disposition", + # payload + "AsyncIterablePayload", + "BufferedReaderPayload", + "BytesIOPayload", + "BytesPayload", + "IOBasePayload", + "JsonPayload", + "PAYLOAD_REGISTRY", + "Payload", + "StringIOPayload", + "StringPayload", + "TextIOPayload", + "get_payload", + "payload_type", + # payload_streamer + "streamer", + # resolver + "AsyncResolver", + "DefaultResolver", + "ThreadedResolver", + # signals + "Signal", + "DataQueue", + "EMPTY_PAYLOAD", + "EofStream", + "FlowControlDataQueue", + "StreamReader", + # tracing + "TraceConfig", + "TraceConnectionCreateEndParams", + "TraceConnectionCreateStartParams", + "TraceConnectionQueuedEndParams", + "TraceConnectionQueuedStartParams", + "TraceConnectionReuseconnParams", + "TraceDnsCacheHitParams", + "TraceDnsCacheMissParams", + "TraceDnsResolveHostEndParams", + "TraceDnsResolveHostStartParams", + "TraceRequestChunkSentParams", + "TraceRequestEndParams", + "TraceRequestExceptionParams", + "TraceRequestRedirectParams", + "TraceRequestStartParams", + "TraceResponseChunkReceivedParams", +) +try: + from .worker import GunicornUVLoopWebWorker, GunicornWebWorker -__all__ = (client.__all__ + # noqa - formdata.__all__ + # noqa - helpers.__all__ + # noqa - multipart.__all__ + # noqa - payload.__all__ + # noqa - payload_streamer.__all__ + # noqa - streams.__all__ + # noqa - ('hdrs', 'HttpVersion', 'HttpVersion10', 'HttpVersion11', - 'WSMsgType', 'WSCloseCode', - 'WebSocketError', 'WSMessage', 'CookieJar', - ) + workers) + __all__ += ("GunicornWebWorker", "GunicornUVLoopWebWorker") +except ImportError: # pragma: no cover + pass diff --git a/aiohttp/_find_header.h b/aiohttp/_find_header.h new file mode 100644 index 00000000000..99b7b4f8282 --- /dev/null +++ b/aiohttp/_find_header.h @@ -0,0 +1,14 @@ +#ifndef _FIND_HEADERS_H +#define _FIND_HEADERS_H + +#ifdef __cplusplus +extern "C" { +#endif + +int find_header(const char *str, int size); + + +#ifdef __cplusplus +} +#endif +#endif diff --git a/aiohttp/_find_header.pxd b/aiohttp/_find_header.pxd new file mode 100644 index 00000000000..37a6c37268e --- /dev/null +++ b/aiohttp/_find_header.pxd @@ -0,0 +1,2 @@ +cdef extern from "_find_header.h": + int find_header(char *, int) diff --git a/aiohttp/_frozenlist.pyx b/aiohttp/_frozenlist.pyx new file mode 100644 index 00000000000..b1305772f4b --- /dev/null +++ b/aiohttp/_frozenlist.pyx @@ -0,0 +1,108 @@ +from collections.abc import MutableSequence + + +cdef class FrozenList: + + cdef readonly bint frozen + cdef list _items + + def __init__(self, items=None): + self.frozen = False + if items is not None: + items = list(items) + else: + items = [] + self._items = items + + cdef object _check_frozen(self): + if self.frozen: + raise RuntimeError("Cannot modify frozen list.") + + cdef inline object _fast_len(self): + return len(self._items) + + def freeze(self): + self.frozen = True + + def __getitem__(self, index): + return self._items[index] + + def __setitem__(self, index, value): + self._check_frozen() + self._items[index] = value + + def __delitem__(self, index): + self._check_frozen() + del self._items[index] + + def __len__(self): + return self._fast_len() + + def __iter__(self): + return self._items.__iter__() + + def __reversed__(self): + return self._items.__reversed__() + + def __richcmp__(self, other, op): + if op == 0: # < + return list(self) < other + if op == 1: # <= + return list(self) <= other + if op == 2: # == + return list(self) == other + if op == 3: # != + return list(self) != other + if op == 4: # > + return list(self) > other + if op == 5: # => + return list(self) >= other + + def insert(self, pos, item): + self._check_frozen() + self._items.insert(pos, item) + + def __contains__(self, item): + return item in self._items + + def __iadd__(self, items): + self._check_frozen() + self._items += list(items) + return self + + def index(self, item): + return self._items.index(item) + + def remove(self, item): + self._check_frozen() + self._items.remove(item) + + def clear(self): + self._check_frozen() + self._items.clear() + + def extend(self, items): + self._check_frozen() + self._items += list(items) + + def reverse(self): + self._check_frozen() + self._items.reverse() + + def pop(self, index=-1): + self._check_frozen() + return self._items.pop(index) + + def append(self, item): + self._check_frozen() + return self._items.append(item) + + def count(self, item): + return self._items.count(item) + + def __repr__(self): + return ''.format(self.frozen, + self._items) + + +MutableSequence.register(FrozenList) diff --git a/aiohttp/_helpers.pyi b/aiohttp/_helpers.pyi new file mode 100644 index 00000000000..1e358937024 --- /dev/null +++ b/aiohttp/_helpers.pyi @@ -0,0 +1,6 @@ +from typing import Any + +class reify: + def __init__(self, wrapped: Any) -> None: ... + def __get__(self, inst: Any, owner: Any) -> Any: ... + def __set__(self, inst: Any, value: Any) -> None: ... diff --git a/aiohttp/_helpers.pyx b/aiohttp/_helpers.pyx new file mode 100644 index 00000000000..665f367c5de --- /dev/null +++ b/aiohttp/_helpers.pyx @@ -0,0 +1,35 @@ +cdef class reify: + """Use as a class method decorator. It operates almost exactly like + the Python `@property` decorator, but it puts the result of the + method it decorates into the instance dict after the first call, + effectively replacing the function it decorates with an instance + variable. It is, in Python parlance, a data descriptor. + + """ + + cdef object wrapped + cdef object name + + def __init__(self, wrapped): + self.wrapped = wrapped + self.name = wrapped.__name__ + + @property + def __doc__(self): + return self.wrapped.__doc__ + + def __get__(self, inst, owner): + try: + try: + return inst._cache[self.name] + except KeyError: + val = self.wrapped(inst) + inst._cache[self.name] = val + return val + except AttributeError: + if inst is None: + return self + raise + + def __set__(self, inst, value): + raise AttributeError("reified property is read-only") diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index b00c0e0d782..c24e31057a8 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -2,27 +2,270 @@ # # Based on https://github.com/MagicStack/httptools # -from __future__ import print_function -from cpython.mem cimport PyMem_Malloc, PyMem_Free -from cpython cimport PyObject_GetBuffer, PyBuffer_Release, PyBUF_SIMPLE, \ - Py_buffer, PyBytes_AsString - -import yarl -from multidict import CIMultiDict +from __future__ import absolute_import, print_function + +from cpython cimport ( + Py_buffer, + PyBUF_SIMPLE, + PyBuffer_Release, + PyBytes_AsString, + PyBytes_AsStringAndSize, + PyObject_GetBuffer, +) +from cpython.mem cimport PyMem_Free, PyMem_Malloc +from libc.limits cimport ULLONG_MAX +from libc.string cimport memcpy + +from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiDictProxy +from yarl import URL as _URL from aiohttp import hdrs + from .http_exceptions import ( - BadHttpMessage, BadStatusLine, InvalidHeader, LineTooLong, InvalidURLError, - PayloadEncodingError, ContentLengthError, TransferEncodingError) -from .http_writer import HttpVersion, HttpVersion10, HttpVersion11, URL -from .http_parser import RawRequestMessage, RawResponseMessage, DeflateBuffer -from .streams import EMPTY_PAYLOAD, FlowControlStreamReader + BadHttpMessage, + BadStatusLine, + ContentLengthError, + InvalidHeader, + InvalidURLError, + LineTooLong, + PayloadEncodingError, + TransferEncodingError, +) +from .http_parser import DeflateBuffer as _DeflateBuffer +from .http_writer import ( + HttpVersion as _HttpVersion, + HttpVersion10 as _HttpVersion10, + HttpVersion11 as _HttpVersion11, +) +from .streams import EMPTY_PAYLOAD as _EMPTY_PAYLOAD, StreamReader as _StreamReader cimport cython -from . cimport _cparser as cparser +from aiohttp cimport _cparser as cparser + +include "_headers.pxi" + +from aiohttp cimport _find_header + +DEF DEFAULT_FREELIST_SIZE = 250 + +cdef extern from "Python.h": + int PyByteArray_Resize(object, Py_ssize_t) except -1 + Py_ssize_t PyByteArray_Size(object) except -1 + char* PyByteArray_AsString(object) + +__all__ = ('HttpRequestParser', 'HttpResponseParser', + 'RawRequestMessage', 'RawResponseMessage') -__all__ = ('HttpRequestParserC', 'HttpResponseMessageC', 'parse_url') +cdef object URL = _URL +cdef object URL_build = URL.build +cdef object CIMultiDict = _CIMultiDict +cdef object CIMultiDictProxy = _CIMultiDictProxy +cdef object HttpVersion = _HttpVersion +cdef object HttpVersion10 = _HttpVersion10 +cdef object HttpVersion11 = _HttpVersion11 +cdef object SEC_WEBSOCKET_KEY1 = hdrs.SEC_WEBSOCKET_KEY1 +cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING +cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD +cdef object StreamReader = _StreamReader +cdef object DeflateBuffer = _DeflateBuffer + + +cdef inline object extend(object buf, const char* at, size_t length): + cdef Py_ssize_t s + cdef char* ptr + s = PyByteArray_Size(buf) + PyByteArray_Resize(buf, s + length) + ptr = PyByteArray_AsString(buf) + memcpy(ptr + s, at, length) + + +DEF METHODS_COUNT = 34; + +cdef list _http_method = [] + +for i in range(METHODS_COUNT): + _http_method.append( + cparser.http_method_str( i).decode('ascii')) + + +cdef inline str http_method_str(int i): + if i < METHODS_COUNT: + return _http_method[i] + else: + return "" + +cdef inline object find_header(bytes raw_header): + cdef Py_ssize_t size + cdef char *buf + cdef int idx + PyBytes_AsStringAndSize(raw_header, &buf, &size) + idx = _find_header.find_header(buf, size) + if idx == -1: + return raw_header.decode('utf-8', 'surrogateescape') + return headers[idx] + + +@cython.freelist(DEFAULT_FREELIST_SIZE) +cdef class RawRequestMessage: + cdef readonly str method + cdef readonly str path + cdef readonly object version # HttpVersion + cdef readonly object headers # CIMultiDict + cdef readonly object raw_headers # tuple + cdef readonly object should_close + cdef readonly object compression + cdef readonly object upgrade + cdef readonly object chunked + cdef readonly object url # yarl.URL + + def __init__(self, method, path, version, headers, raw_headers, + should_close, compression, upgrade, chunked, url): + self.method = method + self.path = path + self.version = version + self.headers = headers + self.raw_headers = raw_headers + self.should_close = should_close + self.compression = compression + self.upgrade = upgrade + self.chunked = chunked + self.url = url + + def __repr__(self): + info = [] + info.append(("method", self.method)) + info.append(("path", self.path)) + info.append(("version", self.version)) + info.append(("headers", self.headers)) + info.append(("raw_headers", self.raw_headers)) + info.append(("should_close", self.should_close)) + info.append(("compression", self.compression)) + info.append(("upgrade", self.upgrade)) + info.append(("chunked", self.chunked)) + info.append(("url", self.url)) + sinfo = ', '.join(name + '=' + repr(val) for name, val in info) + return '' + + def _replace(self, **dct): + cdef RawRequestMessage ret + ret = _new_request_message(self.method, + self.path, + self.version, + self.headers, + self.raw_headers, + self.should_close, + self.compression, + self.upgrade, + self.chunked, + self.url) + if "method" in dct: + ret.method = dct["method"] + if "path" in dct: + ret.path = dct["path"] + if "version" in dct: + ret.version = dct["version"] + if "headers" in dct: + ret.headers = dct["headers"] + if "raw_headers" in dct: + ret.raw_headers = dct["raw_headers"] + if "should_close" in dct: + ret.should_close = dct["should_close"] + if "compression" in dct: + ret.compression = dct["compression"] + if "upgrade" in dct: + ret.upgrade = dct["upgrade"] + if "chunked" in dct: + ret.chunked = dct["chunked"] + if "url" in dct: + ret.url = dct["url"] + return ret + +cdef _new_request_message(str method, + str path, + object version, + object headers, + object raw_headers, + bint should_close, + object compression, + bint upgrade, + bint chunked, + object url): + cdef RawRequestMessage ret + ret = RawRequestMessage.__new__(RawRequestMessage) + ret.method = method + ret.path = path + ret.version = version + ret.headers = headers + ret.raw_headers = raw_headers + ret.should_close = should_close + ret.compression = compression + ret.upgrade = upgrade + ret.chunked = chunked + ret.url = url + return ret + + +@cython.freelist(DEFAULT_FREELIST_SIZE) +cdef class RawResponseMessage: + cdef readonly object version # HttpVersion + cdef readonly int code + cdef readonly str reason + cdef readonly object headers # CIMultiDict + cdef readonly object raw_headers # tuple + cdef readonly object should_close + cdef readonly object compression + cdef readonly object upgrade + cdef readonly object chunked + + def __init__(self, version, code, reason, headers, raw_headers, + should_close, compression, upgrade, chunked): + self.version = version + self.code = code + self.reason = reason + self.headers = headers + self.raw_headers = raw_headers + self.should_close = should_close + self.compression = compression + self.upgrade = upgrade + self.chunked = chunked + + def __repr__(self): + info = [] + info.append(("version", self.version)) + info.append(("code", self.code)) + info.append(("reason", self.reason)) + info.append(("headers", self.headers)) + info.append(("raw_headers", self.raw_headers)) + info.append(("should_close", self.should_close)) + info.append(("compression", self.compression)) + info.append(("upgrade", self.upgrade)) + info.append(("chunked", self.chunked)) + sinfo = ', '.join(name + '=' + repr(val) for name, val in info) + return '' + + +cdef _new_response_message(object version, + int code, + str reason, + object headers, + object raw_headers, + bint should_close, + object compression, + bint upgrade, + bint chunked): + cdef RawResponseMessage ret + ret = RawResponseMessage.__new__(RawResponseMessage) + ret.version = version + ret.code = code + ret.reason = reason + ret.headers = headers + ret.raw_headers = raw_headers + ret.should_close = should_close + ret.compression = compression + ret.upgrade = upgrade + ret.chunked = chunked + return ret @cython.internal @@ -32,10 +275,9 @@ cdef class HttpParser: cparser.http_parser* _cparser cparser.http_parser_settings* _csettings - str _header_name - str _header_value - bytes _raw_header_name - bytes _raw_header_value + bytearray _raw_name + bytearray _raw_value + bint _has_value object _protocol object _loop @@ -45,18 +287,25 @@ cdef class HttpParser: size_t _max_field_size size_t _max_headers bint _response_with_body + bint _read_until_eof + bint _started object _url + bytearray _buf str _path str _reason - list _headers + object _headers list _raw_headers bint _upgraded list _messages object _payload bint _payload_error object _payload_exception - object _last_error + object _last_error + bint _auto_decompress + int _limit + + str _content_encoding Py_buffer py_buf @@ -76,10 +325,12 @@ cdef class HttpParser: PyMem_Free(self._csettings) cdef _init(self, cparser.http_parser_type mode, - object protocol, object loop, object timer=None, + object protocol, object loop, int limit, + object timer=None, size_t max_line_size=8190, size_t max_headers=32768, size_t max_field_size=8190, payload_exception=None, - response_with_body=True): + bint response_with_body=True, bint read_until_eof=False, + bint auto_decompress=True): cparser.http_parser_init(self._cparser, mode) self._cparser.data = self self._cparser.content_length = 0 @@ -90,21 +341,24 @@ cdef class HttpParser: self._loop = loop self._timer = timer + self._buf = bytearray() self._payload = None self._payload_error = 0 self._payload_exception = payload_exception self._messages = [] - self._header_name = None - self._header_value = None - self._raw_header_name = None - self._raw_header_value = None + self._raw_name = bytearray() + self._raw_value = bytearray() + self._has_value = False self._max_line_size = max_line_size self._max_headers = max_headers self._max_field_size = max_field_size self._response_with_body = response_with_body + self._read_until_eof = read_until_eof self._upgraded = False + self._auto_decompress = auto_decompress + self._content_encoding = None self._csettings.on_url = cb_on_url self._csettings.on_status = cb_on_status @@ -114,52 +368,61 @@ cdef class HttpParser: self._csettings.on_body = cb_on_body self._csettings.on_message_begin = cb_on_message_begin self._csettings.on_message_complete = cb_on_message_complete + self._csettings.on_chunk_header = cb_on_chunk_header + self._csettings.on_chunk_complete = cb_on_chunk_complete self._last_error = None + self._limit = limit cdef _process_header(self): - if self._header_name is not None: - name = self._header_name - value = self._header_value + if self._raw_name: + raw_name = bytes(self._raw_name) + raw_value = bytes(self._raw_value) + + name = find_header(raw_name) + value = raw_value.decode('utf-8', 'surrogateescape') - self._header_name = self._header_value = None - self._headers.append((name, value)) + self._headers.add(name, value) - raw_name = self._raw_header_name - raw_value = self._raw_header_value + if name is CONTENT_ENCODING: + self._content_encoding = value - self._raw_header_name = self._raw_header_value = None + PyByteArray_Resize(self._raw_name, 0) + PyByteArray_Resize(self._raw_value, 0) + self._has_value = False self._raw_headers.append((raw_name, raw_value)) - cdef _on_header_field(self, str field, bytes raw_field): - self._process_header() - self._header_name = field - self._raw_header_name = raw_field + cdef _on_header_field(self, char* at, size_t length): + cdef Py_ssize_t size + cdef char *buf + if self._has_value: + self._process_header() - cdef _on_header_value(self, str val, bytes raw_val): - if self._header_value is None: - self._header_value = val - self._raw_header_value = raw_val - else: - # This is unlikely, as mostly HTTP headers are one-line - self._header_value += val - self._raw_header_value += raw_val - - cdef _on_headers_complete(self, - ENCODING='utf-8', - ENCODING_ERR='surrogateescape', - CONTENT_ENCODING=hdrs.CONTENT_ENCODING, - SEC_WEBSOCKET_KEY1=hdrs.SEC_WEBSOCKET_KEY1, - SUPPORTED=('gzip', 'deflate')): + size = PyByteArray_Size(self._raw_name) + PyByteArray_Resize(self._raw_name, size + length) + buf = PyByteArray_AsString(self._raw_name) + memcpy(buf + size, at, length) + + cdef _on_header_value(self, char* at, size_t length): + cdef Py_ssize_t size + cdef char *buf + + size = PyByteArray_Size(self._raw_value) + PyByteArray_Resize(self._raw_value, size + length) + buf = PyByteArray_AsString(self._raw_value) + memcpy(buf + size, at, length) + self._has_value = True + + cdef _on_headers_complete(self): self._process_header() - method = cparser.http_method_str( self._cparser.method) - should_close = not bool(cparser.http_should_keep_alive(self._cparser)) - upgrade = bool(self._cparser.upgrade) - chunked = bool(self._cparser.flags & cparser.F_CHUNKED) + method = http_method_str(self._cparser.method) + should_close = not cparser.http_should_keep_alive(self._cparser) + upgrade = self._cparser.upgrade + chunked = self._cparser.flags & cparser.F_CHUNKED raw_headers = tuple(self._raw_headers) - headers = CIMultiDict(self._headers) + headers = CIMultiDictProxy(self._headers) if upgrade or self._cparser.method == 5: # cparser.CONNECT: self._upgraded = True @@ -168,32 +431,39 @@ cdef class HttpParser: if SEC_WEBSOCKET_KEY1 in headers: raise InvalidHeader(SEC_WEBSOCKET_KEY1) - encoding = headers.get(CONTENT_ENCODING) - if encoding: - encoding = encoding.lower() - if encoding not in SUPPORTED: - encoding = None + encoding = None + enc = self._content_encoding + if enc is not None: + self._content_encoding = None + enc = enc.lower() + if enc in ('gzip', 'deflate', 'br'): + encoding = enc if self._cparser.type == cparser.HTTP_REQUEST: - msg = RawRequestMessage( - method.decode(ENCODING, ENCODING_ERR), self._path, + msg = _new_request_message( + method, self._path, self.http_version(), headers, raw_headers, should_close, encoding, upgrade, chunked, self._url) else: - msg = RawResponseMessage( + msg = _new_response_message( self.http_version(), self._cparser.status_code, self._reason, headers, raw_headers, should_close, encoding, upgrade, chunked) - if (self._cparser.content_length > 0 or chunked or - self._cparser.method == 5): # CONNECT: 5 - payload = FlowControlStreamReader( - self._protocol, timer=self._timer, loop=self._loop) + if (ULLONG_MAX > self._cparser.content_length > 0 or chunked or + self._cparser.method == 5 or # CONNECT: 5 + (self._cparser.status_code >= 199 and + self._cparser.content_length == ULLONG_MAX and + self._read_until_eof) + ): + payload = StreamReader( + self._protocol, timer=self._timer, loop=self._loop, + limit=self._limit) else: payload = EMPTY_PAYLOAD self._payload = payload - if encoding is not None: + if encoding is not None and self._auto_decompress: self._payload = DeflateBuffer(payload, encoding) if not self._response_with_body: @@ -205,10 +475,16 @@ cdef class HttpParser: self._payload.feed_eof() self._payload = None + cdef _on_chunk_header(self): + self._payload.begin_http_chunk_receiving() - ### Public API ### + cdef _on_chunk_complete(self): + self._payload.end_http_chunk_receiving() - def http_version(self): + cdef object _on_status_complete(self): + pass + + cdef inline http_version(self): cdef cparser.http_parser* parser = self._cparser if parser.http_major == 1: @@ -219,6 +495,8 @@ cdef class HttpParser: return HttpVersion(parser.http_major, parser.http_minor) + ### Public API ### + def feed_eof(self): cdef bytes desc @@ -235,6 +513,10 @@ cdef class HttpParser: raise PayloadEncodingError(desc.decode('latin-1')) else: self._payload.feed_eof() + elif self._started: + self._on_headers_complete() + if self._messages: + return self._messages[-1][0] def feed_data(self, data): cdef: @@ -252,12 +534,7 @@ cdef class HttpParser: PyBuffer_Release(&self.py_buf) - # i am not sure about cparser.HPE_INVALID_METHOD, - # seems get err for valid request - # test_client_functional.py::test_post_data_with_bytesio_file - if (self._cparser.http_errno != cparser.HPE_OK and - (self._cparser.http_errno != cparser.HPE_INVALID_METHOD or - self._cparser.method == 0)): + if (self._cparser.http_errno != cparser.HPE_OK): if self._payload_error == 0: if self._last_error is not None: ex = self._last_error @@ -279,53 +556,78 @@ cdef class HttpParser: else: return messages, False, b'' + def set_upgraded(self, val): + self._upgraded = val -cdef class HttpRequestParserC(HttpParser): - def __init__(self, protocol, loop, timer=None, +cdef class HttpRequestParser(HttpParser): + + def __init__(self, protocol, loop, int limit, timer=None, size_t max_line_size=8190, size_t max_headers=32768, size_t max_field_size=8190, payload_exception=None, - response_with_body=True, read_until_eof=False): - self._init(cparser.HTTP_REQUEST, protocol, loop, timer, + bint response_with_body=True, bint read_until_eof=False, + ): + self._init(cparser.HTTP_REQUEST, protocol, loop, limit, timer, max_line_size, max_headers, max_field_size, - payload_exception, response_with_body) - - -cdef class HttpResponseParserC(HttpParser): - - def __init__(self, protocol, loop, timer=None, + payload_exception, response_with_body, read_until_eof) + + cdef object _on_status_complete(self): + cdef Py_buffer py_buf + if not self._buf: + return + self._path = self._buf.decode('utf-8', 'surrogateescape') + if self._cparser.method == 5: # CONNECT + self._url = URL(self._path) + else: + PyObject_GetBuffer(self._buf, &py_buf, PyBUF_SIMPLE) + try: + self._url = _parse_url(py_buf.buf, + py_buf.len) + finally: + PyBuffer_Release(&py_buf) + PyByteArray_Resize(self._buf, 0) + + +cdef class HttpResponseParser(HttpParser): + + def __init__(self, protocol, loop, int limit, timer=None, size_t max_line_size=8190, size_t max_headers=32768, size_t max_field_size=8190, payload_exception=None, - response_with_body=True, read_until_eof=False): - self._init(cparser.HTTP_RESPONSE, protocol, loop, timer, + bint response_with_body=True, bint read_until_eof=False, + bint auto_decompress=True + ): + self._init(cparser.HTTP_RESPONSE, protocol, loop, limit, timer, max_line_size, max_headers, max_field_size, - payload_exception, response_with_body) + payload_exception, response_with_body, read_until_eof, + auto_decompress) + cdef object _on_status_complete(self): + if self._buf: + self._reason = self._buf.decode('utf-8', 'surrogateescape') + PyByteArray_Resize(self._buf, 0) + else: + self._reason = self._reason or '' cdef int cb_on_message_begin(cparser.http_parser* parser) except -1: cdef HttpParser pyparser = parser.data - pyparser._headers = [] + pyparser._started = True + pyparser._headers = CIMultiDict() pyparser._raw_headers = [] + PyByteArray_Resize(pyparser._buf, 0) + pyparser._path = None + pyparser._reason = None return 0 cdef int cb_on_url(cparser.http_parser* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data - cdef str path = None try: if length > pyparser._max_line_size: raise LineTooLong( - 'Status line is too long', pyparser._max_line_size) - - path = at[:length].decode('utf-8', 'surrogateescape') - pyparser._path = path - - if pyparser._cparser.method == 5: # CONNECT - pyparser._url = yarl.URL(path) - else: - pyparser._url = _parse_url(at[:length], length) + 'Status line is too long', pyparser._max_line_size, length) + extend(pyparser._buf, at, length) except BaseException as ex: pyparser._last_error = ex return -1 @@ -336,11 +638,12 @@ cdef int cb_on_url(cparser.http_parser* parser, cdef int cb_on_status(cparser.http_parser* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data + cdef str reason try: if length > pyparser._max_line_size: raise LineTooLong( - 'Status line is too long', pyparser._max_line_size) - pyparser._reason = at[:length].decode('utf-8', 'surrogateescape') + 'Status line is too long', pyparser._max_line_size, length) + extend(pyparser._buf, at, length) except BaseException as ex: pyparser._last_error = ex return -1 @@ -351,12 +654,14 @@ cdef int cb_on_status(cparser.http_parser* parser, cdef int cb_on_header_field(cparser.http_parser* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data + cdef Py_ssize_t size try: - if length > pyparser._max_field_size: + pyparser._on_status_complete() + size = len(pyparser._raw_name) + length + if size > pyparser._max_field_size: raise LineTooLong( - 'Header name is too long', pyparser._max_field_size) - pyparser._on_header_field( - at[:length].decode('utf-8', 'surrogateescape'), at[:length]) + 'Header name is too long', pyparser._max_field_size, size) + pyparser._on_header_field(at, length) except BaseException as ex: pyparser._last_error = ex return -1 @@ -367,12 +672,13 @@ cdef int cb_on_header_field(cparser.http_parser* parser, cdef int cb_on_header_value(cparser.http_parser* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data + cdef Py_ssize_t size try: - if length > pyparser._max_field_size: + size = len(pyparser._raw_value) + length + if size > pyparser._max_field_size: raise LineTooLong( - 'Header value is too long', pyparser._max_field_size) - pyparser._on_header_value( - at[:length].decode('utf-8', 'surrogateescape'), at[:length]) + 'Header value is too long', pyparser._max_field_size, size) + pyparser._on_header_value(at, length) except BaseException as ex: pyparser._last_error = ex return -1 @@ -383,6 +689,7 @@ cdef int cb_on_header_value(cparser.http_parser* parser, cdef int cb_on_headers_complete(cparser.http_parser* parser) except -1: cdef HttpParser pyparser = parser.data try: + pyparser._on_status_complete() pyparser._on_headers_complete() except BaseException as exc: pyparser._last_error = exc @@ -414,6 +721,7 @@ cdef int cb_on_body(cparser.http_parser* parser, cdef int cb_on_message_complete(cparser.http_parser* parser) except -1: cdef HttpParser pyparser = parser.data try: + pyparser._started = False pyparser._on_message_complete() except BaseException as exc: pyparser._last_error = exc @@ -422,6 +730,28 @@ cdef int cb_on_message_complete(cparser.http_parser* parser) except -1: return 0 +cdef int cb_on_chunk_header(cparser.http_parser* parser) except -1: + cdef HttpParser pyparser = parser.data + try: + pyparser._on_chunk_header() + except BaseException as exc: + pyparser._last_error = exc + return -1 + else: + return 0 + + +cdef int cb_on_chunk_complete(cparser.http_parser* parser) except -1: + cdef HttpParser pyparser = parser.data + try: + pyparser._on_chunk_complete() + except BaseException as exc: + pyparser._last_error = exc + return -1 + else: + return 0 + + cdef parser_error_from_errno(cparser.http_errno errno): cdef bytes desc = cparser.http_errno_description(errno) @@ -465,7 +795,7 @@ def parse_url(url): PyBuffer_Release(&py_buf) -def _parse_url(char* buf_data, size_t length): +cdef _parse_url(char* buf_data, size_t length): cdef: cparser.http_parser_url* parsed int res @@ -475,6 +805,8 @@ def _parse_url(char* buf_data, size_t length): str path = None str query = None str fragment = None + str user = None + str password = None str userinfo = None object result = None int off @@ -482,6 +814,8 @@ def _parse_url(char* buf_data, size_t length): parsed = \ PyMem_Malloc(sizeof(cparser.http_parser_url)) + if parsed is NULL: + raise MemoryError() cparser.http_parser_url_init(parsed) try: res = cparser.http_parser_parse_url(buf_data, length, 0, parsed) @@ -530,7 +864,11 @@ def _parse_url(char* buf_data, size_t length): ln = parsed.field_data[cparser.UF_USERINFO].len userinfo = buf_data[off:off+ln].decode('utf-8', 'surrogateescape') - return URL(schema, host, port, path, query, fragment, userinfo) + user, sep, password = userinfo.partition(':') + + return URL_build(scheme=schema, + user=user, password=password, host=host, port=port, + path=path, query_string=query, fragment=fragment, encoded=True) else: raise InvalidURLError("invalid url {!r}".format(buf_data)) finally: diff --git a/aiohttp/_http_writer.pyx b/aiohttp/_http_writer.pyx new file mode 100644 index 00000000000..84b42fa1c35 --- /dev/null +++ b/aiohttp/_http_writer.pyx @@ -0,0 +1,151 @@ +from cpython.bytes cimport PyBytes_FromStringAndSize +from cpython.exc cimport PyErr_NoMemory +from cpython.mem cimport PyMem_Free, PyMem_Malloc, PyMem_Realloc +from cpython.object cimport PyObject_Str +from libc.stdint cimport uint8_t, uint64_t +from libc.string cimport memcpy + +from multidict import istr + +DEF BUF_SIZE = 16 * 1024 # 16KiB +cdef char BUFFER[BUF_SIZE] + +cdef object _istr = istr + + +# ----------------- writer --------------------------- + +cdef struct Writer: + char *buf + Py_ssize_t size + Py_ssize_t pos + + +cdef inline void _init_writer(Writer* writer): + writer.buf = &BUFFER[0] + writer.size = BUF_SIZE + writer.pos = 0 + + +cdef inline void _release_writer(Writer* writer): + if writer.buf != BUFFER: + PyMem_Free(writer.buf) + + +cdef inline int _write_byte(Writer* writer, uint8_t ch): + cdef char * buf + cdef Py_ssize_t size + + if writer.pos == writer.size: + # reallocate + size = writer.size + BUF_SIZE + if writer.buf == BUFFER: + buf = PyMem_Malloc(size) + if buf == NULL: + PyErr_NoMemory() + return -1 + memcpy(buf, writer.buf, writer.size) + else: + buf = PyMem_Realloc(writer.buf, size) + if buf == NULL: + PyErr_NoMemory() + return -1 + writer.buf = buf + writer.size = size + writer.buf[writer.pos] = ch + writer.pos += 1 + return 0 + + +cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol): + cdef uint64_t utf = symbol + + if utf < 0x80: + return _write_byte(writer, utf) + elif utf < 0x800: + if _write_byte(writer, (0xc0 | (utf >> 6))) < 0: + return -1 + return _write_byte(writer, (0x80 | (utf & 0x3f))) + elif 0xD800 <= utf <= 0xDFFF: + # surogate pair, ignored + return 0 + elif utf < 0x10000: + if _write_byte(writer, (0xe0 | (utf >> 12))) < 0: + return -1 + if _write_byte(writer, (0x80 | ((utf >> 6) & 0x3f))) < 0: + return -1 + return _write_byte(writer, (0x80 | (utf & 0x3f))) + elif utf > 0x10FFFF: + # symbol is too large + return 0 + else: + if _write_byte(writer, (0xf0 | (utf >> 18))) < 0: + return -1 + if _write_byte(writer, + (0x80 | ((utf >> 12) & 0x3f))) < 0: + return -1 + if _write_byte(writer, + (0x80 | ((utf >> 6) & 0x3f))) < 0: + return -1 + return _write_byte(writer, (0x80 | (utf & 0x3f))) + + +cdef inline int _write_str(Writer* writer, str s): + cdef Py_UCS4 ch + for ch in s: + if _write_utf8(writer, ch) < 0: + return -1 + + +# --------------- _serialize_headers ---------------------- + +cdef str to_str(object s): + typ = type(s) + if typ is str: + return s + elif typ is _istr: + return PyObject_Str(s) + elif not isinstance(s, str): + raise TypeError("Cannot serialize non-str key {!r}".format(s)) + else: + return str(s) + + +def _serialize_headers(str status_line, headers): + cdef Writer writer + cdef object key + cdef object val + cdef bytes ret + + _init_writer(&writer) + + try: + if _write_str(&writer, status_line) < 0: + raise + if _write_byte(&writer, b'\r') < 0: + raise + if _write_byte(&writer, b'\n') < 0: + raise + + for key, val in headers.items(): + if _write_str(&writer, to_str(key)) < 0: + raise + if _write_byte(&writer, b':') < 0: + raise + if _write_byte(&writer, b' ') < 0: + raise + if _write_str(&writer, to_str(val)) < 0: + raise + if _write_byte(&writer, b'\r') < 0: + raise + if _write_byte(&writer, b'\n') < 0: + raise + + if _write_byte(&writer, b'\r') < 0: + raise + if _write_byte(&writer, b'\n') < 0: + raise + + return PyBytes_FromStringAndSize(writer.buf, writer.pos) + finally: + _release_writer(&writer) diff --git a/aiohttp/_websocket.pyx b/aiohttp/_websocket.pyx index 4fc33971889..94318d2b1be 100644 --- a/aiohttp/_websocket.pyx +++ b/aiohttp/_websocket.pyx @@ -1,13 +1,15 @@ from cpython cimport PyBytes_AsString + #from cpython cimport PyByteArray_AsString # cython still not exports that cdef extern from "Python.h": char* PyByteArray_AsString(bytearray ba) except NULL from libc.stdint cimport uint32_t, uint64_t, uintmax_t -def _websocket_mask_cython(bytes mask, bytearray data): - """Note, this function mutates it's `data` argument + +def _websocket_mask_cython(object mask, object data): + """Note, this function mutates its `data` argument """ cdef: Py_ssize_t data_len, i @@ -19,6 +21,14 @@ def _websocket_mask_cython(bytes mask, bytearray data): assert len(mask) == 4 + if not isinstance(mask, bytes): + mask = bytes(mask) + + if isinstance(data, bytearray): + data = data + else: + data = bytearray(data) + data_len = len(data) in_buf = PyByteArray_AsString(data) mask_buf = PyBytes_AsString(mask) @@ -44,5 +54,3 @@ def _websocket_mask_cython(bytes mask, bytearray data): for i in range(0, data_len): in_buf[i] ^= mask_buf[i] - - return data diff --git a/aiohttp/abc.py b/aiohttp/abc.py index dc343259fd3..4abfd798d7d 100644 --- a/aiohttp/abc.py +++ b/aiohttp/abc.py @@ -1,17 +1,42 @@ import asyncio -import sys +import logging from abc import ABC, abstractmethod -from collections.abc import Iterable, Sized - -PY_35 = sys.version_info >= (3, 5) +from collections.abc import Sized +from http.cookies import BaseCookie, Morsel +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, +) + +from multidict import CIMultiDict +from yarl import URL + +from .helpers import get_running_loop +from .typedefs import LooseCookies + +if TYPE_CHECKING: # pragma: no cover + from .web_app import Application + from .web_exceptions import HTTPException + from .web_request import BaseRequest, Request + from .web_response import StreamResponse +else: + BaseRequest = Request = Application = StreamResponse = None + HTTPException = None class AbstractRouter(ABC): - - def __init__(self): + def __init__(self) -> None: self._frozen = False - def post_init(self, app): + def post_init(self, app: Application) -> None: """Post init stage. Not an abstract method for sake of backward compatibility, @@ -20,43 +45,41 @@ def post_init(self, app): """ @property - def frozen(self): + def frozen(self) -> bool: return self._frozen - def freeze(self): + def freeze(self) -> None: """Freeze router.""" self._frozen = True - @asyncio.coroutine # pragma: no branch @abstractmethod - def resolve(self, request): + async def resolve(self, request: Request) -> "AbstractMatchInfo": """Return MATCH_INFO for given request""" class AbstractMatchInfo(ABC): - - @asyncio.coroutine # pragma: no branch + @property # pragma: no branch @abstractmethod - def handler(self, request): + def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]: """Execute matched request handler""" - @asyncio.coroutine # pragma: no branch + @property @abstractmethod - def expect_handler(self, request): + def expect_handler(self) -> Callable[[Request], Awaitable[None]]: """Expect handler for 100-continue processing""" @property # pragma: no branch @abstractmethod - def http_exception(self): + def http_exception(self) -> Optional[HTTPException]: """HTTPException instance raised on router's resolving, or None""" @abstractmethod # pragma: no branch - def get_info(self): + def get_info(self) -> Dict[str, Any]: """Return a dict with additional info useful for introspection""" @property # pragma: no branch @abstractmethod - def apps(self): + def apps(self) -> Tuple[Application, ...]: """Stack of nested applications. Top level application is left-most element. @@ -64,11 +87,11 @@ def apps(self): """ @abstractmethod - def add_app(self, app): + def add_app(self, app: Application) -> None: """Add application to the nested apps stack.""" @abstractmethod - def freeze(self): + def freeze(self) -> None: """Freeze the match info. The method is called after route resolution. @@ -79,69 +102,99 @@ def freeze(self): class AbstractView(ABC): + """Abstract class based view.""" - def __init__(self, request): + def __init__(self, request: Request) -> None: self._request = request @property - def request(self): + def request(self) -> Request: + """Request instance.""" return self._request - @asyncio.coroutine # pragma: no branch @abstractmethod - def __iter__(self): - while False: # pragma: no cover - yield None - - if PY_35: # pragma: no branch - @abstractmethod - def __await__(self): - return # pragma: no cover + def __await__(self) -> Generator[Any, None, StreamResponse]: + """Execute the view handler.""" class AbstractResolver(ABC): + """Abstract DNS resolver.""" - @asyncio.coroutine # pragma: no branch @abstractmethod - def resolve(self, hostname): + async def resolve(self, host: str, port: int, family: int) -> List[Dict[str, Any]]: """Return IP address for given hostname""" - @asyncio.coroutine # pragma: no branch @abstractmethod - def close(self): + async def close(self) -> None: """Release resolver""" -class AbstractCookieJar(Sized, Iterable): +if TYPE_CHECKING: # pragma: no cover + IterableBase = Iterable[Morsel[str]] +else: + IterableBase = Iterable + + +class AbstractCookieJar(Sized, IterableBase): + """Abstract Cookie Jar.""" - def __init__(self, *, loop=None): - self._loop = loop or asyncio.get_event_loop() + def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + self._loop = get_running_loop(loop) @abstractmethod - def clear(self): + def clear(self) -> None: """Clear all cookies.""" @abstractmethod - def update_cookies(self, cookies, response_url=None): + def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: """Update cookies.""" @abstractmethod - def filter_cookies(self, request_url): + def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": """Return the jar's cookies filtered by their attributes.""" -class AbstractPayloadWriter(ABC): +class AbstractStreamWriter(ABC): + """Abstract stream writer.""" + + buffer_size = 0 + output_size = 0 + length = 0 # type: Optional[int] @abstractmethod - def write(self, chunk): - """Write chunk into stream""" + async def write(self, chunk: bytes) -> None: + """Write chunk into stream.""" - @asyncio.coroutine @abstractmethod - def write_eof(self, chunk=b''): - """Write last chunk""" + async def write_eof(self, chunk: bytes = b"") -> None: + """Write last chunk.""" - @asyncio.coroutine @abstractmethod - def drain(self): + async def drain(self) -> None: """Flush the write buffer.""" + + @abstractmethod + def enable_compression(self, encoding: str = "deflate") -> None: + """Enable HTTP body compression""" + + @abstractmethod + def enable_chunking(self) -> None: + """Enable HTTP chunked mode""" + + @abstractmethod + async def write_headers( + self, status_line: str, headers: "CIMultiDict[str]" + ) -> None: + """Write HTTP headers""" + + +class AbstractAccessLogger(ABC): + """Abstract writer to access log.""" + + def __init__(self, logger: logging.Logger, log_format: str) -> None: + self.logger = logger + self.log_format = log_format + + @abstractmethod + def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None: + """Emit log to logger.""" diff --git a/aiohttp/backport_cookies.py b/aiohttp/backport_cookies.py deleted file mode 100644 index e523e04ba99..00000000000 --- a/aiohttp/backport_cookies.py +++ /dev/null @@ -1,367 +0,0 @@ -#### -# Copyright 2000 by Timothy O'Malley -# -# All Rights Reserved -# -# Permission to use, copy, modify, and distribute this software -# and its documentation for any purpose and without fee is hereby -# granted, provided that the above copyright notice appear in all -# copies and that both that copyright notice and this permission -# notice appear in supporting documentation, and that the name of -# Timothy O'Malley not be used in advertising or publicity -# pertaining to distribution of the software without specific, written -# prior permission. -# -# Timothy O'Malley DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS -# SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY -# AND FITNESS, IN NO EVENT SHALL Timothy O'Malley BE LIABLE FOR -# ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, -# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS -# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR -# PERFORMANCE OF THIS SOFTWARE. -# -#### -# -# Id: Cookie.py,v 2.29 2000/08/23 05:28:49 timo Exp -# by Timothy O'Malley -# -# Cookie.py is a Python module for the handling of HTTP -# cookies as a Python dictionary. See RFC 2109 for more -# information on cookies. -# -# The original idea to treat Cookies as a dictionary came from -# Dave Mitchell (davem@magnet.com) in 1995, when he released the -# first version of nscookie.py. -# -#### - -import re # pragma: no cover -import string # pragma: no cover -from http.cookies import CookieError, Morsel # pragma: no cover - -__all__ = ["CookieError", "BaseCookie", "SimpleCookie"] # pragma: no cover - -_nulljoin = ''.join # pragma: no cover -_semispacejoin = '; '.join # pragma: no cover -_spacejoin = ' '.join # pragma: no cover - -# These quoting routines conform to the RFC2109 specification, which in -# turn references the character definitions from RFC2068. They provide -# a two-way quoting algorithm. Any non-text character is translated -# into a 4 character sequence: a forward-slash followed by the -# three-digit octal equivalent of the character. Any '\' or '"' is -# quoted with a preceeding '\' slash. -# -# These are taken from RFC2068 and RFC2109. -# _LegalChars is the list of chars which don't require "'s -# _Translator hash-table for fast quoting -# -_LegalChars = string.ascii_letters + string.digits + "!#$%&'*+-.^_`|~:" # pragma: no cover -_Translator = { # pragma: no cover - '\000' : '\\000', '\001' : '\\001', '\002' : '\\002', - '\003' : '\\003', '\004' : '\\004', '\005' : '\\005', - '\006' : '\\006', '\007' : '\\007', '\010' : '\\010', - '\011' : '\\011', '\012' : '\\012', '\013' : '\\013', - '\014' : '\\014', '\015' : '\\015', '\016' : '\\016', - '\017' : '\\017', '\020' : '\\020', '\021' : '\\021', - '\022' : '\\022', '\023' : '\\023', '\024' : '\\024', - '\025' : '\\025', '\026' : '\\026', '\027' : '\\027', - '\030' : '\\030', '\031' : '\\031', '\032' : '\\032', - '\033' : '\\033', '\034' : '\\034', '\035' : '\\035', - '\036' : '\\036', '\037' : '\\037', - - # Because of the way browsers really handle cookies (as opposed - # to what the RFC says) we also encode , and ; - - ',' : '\\054', ';' : '\\073', - - '"' : '\\"', '\\' : '\\\\', - - '\177' : '\\177', '\200' : '\\200', '\201' : '\\201', - '\202' : '\\202', '\203' : '\\203', '\204' : '\\204', - '\205' : '\\205', '\206' : '\\206', '\207' : '\\207', - '\210' : '\\210', '\211' : '\\211', '\212' : '\\212', - '\213' : '\\213', '\214' : '\\214', '\215' : '\\215', - '\216' : '\\216', '\217' : '\\217', '\220' : '\\220', - '\221' : '\\221', '\222' : '\\222', '\223' : '\\223', - '\224' : '\\224', '\225' : '\\225', '\226' : '\\226', - '\227' : '\\227', '\230' : '\\230', '\231' : '\\231', - '\232' : '\\232', '\233' : '\\233', '\234' : '\\234', - '\235' : '\\235', '\236' : '\\236', '\237' : '\\237', - '\240' : '\\240', '\241' : '\\241', '\242' : '\\242', - '\243' : '\\243', '\244' : '\\244', '\245' : '\\245', - '\246' : '\\246', '\247' : '\\247', '\250' : '\\250', - '\251' : '\\251', '\252' : '\\252', '\253' : '\\253', - '\254' : '\\254', '\255' : '\\255', '\256' : '\\256', - '\257' : '\\257', '\260' : '\\260', '\261' : '\\261', - '\262' : '\\262', '\263' : '\\263', '\264' : '\\264', - '\265' : '\\265', '\266' : '\\266', '\267' : '\\267', - '\270' : '\\270', '\271' : '\\271', '\272' : '\\272', - '\273' : '\\273', '\274' : '\\274', '\275' : '\\275', - '\276' : '\\276', '\277' : '\\277', '\300' : '\\300', - '\301' : '\\301', '\302' : '\\302', '\303' : '\\303', - '\304' : '\\304', '\305' : '\\305', '\306' : '\\306', - '\307' : '\\307', '\310' : '\\310', '\311' : '\\311', - '\312' : '\\312', '\313' : '\\313', '\314' : '\\314', - '\315' : '\\315', '\316' : '\\316', '\317' : '\\317', - '\320' : '\\320', '\321' : '\\321', '\322' : '\\322', - '\323' : '\\323', '\324' : '\\324', '\325' : '\\325', - '\326' : '\\326', '\327' : '\\327', '\330' : '\\330', - '\331' : '\\331', '\332' : '\\332', '\333' : '\\333', - '\334' : '\\334', '\335' : '\\335', '\336' : '\\336', - '\337' : '\\337', '\340' : '\\340', '\341' : '\\341', - '\342' : '\\342', '\343' : '\\343', '\344' : '\\344', - '\345' : '\\345', '\346' : '\\346', '\347' : '\\347', - '\350' : '\\350', '\351' : '\\351', '\352' : '\\352', - '\353' : '\\353', '\354' : '\\354', '\355' : '\\355', - '\356' : '\\356', '\357' : '\\357', '\360' : '\\360', - '\361' : '\\361', '\362' : '\\362', '\363' : '\\363', - '\364' : '\\364', '\365' : '\\365', '\366' : '\\366', - '\367' : '\\367', '\370' : '\\370', '\371' : '\\371', - '\372' : '\\372', '\373' : '\\373', '\374' : '\\374', - '\375' : '\\375', '\376' : '\\376', '\377' : '\\377' - } - -def _quote(str, LegalChars=_LegalChars): # pragma: no cover - r"""Quote a string for use in a cookie header. - - If the string does not need to be double-quoted, then just return the - string. Otherwise, surround the string in doublequotes and quote - (with a \) special characters. - """ - if all(c in LegalChars for c in str): - return str - else: - return '"' + _nulljoin(_Translator.get(s, s) for s in str) + '"' - - -_OctalPatt = re.compile(r"\\[0-3][0-7][0-7]") # pragma: no cover -_QuotePatt = re.compile(r"[\\].") # pragma: no cover - -def _unquote(str): # pragma: no cover - # If there aren't any doublequotes, - # then there can't be any special characters. See RFC 2109. - if len(str) < 2: - return str - if str[0] != '"' or str[-1] != '"': - return str - - # We have to assume that we must decode this string. - # Down to work. - - # Remove the "s - str = str[1:-1] - - # Check for special sequences. Examples: - # \012 --> \n - # \" --> " - # - i = 0 - n = len(str) - res = [] - while 0 <= i < n: - o_match = _OctalPatt.search(str, i) - q_match = _QuotePatt.search(str, i) - if not o_match and not q_match: # Neither matched - res.append(str[i:]) - break - # else: - j = k = -1 - if o_match: - j = o_match.start(0) - if q_match: - k = q_match.start(0) - if q_match and (not o_match or k < j): # QuotePatt matched - res.append(str[i:k]) - res.append(str[k+1]) - i = k + 2 - else: # OctalPatt matched - res.append(str[i:j]) - res.append(chr(int(str[j+1:j+4], 8))) - i = j + 4 - return _nulljoin(res) - -# The _getdate() routine is used to set the expiration time in the cookie's HTTP -# header. By default, _getdate() returns the current time in the appropriate -# "expires" format for a Set-Cookie header. The one optional argument is an -# offset from now, in seconds. For example, an offset of -3600 means "one hour -# ago". The offset may be a floating point number. -# - -_weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] # pragma: no cover - -_monthname = [None, - 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] # pragma: no cover - -def _getdate(future=0, weekdayname=_weekdayname, monthname=_monthname): # pragma: no cover - from time import gmtime, time - now = time() - year, month, day, hh, mm, ss, wd, y, z = gmtime(now + future) - return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % \ - (weekdayname[wd], day, monthname[month], year, hh, mm, ss) - - -# -# Pattern for finding cookie -# -# This used to be strict parsing based on the RFC2109 and RFC2068 -# specifications. I have since discovered that MSIE 3.0x doesn't -# follow the character rules outlined in those specs. As a -# result, the parsing rules here are less strict. -# - -_LegalKeyChars = r"\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=" # pragma: no cover -_LegalValueChars = _LegalKeyChars + r'\[\]' # pragma: no cover -_CookiePattern = re.compile(r""" - (?x) # This is a verbose pattern - \s* # Optional whitespace at start of cookie - (?P # Start of group 'key' - [""" + _LegalKeyChars + r"""]+? # Any word of at least one letter - ) # End of group 'key' - ( # Optional group: there may not be a value. - \s*=\s* # Equal Sign - (?P # Start of group 'val' - "(?:[^\\"]|\\.)*" # Any doublequoted string - | # or - \w{3},\s[\w\d\s-]{9,11}\s[\d:]{8}\sGMT # Special case for "expires" attr - | # or - [""" + _LegalValueChars + r"""]* # Any word or empty string - ) # End of group 'val' - )? # End of optional value group - \s* # Any number of spaces. - (\s+|;|$) # Ending either at space, semicolon, or EOS. - """, re.ASCII) # pragma: no cover - - -# At long last, here is the cookie class. Using this class is almost just like -# using a dictionary. See this module's docstring for example usage. -# -class BaseCookie(dict): # pragma: no cover - """A container class for a set of Morsels.""" - - def value_decode(self, val): - """real_value, coded_value = value_decode(STRING) - Called prior to setting a cookie's value from the network - representation. The VALUE is the value read from HTTP - header. - Override this function to modify the behavior of cookies. - """ - return val, val - - def value_encode(self, val): - """real_value, coded_value = value_encode(VALUE) - Called prior to setting a cookie's value from the dictionary - representation. The VALUE is the value being assigned. - Override this function to modify the behavior of cookies. - """ - strval = str(val) - return strval, strval - - def __init__(self, input=None): - if input: - self.load(input) - - def __set(self, key, real_value, coded_value): - """Private method for setting a cookie's value""" - M = self.get(key, Morsel()) - M.set(key, real_value, coded_value) - dict.__setitem__(self, key, M) - - def __setitem__(self, key, value): - """Dictionary style assignment.""" - if isinstance(value, Morsel): - # allow assignment of constructed Morsels (e.g. for pickling) - dict.__setitem__(self, key, value) - else: - rval, cval = self.value_encode(value) - self.__set(key, rval, cval) - - def output(self, attrs=None, header="Set-Cookie:", sep="\015\012"): - """Return a string suitable for HTTP.""" - result = [] - items = sorted(self.items()) - for key, value in items: - result.append(value.output(attrs, header)) - return sep.join(result) - - __str__ = output - - def __repr__(self): - l = [] - items = sorted(self.items()) - for key, value in items: - l.append('%s=%s' % (key, repr(value.value))) - return '<%s: %s>' % (self.__class__.__name__, _spacejoin(l)) - - def js_output(self, attrs=None): - """Return a string suitable for JavaScript.""" - result = [] - items = sorted(self.items()) - for key, value in items: - result.append(value.js_output(attrs)) - return _nulljoin(result) - - def load(self, rawdata): - """Load cookies from a string (presumably HTTP_COOKIE) or - from a dictionary. Loading cookies from a dictionary 'd' - is equivalent to calling: - map(Cookie.__setitem__, d.keys(), d.values()) - """ - if isinstance(rawdata, str): - self.__parse_string(rawdata) - else: - # self.update() wouldn't call our custom __setitem__ - for key, value in rawdata.items(): - self[key] = value - return - - def __parse_string(self, str, patt=_CookiePattern): - i = 0 # Our starting point - n = len(str) # Length of string - M = None # current morsel - - while 0 <= i < n: - # Start looking for a cookie - match = patt.match(str, i) - if not match: - # No more cookies - break - - key, value = match.group("key"), match.group("val") - i = match.end(0) - - # Parse the key, value in case it's metainfo - if key[0] == "$": - # We ignore attributes which pertain to the cookie - # mechanism as a whole. See RFC 2109. - # (Does anyone care?) - if M: - M[key[1:]] = value - elif key.lower() in Morsel._reserved: - if M: - if value is None: - if key.lower() in Morsel._flags: - M[key] = True - else: - M[key] = _unquote(value) - elif value is not None: - rval, cval = self.value_decode(value) - self.__set(key, rval, cval) - M = self[key] - - -class SimpleCookie(BaseCookie): # pragma: no cover - """ - SimpleCookie supports strings as cookie values. When setting - the value using the dictionary assignment notation, `SimpleCookie` - calls the builtin `str()` to convert the value to a string. Values - received from HTTP are kept as strings. - """ - def value_decode(self, val): # pragma: no cover - return _unquote(val), val - - def value_encode(self, val): # pragma: no cover - strval = str(val) - return strval, _quote(strval) diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py new file mode 100644 index 00000000000..01e18310b47 --- /dev/null +++ b/aiohttp/base_protocol.py @@ -0,0 +1,87 @@ +import asyncio +from typing import Optional, cast + +from .tcp_helpers import tcp_nodelay + + +class BaseProtocol(asyncio.Protocol): + __slots__ = ( + "_loop", + "_paused", + "_drain_waiter", + "_connection_lost", + "_reading_paused", + "transport", + ) + + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop # type: asyncio.AbstractEventLoop + self._paused = False + self._drain_waiter = None # type: Optional[asyncio.Future[None]] + self._connection_lost = False + self._reading_paused = False + + self.transport = None # type: Optional[asyncio.Transport] + + def pause_writing(self) -> None: + assert not self._paused + self._paused = True + + def resume_writing(self) -> None: + assert self._paused + self._paused = False + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def pause_reading(self) -> None: + if not self._reading_paused and self.transport is not None: + try: + self.transport.pause_reading() + except (AttributeError, NotImplementedError, RuntimeError): + pass + self._reading_paused = True + + def resume_reading(self) -> None: + if self._reading_paused and self.transport is not None: + try: + self.transport.resume_reading() + except (AttributeError, NotImplementedError, RuntimeError): + pass + self._reading_paused = False + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + tr = cast(asyncio.Transport, transport) + tcp_nodelay(tr, True) + self.transport = tr + + def connection_lost(self, exc: Optional[BaseException]) -> None: + self._connection_lost = True + # Wake up the writer if currently paused. + self.transport = None + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + async def _drain_helper(self) -> None: + if self._connection_lost: + raise ConnectionResetError("Connection lost") + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = self._loop.create_future() + self._drain_waiter = waiter + await waiter diff --git a/aiohttp/client.py b/aiohttp/client.py index d1154ebfbaa..a9da8e155d5 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -8,106 +8,289 @@ import sys import traceback import warnings - +from types import SimpleNamespace, TracebackType +from typing import ( + Any, + Awaitable, + Callable, + Coroutine, + FrozenSet, + Generator, + Generic, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +import attr from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr from yarl import URL -from . import connector as connector_mod -from . import client_exceptions, client_reqrep, hdrs, http, payload -from .client_exceptions import * # noqa -from .client_exceptions import (ClientError, ClientOSError, - ClientResponseError, ServerTimeoutError, - WSServerHandshakeError) -from .client_reqrep import * # noqa -from .client_reqrep import ClientRequest, ClientResponse -from .client_ws import ClientWebSocketResponse -from .connector import * # noqa -from .connector import TCPConnector +from . import hdrs, http, payload +from .abc import AbstractCookieJar +from .client_exceptions import ( + ClientConnectionError as ClientConnectionError, + ClientConnectorCertificateError as ClientConnectorCertificateError, + ClientConnectorError as ClientConnectorError, + ClientConnectorSSLError as ClientConnectorSSLError, + ClientError as ClientError, + ClientHttpProxyError as ClientHttpProxyError, + ClientOSError as ClientOSError, + ClientPayloadError as ClientPayloadError, + ClientProxyConnectionError as ClientProxyConnectionError, + ClientResponseError as ClientResponseError, + ClientSSLError as ClientSSLError, + ContentTypeError as ContentTypeError, + InvalidURL as InvalidURL, + ServerConnectionError as ServerConnectionError, + ServerDisconnectedError as ServerDisconnectedError, + ServerFingerprintMismatch as ServerFingerprintMismatch, + ServerTimeoutError as ServerTimeoutError, + TooManyRedirects as TooManyRedirects, + WSServerHandshakeError as WSServerHandshakeError, +) +from .client_reqrep import ( + ClientRequest as ClientRequest, + ClientResponse as ClientResponse, + Fingerprint as Fingerprint, + RequestInfo as RequestInfo, + _merge_ssl_params, +) +from .client_ws import ClientWebSocketResponse as ClientWebSocketResponse +from .connector import ( + BaseConnector as BaseConnector, + NamedPipeConnector as NamedPipeConnector, + TCPConnector as TCPConnector, + UnixConnector as UnixConnector, +) from .cookiejar import CookieJar -from .helpers import PY_35, CeilTimeout, TimeoutHandle, deprecated_noop -from .http import WS_KEY, WebSocketReader, WebSocketWriter +from .helpers import ( + DEBUG, + PY_36, + BasicAuth, + CeilTimeout, + TimeoutHandle, + get_running_loop, + proxies_from_env, + sentinel, + strip_auth_from_url, +) +from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter +from .http_websocket import WSHandshakeError, WSMessage, ws_ext_gen, ws_ext_parse from .streams import FlowControlDataQueue - -__all__ = (client_exceptions.__all__ + # noqa - client_reqrep.__all__ + # noqa - connector_mod.__all__ + # noqa - ('ClientSession', 'ClientWebSocketResponse', 'request')) - - -# 5 Minute default read and connect timeout -DEFAULT_TIMEOUT = 5 * 60 +from .tracing import Trace, TraceConfig +from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, StrOrURL + +__all__ = ( + # client_exceptions + "ClientConnectionError", + "ClientConnectorCertificateError", + "ClientConnectorError", + "ClientConnectorSSLError", + "ClientError", + "ClientHttpProxyError", + "ClientOSError", + "ClientPayloadError", + "ClientProxyConnectionError", + "ClientResponseError", + "ClientSSLError", + "ContentTypeError", + "InvalidURL", + "ServerConnectionError", + "ServerDisconnectedError", + "ServerFingerprintMismatch", + "ServerTimeoutError", + "TooManyRedirects", + "WSServerHandshakeError", + # client_reqrep + "ClientRequest", + "ClientResponse", + "Fingerprint", + "RequestInfo", + # connector + "BaseConnector", + "TCPConnector", + "UnixConnector", + "NamedPipeConnector", + # client_ws + "ClientWebSocketResponse", + # client + "ClientSession", + "ClientTimeout", + "request", +) + + +try: + from ssl import SSLContext +except ImportError: # pragma: no cover + SSLContext = object # type: ignore + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class ClientTimeout: + total: Optional[float] = None + connect: Optional[float] = None + sock_read: Optional[float] = None + sock_connect: Optional[float] = None + + # pool_queue_timeout: Optional[float] = None + # dns_resolution_timeout: Optional[float] = None + # socket_connect_timeout: Optional[float] = None + # connection_acquiring_timeout: Optional[float] = None + # new_connection_timeout: Optional[float] = None + # http_header_timeout: Optional[float] = None + # response_body_timeout: Optional[float] = None + + # to create a timeout specific for a single request, either + # - create a completely new one to overwrite the default + # - or use http://www.attrs.org/en/stable/api.html#attr.evolve + # to overwrite the defaults + + +# 5 Minute default read timeout +DEFAULT_TIMEOUT = ClientTimeout(total=5 * 60) + +_RetType = TypeVar("_RetType") class ClientSession: """First-class interface for making HTTP requests.""" + ATTRS = frozenset( + [ + "_source_traceback", + "_connector", + "requote_redirect_url", + "_loop", + "_cookie_jar", + "_connector_owner", + "_default_auth", + "_version", + "_json_serialize", + "_requote_redirect_url", + "_timeout", + "_raise_for_status", + "_auto_decompress", + "_trust_env", + "_default_headers", + "_skip_auto_headers", + "_request_class", + "_response_class", + "_ws_response_class", + "_trace_configs", + "_read_bufsize", + ] + ) + _source_traceback = None - _connector = None - - def __init__(self, *, connector=None, loop=None, cookies=None, - headers=None, skip_auto_headers=None, - auth=None, json_serialize=json.dumps, - request_class=ClientRequest, response_class=ClientResponse, - ws_response_class=ClientWebSocketResponse, - version=http.HttpVersion11, - cookie_jar=None, connector_owner=True, raise_for_status=False, - read_timeout=None, conn_timeout=None): - - implicit_loop = False + + def __init__( + self, + *, + connector: Optional[BaseConnector] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + cookies: Optional[LooseCookies] = None, + headers: Optional[LooseHeaders] = None, + skip_auto_headers: Optional[Iterable[str]] = None, + auth: Optional[BasicAuth] = None, + json_serialize: JSONEncoder = json.dumps, + request_class: Type[ClientRequest] = ClientRequest, + response_class: Type[ClientResponse] = ClientResponse, + ws_response_class: Type[ClientWebSocketResponse] = ClientWebSocketResponse, + version: HttpVersion = http.HttpVersion11, + cookie_jar: Optional[AbstractCookieJar] = None, + connector_owner: bool = True, + raise_for_status: bool = False, + read_timeout: Union[float, object] = sentinel, + conn_timeout: Optional[float] = None, + timeout: Union[object, ClientTimeout] = sentinel, + auto_decompress: bool = True, + trust_env: bool = False, + requote_redirect_url: bool = True, + trace_configs: Optional[List[TraceConfig]] = None, + read_bufsize: int = 2 ** 16, + ) -> None: + if loop is None: if connector is not None: loop = connector._loop - else: - implicit_loop = True - loop = asyncio.get_event_loop() + + loop = get_running_loop(loop) if connector is None: connector = TCPConnector(loop=loop) if connector._loop is not loop: - raise RuntimeError( - "Session and connector has to use same event loop") + raise RuntimeError("Session and connector has to use same event loop") self._loop = loop if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) - if implicit_loop and not loop.is_running(): - warnings.warn("Creating a client session outside of coroutine is " - "a very dangerous idea", ResourceWarning, - stacklevel=2) - context = {'client_session': self, - 'message': 'Creating a client session outside ' - 'of coroutine'} - if self._source_traceback is not None: - context['source_traceback'] = self._source_traceback - loop.call_exception_handler(context) - if cookie_jar is None: cookie_jar = CookieJar(loop=loop) self._cookie_jar = cookie_jar if cookies is not None: self._cookie_jar.update_cookies(cookies) - self._connector = connector + + self._connector = connector # type: Optional[BaseConnector] self._connector_owner = connector_owner self._default_auth = auth self._version = version self._json_serialize = json_serialize - self._read_timeout = read_timeout - self._conn_timeout = conn_timeout + if timeout is sentinel: + self._timeout = DEFAULT_TIMEOUT + if read_timeout is not sentinel: + warnings.warn( + "read_timeout is deprecated, " "use timeout argument instead", + DeprecationWarning, + stacklevel=2, + ) + self._timeout = attr.evolve(self._timeout, total=read_timeout) + if conn_timeout is not None: + self._timeout = attr.evolve(self._timeout, connect=conn_timeout) + warnings.warn( + "conn_timeout is deprecated, " "use timeout argument instead", + DeprecationWarning, + stacklevel=2, + ) + else: + self._timeout = timeout # type: ignore + if read_timeout is not sentinel: + raise ValueError( + "read_timeout and timeout parameters " + "conflict, please setup " + "timeout.read" + ) + if conn_timeout is not None: + raise ValueError( + "conn_timeout and timeout parameters " + "conflict, please setup " + "timeout.connect" + ) self._raise_for_status = raise_for_status + self._auto_decompress = auto_decompress + self._trust_env = trust_env + self._requote_redirect_url = requote_redirect_url + self._read_bufsize = read_bufsize # Convert to list of tuples if headers: - headers = CIMultiDict(headers) + real_headers = CIMultiDict(headers) # type: CIMultiDict[str] else: - headers = CIMultiDict() - self._default_headers = headers + real_headers = CIMultiDict() + self._default_headers = real_headers # type: CIMultiDict[str] if skip_auto_headers is not None: - self._skip_auto_headers = frozenset([istr(i) - for i in skip_auto_headers]) + self._skip_auto_headers = frozenset([istr(i) for i in skip_auto_headers]) else: self._skip_auto_headers = frozenset() @@ -115,63 +298,99 @@ def __init__(self, *, connector=None, loop=None, cookies=None, self._response_class = response_class self._ws_response_class = ws_response_class - def __del__(self, _warnings=warnings): + self._trace_configs = trace_configs or [] + for trace_config in self._trace_configs: + trace_config.freeze() + + def __init_subclass__(cls: Type["ClientSession"]) -> None: + warnings.warn( + "Inheritance class {} from ClientSession " + "is discouraged".format(cls.__name__), + DeprecationWarning, + stacklevel=2, + ) + + if DEBUG: + + def __setattr__(self, name: str, val: Any) -> None: + if name not in self.ATTRS: + warnings.warn( + "Setting custom ClientSession.{} attribute " + "is discouraged".format(name), + DeprecationWarning, + stacklevel=2, + ) + super().__setattr__(name, val) + + def __del__(self, _warnings: Any = warnings) -> None: if not self.closed: - self.close() - - _warnings.warn("Unclosed client session {!r}".format(self), - ResourceWarning) - context = {'client_session': self, - 'message': 'Unclosed client session'} + if PY_36: + kwargs = {"source": self} + else: + kwargs = {} + _warnings.warn( + f"Unclosed client session {self!r}", ResourceWarning, **kwargs + ) + context = {"client_session": self, "message": "Unclosed client session"} if self._source_traceback is not None: - context['source_traceback'] = self._source_traceback + context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) - def request(self, method, url, **kwargs): + def request( + self, method: str, url: StrOrURL, **kwargs: Any + ) -> "_RequestContextManager": """Perform HTTP request.""" return _RequestContextManager(self._request(method, url, **kwargs)) - @asyncio.coroutine - def _request(self, method, url, *, - params=None, - data=None, - json=None, - headers=None, - skip_auto_headers=None, - auth=None, - allow_redirects=True, - max_redirects=10, - encoding=None, - compress=None, - chunked=None, - expect100=False, - read_until_eof=True, - proxy=None, - proxy_auth=None, - timeout=DEFAULT_TIMEOUT): + async def _request( + self, + method: str, + str_or_url: StrOrURL, + *, + params: Optional[Mapping[str, str]] = None, + data: Any = None, + json: Any = None, + cookies: Optional[LooseCookies] = None, + headers: Optional[LooseHeaders] = None, + skip_auto_headers: Optional[Iterable[str]] = None, + auth: Optional[BasicAuth] = None, + allow_redirects: bool = True, + max_redirects: int = 10, + compress: Optional[str] = None, + chunked: Optional[bool] = None, + expect100: bool = False, + raise_for_status: Optional[bool] = None, + read_until_eof: bool = True, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + timeout: Union[ClientTimeout, object] = sentinel, + verify_ssl: Optional[bool] = None, + fingerprint: Optional[bytes] = None, + ssl_context: Optional[SSLContext] = None, + ssl: Optional[Union[SSLContext, bool, Fingerprint]] = None, + proxy_headers: Optional[LooseHeaders] = None, + trace_request_ctx: Optional[SimpleNamespace] = None, + read_bufsize: Optional[int] = None, + ) -> ClientResponse: # NOTE: timeout clamps existing connect and read timeouts. We cannot # set the default to None because we need to detect if the user wants # to use the existing timeouts by setting timeout to None. - if encoding is not None: - warnings.warn( - "encoding parameter is not supported, " - "please use FormData(charset='utf-8') instead", - DeprecationWarning) - if self.closed: - raise RuntimeError('Session is closed') + raise RuntimeError("Session is closed") + + ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) if data is not None and json is not None: raise ValueError( - 'data and json parameters can not be used at the same time') + "data and json parameters can not be used at the same time" + ) elif json is not None: data = payload.JsonPayload(json, dumps=self._json_serialize) if not isinstance(chunked, bool) and chunked is not None: - warnings.warn( - 'Chunk size is deprecated #1615', DeprecationWarning) + warnings.warn("Chunk size is deprecated #1615", DeprecationWarning) redirects = 0 history = [] @@ -179,15 +398,12 @@ def _request(self, method, url, *, # Merge with default headers and transform to CIMultiDict headers = self._prepare_headers(headers) - if auth is None: - auth = self._default_auth - # It would be confusing if we support explicit Authorization header - # with `auth` argument - if (headers is not None and - auth is not None and - hdrs.AUTHORIZATION in headers): - raise ValueError("Can't combine `Authorization` header with " - "`auth` argument") + proxy_headers = self._prepare_headers(proxy_headers) + + try: + url = URL(str_or_url) + except ValueError as e: + raise InvalidURL(str_or_url) from e skip_headers = set(self._skip_auto_headers) if skip_auto_headers is not None: @@ -195,99 +411,207 @@ def _request(self, method, url, *, skip_headers.add(istr(i)) if proxy is not None: - proxy = URL(proxy) + try: + proxy = URL(proxy) + except ValueError as e: + raise InvalidURL(proxy) from e + if timeout is sentinel: + real_timeout = self._timeout # type: ClientTimeout + else: + if not isinstance(timeout, ClientTimeout): + real_timeout = ClientTimeout(total=timeout) # type: ignore + else: + real_timeout = timeout # timeout is cumulative for all request operations # (request, redirects, responses, data consuming) - tm = TimeoutHandle( - self._loop, timeout if timeout is not None else self._read_timeout) + tm = TimeoutHandle(self._loop, real_timeout.total) handle = tm.start() + if read_bufsize is None: + read_bufsize = self._read_bufsize + + traces = [ + Trace( + self, + trace_config, + trace_config.trace_config_ctx(trace_request_ctx=trace_request_ctx), + ) + for trace_config in self._trace_configs + ] + + for trace in traces: + await trace.send_request_start(method, url, headers) + timer = tm.timer() try: with timer: while True: - url = URL(url).with_fragment(None) - cookies = self._cookie_jar.filter_cookies(url) + url, auth_from_url = strip_auth_from_url(url) + if auth and auth_from_url: + raise ValueError( + "Cannot combine AUTH argument with " + "credentials encoded in URL" + ) + + if auth is None: + auth = auth_from_url + if auth is None: + auth = self._default_auth + # It would be confusing if we support explicit + # Authorization header with auth argument + if ( + headers is not None + and auth is not None + and hdrs.AUTHORIZATION in headers + ): + raise ValueError( + "Cannot combine AUTHORIZATION header " + "with AUTH argument or credentials " + "encoded in URL" + ) + + all_cookies = self._cookie_jar.filter_cookies(url) + + if cookies is not None: + tmp_cookie_jar = CookieJar() + tmp_cookie_jar.update_cookies(cookies) + req_cookies = tmp_cookie_jar.filter_cookies(url) + if req_cookies: + all_cookies.load(req_cookies) + + if proxy is not None: + proxy = URL(proxy) + elif self._trust_env: + for scheme, proxy_info in proxies_from_env().items(): + if scheme == url.scheme: + proxy = proxy_info.proxy + proxy_auth = proxy_info.proxy_auth + break req = self._request_class( - method, url, params=params, headers=headers, - skip_auto_headers=skip_headers, data=data, - cookies=cookies, auth=auth, version=version, - compress=compress, chunked=chunked, - expect100=expect100, loop=self._loop, + method, + url, + params=params, + headers=headers, + skip_auto_headers=skip_headers, + data=data, + cookies=all_cookies, + auth=auth, + version=version, + compress=compress, + chunked=chunked, + expect100=expect100, + loop=self._loop, response_class=self._response_class, - proxy=proxy, proxy_auth=proxy_auth, timer=timer) + proxy=proxy, + proxy_auth=proxy_auth, + timer=timer, + session=self, + ssl=ssl, + proxy_headers=proxy_headers, + traces=traces, + ) # connection timeout try: - with CeilTimeout(self._conn_timeout, loop=self._loop): - conn = yield from self._connector.connect(req) + with CeilTimeout(real_timeout.connect, loop=self._loop): + assert self._connector is not None + conn = await self._connector.connect( + req, traces=traces, timeout=real_timeout + ) except asyncio.TimeoutError as exc: raise ServerTimeoutError( - 'Connection timeout ' - 'to host {0}'.format(url)) from exc + "Connection timeout " "to host {}".format(url) + ) from exc + + assert conn.transport is not None + + assert conn.protocol is not None + conn.protocol.set_response_params( + timer=timer, + skip_payload=method.upper() == "HEAD", + read_until_eof=read_until_eof, + auto_decompress=self._auto_decompress, + read_timeout=real_timeout.sock_read, + read_bufsize=read_bufsize, + ) - conn.writer.set_tcp_nodelay(True) try: - resp = req.send(conn) try: - yield from resp.start(conn, read_until_eof) - except: - resp.close() + resp = await req.send(conn) + try: + await resp.start(conn) + except BaseException: + resp.close() + raise + except BaseException: conn.close() raise except ClientError: raise - except http.HttpProcessingError as exc: - raise ClientResponseError( - code=exc.code, - message=exc.message, headers=exc.headers) from exc except OSError as exc: raise ClientOSError(*exc.args) from exc self._cookie_jar.update_cookies(resp.cookies, resp.url) # redirects - if resp.status in (301, 302, 303, 307) and allow_redirects: + if resp.status in (301, 302, 303, 307, 308) and allow_redirects: + + for trace in traces: + await trace.send_request_redirect( + method, url, headers, resp + ) + redirects += 1 history.append(resp) if max_redirects and redirects >= max_redirects: resp.close() - break - else: - resp.release() + raise TooManyRedirects( + history[0].request_info, tuple(history) + ) # For 301 and 302, mimic IE, now changed in RFC # https://github.com/kennethreitz/requests/pull/269 - if (resp.status == 303 and - resp.method != hdrs.METH_HEAD) \ - or (resp.status in (301, 302) and - resp.method == hdrs.METH_POST): + if (resp.status == 303 and resp.method != hdrs.METH_HEAD) or ( + resp.status in (301, 302) and resp.method == hdrs.METH_POST + ): method = hdrs.METH_GET data = None if headers.get(hdrs.CONTENT_LENGTH): headers.pop(hdrs.CONTENT_LENGTH) - r_url = (resp.headers.get(hdrs.LOCATION) or - resp.headers.get(hdrs.URI)) + r_url = resp.headers.get(hdrs.LOCATION) or resp.headers.get( + hdrs.URI + ) if r_url is None: - raise RuntimeError( - "{0.method} {0.url} returns " - "a redirect [{0.status}] status " - "but response lacks a Location " - "or URI HTTP header".format(resp)) - r_url = URL(r_url) - - scheme = r_url.scheme - if scheme not in ('http', 'https', ''): + # see github.com/aio-libs/aiohttp/issues/2022 + break + else: + # reading from correct redirection + # response is forbidden + resp.release() + + try: + parsed_url = URL( + r_url, encoded=not self._requote_redirect_url + ) + + except ValueError as e: + raise InvalidURL(r_url) from e + + scheme = parsed_url.scheme + if scheme not in ("http", "https", ""): resp.close() - raise ValueError( - 'Can redirect only to http or https') + raise ValueError("Can redirect only to http or https") elif not scheme: - r_url = url.join(r_url) + parsed_url = url.join(parsed_url) + + if url.origin() != parsed_url.origin(): + auth = None + headers.pop(hdrs.AUTHORIZATION, None) - url = r_url + url = parsed_url params = None resp.release() continue @@ -295,7 +619,9 @@ def _request(self, method, url, *, break # check response status - if self._raise_for_status: + if raise_for_status is None: + raise_for_status = self._raise_for_status + if raise_for_status: resp.raise_for_status() # register connection @@ -306,158 +632,254 @@ def _request(self, method, url, *, handle.cancel() resp._history = tuple(history) + + for trace in traces: + await trace.send_request_end(method, url, headers, resp) return resp - except: + except BaseException as e: # cleanup timer tm.close() if handle: handle.cancel() handle = None + for trace in traces: + await trace.send_request_exception(method, url, headers, e) raise - def ws_connect(self, url, *, - protocols=(), - timeout=10.0, - receive_timeout=None, - autoclose=True, - autoping=True, - heartbeat=None, - auth=None, - origin=None, - headers=None, - proxy=None, - proxy_auth=None): + def ws_connect( + self, + url: StrOrURL, + *, + method: str = hdrs.METH_GET, + protocols: Iterable[str] = (), + timeout: float = 10.0, + receive_timeout: Optional[float] = None, + autoclose: bool = True, + autoping: bool = True, + heartbeat: Optional[float] = None, + auth: Optional[BasicAuth] = None, + origin: Optional[str] = None, + headers: Optional[LooseHeaders] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + ssl: Union[SSLContext, bool, None, Fingerprint] = None, + verify_ssl: Optional[bool] = None, + fingerprint: Optional[bytes] = None, + ssl_context: Optional[SSLContext] = None, + proxy_headers: Optional[LooseHeaders] = None, + compress: int = 0, + max_msg_size: int = 4 * 1024 * 1024, + ) -> "_WSRequestContextManager": """Initiate websocket connection.""" return _WSRequestContextManager( - self._ws_connect(url, - protocols=protocols, - timeout=timeout, - receive_timeout=receive_timeout, - autoclose=autoclose, - autoping=autoping, - heartbeat=heartbeat, - auth=auth, - origin=origin, - headers=headers, - proxy=proxy, - proxy_auth=proxy_auth)) - - @asyncio.coroutine - def _ws_connect(self, url, *, - protocols=(), - timeout=10.0, - receive_timeout=None, - autoclose=True, - autoping=True, - heartbeat=None, - auth=None, - origin=None, - headers=None, - proxy=None, - proxy_auth=None): + self._ws_connect( + url, + method=method, + protocols=protocols, + timeout=timeout, + receive_timeout=receive_timeout, + autoclose=autoclose, + autoping=autoping, + heartbeat=heartbeat, + auth=auth, + origin=origin, + headers=headers, + proxy=proxy, + proxy_auth=proxy_auth, + ssl=ssl, + verify_ssl=verify_ssl, + fingerprint=fingerprint, + ssl_context=ssl_context, + proxy_headers=proxy_headers, + compress=compress, + max_msg_size=max_msg_size, + ) + ) + + async def _ws_connect( + self, + url: StrOrURL, + *, + method: str = hdrs.METH_GET, + protocols: Iterable[str] = (), + timeout: float = 10.0, + receive_timeout: Optional[float] = None, + autoclose: bool = True, + autoping: bool = True, + heartbeat: Optional[float] = None, + auth: Optional[BasicAuth] = None, + origin: Optional[str] = None, + headers: Optional[LooseHeaders] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + ssl: Union[SSLContext, bool, None, Fingerprint] = None, + verify_ssl: Optional[bool] = None, + fingerprint: Optional[bytes] = None, + ssl_context: Optional[SSLContext] = None, + proxy_headers: Optional[LooseHeaders] = None, + compress: int = 0, + max_msg_size: int = 4 * 1024 * 1024, + ) -> ClientWebSocketResponse: if headers is None: - headers = CIMultiDict() + real_headers = CIMultiDict() # type: CIMultiDict[str] + else: + real_headers = CIMultiDict(headers) default_headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, - hdrs.SEC_WEBSOCKET_VERSION: '13', + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_VERSION: "13", } for key, value in default_headers.items(): - if key not in headers: - headers[key] = value + real_headers.setdefault(key, value) sec_key = base64.b64encode(os.urandom(16)) - headers[hdrs.SEC_WEBSOCKET_KEY] = sec_key.decode() + real_headers[hdrs.SEC_WEBSOCKET_KEY] = sec_key.decode() if protocols: - headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols) + real_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ",".join(protocols) if origin is not None: - headers[hdrs.ORIGIN] = origin + real_headers[hdrs.ORIGIN] = origin + if compress: + extstr = ws_ext_gen(compress=compress) + real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr + + ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) # send request - resp = yield from self.get(url, headers=headers, - read_until_eof=False, - auth=auth, - proxy=proxy, - proxy_auth=proxy_auth) + resp = await self.request( + method, + url, + headers=real_headers, + read_until_eof=False, + auth=auth, + proxy=proxy, + proxy_auth=proxy_auth, + ssl=ssl, + proxy_headers=proxy_headers, + ) try: # check handshake if resp.status != 101: raise WSServerHandshakeError( - message='Invalid response status', - code=resp.status, - headers=resp.headers) - - if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket': + resp.request_info, + resp.history, + message="Invalid response status", + status=resp.status, + headers=resp.headers, + ) + + if resp.headers.get(hdrs.UPGRADE, "").lower() != "websocket": raise WSServerHandshakeError( - message='Invalid upgrade header', - code=resp.status, - headers=resp.headers) - - if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade': + resp.request_info, + resp.history, + message="Invalid upgrade header", + status=resp.status, + headers=resp.headers, + ) + + if resp.headers.get(hdrs.CONNECTION, "").lower() != "upgrade": raise WSServerHandshakeError( - message='Invalid connection header', - code=resp.status, - headers=resp.headers) + resp.request_info, + resp.history, + message="Invalid connection header", + status=resp.status, + headers=resp.headers, + ) # key calculation - key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '') - match = base64.b64encode( - hashlib.sha1(sec_key + WS_KEY).digest()).decode() - if key != match: + r_key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, "") + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode() + if r_key != match: raise WSServerHandshakeError( - message='Invalid challenge response', - code=resp.status, - headers=resp.headers) + resp.request_info, + resp.history, + message="Invalid challenge response", + status=resp.status, + headers=resp.headers, + ) # websocket protocol protocol = None if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers: resp_protocols = [ - proto.strip() for proto in - resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')] + proto.strip() + for proto in resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",") + ] for proto in resp_protocols: if proto in protocols: protocol = proto break - proto = resp.connection.protocol + # websocket compress + notakeover = False + if compress: + compress_hdrs = resp.headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS) + if compress_hdrs: + try: + compress, notakeover = ws_ext_parse(compress_hdrs) + except WSHandshakeError as exc: + raise WSServerHandshakeError( + resp.request_info, + resp.history, + message=exc.args[0], + status=resp.status, + headers=resp.headers, + ) from exc + else: + compress = 0 + notakeover = False + + conn = resp.connection + assert conn is not None + conn_proto = conn.protocol + assert conn_proto is not None + transport = conn.transport + assert transport is not None reader = FlowControlDataQueue( - proto, limit=2 ** 16, loop=self._loop) - proto.set_parser(WebSocketReader(reader), reader) - resp.connection.writer.set_tcp_nodelay(True) - writer = WebSocketWriter(resp.connection.writer, use_mask=True) - except Exception: + conn_proto, 2 ** 16, loop=self._loop + ) # type: FlowControlDataQueue[WSMessage] + conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader) + writer = WebSocketWriter( + conn_proto, + transport, + use_mask=True, + compress=compress, + notakeover=notakeover, + ) + except BaseException: resp.close() raise else: - return self._ws_response_class(reader, - writer, - protocol, - resp, - timeout, - autoclose, - autoping, - self._loop, - receive_timeout=receive_timeout, - heartbeat=heartbeat) - - def _prepare_headers(self, headers): - """ Add default headers and transform it to CIMultiDict - """ + return self._ws_response_class( + reader, + writer, + protocol, + resp, + timeout, + autoclose, + autoping, + self._loop, + receive_timeout=receive_timeout, + heartbeat=heartbeat, + compress=compress, + client_notakeover=notakeover, + ) + + def _prepare_headers(self, headers: Optional[LooseHeaders]) -> "CIMultiDict[str]": + """Add default headers and transform it to CIMultiDict""" # Convert headers to MultiDict result = CIMultiDict(self._default_headers) if headers: if not isinstance(headers, (MultiDictProxy, MultiDict)): headers = CIMultiDict(headers) - added_names = set() + added_names = set() # type: Set[str] for key, value in headers.items(): if key in added_names: result.add(key, value) @@ -466,68 +888,74 @@ def _prepare_headers(self, headers): added_names.add(key) return result - def get(self, url, *, allow_redirects=True, **kwargs): + def get( + self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any + ) -> "_RequestContextManager": """Perform HTTP GET request.""" return _RequestContextManager( - self._request(hdrs.METH_GET, url, - allow_redirects=allow_redirects, - **kwargs)) + self._request(hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs) + ) - def options(self, url, *, allow_redirects=True, **kwargs): + def options( + self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any + ) -> "_RequestContextManager": """Perform HTTP OPTIONS request.""" return _RequestContextManager( - self._request(hdrs.METH_OPTIONS, url, - allow_redirects=allow_redirects, - **kwargs)) - - def head(self, url, *, allow_redirects=False, **kwargs): + self._request( + hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs + ) + ) + + def head( + self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any + ) -> "_RequestContextManager": """Perform HTTP HEAD request.""" return _RequestContextManager( - self._request(hdrs.METH_HEAD, url, - allow_redirects=allow_redirects, - **kwargs)) - - def post(self, url, *, data=None, **kwargs): + self._request( + hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs + ) + ) + + def post( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": """Perform HTTP POST request.""" return _RequestContextManager( - self._request(hdrs.METH_POST, url, - data=data, - **kwargs)) + self._request(hdrs.METH_POST, url, data=data, **kwargs) + ) - def put(self, url, *, data=None, **kwargs): + def put( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": """Perform HTTP PUT request.""" return _RequestContextManager( - self._request(hdrs.METH_PUT, url, - data=data, - **kwargs)) + self._request(hdrs.METH_PUT, url, data=data, **kwargs) + ) - def patch(self, url, *, data=None, **kwargs): + def patch( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": """Perform HTTP PATCH request.""" return _RequestContextManager( - self._request(hdrs.METH_PATCH, url, - data=data, - **kwargs)) + self._request(hdrs.METH_PATCH, url, data=data, **kwargs) + ) - def delete(self, url, **kwargs): + def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager": """Perform HTTP DELETE request.""" - return _RequestContextManager( - self._request(hdrs.METH_DELETE, url, - **kwargs)) + return _RequestContextManager(self._request(hdrs.METH_DELETE, url, **kwargs)) - def close(self): + async def close(self) -> None: """Close underlying connector. Release all acquired resources. """ if not self.closed: - if self._connector_owner: - self._connector.close() + if self._connector is not None and self._connector_owner: + await self._connector.close() self._connector = None - return deprecated_noop('ClientSession.close() is not coroutine') - @property - def closed(self): + def closed(self) -> bool: """Is client session closed. A readonly property. @@ -535,173 +963,245 @@ def closed(self): return self._connector is None or self._connector.closed @property - def connector(self): + def connector(self) -> Optional[BaseConnector]: """Connector instance used for the session.""" return self._connector @property - def cookie_jar(self): + def cookie_jar(self) -> AbstractCookieJar: """The session cookies.""" return self._cookie_jar @property - def version(self): + def version(self) -> Tuple[int, int]: """The session HTTP protocol version.""" return self._version @property - def loop(self): + def requote_redirect_url(self) -> bool: + """Do URL requoting on redirection handling.""" + return self._requote_redirect_url + + @requote_redirect_url.setter + def requote_redirect_url(self, val: bool) -> None: + """Do URL requoting on redirection handling.""" + warnings.warn( + "session.requote_redirect_url modification " "is deprecated #2778", + DeprecationWarning, + stacklevel=2, + ) + self._requote_redirect_url = val + + @property + def loop(self) -> asyncio.AbstractEventLoop: """Session's loop.""" + warnings.warn( + "client.loop property is deprecated", DeprecationWarning, stacklevel=2 + ) return self._loop - def detach(self): - """Detach connector from session without closing the former. + @property + def timeout(self) -> Union[object, ClientTimeout]: + """Timeout for the session.""" + return self._timeout - Session is switched to closed state anyway. + @property + def headers(self) -> "CIMultiDict[str]": + """The default headers of the client session.""" + return self._default_headers + + @property + def skip_auto_headers(self) -> FrozenSet[istr]: + """Headers for which autogeneration should be skipped""" + return self._skip_auto_headers + + @property + def auth(self) -> Optional[BasicAuth]: + """An object that represents HTTP Basic Authorization""" + return self._default_auth + + @property + def json_serialize(self) -> JSONEncoder: + """Json serializer callable""" + return self._json_serialize + + @property + def connector_owner(self) -> bool: + """Should connector be closed on session closing""" + return self._connector_owner + + @property + def raise_for_status( + self, + ) -> Union[bool, Callable[[ClientResponse], Awaitable[None]]]: """ - self._connector = None + Should `ClientResponse.raise_for_status()` + be called for each response + """ + return self._raise_for_status - def __enter__(self): - warnings.warn("Use async with instead", DeprecationWarning) - return self + @property + def auto_decompress(self) -> bool: + """Should the body response be automatically decompressed""" + return self._auto_decompress - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() + @property + def trust_env(self) -> bool: + """ + Should get proxies information + from HTTP_PROXY / HTTPS_PROXY environment variables + or ~/.netrc file if present + """ + return self._trust_env - if PY_35: - @asyncio.coroutine - def __aenter__(self): - return self + @property + def trace_configs(self) -> List[TraceConfig]: + """A list of TraceConfig instances used for client tracing""" + return self._trace_configs - @asyncio.coroutine - def __aexit__(self, exc_type, exc_val, exc_tb): - self.close() + def detach(self) -> None: + """Detach connector from session without closing the former. + Session is switched to closed state anyway. + """ + self._connector = None -if PY_35: - from collections.abc import Coroutine - base = Coroutine -else: - base = object + def __enter__(self) -> None: + raise TypeError("Use async with instead") + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + # __exit__ should exist in pair with __enter__ but never executed + pass # pragma: no cover -class _BaseRequestContextManager(base): + async def __aenter__(self) -> "ClientSession": + return self - __slots__ = ('_coro', '_resp', 'send', 'throw', 'close') + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() - def __init__(self, coro): - self._coro = coro - self._resp = None - self.send = coro.send - self.throw = coro.throw - self.close = coro.close - @property - def gi_frame(self): - return self._coro.gi_frame +class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType]): - @property - def gi_running(self): - return self._coro.gi_running + __slots__ = ("_coro", "_resp") - @property - def gi_code(self): - return self._coro.gi_code + def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None: + self._coro = coro - def __next__(self): - return self.send(None) + def send(self, arg: None) -> "asyncio.Future[Any]": + return self._coro.send(arg) - @asyncio.coroutine - def __iter__(self): - resp = yield from self._coro - return resp + def throw(self, arg: BaseException) -> None: # type: ignore + self._coro.throw(arg) - if PY_35: - def __await__(self): - resp = yield from self._coro - return resp + def close(self) -> None: + return self._coro.close() - @asyncio.coroutine - def __aenter__(self): - self._resp = yield from self._coro - return self._resp + def __await__(self) -> Generator[Any, None, _RetType]: + ret = self._coro.__await__() + return ret + def __iter__(self) -> Generator[Any, None, _RetType]: + return self.__await__() -if not PY_35: - try: - from asyncio import coroutines - coroutines._COROUTINE_TYPES += (_BaseRequestContextManager,) - except: # pragma: no cover - pass # Python 3.4.2 and 3.4.3 has no coroutines._COROUTINE_TYPES + async def __aenter__(self) -> _RetType: + self._resp = await self._coro + return self._resp -class _RequestContextManager(_BaseRequestContextManager): - if PY_35: - @asyncio.coroutine - def __aexit__(self, exc_type, exc, tb): - # We're basing behavior on the exception as it can be caused by - # user code unrelated to the status of the connection. If you - # would like to close a connection you must do that - # explicitly. Otherwise connection error handling should kick in - # and close/recycle the connection as required. - self._resp.release() +class _RequestContextManager(_BaseRequestContextManager[ClientResponse]): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + # We're basing behavior on the exception as it can be caused by + # user code unrelated to the status of the connection. If you + # would like to close a connection you must do that + # explicitly. Otherwise connection error handling should kick in + # and close/recycle the connection as required. + self._resp.release() -class _WSRequestContextManager(_BaseRequestContextManager): - if PY_35: - @asyncio.coroutine - def __aexit__(self, exc_type, exc, tb): - yield from self._resp.close() +class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + await self._resp.close() -class _SessionRequestContextManager(_RequestContextManager): +class _SessionRequestContextManager: - __slots__ = _RequestContextManager.__slots__ + ('_session', ) + __slots__ = ("_coro", "_resp", "_session") - def __init__(self, coro, session): - super().__init__(coro) + def __init__( + self, + coro: Coroutine["asyncio.Future[Any]", None, ClientResponse], + session: ClientSession, + ) -> None: + self._coro = coro + self._resp = None # type: Optional[ClientResponse] self._session = session - @asyncio.coroutine - def __iter__(self): + async def __aenter__(self) -> ClientResponse: try: - return (yield from self._coro) - except: - self._session.close() + self._resp = await self._coro + except BaseException: + await self._session.close() raise + else: + return self._resp - if PY_35: - def __await__(self): - try: - return (yield from self._coro) - except: - self._session.close() - raise - - def __del__(self): - self._session.close() - - -def request(method, url, *, - params=None, - data=None, - json=None, - headers=None, - skip_auto_headers=None, - cookies=None, - auth=None, - allow_redirects=True, - max_redirects=10, - encoding=None, - version=http.HttpVersion11, - compress=None, - chunked=None, - expect100=False, - connector=None, - loop=None, - read_until_eof=True, - proxy=None, - proxy_auth=None): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + assert self._resp is not None + self._resp.close() + await self._session.close() + + +def request( + method: str, + url: StrOrURL, + *, + params: Optional[Mapping[str, str]] = None, + data: Any = None, + json: Any = None, + headers: Optional[LooseHeaders] = None, + skip_auto_headers: Optional[Iterable[str]] = None, + auth: Optional[BasicAuth] = None, + allow_redirects: bool = True, + max_redirects: int = 10, + compress: Optional[str] = None, + chunked: Optional[bool] = None, + expect100: bool = False, + raise_for_status: Optional[bool] = None, + read_until_eof: bool = True, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + timeout: Union[ClientTimeout, object] = sentinel, + cookies: Optional[LooseCookies] = None, + version: HttpVersion = http.HttpVersion11, + connector: Optional[BaseConnector] = None, + read_bufsize: Optional[int] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> _SessionRequestContextManager: """Constructs and sends a request. Returns response object. method - HTTP method url - request url @@ -709,7 +1209,7 @@ def request(method, url, *, string of the new request data - (optional) Dictionary, bytes, or file-like object to send in the body of the request - json - (optional) Any json compatibile python object + json - (optional) Any json compatible python object headers - (optional) Dictionary of HTTP Headers to send with the request cookies - (optional) Dict object to send with the request @@ -727,12 +1227,14 @@ def request(method, url, *, read_until_eof - Read response until eof if response does not have Content-Length header. loop - Optional event loop. + timeout - Optional ClientTimeout settings structure, 5min + total timeout by default. Usage:: >>> import aiohttp - >>> resp = yield from aiohttp.request('GET', 'http://python.org/') + >>> resp = await aiohttp.request('GET', 'http://python.org/') >>> resp - >>> data = yield from resp.read() + >>> data = await resp.read() """ connector_owner = False if connector is None: @@ -740,24 +1242,34 @@ def request(method, url, *, connector = TCPConnector(loop=loop, force_close=True) session = ClientSession( - loop=loop, cookies=cookies, version=version, - connector=connector, connector_owner=connector_owner) + loop=loop, + cookies=cookies, + version=version, + timeout=timeout, + connector=connector, + connector_owner=connector_owner, + ) return _SessionRequestContextManager( - session._request(method, url, - params=params, - data=data, - json=json, - headers=headers, - skip_auto_headers=skip_auto_headers, - auth=auth, - allow_redirects=allow_redirects, - max_redirects=max_redirects, - encoding=encoding, - compress=compress, - chunked=chunked, - expect100=expect100, - read_until_eof=read_until_eof, - proxy=proxy, - proxy_auth=proxy_auth,), - session=session) + session._request( + method, + url, + params=params, + data=data, + json=json, + headers=headers, + skip_auto_headers=skip_auto_headers, + auth=auth, + allow_redirects=allow_redirects, + max_redirects=max_redirects, + compress=compress, + chunked=chunked, + expect100=expect100, + raise_for_status=raise_for_status, + read_until_eof=read_until_eof, + proxy=proxy, + proxy_auth=proxy_auth, + read_bufsize=read_bufsize, + ), + session, + ) diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index dab42ee93e4..f4be3bfb5e2 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -1,18 +1,44 @@ """HTTP related errors.""" -from asyncio import TimeoutError +import asyncio +import warnings +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union -__all__ = ( - 'ClientError', +from .typedefs import LooseHeaders + +try: + import ssl + + SSLContext = ssl.SSLContext +except ImportError: # pragma: no cover + ssl = SSLContext = None # type: ignore - 'ClientConnectionError', - 'ClientOSError', 'ClientConnectorError', 'ClientProxyConnectionError', - 'ServerConnectionError', 'ServerTimeoutError', 'ServerDisconnectedError', - 'ServerFingerprintMismatch', +if TYPE_CHECKING: # pragma: no cover + from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo +else: + RequestInfo = ClientResponse = ConnectionKey = None - 'ClientResponseError', 'ClientPayloadError', - 'ClientHttpProxyError', 'WSServerHandshakeError') +__all__ = ( + "ClientError", + "ClientConnectionError", + "ClientOSError", + "ClientConnectorError", + "ClientProxyConnectionError", + "ClientSSLError", + "ClientConnectorSSLError", + "ClientConnectorCertificateError", + "ServerConnectionError", + "ServerTimeoutError", + "ServerDisconnectedError", + "ServerFingerprintMismatch", + "ClientResponseError", + "ClientHttpProxyError", + "WSServerHandshakeError", + "ContentTypeError", + "ClientPayloadError", + "InvalidURL", +) class ClientError(Exception): @@ -20,23 +46,82 @@ class ClientError(Exception): class ClientResponseError(ClientError): - """Connection error during reading response.""" + """Connection error during reading response. - code = 0 - message = '' - headers = None + request_info: instance of RequestInfo + """ - def __init__(self, *, code=None, message='', headers=None): + def __init__( + self, + request_info: RequestInfo, + history: Tuple[ClientResponse, ...], + *, + code: Optional[int] = None, + status: Optional[int] = None, + message: str = "", + headers: Optional[LooseHeaders] = None, + ) -> None: + self.request_info = request_info if code is not None: - self.code = code - self.message = message - self.headers = headers - - super().__init__("%s, message='%s'" % (self.code, message)) - - -class ClientPayloadError(ClientError): - """Response payload error.""" + if status is not None: + raise ValueError( + "Both code and status arguments are provided; " + "code is deprecated, use status instead" + ) + warnings.warn( + "code argument is deprecated, use status instead", + DeprecationWarning, + stacklevel=2, + ) + if status is not None: + self.status = status + elif code is not None: + self.status = code + else: + self.status = 0 + self.message = message + self.headers = headers + self.history = history + self.args = (request_info, history) + + def __str__(self) -> str: + return "{}, message={!r}, url={!r}".format( + self.status, + self.message, + self.request_info.real_url, + ) + + def __repr__(self) -> str: + args = f"{self.request_info!r}, {self.history!r}" + if self.status != 0: + args += f", status={self.status!r}" + if self.message != "": + args += f", message={self.message!r}" + if self.headers is not None: + args += f", headers={self.headers!r}" + return "{}({})".format(type(self).__name__, args) + + @property + def code(self) -> int: + warnings.warn( + "code property is deprecated, use status instead", + DeprecationWarning, + stacklevel=2, + ) + return self.status + + @code.setter + def code(self, value: int) -> None: + warnings.warn( + "code property is deprecated, use status instead", + DeprecationWarning, + stacklevel=2, + ) + self.status = value + + +class ContentTypeError(ClientResponseError): + """ContentType found is not valid.""" class WSServerHandshakeError(ClientResponseError): @@ -52,6 +137,10 @@ class ClientHttpProxyError(ClientResponseError): """ +class TooManyRedirects(ClientResponseError): + """Client was redirected too many times.""" + + class ClientConnectionError(ClientError): """Base class for client socket errors.""" @@ -67,6 +156,36 @@ class ClientConnectorError(ClientOSError): connection to proxy can not be established. """ + def __init__(self, connection_key: ConnectionKey, os_error: OSError) -> None: + self._conn_key = connection_key + self._os_error = os_error + super().__init__(os_error.errno, os_error.strerror) + self.args = (connection_key, os_error) + + @property + def os_error(self) -> OSError: + return self._os_error + + @property + def host(self) -> str: + return self._conn_key.host + + @property + def port(self) -> Optional[int]: + return self._conn_key.port + + @property + def ssl(self) -> Union[SSLContext, None, bool, "Fingerprint"]: + return self._conn_key.ssl + + def __str__(self) -> str: + return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format( + self, self.ssl if self.ssl is not None else "default", self.strerror + ) + + # OSError.__reduce__ does too much black magick + __reduce__ = BaseException.__reduce__ + class ClientProxyConnectionError(ClientConnectorError): """Proxy connection error. @@ -83,21 +202,116 @@ class ServerConnectionError(ClientConnectionError): class ServerDisconnectedError(ServerConnectionError): """Server disconnected.""" + def __init__(self, message: Optional[str] = None) -> None: + if message is None: + message = "Server disconnected" + + self.args = (message,) + self.message = message + -class ServerTimeoutError(ServerConnectionError, TimeoutError): +class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError): """Server timeout error.""" class ServerFingerprintMismatch(ServerConnectionError): """SSL certificate does not match expected fingerprint.""" - def __init__(self, expected, got, host, port): + def __init__(self, expected: bytes, got: bytes, host: str, port: int) -> None: self.expected = expected self.got = got self.host = host self.port = port + self.args = (expected, got, host, port) + + def __repr__(self) -> str: + return "<{} expected={!r} got={!r} host={!r} port={!r}>".format( + self.__class__.__name__, self.expected, self.got, self.host, self.port + ) + + +class ClientPayloadError(ClientError): + """Response payload error.""" + + +class InvalidURL(ClientError, ValueError): + """Invalid URL. + + URL used for fetching is malformed, e.g. it doesn't contains host + part.""" + + # Derive from ValueError for backward compatibility + + def __init__(self, url: Any) -> None: + # The type of url is not yarl.URL because the exception can be raised + # on URL(url) call + super().__init__(url) + + @property + def url(self) -> Any: + return self.args[0] + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.url}>" + + +class ClientSSLError(ClientConnectorError): + """Base error for ssl.*Errors.""" + + +if ssl is not None: + cert_errors = (ssl.CertificateError,) + cert_errors_bases = ( + ClientSSLError, + ssl.CertificateError, + ) + + ssl_errors = (ssl.SSLError,) + ssl_error_bases = (ClientSSLError, ssl.SSLError) +else: # pragma: no cover + cert_errors = tuple() + cert_errors_bases = ( + ClientSSLError, + ValueError, + ) + + ssl_errors = tuple() + ssl_error_bases = (ClientSSLError,) + + +class ClientConnectorSSLError(*ssl_error_bases): # type: ignore + """Response ssl error.""" + + +class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore + """Response certificate error.""" + + def __init__( + self, connection_key: ConnectionKey, certificate_error: Exception + ) -> None: + self._conn_key = connection_key + self._certificate_error = certificate_error + self.args = (connection_key, certificate_error) + + @property + def certificate_error(self) -> Exception: + return self._certificate_error + + @property + def host(self) -> str: + return self._conn_key.host + + @property + def port(self) -> Optional[int]: + return self._conn_key.port + + @property + def ssl(self) -> bool: + return self._conn_key.is_ssl - def __repr__(self): - return '<{} expected={} got={} host={} port={}>'.format( - self.__class__.__name__, self.expected, self.got, - self.host, self.port) + def __str__(self) -> str: + return ( + "Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} " + "[{0.certificate_error.__class__.__name__}: " + "{0.certificate_error.args}]".format(self) + ) diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 718d49a4261..2973342e440 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -1,146 +1,195 @@ import asyncio -import asyncio.streams - -from .client_exceptions import (ClientOSError, ClientPayloadError, - ClientResponseError, ServerDisconnectedError) -from .http import HttpResponseParser, StreamWriter -from .streams import EMPTY_PAYLOAD, DataQueue - - -class ResponseHandler(DataQueue, asyncio.streams.FlowControlMixin): +from contextlib import suppress +from typing import Any, Optional, Tuple + +from .base_protocol import BaseProtocol +from .client_exceptions import ( + ClientOSError, + ClientPayloadError, + ServerDisconnectedError, + ServerTimeoutError, +) +from .helpers import BaseTimerContext +from .http import HttpResponseParser, RawResponseMessage +from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader + + +class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamReader]]): """Helper class to adapt between Protocol and StreamReader.""" - def __init__(self, *, loop=None, **kwargs): - asyncio.streams.FlowControlMixin.__init__(self, loop=loop) - DataQueue.__init__(self, loop=loop) + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + BaseProtocol.__init__(self, loop=loop) + DataQueue.__init__(self, loop) - self.paused = False - self.transport = None - self.writer = None self._should_close = False - self._message = None self._payload = None + self._skip_payload = False self._payload_parser = None - self._reading_paused = False self._timer = None - self._skip_status = () - self._tail = b'' + self._tail = b"" self._upgraded = False - self._parser = None + self._parser = None # type: Optional[HttpResponseParser] + + self._read_timeout = None # type: Optional[float] + self._read_timeout_handle = None # type: Optional[asyncio.TimerHandle] @property - def upgraded(self): + def upgraded(self) -> bool: return self._upgraded @property - def should_close(self): - if (self._payload is not None and - not self._payload.is_eof() or self._upgraded): + def should_close(self) -> bool: + if self._payload is not None and not self._payload.is_eof() or self._upgraded: return True - return (self._should_close or self._upgraded or - self.exception() is not None or - self._payload_parser is not None or - len(self) or self._tail) + return ( + self._should_close + or self._upgraded + or self.exception() is not None + or self._payload_parser is not None + or len(self) > 0 + or bool(self._tail) + ) + + def force_close(self) -> None: + self._should_close = True - def close(self): + def close(self) -> None: transport = self.transport if transport is not None: transport.close() self.transport = None self._payload = None - return transport + self._drop_timeout() - def is_connected(self): - return self.transport is not None + def is_connected(self) -> bool: + return self.transport is not None and not self.transport.is_closing() - def connection_made(self, transport): - self.transport = transport - self.writer = StreamWriter(self, transport, self._loop) + def connection_lost(self, exc: Optional[BaseException]) -> None: + self._drop_timeout() - def connection_lost(self, exc): if self._payload_parser is not None: - try: + with suppress(Exception): self._payload_parser.feed_eof() - except Exception: - pass - try: - self._parser.feed_eof() - except Exception as e: - if self._payload is not None: - self._payload.set_exception( - ClientPayloadError('Response payload is not completed')) + uncompleted = None + if self._parser is not None: + try: + uncompleted = self._parser.feed_eof() + except Exception: + if self._payload is not None: + self._payload.set_exception( + ClientPayloadError("Response payload is not completed") + ) if not self.is_eof(): if isinstance(exc, OSError): exc = ClientOSError(*exc.args) if exc is None: - exc = ServerDisconnectedError() - DataQueue.set_exception(self, exc) + exc = ServerDisconnectedError(uncompleted) + # assigns self._should_close to True as side effect, + # we do it anyway below + self.set_exception(exc) - self.transport = self.writer = None self._should_close = True self._parser = None - self._message = None self._payload = None self._payload_parser = None self._reading_paused = False super().connection_lost(exc) - def eof_received(self): - pass + def eof_received(self) -> None: + # should call parser.feed_eof() most likely + self._drop_timeout() - def pause_reading(self): - if not self._reading_paused: - try: - self.transport.pause_reading() - except (AttributeError, NotImplementedError, RuntimeError): - pass - self._reading_paused = True + def pause_reading(self) -> None: + super().pause_reading() + self._drop_timeout() - def resume_reading(self): - if self._reading_paused: - try: - self.transport.resume_reading() - except (AttributeError, NotImplementedError, RuntimeError): - pass - self._reading_paused = False + def resume_reading(self) -> None: + super().resume_reading() + self._reschedule_timeout() - def set_exception(self, exc): + def set_exception(self, exc: BaseException) -> None: self._should_close = True - + self._drop_timeout() super().set_exception(exc) - def set_parser(self, parser, payload): + def set_parser(self, parser: Any, payload: Any) -> None: + # TODO: actual types are: + # parser: WebSocketReader + # payload: FlowControlDataQueue + # but they are not generi enough + # Need an ABC for both types self._payload = payload self._payload_parser = parser + self._drop_timeout() + if self._tail: - data, self._tail = self._tail, None + data, self._tail = self._tail, b"" self.data_received(data) - def set_response_params(self, *, timer=None, - skip_payload=False, - skip_status_codes=(), - read_until_eof=False): + def set_response_params( + self, + *, + timer: Optional[BaseTimerContext] = None, + skip_payload: bool = False, + read_until_eof: bool = False, + auto_decompress: bool = True, + read_timeout: Optional[float] = None, + read_bufsize: int = 2 ** 16 + ) -> None: self._skip_payload = skip_payload - self._skip_status_codes = skip_status_codes - self._read_until_eof = read_until_eof + + self._read_timeout = read_timeout + self._reschedule_timeout() + self._parser = HttpResponseParser( - self, self._loop, timer=timer, + self, + self._loop, + read_bufsize, + timer=timer, payload_exception=ClientPayloadError, - read_until_eof=read_until_eof) + response_with_body=not skip_payload, + read_until_eof=read_until_eof, + auto_decompress=auto_decompress, + ) if self._tail: - data, self._tail = self._tail, b'' + data, self._tail = self._tail, b"" self.data_received(data) - def data_received(self, data): + def _drop_timeout(self) -> None: + if self._read_timeout_handle is not None: + self._read_timeout_handle.cancel() + self._read_timeout_handle = None + + def _reschedule_timeout(self) -> None: + timeout = self._read_timeout + if self._read_timeout_handle is not None: + self._read_timeout_handle.cancel() + + if timeout: + self._read_timeout_handle = self._loop.call_later( + timeout, self._on_read_timeout + ) + else: + self._read_timeout_handle = None + + def _on_read_timeout(self) -> None: + exc = ServerTimeoutError("Timeout on reading data from socket") + self.set_exception(exc) + if self._payload is not None: + self._payload.set_exception(exc) + + def data_received(self, data: bytes) -> None: + self._reschedule_timeout() + if not data: return @@ -163,28 +212,40 @@ def data_received(self, data): try: messages, upgraded, tail = self._parser.feed_data(data) except BaseException as exc: - self._should_close = True - self.set_exception( - ClientResponseError(code=400, message=str(exc))) - self.transport.close() + if self.transport is not None: + # connection.release() could be called BEFORE + # data_received(), the transport is already + # closed in this case + self.transport.close() + # should_close is True after the call + self.set_exception(exc) return self._upgraded = upgraded + payload = None for message, payload in messages: if message.should_close: self._should_close = True - self._message = message self._payload = payload - if (self._skip_payload or - message.code in self._skip_status_codes): - self.feed_data((message, EMPTY_PAYLOAD), 0) + if self._skip_payload or message.code in (204, 304): + self.feed_data((message, EMPTY_PAYLOAD), 0) # type: ignore else: self.feed_data((message, payload), 0) + if payload is not None: + # new message(s) was processed + # register timeout handler unsubscribing + # either on end-of-stream or immediately for + # EMPTY_PAYLOAD + if payload is not EMPTY_PAYLOAD: + payload.on_eof(self._drop_timeout) + else: + self._drop_timeout() - if upgraded: - self.data_received(tail) - else: - self._tail = tail + if tail: + if upgraded: + self.data_received(tail) + else: + self._tail = tail diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 1aafc877fc1..d826bfeb7e5 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -1,48 +1,240 @@ import asyncio +import codecs +import functools import io -import json +import re import sys import traceback import warnings -from http.cookies import CookieError, Morsel - +from hashlib import md5, sha1, sha256 +from http.cookies import CookieError, Morsel, SimpleCookie +from types import MappingProxyType, TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Type, + Union, + cast, +) + +import attr from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy from yarl import URL -from . import hdrs, helpers, http, payload -from .client_exceptions import (ClientConnectionError, ClientOSError, - ClientResponseError) +from . import hdrs, helpers, http, multipart, payload +from .abc import AbstractStreamWriter +from .client_exceptions import ( + ClientConnectionError, + ClientOSError, + ClientResponseError, + ContentTypeError, + InvalidURL, + ServerFingerprintMismatch, +) from .formdata import FormData -from .helpers import PY_35, HeadersMixin, SimpleCookie, TimerNoop, noop -from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, PayloadWriter +from .helpers import ( + PY_36, + BaseTimerContext, + BasicAuth, + HeadersMixin, + TimerNoop, + noop, + reify, + set_result, +) +from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter from .log import client_logger -from .streams import FlowControlStreamReader +from .streams import StreamReader +from .typedefs import ( + DEFAULT_JSON_DECODER, + JSONDecoder, + LooseCookies, + LooseHeaders, + RawHeaders, +) + +try: + import ssl + from ssl import SSLContext +except ImportError: # pragma: no cover + ssl = None # type: ignore + SSLContext = object # type: ignore try: import cchardet as chardet except ImportError: # pragma: no cover - import chardet + import chardet # type: ignore -__all__ = ('ClientRequest', 'ClientResponse') +__all__ = ("ClientRequest", "ClientResponse", "RequestInfo", "Fingerprint") -class ClientRequest: +if TYPE_CHECKING: # pragma: no cover + from .client import ClientSession + from .connector import Connection + from .tracing import Trace + + +json_re = re.compile(r"^application/(?:[\w.+-]+?\+)?json") + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class ContentDisposition: + type: Optional[str] + parameters: "MappingProxyType[str, str]" + filename: Optional[str] + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class RequestInfo: + url: URL + method: str + headers: "CIMultiDictProxy[str]" + real_url: URL = attr.ib() + + @real_url.default + def real_url_default(self) -> URL: + return self.url + + +class Fingerprint: + HASHFUNC_BY_DIGESTLEN = { + 16: md5, + 20: sha1, + 32: sha256, + } + + def __init__(self, fingerprint: bytes) -> None: + digestlen = len(fingerprint) + hashfunc = self.HASHFUNC_BY_DIGESTLEN.get(digestlen) + if not hashfunc: + raise ValueError("fingerprint has invalid length") + elif hashfunc is md5 or hashfunc is sha1: + raise ValueError( + "md5 and sha1 are insecure and " "not supported. Use sha256." + ) + self._hashfunc = hashfunc + self._fingerprint = fingerprint - GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS} + @property + def fingerprint(self) -> bytes: + return self._fingerprint + + def check(self, transport: asyncio.Transport) -> None: + if not transport.get_extra_info("sslcontext"): + return + sslobj = transport.get_extra_info("ssl_object") + cert = sslobj.getpeercert(binary_form=True) + got = self._hashfunc(cert).digest() + if got != self._fingerprint: + host, port, *_ = transport.get_extra_info("peername") + raise ServerFingerprintMismatch(self._fingerprint, got, host, port) + + +if ssl is not None: + SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None)) +else: # pragma: no cover + SSL_ALLOWED_TYPES = type(None) + + +def _merge_ssl_params( + ssl: Union["SSLContext", bool, Fingerprint, None], + verify_ssl: Optional[bool], + ssl_context: Optional["SSLContext"], + fingerprint: Optional[bytes], +) -> Union["SSLContext", bool, Fingerprint, None]: + if verify_ssl is not None and not verify_ssl: + warnings.warn( + "verify_ssl is deprecated, use ssl=False instead", + DeprecationWarning, + stacklevel=3, + ) + if ssl is not None: + raise ValueError( + "verify_ssl, ssl_context, fingerprint and ssl " + "parameters are mutually exclusive" + ) + else: + ssl = False + if ssl_context is not None: + warnings.warn( + "ssl_context is deprecated, use ssl=context instead", + DeprecationWarning, + stacklevel=3, + ) + if ssl is not None: + raise ValueError( + "verify_ssl, ssl_context, fingerprint and ssl " + "parameters are mutually exclusive" + ) + else: + ssl = ssl_context + if fingerprint is not None: + warnings.warn( + "fingerprint is deprecated, " "use ssl=Fingerprint(fingerprint) instead", + DeprecationWarning, + stacklevel=3, + ) + if ssl is not None: + raise ValueError( + "verify_ssl, ssl_context, fingerprint and ssl " + "parameters are mutually exclusive" + ) + else: + ssl = Fingerprint(fingerprint) + if not isinstance(ssl, SSL_ALLOWED_TYPES): + raise TypeError( + "ssl should be SSLContext, bool, Fingerprint or None, " + "got {!r} instead.".format(ssl) + ) + return ssl + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class ConnectionKey: + # the key should contain an information about used proxy / TLS + # to prevent reusing wrong connections from a pool + host: str + port: Optional[int] + is_ssl: bool + ssl: Union[SSLContext, None, bool, Fingerprint] + proxy: Optional[URL] + proxy_auth: Optional[BasicAuth] + proxy_headers_hash: Optional[int] # hash(CIMultiDict) + + +def _is_expected_content_type( + response_content_type: str, expected_content_type: str +) -> bool: + if expected_content_type == "application/json": + return json_re.match(response_content_type) is not None + return expected_content_type in response_content_type + + +class ClientRequest: + GET_METHODS = { + hdrs.METH_GET, + hdrs.METH_HEAD, + hdrs.METH_OPTIONS, + hdrs.METH_TRACE, + } POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT} - ALL_METHODS = GET_METHODS.union(POST_METHODS).union( - {hdrs.METH_DELETE, hdrs.METH_TRACE}) + ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE}) DEFAULT_HEADERS = { - hdrs.ACCEPT: '*/*', - hdrs.ACCEPT_ENCODING: 'gzip, deflate', + hdrs.ACCEPT: "*/*", + hdrs.ACCEPT_ENCODING: "gzip, deflate", } - body = b'' + body = b"" auth = None response = None - response_class = None _writer = None # async task for streaming data _continue = None # waiter future for '100 Continue' response @@ -52,34 +244,59 @@ class ClientRequest: # because _writer is instance method, thus it keeps a reference to self. # Until writer has finished finalizer will not be called. - def __init__(self, method, url, *, - params=None, headers=None, skip_auto_headers=frozenset(), - data=None, cookies=None, - auth=None, version=http.HttpVersion11, compress=None, - chunked=None, expect100=False, - loop=None, response_class=None, - proxy=None, proxy_auth=None, timer=None): + def __init__( + self, + method: str, + url: URL, + *, + params: Optional[Mapping[str, str]] = None, + headers: Optional[LooseHeaders] = None, + skip_auto_headers: Iterable[str] = frozenset(), + data: Any = None, + cookies: Optional[LooseCookies] = None, + auth: Optional[BasicAuth] = None, + version: http.HttpVersion = http.HttpVersion11, + compress: Optional[str] = None, + chunked: Optional[bool] = None, + expect100: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, + response_class: Optional[Type["ClientResponse"]] = None, + proxy: Optional[URL] = None, + proxy_auth: Optional[BasicAuth] = None, + timer: Optional[BaseTimerContext] = None, + session: Optional["ClientSession"] = None, + ssl: Union[SSLContext, bool, Fingerprint, None] = None, + proxy_headers: Optional[LooseHeaders] = None, + traces: Optional[List["Trace"]] = None, + ): if loop is None: loop = asyncio.get_event_loop() assert isinstance(url, URL), url assert isinstance(proxy, (URL, type(None))), proxy - + # FIXME: session is None in tests only, need to fix tests + # assert session is not None + self._session = cast("ClientSession", session) if params: q = MultiDict(url.query) url2 = url.with_query(params) q.extend(url2.query) url = url.with_query(q) - self.url = url.with_fragment(None) self.original_url = url + self.url = url.with_fragment(None) self.method = method.upper() self.chunked = chunked self.compress = compress self.loop = loop self.length = None - self.response_class = response_class or ClientResponse + if response_class is None: + real_response_class = ClientResponse + else: + real_response_class = response_class + self.response_class = real_response_class # type: Type[ClientResponse] self._timer = timer if timer is not None else TimerNoop() + self._ssl = ssl if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) @@ -91,140 +308,185 @@ def __init__(self, method, url, *, self.update_cookies(cookies) self.update_content_encoding(data) self.update_auth(auth) - self.update_proxy(proxy, proxy_auth) + self.update_proxy(proxy, proxy_auth, proxy_headers) - self.update_body_from_data(data, skip_auto_headers) - self.update_transfer_encoding() + self.update_body_from_data(data) + if data or self.method not in self.GET_METHODS: + self.update_transfer_encoding() self.update_expect_continue(expect100) + if traces is None: + traces = [] + self._traces = traces + + def is_ssl(self) -> bool: + return self.url.scheme in ("https", "wss") @property - def host(self): - return self.url.host + def ssl(self) -> Union["SSLContext", None, bool, Fingerprint]: + return self._ssl @property - def port(self): + def connection_key(self) -> ConnectionKey: + proxy_headers = self.proxy_headers + if proxy_headers: + h = hash( + tuple((k, v) for k, v in proxy_headers.items()) + ) # type: Optional[int] + else: + h = None + return ConnectionKey( + self.host, + self.port, + self.is_ssl(), + self.ssl, + self.proxy, + self.proxy_auth, + h, + ) + + @property + def host(self) -> str: + ret = self.url.raw_host + assert ret is not None + return ret + + @property + def port(self) -> Optional[int]: return self.url.port - def update_host(self, url): + @property + def request_info(self) -> RequestInfo: + headers = CIMultiDictProxy(self.headers) # type: CIMultiDictProxy[str] + return RequestInfo(self.url, self.method, headers, self.original_url) + + def update_host(self, url: URL) -> None: """Update destination host, port and connection type (ssl).""" # get host/port - if not url.host: - raise ValueError( - "Could not parse hostname from URL '{}'".format(url)) + if not url.raw_host: + raise InvalidURL(url) # basic auth info username, password = url.user, url.password if username: - self.auth = helpers.BasicAuth(username, password or '') + self.auth = helpers.BasicAuth(username, password or "") - # Record entire netloc for usage in host header - - scheme = url.scheme - self.ssl = scheme in ('https', 'wss') - - def update_version(self, version): + def update_version(self, version: Union[http.HttpVersion, str]) -> None: """Convert request version to two elements tuple. parser HTTP version '1.1' => (1, 1) """ if isinstance(version, str): - v = [l.strip() for l in version.split('.', 1)] + v = [part.strip() for part in version.split(".", 1)] try: - version = int(v[0]), int(v[1]) + version = http.HttpVersion(int(v[0]), int(v[1])) except ValueError: raise ValueError( - 'Can not parse http version number: {}' - .format(version)) from None + f"Can not parse http version number: {version}" + ) from None self.version = version - def update_headers(self, headers): + def update_headers(self, headers: Optional[LooseHeaders]) -> None: """Update request headers.""" - self.headers = CIMultiDict() + self.headers = CIMultiDict() # type: CIMultiDict[str] + + # add host + netloc = cast(str, self.url.raw_host) + if helpers.is_ipv6_address(netloc): + netloc = f"[{netloc}]" + if self.url.port is not None and not self.url.is_default_port(): + netloc += ":" + str(self.url.port) + self.headers[hdrs.HOST] = netloc + if headers: if isinstance(headers, (dict, MultiDictProxy, MultiDict)): - headers = headers.items() + headers = headers.items() # type: ignore - for key, value in headers: - self.headers.add(key, value) + for key, value in headers: # type: ignore + # A special case for Host header + if key.lower() == "host": + self.headers[key] = value + else: + self.headers.add(key, value) - def update_auto_headers(self, skip_auto_headers): - self.skip_auto_headers = skip_auto_headers - used_headers = set(self.headers) | skip_auto_headers + def update_auto_headers(self, skip_auto_headers: Iterable[str]) -> None: + self.skip_auto_headers = CIMultiDict( + (hdr, None) for hdr in sorted(skip_auto_headers) + ) + used_headers = self.headers.copy() + used_headers.extend(self.skip_auto_headers) # type: ignore for hdr, val in self.DEFAULT_HEADERS.items(): if hdr not in used_headers: self.headers.add(hdr, val) - # add host - if hdrs.HOST not in used_headers: - netloc = self.url.raw_host - if not self.url.is_default_port(): - netloc += ':' + str(self.url.port) - self.headers[hdrs.HOST] = netloc - if hdrs.USER_AGENT not in used_headers: self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE - def update_cookies(self, cookies): + def update_cookies(self, cookies: Optional[LooseCookies]) -> None: """Update request cookies header.""" if not cookies: return - c = SimpleCookie() + c = SimpleCookie() # type: SimpleCookie[str] if hdrs.COOKIE in self.headers: - c.load(self.headers.get(hdrs.COOKIE, '')) + c.load(self.headers.get(hdrs.COOKIE, "")) del self.headers[hdrs.COOKIE] - for name, value in cookies.items(): + if isinstance(cookies, Mapping): + iter_cookies = cookies.items() + else: + iter_cookies = cookies # type: ignore + for name, value in iter_cookies: if isinstance(value, Morsel): # Preserve coded_value mrsl_val = value.get(value.key, Morsel()) mrsl_val.set(value.key, value.value, value.coded_value) c[name] = mrsl_val else: - c[name] = value + c[name] = value # type: ignore - self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip() + self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip() - def update_content_encoding(self, data): + def update_content_encoding(self, data: Any) -> None: """Set request content encoding.""" if not data: return - enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower() + enc = self.headers.get(hdrs.CONTENT_ENCODING, "").lower() if enc: if self.compress: raise ValueError( - 'compress can not be set ' - 'if Content-Encoding header is set') + "compress can not be set " "if Content-Encoding header is set" + ) elif self.compress: if not isinstance(self.compress, str): - self.compress = 'deflate' + self.compress = "deflate" self.headers[hdrs.CONTENT_ENCODING] = self.compress self.chunked = True # enable chunked, no need to deal with length - def update_transfer_encoding(self): + def update_transfer_encoding(self) -> None: """Analyze transfer-encoding header.""" - te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower() + te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower() - if 'chunked' in te: + if "chunked" in te: if self.chunked: raise ValueError( - 'chunked can not be set ' - 'if "Transfer-Encoding: chunked" header is set') + "chunked can not be set " + 'if "Transfer-Encoding: chunked" header is set' + ) elif self.chunked: if hdrs.CONTENT_LENGTH in self.headers: raise ValueError( - 'chunked can not be set ' - 'if Content-Length header is set') + "chunked can not be set " "if Content-Length header is set" + ) - self.headers[hdrs.TRANSFER_ENCODING] = 'chunked' + self.headers[hdrs.TRANSFER_ENCODING] = "chunked" else: if hdrs.CONTENT_LENGTH not in self.headers: self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body)) - def update_auth(self, auth): + def update_auth(self, auth: Optional[BasicAuth]) -> None: """Set basic auth.""" if auth is None: auth = self.auth @@ -232,11 +494,11 @@ def update_auth(self, auth): return if not isinstance(auth, helpers.BasicAuth): - raise TypeError('BasicAuth() tuple is required instead') + raise TypeError("BasicAuth() tuple is required instead") self.headers[hdrs.AUTHORIZATION] = auth.encode() - def update_body_from_data(self, body, skip_auto_headers): + def update_body_from_data(self, body: Any) -> None: if not body: return @@ -261,94 +523,116 @@ def update_body_from_data(self, body, skip_auto_headers): if hdrs.CONTENT_LENGTH not in self.headers: self.headers[hdrs.CONTENT_LENGTH] = str(size) - # set content-type - if (hdrs.CONTENT_TYPE not in self.headers and - hdrs.CONTENT_TYPE not in skip_auto_headers): - self.headers[hdrs.CONTENT_TYPE] = body.content_type - # copy payload headers - if body.headers: - for (key, value) in body.headers.items(): - if key not in self.headers: - self.headers[key] = value - - def update_expect_continue(self, expect=False): + assert body.headers + for (key, value) in body.headers.items(): + if key in self.headers: + continue + if key in self.skip_auto_headers: + continue + self.headers[key] = value + + def update_expect_continue(self, expect: bool = False) -> None: if expect: - self.headers[hdrs.EXPECT] = '100-continue' - elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue': + self.headers[hdrs.EXPECT] = "100-continue" + elif self.headers.get(hdrs.EXPECT, "").lower() == "100-continue": expect = True if expect: - self._continue = helpers.create_future(self.loop) - - def update_proxy(self, proxy, proxy_auth): - if proxy and not proxy.scheme == 'http': + self._continue = self.loop.create_future() + + def update_proxy( + self, + proxy: Optional[URL], + proxy_auth: Optional[BasicAuth], + proxy_headers: Optional[LooseHeaders], + ) -> None: + if proxy and not proxy.scheme == "http": raise ValueError("Only http proxies are supported") if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth): raise ValueError("proxy_auth must be None or BasicAuth() tuple") self.proxy = proxy self.proxy_auth = proxy_auth + self.proxy_headers = proxy_headers - def keep_alive(self): + def keep_alive(self) -> bool: if self.version < HttpVersion10: # keep alive not supported at all return False if self.version == HttpVersion10: - if self.headers.get(hdrs.CONNECTION) == 'keep-alive': + if self.headers.get(hdrs.CONNECTION) == "keep-alive": return True else: # no headers means we close for Http 1.0 return False - elif self.headers.get(hdrs.CONNECTION) == 'close': + elif self.headers.get(hdrs.CONNECTION) == "close": return False return True - @asyncio.coroutine - def write_bytes(self, writer, conn): + async def write_bytes( + self, writer: AbstractStreamWriter, conn: "Connection" + ) -> None: """Support coroutines that yields bytes objects.""" # 100 response if self._continue is not None: - yield from writer.drain() - yield from self._continue + await writer.drain() + await self._continue + protocol = conn.protocol + assert protocol is not None try: if isinstance(self.body, payload.Payload): - yield from self.body.write(writer) + await self.body.write(writer) else: if isinstance(self.body, (bytes, bytearray)): - self.body = (self.body,) + self.body = (self.body,) # type: ignore for chunk in self.body: - writer.write(chunk) + await writer.write(chunk) # type: ignore - yield from writer.write_eof() + await writer.write_eof() except OSError as exc: new_exc = ClientOSError( - exc.errno, - 'Can not write request body for %s' % self.url) + exc.errno, "Can not write request body for %s" % self.url + ) new_exc.__context__ = exc new_exc.__cause__ = exc - conn.protocol.set_exception(new_exc) + protocol.set_exception(new_exc) + except asyncio.CancelledError as exc: + if not conn.closed: + protocol.set_exception(exc) except Exception as exc: - conn.protocol.set_exception(exc) + protocol.set_exception(exc) finally: self._writer = None - def send(self, conn): + async def send(self, conn: "Connection") -> "ClientResponse": # Specify request target: # - CONNECT request must send authority form URI # - not CONNECT proxy must send absolute form URI # - most common is origin form URI if self.method == hdrs.METH_CONNECT: - path = '{}:{}'.format(self.url.raw_host, self.url.port) - elif self.proxy and not self.ssl: + connect_host = self.url.raw_host + assert connect_host is not None + if helpers.is_ipv6_address(connect_host): + connect_host = f"[{connect_host}]" + path = f"{connect_host}:{self.url.port}" + elif self.proxy and not self.is_ssl(): path = str(self.url) else: path = self.url.raw_path if self.url.raw_query_string: - path += '?' + self.url.raw_query_string - - writer = PayloadWriter(conn.writer, self.loop) + path += "?" + self.url.raw_query_string + + protocol = conn.protocol + assert protocol is not None + writer = StreamWriter( + protocol, + self.loop, + on_chunk_sent=functools.partial( + self._on_chunk_request_sent, self.method, self.url + ), + ) if self.compress: writer.enable_compression(self.compress) @@ -357,116 +641,161 @@ def send(self, conn): writer.enable_chunking() # set default content-type - if (self.method in self.POST_METHODS and - hdrs.CONTENT_TYPE not in self.skip_auto_headers and - hdrs.CONTENT_TYPE not in self.headers): - self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream' + if ( + self.method in self.POST_METHODS + and hdrs.CONTENT_TYPE not in self.skip_auto_headers + and hdrs.CONTENT_TYPE not in self.headers + ): + self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream" # set the connection header connection = self.headers.get(hdrs.CONNECTION) if not connection: if self.keep_alive(): if self.version == HttpVersion10: - connection = 'keep-alive' + connection = "keep-alive" else: if self.version == HttpVersion11: - connection = 'close' + connection = "close" if connection is not None: self.headers[hdrs.CONNECTION] = connection # status + headers - status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( - self.method, path, self.version) - writer.write_headers(status_line, self.headers) - - self._writer = helpers.ensure_future( - self.write_bytes(writer, conn), loop=self.loop) - - self.response = self.response_class( - self.method, self.original_url, - writer=self._writer, continue100=self._continue, timer=self._timer) - - self.response._post_init(self.loop) + status_line = "{0} {1} HTTP/{2[0]}.{2[1]}".format( + self.method, path, self.version + ) + await writer.write_headers(status_line, self.headers) + + self._writer = self.loop.create_task(self.write_bytes(writer, conn)) + + response_class = self.response_class + assert response_class is not None + self.response = response_class( + self.method, + self.original_url, + writer=self._writer, + continue100=self._continue, + timer=self._timer, + request_info=self.request_info, + traces=self._traces, + loop=self.loop, + session=self._session, + ) return self.response - @asyncio.coroutine - def close(self): + async def close(self) -> None: if self._writer is not None: try: - yield from self._writer + await self._writer finally: self._writer = None - def terminate(self): + def terminate(self) -> None: if self._writer is not None: if not self.loop.is_closed(): self._writer.cancel() self._writer = None + async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None: + for trace in self._traces: + await trace.send_request_chunk_sent(method, url, chunk) + class ClientResponse(HeadersMixin): # from the Status-Line of the response version = None # HTTP-Version - status = None # Status-Code - reason = None # Reason-Phrase + status = None # type: int # Status-Code + reason = None # Reason-Phrase - content = None # Payload stream - headers = None # Response headers, CIMultiDictProxy - raw_headers = None # Response raw headers, a sequence of pairs + content = None # type: StreamReader # Payload stream + _headers = None # type: CIMultiDictProxy[str] # Response headers + _raw_headers = None # type: RawHeaders # Response raw headers _connection = None # current connection - flow_control_class = FlowControlStreamReader # reader flow control - _reader = None # input stream _source_traceback = None # setted up by ClientRequest after ClientResponse object creation # post-init stage allows to not change ctor signature - _loop = None _closed = True # to allow __del__ for non-initialized properly response - - def __init__(self, method, url, *, - writer=None, continue100=None, timer=None): + _released = False + + def __init__( + self, + method: str, + url: URL, + *, + writer: "asyncio.Task[None]", + continue100: Optional["asyncio.Future[bool]"], + timer: BaseTimerContext, + request_info: RequestInfo, + traces: List["Trace"], + loop: asyncio.AbstractEventLoop, + session: "ClientSession", + ) -> None: assert isinstance(url, URL) self.method = method - self.headers = None - self.cookies = SimpleCookie() + self.cookies = SimpleCookie() # type: SimpleCookie[str] - self._url = url - self._content = None - self._writer = writer - self._continue = continue100 + self._real_url = url + self._url = url.with_fragment(None) + self._body = None # type: Any + self._writer = writer # type: Optional[asyncio.Task[None]] + self._continue = continue100 # None by default self._closed = True - self._history = () + self._history = () # type: Tuple[ClientResponse, ...] + self._request_info = request_info self._timer = timer if timer is not None else TimerNoop() + self._cache = {} # type: Dict[str, Any] + self._traces = traces + self._loop = loop + # store a reference to session #1985 + self._session = session # type: Optional[ClientSession] + if loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) - @property - def url(self): + @reify + def url(self) -> URL: return self._url - @property - def url_obj(self): - warnings.warn( - "Deprecated, use .url #1654", DeprecationWarning, stacklevel=2) + @reify + def url_obj(self) -> URL: + warnings.warn("Deprecated, use .url #1654", DeprecationWarning, stacklevel=2) return self._url - @property - def host(self): + @reify + def real_url(self) -> URL: + return self._real_url + + @reify + def host(self) -> str: + assert self._url.host is not None return self._url.host - @property - def _headers(self): - return self.headers + @reify + def headers(self) -> "CIMultiDictProxy[str]": + return self._headers - def _post_init(self, loop): - self._loop = loop - if loop.get_debug(): - self._source_traceback = traceback.extract_stack(sys._getframe(1)) + @reify + def raw_headers(self) -> RawHeaders: + return self._raw_headers + + @reify + def request_info(self) -> RequestInfo: + return self._request_info - def __del__(self, _warnings=warnings): - if self._loop is None: - return # not started + @reify + def content_disposition(self) -> Optional[ContentDisposition]: + raw = self._headers.get(hdrs.CONTENT_DISPOSITION) + if raw is None: + return None + disposition_type, params_dct = multipart.parse_content_disposition(raw) + params = MappingProxyType(params_dct) + filename = multipart.content_disposition_filename(params) + return ContentDisposition(disposition_type, params, filename) + + def __del__(self, _warnings: Any = warnings) -> None: if self._closed: return @@ -474,64 +803,105 @@ def __del__(self, _warnings=warnings): self._connection.release() self._cleanup_writer() - # warn - if __debug__: - if self._loop.get_debug(): - _warnings.warn("Unclosed response {!r}".format(self), - ResourceWarning) - context = {'client_response': self, - 'message': 'Unclosed response'} - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) - - def __repr__(self): + if self._loop.get_debug(): + if PY_36: + kwargs = {"source": self} + else: + kwargs = {} + _warnings.warn(f"Unclosed response {self!r}", ResourceWarning, **kwargs) + context = {"client_response": self, "message": "Unclosed response"} + if self._source_traceback: + context["source_traceback"] = self._source_traceback + self._loop.call_exception_handler(context) + + def __repr__(self) -> str: out = io.StringIO() ascii_encodable_url = str(self.url) if self.reason: - ascii_encodable_reason = self.reason.encode('ascii', - 'backslashreplace') \ - .decode('ascii') + ascii_encodable_reason = self.reason.encode( + "ascii", "backslashreplace" + ).decode("ascii") else: ascii_encodable_reason = self.reason - print(''.format( - ascii_encodable_url, self.status, ascii_encodable_reason), - file=out) + print( + "".format( + ascii_encodable_url, self.status, ascii_encodable_reason + ), + file=out, + ) print(self.headers, file=out) return out.getvalue() @property - def connection(self): + def connection(self) -> Optional["Connection"]: return self._connection - @property - def history(self): + @reify + def history(self) -> Tuple["ClientResponse", ...]: """A sequence of of responses, if redirects occurred.""" return self._history - @asyncio.coroutine - def start(self, connection, read_until_eof=False): + @reify + def links(self) -> "MultiDictProxy[MultiDictProxy[Union[str, URL]]]": + links_str = ", ".join(self.headers.getall("link", [])) + + if not links_str: + return MultiDictProxy(MultiDict()) + + links = MultiDict() # type: MultiDict[MultiDictProxy[Union[str, URL]]] + + for val in re.split(r",(?=\s*<)", links_str): + match = re.match(r"\s*<(.*)>(.*)", val) + if match is None: # pragma: no cover + # the check exists to suppress mypy error + continue + url, params_str = match.groups() + params = params_str.split(";")[1:] + + link = MultiDict() # type: MultiDict[Union[str, URL]] + + for param in params: + match = re.match(r"^\s*(\S*)\s*=\s*(['\"]?)(.*?)(\2)\s*$", param, re.M) + if match is None: # pragma: no cover + # the check exists to suppress mypy error + continue + key, _, value, _ = match.groups() + + link.add(key, value) + + key = link.get("rel", url) # type: ignore + + link.add("url", self.url.join(URL(url))) + + links.add(key, MultiDictProxy(link)) + + return MultiDictProxy(links) + + async def start(self, connection: "Connection") -> "ClientResponse": """Start response processing.""" self._closed = False self._protocol = connection.protocol self._connection = connection - connection.protocol.set_response_params( - timer=self._timer, - skip_payload=self.method.lower() == 'head', - skip_status_codes=(204, 304), - read_until_eof=read_until_eof) - with self._timer: while True: # read response - (message, payload) = yield from self._protocol.read() - if (message.code < 100 or - message.code > 199 or message.code == 101): + try: + message, payload = await self._protocol.read() # type: ignore + except http.HttpProcessingError as exc: + raise ClientResponseError( + self.request_info, + self.history, + status=exc.code, + message=exc.message, + headers=exc.headers, + ) from exc + + if message.code < 100 or message.code > 199 or message.code == 101: break - if self._continue is not None and not self._continue.done(): - self._continue.set_result(True) + if self._continue is not None: + set_result(self._continue, True) self._continue = None # payload eof handler @@ -543,8 +913,8 @@ def start(self, connection, read_until_eof=False): self.reason = message.reason # headers - self.headers = CIMultiDictProxy(message.headers) - self.raw_headers = tuple(message.raw_headers) + self._headers = message.headers # type is CIMultiDictProxy + self._raw_headers = message.raw_headers # type is Tuple[bytes, bytes] # payload self.content = payload @@ -554,19 +924,20 @@ def start(self, connection, read_until_eof=False): try: self.cookies.load(hdr) except CookieError as exc: - client_logger.warning( - 'Can not load response cookies: %s', exc) + client_logger.warning("Can not load response cookies: %s", exc) return self - def _response_eof(self): + def _response_eof(self) -> None: if self._closed: return if self._connection is not None: # websocket, protocol could be None because # connection could be detached - if (self._connection.protocol is not None and - self._connection.protocol.upgraded): + if ( + self._connection.protocol is not None + and self._connection.protocol.upgraded + ): return self._connection.release() @@ -576,10 +947,12 @@ def _response_eof(self): self._cleanup_writer() @property - def closed(self): + def closed(self) -> bool: return self._closed - def close(self): + def close(self) -> None: + if not self._released: + self._notify_content() if self._closed: return @@ -591,9 +964,10 @@ def close(self): self._connection.close() self._connection = None self._cleanup_writer() - self._notify_content() - def release(self): + def release(self) -> Any: + if not self._released: + self._notify_content() if self._closed: return noop() @@ -603,107 +977,151 @@ def release(self): self._connection = None self._cleanup_writer() - self._notify_content() return noop() - def raise_for_status(self): + @property + def ok(self) -> bool: + """Returns ``True`` if ``status`` is less than ``400``, ``False`` if not. + + This is **not** a check for ``200 OK`` but a check that the response + status is under 400. + """ + try: + self.raise_for_status() + except ClientResponseError: + return False + return True + + def raise_for_status(self) -> None: if 400 <= self.status: + # reason should always be not None for a started response + assert self.reason is not None + self.release() raise ClientResponseError( - code=self.status, + self.request_info, + self.history, + status=self.status, message=self.reason, - headers=self.headers) + headers=self.headers, + ) - def _cleanup_writer(self): - if self._writer is not None and not self._writer.done(): + def _cleanup_writer(self) -> None: + if self._writer is not None: self._writer.cancel() self._writer = None + self._session = None - def _notify_content(self): + def _notify_content(self) -> None: content = self.content - if content and content.exception() is None and not content.is_eof(): - content.set_exception( - ClientConnectionError('Connection closed')) + if content and content.exception() is None: + content.set_exception(ClientConnectionError("Connection closed")) + self._released = True - @asyncio.coroutine - def wait_for_close(self): + async def wait_for_close(self) -> None: if self._writer is not None: try: - yield from self._writer + await self._writer finally: self._writer = None self.release() - @asyncio.coroutine - def read(self): + async def read(self) -> bytes: """Read response payload.""" - if self._content is None: + if self._body is None: try: - self._content = yield from self.content.read() - except: + self._body = await self.content.read() + for trace in self._traces: + await trace.send_response_chunk_received( + self.method, self.url, self._body + ) + except BaseException: self.close() raise + elif self._released: + raise ClientConnectionError("Connection closed") - return self._content + return self._body - def _get_encoding(self): - ctype = self.headers.get(hdrs.CONTENT_TYPE, '').lower() - mtype, stype, _, params = helpers.parse_mimetype(ctype) + def get_encoding(self) -> str: + ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower() + mimetype = helpers.parse_mimetype(ctype) - encoding = params.get('charset') + encoding = mimetype.parameters.get("charset") + if encoding: + try: + codecs.lookup(encoding) + except LookupError: + encoding = None if not encoding: - if mtype == 'application' and stype == 'json': + if mimetype.type == "application" and ( + mimetype.subtype == "json" or mimetype.subtype == "rdap" + ): # RFC 7159 states that the default encoding is UTF-8. - encoding = 'utf-8' + # RFC 7483 defines application/rdap+json + encoding = "utf-8" + elif self._body is None: + raise RuntimeError( + "Cannot guess the encoding of " "a not yet read body" + ) else: - encoding = chardet.detect(self._content)['encoding'] + encoding = chardet.detect(self._body)["encoding"] if not encoding: - encoding = 'utf-8' + encoding = "utf-8" return encoding - @asyncio.coroutine - def text(self, encoding=None, errors='strict'): + async def text(self, encoding: Optional[str] = None, errors: str = "strict") -> str: """Read response payload and decode.""" - if self._content is None: - yield from self.read() + if self._body is None: + await self.read() if encoding is None: - encoding = self._get_encoding() + encoding = self.get_encoding() - return self._content.decode(encoding, errors=errors) + return self._body.decode(encoding, errors=errors) # type: ignore - @asyncio.coroutine - def json(self, *, encoding=None, loads=json.loads, - content_type='application/json'): + async def json( + self, + *, + encoding: Optional[str] = None, + loads: JSONDecoder = DEFAULT_JSON_DECODER, + content_type: Optional[str] = "application/json", + ) -> Any: """Read and decodes JSON response.""" - if self._content is None: - yield from self.read() + if self._body is None: + await self.read() if content_type: - ctype = self.headers.get(hdrs.CONTENT_TYPE, '').lower() - if content_type not in ctype: - raise ClientResponseError( - message=('Attempt to decode JSON with ' - 'unexpected mimetype: %s' % ctype), - headers=self.headers) - - stripped = self._content.strip() + ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower() + if not _is_expected_content_type(ctype, content_type): + raise ContentTypeError( + self.request_info, + self.history, + message=( + "Attempt to decode JSON with " "unexpected mimetype: %s" % ctype + ), + headers=self.headers, + ) + + stripped = self._body.strip() # type: ignore if not stripped: return None if encoding is None: - encoding = self._get_encoding() + encoding = self.get_encoding() return loads(stripped.decode(encoding)) - if PY_35: - @asyncio.coroutine - def __aenter__(self): - return self + async def __aenter__(self) -> "ClientResponse": + return self - @asyncio.coroutine - def __aexit__(self, exc_type, exc_val, exc_tb): - # similar to _RequestContextManager, we do not need to check - # for exceptions, response object can closes connection - # is state is broken - self.release() + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + # similar to _RequestContextManager, we do not need to check + # for exceptions, response object can close connection + # if state is broken + self.release() diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index acf19094988..28fa371cce9 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -1,19 +1,47 @@ """WebSocket client for asyncio.""" import asyncio -import json +from typing import Any, Optional + +import async_timeout from .client_exceptions import ClientError -from .helpers import PY_35, PY_352, Timeout, call_later, create_future -from .http import (WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, - WebSocketError, WSMessage, WSMsgType) +from .client_reqrep import ClientResponse +from .helpers import call_later, set_result +from .http import ( + WS_CLOSED_MESSAGE, + WS_CLOSING_MESSAGE, + WebSocketError, + WSMessage, + WSMsgType, +) +from .http_websocket import WebSocketWriter # WSMessage +from .streams import EofStream, FlowControlDataQueue +from .typedefs import ( + DEFAULT_JSON_DECODER, + DEFAULT_JSON_ENCODER, + JSONDecoder, + JSONEncoder, +) class ClientWebSocketResponse: - - def __init__(self, reader, writer, protocol, - response, timeout, autoclose, autoping, loop, *, - receive_timeout=None, heartbeat=None): + def __init__( + self, + reader: "FlowControlDataQueue[WSMessage]", + writer: WebSocketWriter, + protocol: Optional[str], + response: ClientResponse, + timeout: float, + autoclose: bool, + autoping: bool, + loop: asyncio.AbstractEventLoop, + *, + receive_timeout: Optional[float] = None, + heartbeat: Optional[float] = None, + compress: int = 0, + client_notakeover: bool = False, + ) -> None: self._response = response self._conn = response.connection @@ -22,7 +50,7 @@ def __init__(self, reader, writer, protocol, self._protocol = protocol self._closed = False self._closing = False - self._close_code = None + self._close_code = None # type: Optional[int] self._timeout = timeout self._receive_timeout = receive_timeout self._autoclose = autoclose @@ -30,15 +58,17 @@ def __init__(self, reader, writer, protocol, self._heartbeat = heartbeat self._heartbeat_cb = None if heartbeat is not None: - self._pong_heartbeat = heartbeat/2.0 + self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb = None self._loop = loop - self._waiting = None - self._exception = None + self._waiting = None # type: Optional[asyncio.Future[bool]] + self._exception = None # type: Optional[BaseException] + self._compress = compress + self._client_notakeover = client_notakeover self._reset_heartbeat() - def _cancel_heartbeat(self): + def _cancel_heartbeat(self) -> None: if self._pong_response_cb is not None: self._pong_response_cb.cancel() self._pong_response_cb = None @@ -47,23 +77,28 @@ def _cancel_heartbeat(self): self._heartbeat_cb.cancel() self._heartbeat_cb = None - def _reset_heartbeat(self): + def _reset_heartbeat(self) -> None: self._cancel_heartbeat() if self._heartbeat is not None: self._heartbeat_cb = call_later( - self._send_heartbeat, self._heartbeat, self._loop) + self._send_heartbeat, self._heartbeat, self._loop + ) - def _send_heartbeat(self): + def _send_heartbeat(self) -> None: if self._heartbeat is not None and not self._closed: - self.ping() + # fire-and-forget a task is not perfect but maybe ok for + # sending ping. Otherwise we need a long-living heartbeat + # task in the class. + self._loop.create_task(self._writer.ping()) if self._pong_response_cb is not None: self._pong_response_cb.cancel() self._pong_response_cb = call_later( - self._pong_not_received, self._pong_heartbeat, self._loop) + self._pong_not_received, self._pong_heartbeat, self._loop + ) - def _pong_not_received(self): + def _pong_not_received(self) -> None: if not self._closed: self._closed = True self._close_code = 1006 @@ -71,61 +106,75 @@ def _pong_not_received(self): self._response.close() @property - def closed(self): + def closed(self) -> bool: return self._closed @property - def close_code(self): + def close_code(self) -> Optional[int]: return self._close_code @property - def protocol(self): + def protocol(self) -> Optional[str]: return self._protocol - def get_extra_info(self, name, default=None): + @property + def compress(self) -> int: + return self._compress + + @property + def client_notakeover(self) -> bool: + return self._client_notakeover + + def get_extra_info(self, name: str, default: Any = None) -> Any: """extra info from connection transport""" - try: - return self._response.connection.transport.get_extra_info( - name, default) - except: + conn = self._response.connection + if conn is None: return default + transport = conn.transport + if transport is None: + return default + return transport.get_extra_info(name, default) - def exception(self): + def exception(self) -> Optional[BaseException]: return self._exception - def ping(self, message='b'): - self._writer.ping(message) + async def ping(self, message: bytes = b"") -> None: + await self._writer.ping(message) - def pong(self, message='b'): - self._writer.pong(message) + async def pong(self, message: bytes = b"") -> None: + await self._writer.pong(message) - def send_str(self, data): + async def send_str(self, data: str, compress: Optional[int] = None) -> None: if not isinstance(data, str): - raise TypeError('data argument must be str (%r)' % type(data)) - return self._writer.send(data, binary=False) + raise TypeError("data argument must be str (%r)" % type(data)) + await self._writer.send(data, binary=False, compress=compress) - def send_bytes(self, data): + async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)' % - type(data)) - return self._writer.send(data, binary=True) - - def send_json(self, data, *, dumps=json.dumps): - return self.send_str(dumps(data)) - - @asyncio.coroutine - def close(self, *, code=1000, message=b''): + raise TypeError("data argument must be byte-ish (%r)" % type(data)) + await self._writer.send(data, binary=True, compress=compress) + + async def send_json( + self, + data: Any, + compress: Optional[int] = None, + *, + dumps: JSONEncoder = DEFAULT_JSON_ENCODER, + ) -> None: + await self.send_str(dumps(data), compress=compress) + + async def close(self, *, code: int = 1000, message: bytes = b"") -> bool: # we need to break `receive()` cycle first, # `close()` may be called from different task if self._waiting is not None and not self._closed: self._reader.feed_data(WS_CLOSING_MESSAGE, 0) - yield from self._waiting + await self._waiting if not self._closed: self._cancel_heartbeat() self._closed = True try: - self._writer.close(code, message) + await self._writer.close(code, message) except asyncio.CancelledError: self._close_code = 1006 self._response.close() @@ -142,8 +191,8 @@ def close(self, *, code=1000, message=b''): while True: try: - with Timeout(self._timeout, loop=self._loop): - msg = yield from self._reader.read() + with async_timeout.timeout(self._timeout, loop=self._loop): + msg = await self._reader.read() except asyncio.CancelledError: self._close_code = 1006 self._response.close() @@ -161,99 +210,92 @@ def close(self, *, code=1000, message=b''): else: return False - @asyncio.coroutine - def receive(self, timeout=None): + async def receive(self, timeout: Optional[float] = None) -> WSMessage: while True: if self._waiting is not None: - raise RuntimeError( - 'Concurrent call to receive() is not allowed') + raise RuntimeError("Concurrent call to receive() is not allowed") if self._closed: return WS_CLOSED_MESSAGE elif self._closing: - yield from self.close() + await self.close() return WS_CLOSED_MESSAGE try: - self._waiting = create_future(self._loop) + self._waiting = self._loop.create_future() try: - with Timeout( - timeout or self._receive_timeout, - loop=self._loop): - msg = yield from self._reader.read() + with async_timeout.timeout( + timeout or self._receive_timeout, loop=self._loop + ): + msg = await self._reader.read() self._reset_heartbeat() finally: waiter = self._waiting self._waiting = None - waiter.set_result(True) + set_result(waiter, True) except (asyncio.CancelledError, asyncio.TimeoutError): self._close_code = 1006 raise + except EofStream: + self._close_code = 1000 + await self.close() + return WSMessage(WSMsgType.CLOSED, None, None) except ClientError: self._closed = True self._close_code = 1006 return WS_CLOSED_MESSAGE except WebSocketError as exc: self._close_code = exc.code - yield from self.close(code=exc.code) + await self.close(code=exc.code) return WSMessage(WSMsgType.ERROR, exc, None) except Exception as exc: self._exception = exc self._closing = True self._close_code = 1006 - yield from self.close() + await self.close() return WSMessage(WSMsgType.ERROR, exc, None) if msg.type == WSMsgType.CLOSE: self._closing = True self._close_code = msg.data if not self._closed and self._autoclose: - yield from self.close() + await self.close() elif msg.type == WSMsgType.CLOSING: self._closing = True elif msg.type == WSMsgType.PING and self._autoping: - self.pong(msg.data) + await self.pong(msg.data) continue elif msg.type == WSMsgType.PONG and self._autoping: continue return msg - @asyncio.coroutine - def receive_str(self, *, timeout=None): - msg = yield from self.receive(timeout) + async def receive_str(self, *, timeout: Optional[float] = None) -> str: + msg = await self.receive(timeout) if msg.type != WSMsgType.TEXT: - raise TypeError( - "Received message {}:{!r} is not str".format(msg.type, - msg.data)) + raise TypeError(f"Received message {msg.type}:{msg.data!r} is not str") return msg.data - @asyncio.coroutine - def receive_bytes(self, *, timeout=None): - msg = yield from self.receive(timeout) + async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: + msg = await self.receive(timeout) if msg.type != WSMsgType.BINARY: - raise TypeError( - "Received message {}:{!r} is not bytes".format(msg.type, - msg.data)) + raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes") return msg.data - @asyncio.coroutine - def receive_json(self, *, loads=json.loads, timeout=None): - data = yield from self.receive_str(timeout=timeout) + async def receive_json( + self, + *, + loads: JSONDecoder = DEFAULT_JSON_DECODER, + timeout: Optional[float] = None, + ) -> Any: + data = await self.receive_str(timeout=timeout) return loads(data) - if PY_35: - def __aiter__(self): - return self + def __aiter__(self) -> "ClientWebSocketResponse": + return self - if not PY_352: # pragma: no cover - __aiter__ = asyncio.coroutine(__aiter__) - - @asyncio.coroutine - def __anext__(self): - msg = yield from self.receive() - if msg.type in (WSMsgType.CLOSE, - WSMsgType.CLOSING, - WSMsgType.CLOSED): - raise StopAsyncIteration # NOQA - return msg + async def __anext__(self) -> WSMessage: + msg = await self.receive() + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): + raise StopAsyncIteration + return msg diff --git a/aiohttp/connector.py b/aiohttp/connector.py index fb1b87d679e..748b22a4228 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -1,29 +1,90 @@ import asyncio import functools -import ssl +import random import sys import traceback import warnings -from collections import defaultdict -from hashlib import md5, sha1, sha256 -from types import MappingProxyType +from collections import defaultdict, deque +from contextlib import suppress +from http.cookies import SimpleCookie +from itertools import cycle, islice +from time import monotonic +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + DefaultDict, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) + +import attr from . import hdrs, helpers -from .client_exceptions import (ClientConnectorError, ClientHttpProxyError, - ClientProxyConnectionError, - ServerFingerprintMismatch) +from .abc import AbstractResolver +from .client_exceptions import ( + ClientConnectionError, + ClientConnectorCertificateError, + ClientConnectorError, + ClientConnectorSSLError, + ClientHttpProxyError, + ClientProxyConnectionError, + ServerFingerprintMismatch, + cert_errors, + ssl_errors, +) from .client_proto import ResponseHandler -from .client_reqrep import ClientRequest -from .helpers import SimpleCookie, is_ip_address, noop, sentinel +from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params +from .helpers import PY_36, CeilTimeout, get_running_loop, is_ip_address, noop, sentinel +from .http import RESPONSES +from .locks import EventResultOrError from .resolver import DefaultResolver -__all__ = ('BaseConnector', 'TCPConnector', 'UnixConnector') +try: + import ssl -HASHFUNC_BY_DIGESTLEN = { - 16: md5, - 20: sha1, - 32: sha256, -} + SSLContext = ssl.SSLContext +except ImportError: # pragma: no cover + ssl = None # type: ignore + SSLContext = object # type: ignore + + +__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector") + + +if TYPE_CHECKING: # pragma: no cover + from .client import ClientTimeout + from .client_reqrep import ConnectionKey + from .tracing import Trace + + +class _DeprecationWaiter: + __slots__ = ("_awaitable", "_awaited") + + def __init__(self, awaitable: Awaitable[Any]) -> None: + self._awaitable = awaitable + self._awaited = False + + def __await__(self) -> Any: + self._awaited = True + return self._awaitable.__await__() + + def __del__(self) -> None: + if not self._awaited: + warnings.warn( + "Connector.close() is a coroutine, " + "please use await connector.close()", + DeprecationWarning, + ) class Connection: @@ -31,101 +92,99 @@ class Connection: _source_traceback = None _transport = None - def __init__(self, connector, key, protocol, loop): + def __init__( + self, + connector: "BaseConnector", + key: "ConnectionKey", + protocol: ResponseHandler, + loop: asyncio.AbstractEventLoop, + ) -> None: self._key = key self._connector = connector self._loop = loop - self._protocol = protocol - self._callbacks = [] + self._protocol = protocol # type: Optional[ResponseHandler] + self._callbacks = [] # type: List[Callable[[], None]] if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) - def __repr__(self): - return 'Connection<{}>'.format(self._key) + def __repr__(self) -> str: + return f"Connection<{self._key}>" - def __del__(self, _warnings=warnings): + def __del__(self, _warnings: Any = warnings) -> None: if self._protocol is not None: - _warnings.warn('Unclosed connection {!r}'.format(self), - ResourceWarning) + if PY_36: + kwargs = {"source": self} + else: + kwargs = {} + _warnings.warn(f"Unclosed connection {self!r}", ResourceWarning, **kwargs) if self._loop.is_closed(): return - self._connector._release( - self._key, self._protocol, should_close=True) + self._connector._release(self._key, self._protocol, should_close=True) - context = {'client_connection': self, - 'message': 'Unclosed connection'} + context = {"client_connection": self, "message": "Unclosed connection"} if self._source_traceback is not None: - context['source_traceback'] = self._source_traceback + context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) @property - def loop(self): + def loop(self) -> asyncio.AbstractEventLoop: + warnings.warn( + "connector.loop property is deprecated", DeprecationWarning, stacklevel=2 + ) return self._loop @property - def transport(self): + def transport(self) -> Optional[asyncio.Transport]: + if self._protocol is None: + return None return self._protocol.transport @property - def protocol(self): + def protocol(self) -> Optional[ResponseHandler]: return self._protocol - @property - def writer(self): - return self._protocol.writer - - def add_callback(self, callback): + def add_callback(self, callback: Callable[[], None]) -> None: if callback is not None: self._callbacks.append(callback) - def _notify_release(self): + def _notify_release(self) -> None: callbacks, self._callbacks = self._callbacks[:], [] for cb in callbacks: - try: + with suppress(Exception): cb() - except: - pass - def close(self): + def close(self) -> None: self._notify_release() if self._protocol is not None: - self._connector._release( - self._key, self._protocol, should_close=True) + self._connector._release(self._key, self._protocol, should_close=True) self._protocol = None - def release(self): + def release(self) -> None: self._notify_release() if self._protocol is not None: self._connector._release( - self._key, self._protocol, - should_close=self._protocol.should_close) + self._key, self._protocol, should_close=self._protocol.should_close + ) self._protocol = None - def detach(self): - self._notify_release() - - if self._protocol is not None: - self._connector._release_acquired(self._protocol) - self._protocol = None - @property - def closed(self): + def closed(self) -> bool: return self._protocol is None or not self._protocol.is_connected() class _TransportPlaceholder: """ placeholder for BaseConnector.connect function """ - def close(self): + def close(self) -> None: pass -class BaseConnector(object): +class BaseConnector: """Base connector class. keepalive_timeout - (optional) Keep-alive timeout. @@ -133,7 +192,8 @@ class BaseConnector(object): after each request (and between redirects). limit - The total number of simultaneous connections. limit_per_host - Number of simultaneous connections to one host. - disable_cleanup_closed - Disable clean-up closed ssl transports. + enable_cleanup_closed - Enables clean-up closed ssl transports. + Disabled by default. loop - Optional event loop. """ @@ -143,39 +203,51 @@ class BaseConnector(object): # abort transport after 2 seconds (cleanup broken connections) _cleanup_closed_period = 2.0 - def __init__(self, *, keepalive_timeout=sentinel, - force_close=False, limit=100, limit_per_host=0, - enable_cleanup_closed=False, loop=None): + def __init__( + self, + *, + keepalive_timeout: Union[object, None, float] = sentinel, + force_close: bool = False, + limit: int = 100, + limit_per_host: int = 0, + enable_cleanup_closed: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: if force_close: - if keepalive_timeout is not None and \ - keepalive_timeout is not sentinel: - raise ValueError('keepalive_timeout cannot ' - 'be set if force_close is True') + if keepalive_timeout is not None and keepalive_timeout is not sentinel: + raise ValueError( + "keepalive_timeout cannot " "be set if force_close is True" + ) else: if keepalive_timeout is sentinel: keepalive_timeout = 15.0 - if loop is None: - loop = asyncio.get_event_loop() + loop = get_running_loop(loop) self._closed = False if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) - self._conns = {} + self._conns = ( + {} + ) # type: Dict[ConnectionKey, List[Tuple[ResponseHandler, float]]] self._limit = limit self._limit_per_host = limit_per_host - self._acquired = set() - self._acquired_per_host = defaultdict(set) - self._keepalive_timeout = keepalive_timeout + self._acquired = set() # type: Set[ResponseHandler] + self._acquired_per_host = defaultdict( + set + ) # type: DefaultDict[ConnectionKey, Set[ResponseHandler]] + self._keepalive_timeout = cast(float, keepalive_timeout) self._force_close = force_close - self._waiters = defaultdict(list) + + # {host_key: FIFO list of waiters} + self._waiters = defaultdict(deque) # type: ignore self._loop = loop self._factory = functools.partial(ResponseHandler, loop=loop) - self.cookies = SimpleCookie() + self.cookies = SimpleCookie() # type: SimpleCookie[str] # start keep-alive connection cleanup task self._cleanup_handle = None @@ -183,10 +255,10 @@ def __init__(self, *, keepalive_timeout=sentinel, # start cleanup closed transports task self._cleanup_closed_handle = None self._cleanup_closed_disabled = not enable_cleanup_closed - self._cleanup_closed_transports = [] + self._cleanup_closed_transports = [] # type: List[Optional[asyncio.Transport]] self._cleanup_closed() - def __del__(self, _warnings=warnings): + def __del__(self, _warnings: Any = warnings) -> None: if self._closed: return if not self._conns: @@ -194,30 +266,51 @@ def __del__(self, _warnings=warnings): conns = [repr(c) for c in self._conns.values()] - self.close() + self._close() - _warnings.warn("Unclosed connector {!r}".format(self), - ResourceWarning) - context = {'connector': self, - 'connections': conns, - 'message': 'Unclosed connector'} + if PY_36: + kwargs = {"source": self} + else: + kwargs = {} + _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, **kwargs) + context = { + "connector": self, + "connections": conns, + "message": "Unclosed connector", + } if self._source_traceback is not None: - context['source_traceback'] = self._source_traceback + context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) - def __enter__(self): + def __enter__(self) -> "BaseConnector": + warnings.warn( + '"witn Connector():" is deprecated, ' + 'use "async with Connector():" instead', + DeprecationWarning, + ) return self - def __exit__(self, *exc): + def __exit__(self, *exc: Any) -> None: self.close() + async def __aenter__(self) -> "BaseConnector": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + exc_traceback: Optional[TracebackType] = None, + ) -> None: + await self.close() + @property - def force_close(self): + def force_close(self) -> bool: """Ultimately close connection on releasing if True.""" return self._force_close @property - def limit(self): + def limit(self) -> int: """The total number for simultaneous connections. If limit is 0 the connector has no limit. @@ -226,7 +319,7 @@ def limit(self): return self._limit @property - def limit_per_host(self): + def limit_per_host(self) -> int: """The limit_per_host for simultaneous connections to the same endpoint. @@ -236,10 +329,13 @@ def limit_per_host(self): """ return self._limit_per_host - def _cleanup(self): + def _cleanup(self) -> None: """Cleanup unused transports.""" if self._cleanup_handle: self._cleanup_handle.cancel() + # _cleanup_handle should be unset, otherwise _release() will not + # recreate it ever! + self._cleanup_handle = None now = self._loop.time() timeout = self._keepalive_timeout @@ -252,12 +348,17 @@ def _cleanup(self): for proto, use_time in conns: if proto.is_connected(): if use_time - deadline < 0: - transport = proto.close() - if (key[-1] and not self._cleanup_closed_disabled): - self._cleanup_closed_transports.append( - transport) + transport = proto.transport + proto.close() + if key.is_ssl and not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transport) else: alive.append((proto, use_time)) + else: + transport = proto.transport + proto.close() + if key.is_ssl and not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transport) if alive: connections[key] = alive @@ -266,9 +367,21 @@ def _cleanup(self): if self._conns: self._cleanup_handle = helpers.weakref_handle( - self, '_cleanup', timeout, self._loop) + self, "_cleanup", timeout, self._loop + ) + + def _drop_acquired_per_host( + self, key: "ConnectionKey", val: ResponseHandler + ) -> None: + acquired_per_host = self._acquired_per_host + if key not in acquired_per_host: + return + conns = acquired_per_host[key] + conns.remove(val) + if not conns: + del self._acquired_per_host[key] - def _cleanup_closed(self): + def _cleanup_closed(self) -> None: """Double confirmation for transport close. Some broken ssl servers may leave socket open without proper close. """ @@ -283,11 +396,15 @@ def _cleanup_closed(self): if not self._cleanup_closed_disabled: self._cleanup_closed_handle = helpers.weakref_handle( - self, '_cleanup_closed', - self._cleanup_closed_period, self._loop) + self, "_cleanup_closed", self._cleanup_closed_period, self._loop + ) - def close(self): + def close(self) -> Awaitable[None]: """Close all opened transports.""" + self._close() + return _DeprecationWaiter(noop()) + + def _close(self) -> None: if self._closed: return @@ -295,13 +412,13 @@ def close(self): try: if self._loop.is_closed(): - return noop() + return - # cacnel cleanup task + # cancel cleanup task if self._cleanup_handle: self._cleanup_handle.cancel() - # cacnel cleanup close task + # cancel cleanup close task if self._cleanup_closed_handle: self._cleanup_closed_handle.cancel() @@ -325,68 +442,124 @@ def close(self): self._cleanup_closed_handle = None @property - def closed(self): + def closed(self) -> bool: """Is connector closed. A readonly property. """ return self._closed - @asyncio.coroutine - def connect(self, req): - """Get from pool or create new connection.""" - key = (req.host, req.port, req.ssl) + def _available_connections(self, key: "ConnectionKey") -> int: + """ + Return number of available connections taking into account + the limit, limit_per_host and the connection key. + + If it returns less than 1 means that there is no connections + availables. + """ if self._limit: # total calc available connections - available = self._limit - len(self._waiters) - len(self._acquired) + available = self._limit - len(self._acquired) # check limit per host - if (self._limit_per_host and available > 0 and - key in self._acquired_per_host): - available = self._limit_per_host - len( - self._acquired_per_host.get(key)) + if ( + self._limit_per_host + and available > 0 + and key in self._acquired_per_host + ): + acquired = self._acquired_per_host.get(key) + assert acquired is not None + available = self._limit_per_host - len(acquired) elif self._limit_per_host and key in self._acquired_per_host: # check limit per host - available = self._limit_per_host - len( - self._acquired_per_host.get(key)) + acquired = self._acquired_per_host.get(key) + assert acquired is not None + available = self._limit_per_host - len(acquired) else: available = 1 - # Wait if there are no available connections. - if available <= 0: - fut = helpers.create_future(self._loop) + return available + + async def connect( + self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + ) -> Connection: + """Get from pool or create new connection.""" + key = req.connection_key + available = self._available_connections(key) + + # Wait if there are no available connections or if there are/were + # waiters (i.e. don't steal connection from a waiter about to wake up) + if available <= 0 or key in self._waiters: + fut = self._loop.create_future() # This connection will now count towards the limit. - waiters = self._waiters[key] - waiters.append(fut) - yield from fut - waiters.remove(fut) - if not waiters: - del self._waiters[key] + self._waiters[key].append(fut) + + if traces: + for trace in traces: + await trace.send_connection_queued_start() + + try: + await fut + except BaseException as e: + if key in self._waiters: + # remove a waiter even if it was cancelled, normally it's + # removed when it's notified + try: + self._waiters[key].remove(fut) + except ValueError: # fut may no longer be in list + pass + + raise e + finally: + if key in self._waiters and not self._waiters[key]: + del self._waiters[key] + + if traces: + for trace in traces: + await trace.send_connection_queued_end() proto = self._get(key) if proto is None: - placeholder = _TransportPlaceholder() + placeholder = cast(ResponseHandler, _TransportPlaceholder()) self._acquired.add(placeholder) self._acquired_per_host[key].add(placeholder) + + if traces: + for trace in traces: + await trace.send_connection_create_start() + try: - proto = yield from self._create_connection(req) - except OSError as exc: - raise ClientConnectorError( - exc.errno, - 'Cannot connect to host {0[0]}:{0[1]} ssl:{0[2]} [{1}]' - .format(key, exc.strerror)) from exc - finally: - self._acquired.remove(placeholder) - self._acquired_per_host[key].remove(placeholder) + proto = await self._create_connection(req, traces, timeout) + if self._closed: + proto.close() + raise ClientConnectionError("Connector is closed.") + except BaseException: + if not self._closed: + self._acquired.remove(placeholder) + self._drop_acquired_per_host(key, placeholder) + self._release_waiter() + raise + else: + if not self._closed: + self._acquired.remove(placeholder) + self._drop_acquired_per_host(key, placeholder) + + if traces: + for trace in traces: + await trace.send_connection_create_end() + else: + if traces: + for trace in traces: + await trace.send_connection_reuseconn() self._acquired.add(proto) self._acquired_per_host[key].add(proto) return Connection(self, key, proto, self._loop) - def _get(self, key): + def _get(self, key: "ConnectionKey") -> Optional[ResponseHandler]: try: conns = self._conns[key] except KeyError: @@ -397,51 +570,58 @@ def _get(self, key): proto, t0 = conns.pop() if proto.is_connected(): if t1 - t0 > self._keepalive_timeout: - transport = proto.close() + transport = proto.transport + proto.close() # only for SSL transports - if key[-1] and not self._cleanup_closed_disabled: + if key.is_ssl and not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transport) else: if not conns: # The very last connection was reclaimed: drop the key del self._conns[key] return proto + else: + transport = proto.transport + proto.close() + if key.is_ssl and not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transport) # No more connections: drop the key del self._conns[key] return None - def _release_waiter(self): - # always release only one waiter + def _release_waiter(self) -> None: + """ + Iterates over all waiters till found one that is not finsihed and + belongs to a host that has available connections. + """ + if not self._waiters: + return - if self._limit: - # if we have limit and we have available - if self._limit - len(self._acquired) > 0: - for key, waiters in self._waiters.items(): - if waiters: - if not waiters[0].done(): - waiters[0].set_result(None) - break - - elif self._limit_per_host: - # if we have dont have limit but have limit per host - # then release first available - for key, waiters in self._waiters.items(): - if waiters: - if not waiters[0].done(): - waiters[0].set_result(None) - break - - def _release_acquired(self, key, proto): + # Having the dict keys ordered this avoids to iterate + # at the same order at each call. + queues = list(self._waiters.keys()) + random.shuffle(queues) + + for key in queues: + if self._available_connections(key) < 1: + continue + + waiters = self._waiters[key] + while waiters: + waiter = waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + return + + def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None: if self._closed: # acquired connection is already released on connector closing return try: self._acquired.remove(proto) - self._acquired_per_host[key].remove(proto) - if not self._acquired_per_host[key]: - del self._acquired_per_host[key] + self._drop_acquired_per_host(key, proto) except KeyError: # pragma: no cover # this may be result of undetermenistic order of objects # finalization due garbage collection. @@ -449,7 +629,13 @@ def _release_acquired(self, key, proto): else: self._release_waiter() - def _release(self, key, protocol, *, should_close=False): + def _release( + self, + key: "ConnectionKey", + protocol: ResponseHandler, + *, + should_close: bool = False, + ) -> None: if self._closed: # acquired connection is already released on connector closing return @@ -460,9 +646,10 @@ def _release(self, key, protocol, *, should_close=False): should_close = True if should_close or protocol.should_close: - transport = protocol.close() + transport = protocol.transport + protocol.close() - if key[-1] and not self._cleanup_closed_disabled: + if key.is_ssl and not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transport) else: conns = self._conns.get(key) @@ -472,29 +659,68 @@ def _release(self, key, protocol, *, should_close=False): if self._cleanup_handle is None: self._cleanup_handle = helpers.weakref_handle( - self, '_cleanup', self._keepalive_timeout, self._loop) + self, "_cleanup", self._keepalive_timeout, self._loop + ) - @asyncio.coroutine - def _create_connection(self, req): + async def _create_connection( + self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + ) -> ResponseHandler: raise NotImplementedError() -_SSL_OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0) +class _DNSCacheTable: + def __init__(self, ttl: Optional[float] = None) -> None: + self._addrs_rr = ( + {} + ) # type: Dict[Tuple[str, int], Tuple[Iterator[Dict[str, Any]], int]] + self._timestamps = {} # type: Dict[Tuple[str, int], float] + self._ttl = ttl + + def __contains__(self, host: object) -> bool: + return host in self._addrs_rr + + def add(self, key: Tuple[str, int], addrs: List[Dict[str, Any]]) -> None: + self._addrs_rr[key] = (cycle(addrs), len(addrs)) + + if self._ttl: + self._timestamps[key] = monotonic() + + def remove(self, key: Tuple[str, int]) -> None: + self._addrs_rr.pop(key, None) + + if self._ttl: + self._timestamps.pop(key, None) + + def clear(self) -> None: + self._addrs_rr.clear() + self._timestamps.clear() + + def next_addrs(self, key: Tuple[str, int]) -> List[Dict[str, Any]]: + loop, length = self._addrs_rr[key] + addrs = list(islice(loop, length)) + # Consume one more element to shift internal state of `cycle` + next(loop) + return addrs + + def expired(self, key: Tuple[str, int]) -> bool: + if self._ttl is None: + return False + + return self._timestamps[key] + self._ttl < monotonic() class TCPConnector(BaseConnector): """TCP connector. verify_ssl - Set to True to check ssl certifications. - fingerprint - Pass the binary md5, sha1, or sha256 + fingerprint - Pass the binary sha256 digest of the expected certificate in DER format to verify that the certificate the server presents matches. See also https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning - resolve - (Deprecated) Set to True to do DNS lookup for - host name. resolver - Enable DNS lookups and use this resolver use_dns_cache - Use memory cache for DNS lookups. + ttl_dns_cache - Max seconds having cached a DNS entry, None forever. family - socket address family local_addr - local tuple of (host, port) to bind socket to @@ -503,207 +729,365 @@ class TCPConnector(BaseConnector): after each request (and between redirects). limit - The total number of simultaneous connections. limit_per_host - Number of simultaneous connections to one host. + enable_cleanup_closed - Enables clean-up closed ssl transports. + Disabled by default. loop - Optional event loop. """ - def __init__(self, *, verify_ssl=True, fingerprint=None, - resolve=sentinel, use_dns_cache=True, - family=0, ssl_context=None, local_addr=None, - resolver=None, keepalive_timeout=sentinel, - force_close=False, limit=100, limit_per_host=0, - enable_cleanup_closed=False, loop=None): - super().__init__(keepalive_timeout=keepalive_timeout, - force_close=force_close, - limit=limit, limit_per_host=limit_per_host, - enable_cleanup_closed=enable_cleanup_closed, - loop=loop) - - if not verify_ssl and ssl_context is not None: - raise ValueError( - "Either disable ssl certificate validation by " - "verify_ssl=False or specify ssl_context, not both.") - - self._verify_ssl = verify_ssl - - if fingerprint: - digestlen = len(fingerprint) - hashfunc = HASHFUNC_BY_DIGESTLEN.get(digestlen) - if not hashfunc: - raise ValueError('fingerprint has invalid length') - elif hashfunc is md5 or hashfunc is sha1: - warnings.simplefilter('always') - warnings.warn('md5 and sha1 are insecure and deprecated. ' - 'Use sha256.', - DeprecationWarning, stacklevel=2) - self._hashfunc = hashfunc - self._fingerprint = fingerprint - + def __init__( + self, + *, + verify_ssl: bool = True, + fingerprint: Optional[bytes] = None, + use_dns_cache: bool = True, + ttl_dns_cache: Optional[int] = 10, + family: int = 0, + ssl_context: Optional[SSLContext] = None, + ssl: Union[None, bool, Fingerprint, SSLContext] = None, + local_addr: Optional[Tuple[str, int]] = None, + resolver: Optional[AbstractResolver] = None, + keepalive_timeout: Union[None, float, object] = sentinel, + force_close: bool = False, + limit: int = 100, + limit_per_host: int = 0, + enable_cleanup_closed: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + super().__init__( + keepalive_timeout=keepalive_timeout, + force_close=force_close, + limit=limit, + limit_per_host=limit_per_host, + enable_cleanup_closed=enable_cleanup_closed, + loop=loop, + ) + + self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) if resolver is None: resolver = DefaultResolver(loop=self._loop) self._resolver = resolver self._use_dns_cache = use_dns_cache - self._cached_hosts = {} - self._ssl_context = ssl_context + self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) + self._throttle_dns_events = ( + {} + ) # type: Dict[Tuple[str, int], EventResultOrError] self._family = family self._local_addr = local_addr - @property - def verify_ssl(self): - """Do check for ssl certifications?""" - return self._verify_ssl + def close(self) -> Awaitable[None]: + """Close all ongoing DNS calls.""" + for ev in self._throttle_dns_events.values(): + ev.cancel() - @property - def fingerprint(self): - """Expected ssl certificate fingerprint.""" - return self._fingerprint + return super().close() @property - def ssl_context(self): - """SSLContext instance for https requests. - - Lazy property, creates context on demand. - """ - if self._ssl_context is None: - if not self._verify_ssl: - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.options |= _SSL_OP_NO_COMPRESSION - sslcontext.set_default_verify_paths() - else: - sslcontext = ssl.create_default_context() - self._ssl_context = sslcontext - return self._ssl_context - - @property - def family(self): + def family(self) -> int: """Socket family like AF_INET.""" return self._family @property - def use_dns_cache(self): + def use_dns_cache(self) -> bool: """True if local DNS caching is enabled.""" return self._use_dns_cache - @property - def cached_hosts(self): - """Read-only dict of cached DNS record.""" - return MappingProxyType(self._cached_hosts) - - def clear_dns_cache(self, host=None, port=None): + def clear_dns_cache( + self, host: Optional[str] = None, port: Optional[int] = None + ) -> None: """Remove specified host/port or clear all dns local cache.""" if host is not None and port is not None: - self._cached_hosts.pop((host, port), None) + self._cached_hosts.remove((host, port)) elif host is not None or port is not None: - raise ValueError("either both host and port " - "or none of them are allowed") + raise ValueError("either both host and port " "or none of them are allowed") else: self._cached_hosts.clear() - @asyncio.coroutine - def _resolve_host(self, host, port): + async def _resolve_host( + self, host: str, port: int, traces: Optional[List["Trace"]] = None + ) -> List[Dict[str, Any]]: if is_ip_address(host): - return [{'hostname': host, 'host': host, 'port': port, - 'family': self._family, 'proto': 0, 'flags': 0}] + return [ + { + "hostname": host, + "host": host, + "port": port, + "family": self._family, + "proto": 0, + "flags": 0, + } + ] - if self._use_dns_cache: - key = (host, port) + if not self._use_dns_cache: - if key not in self._cached_hosts: - self._cached_hosts[key] = yield from \ - self._resolver.resolve(host, port, family=self._family) + if traces: + for trace in traces: + await trace.send_dns_resolvehost_start(host) + + res = await self._resolver.resolve(host, port, family=self._family) + + if traces: + for trace in traces: + await trace.send_dns_resolvehost_end(host) - return self._cached_hosts[key] - else: - res = yield from self._resolver.resolve( - host, port, family=self._family) return res - @asyncio.coroutine - def _create_connection(self, req): + key = (host, port) + + if (key in self._cached_hosts) and (not self._cached_hosts.expired(key)): + # get result early, before any await (#4014) + result = self._cached_hosts.next_addrs(key) + + if traces: + for trace in traces: + await trace.send_dns_cache_hit(host) + return result + + if key in self._throttle_dns_events: + # get event early, before any await (#4014) + event = self._throttle_dns_events[key] + if traces: + for trace in traces: + await trace.send_dns_cache_hit(host) + await event.wait() + else: + # update dict early, before any await (#4014) + self._throttle_dns_events[key] = EventResultOrError(self._loop) + if traces: + for trace in traces: + await trace.send_dns_cache_miss(host) + try: + + if traces: + for trace in traces: + await trace.send_dns_resolvehost_start(host) + + addrs = await self._resolver.resolve(host, port, family=self._family) + if traces: + for trace in traces: + await trace.send_dns_resolvehost_end(host) + + self._cached_hosts.add(key, addrs) + self._throttle_dns_events[key].set() + except BaseException as e: + # any DNS exception, independently of the implementation + # is set for the waiters to raise the same exception. + self._throttle_dns_events[key].set(exc=e) + raise + finally: + self._throttle_dns_events.pop(key) + + return self._cached_hosts.next_addrs(key) + + async def _create_connection( + self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + ) -> ResponseHandler: """Create connection. Has same keyword arguments as BaseEventLoop.create_connection. """ if req.proxy: - _, proto = yield from self._create_proxy_connection(req) + _, proto = await self._create_proxy_connection(req, traces, timeout) else: - _, proto = yield from self._create_direct_connection(req) + _, proto = await self._create_direct_connection(req, traces, timeout) return proto - @asyncio.coroutine - def _create_direct_connection(self, req): - if req.ssl: - sslcontext = self.ssl_context + @staticmethod + @functools.lru_cache(None) + def _make_ssl_context(verified: bool) -> SSLContext: + if verified: + return ssl.create_default_context() + else: + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + try: + sslcontext.options |= ssl.OP_NO_COMPRESSION + except AttributeError as attr_err: + warnings.warn( + "{!s}: The Python interpreter is compiled " + "against OpenSSL < 1.0.0. Ref: " + "https://docs.python.org/3/library/ssl.html" + "#ssl.OP_NO_COMPRESSION".format(attr_err), + ) + sslcontext.set_default_verify_paths() + return sslcontext + + def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]: + """Logic to get the correct SSL context + + 0. if req.ssl is false, return None + + 1. if ssl_context is specified in req, use it + 2. if _ssl_context is specified in self, use it + 3. otherwise: + 1. if verify_ssl is not specified in req, use self.ssl_context + (will generate a default context according to self.verify_ssl) + 2. if verify_ssl is True in req, generate a default SSL context + 3. if verify_ssl is False in req, generate a SSL context that + won't verify + """ + if req.is_ssl(): + if ssl is None: # pragma: no cover + raise RuntimeError("SSL is not supported.") + sslcontext = req.ssl + if isinstance(sslcontext, ssl.SSLContext): + return sslcontext + if sslcontext is not None: + # not verified or fingerprinted + return self._make_ssl_context(False) + sslcontext = self._ssl + if isinstance(sslcontext, ssl.SSLContext): + return sslcontext + if sslcontext is not None: + # not verified or fingerprinted + return self._make_ssl_context(False) + return self._make_ssl_context(True) else: - sslcontext = None + return None - hosts = yield from self._resolve_host(req.url.raw_host, req.port) - exc = None + def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]: + ret = req.ssl + if isinstance(ret, Fingerprint): + return ret + ret = self._ssl + if isinstance(ret, Fingerprint): + return ret + return None + + async def _wrap_create_connection( + self, + *args: Any, + req: "ClientRequest", + timeout: "ClientTimeout", + client_error: Type[Exception] = ClientConnectorError, + **kwargs: Any, + ) -> Tuple[asyncio.Transport, ResponseHandler]: + try: + with CeilTimeout(timeout.sock_connect): + return await self._loop.create_connection(*args, **kwargs) # type: ignore # noqa + except cert_errors as exc: + raise ClientConnectorCertificateError(req.connection_key, exc) from exc + except ssl_errors as exc: + raise ClientConnectorSSLError(req.connection_key, exc) from exc + except OSError as exc: + raise client_error(req.connection_key, exc) from exc + + async def _create_direct_connection( + self, + req: "ClientRequest", + traces: List["Trace"], + timeout: "ClientTimeout", + *, + client_error: Type[Exception] = ClientConnectorError, + ) -> Tuple[asyncio.Transport, ResponseHandler]: + sslcontext = self._get_ssl_context(req) + fingerprint = self._get_fingerprint(req) + + host = req.url.raw_host + assert host is not None + port = req.port + assert port is not None + host_resolved = asyncio.ensure_future( + self._resolve_host(host, port, traces=traces), loop=self._loop + ) + try: + # Cancelling this lookup should not cancel the underlying lookup + # or else the cancel event will get broadcast to all the waiters + # across all connections. + hosts = await asyncio.shield(host_resolved) + except asyncio.CancelledError: + + def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: + with suppress(Exception, asyncio.CancelledError): + fut.result() + + host_resolved.add_done_callback(drop_exception) + raise + except OSError as exc: + # in case of proxy it is not ClientProxyConnectionError + # it is problem of resolving proxy ip itself + raise ClientConnectorError(req.connection_key, exc) from exc + + last_exc = None # type: Optional[Exception] for hinfo in hosts: + host = hinfo["host"] + port = hinfo["port"] + try: - host = hinfo['host'] - port = hinfo['port'] - transp, proto = yield from self._loop.create_connection( - self._factory, host, port, - ssl=sslcontext, family=hinfo['family'], - proto=hinfo['proto'], flags=hinfo['flags'], - server_hostname=hinfo['hostname'] if sslcontext else None, - local_addr=self._local_addr) - has_cert = transp.get_extra_info('sslcontext') - if has_cert and self._fingerprint: - sock = transp.get_extra_info('socket') - if not hasattr(sock, 'getpeercert'): - # Workaround for asyncio 3.5.0 - # Starting from 3.5.1 version - # there is 'ssl_object' extra info in transport - sock = transp._ssl_protocol._sslpipe.ssl_object - # gives DER-encoded cert as a sequence of bytes (or None) - cert = sock.getpeercert(binary_form=True) - assert cert - got = self._hashfunc(cert).digest() - expected = self._fingerprint - if got != expected: - transp.close() - if not self._cleanup_closed_disabled: - self._cleanup_closed_transports.append(transp) - raise ServerFingerprintMismatch( - expected, got, host, port) - return transp, proto - except OSError as e: - exc = e + transp, proto = await self._wrap_create_connection( + self._factory, + host, + port, + timeout=timeout, + ssl=sslcontext, + family=hinfo["family"], + proto=hinfo["proto"], + flags=hinfo["flags"], + server_hostname=hinfo["hostname"] if sslcontext else None, + local_addr=self._local_addr, + req=req, + client_error=client_error, + ) + except ClientConnectorError as exc: + last_exc = exc + continue + + if req.is_ssl() and fingerprint: + try: + fingerprint.check(transp) + except ServerFingerprintMismatch as exc: + transp.close() + if not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transp) + last_exc = exc + continue + + return transp, proto else: - raise ClientConnectorError( - exc.errno, - 'Can not connect to %s:%s [%s]' % - (req.host, req.port, exc.strerror)) from exc - - @asyncio.coroutine - def _create_proxy_connection(self, req): + assert last_exc is not None + raise last_exc + + async def _create_proxy_connection( + self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + ) -> Tuple[asyncio.Transport, ResponseHandler]: + headers = {} # type: Dict[str, str] + if req.proxy_headers is not None: + headers = req.proxy_headers # type: ignore + headers[hdrs.HOST] = req.headers[hdrs.HOST] + + url = req.proxy + assert url is not None proxy_req = ClientRequest( - hdrs.METH_GET, req.proxy, - headers={hdrs.HOST: req.headers[hdrs.HOST]}, + hdrs.METH_GET, + url, + headers=headers, auth=req.proxy_auth, - loop=self._loop) - try: - # create connection to proxy server - transport, proto = yield from self._create_direct_connection( - proxy_req) - except OSError as exc: - raise ClientProxyConnectionError(*exc.args) from exc - - if hdrs.AUTHORIZATION in proxy_req.headers: - auth = proxy_req.headers[hdrs.AUTHORIZATION] - del proxy_req.headers[hdrs.AUTHORIZATION] - if not req.ssl: + loop=self._loop, + ssl=req.ssl, + ) + + # create connection to proxy server + transport, proto = await self._create_direct_connection( + proxy_req, [], timeout, client_error=ClientProxyConnectionError + ) + + # Many HTTP proxies has buggy keepalive support. Let's not + # reuse connection but close it after processing every + # response. + proto.force_close() + + auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) + if auth is not None: + if not req.is_ssl(): req.headers[hdrs.PROXY_AUTHORIZATION] = auth else: proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth - if req.ssl: + if req.is_ssl(): + sslcontext = self._get_ssl_context(req) # For HTTPS requests over HTTP proxy # we must notify proxy to tunnel connection # so we send CONNECT command: @@ -715,12 +1099,17 @@ def _create_proxy_connection(self, req): # asyncio handles this perfectly proxy_req.method = hdrs.METH_CONNECT proxy_req.url = req.url - key = (req.host, req.port, req.ssl) + key = attr.evolve( + req.connection_key, proxy=None, proxy_auth=None, proxy_headers_hash=None + ) conn = Connection(self, key, proto, self._loop) - proxy_resp = proxy_req.send(conn) + proxy_resp = await proxy_req.send(conn) try: - resp = yield from proxy_resp.start(conn, True) - except: + protocol = conn._protocol + assert protocol is not None + protocol.set_response_params() + resp = await proxy_resp.start(conn) + except BaseException: proxy_resp.close() conn.close() raise @@ -729,21 +1118,32 @@ def _create_proxy_connection(self, req): conn._transport = None try: if resp.status != 200: - raise ClientHttpProxyError(code=resp.status, - message=resp.reason, - headers=resp.headers) - rawsock = transport.get_extra_info('socket', default=None) + message = resp.reason + if message is None: + message = RESPONSES[resp.status][0] + raise ClientHttpProxyError( + proxy_resp.request_info, + resp.history, + status=resp.status, + message=message, + headers=resp.headers, + ) + rawsock = transport.get_extra_info("socket", default=None) if rawsock is None: - raise RuntimeError( - "Transport does not expose socket instance") + raise RuntimeError("Transport does not expose socket instance") # Duplicate the socket, so now we can close proxy transport rawsock = rawsock.dup() finally: transport.close() - transport, proto = yield from self._loop.create_connection( - self._factory, ssl=self.ssl_context, sock=rawsock, - server_hostname=req.host) + transport, proto = await self._wrap_create_connection( + self._factory, + timeout=timeout, + ssl=sslcontext, + sock=rawsock, + server_hostname=req.host, + req=req, + ) finally: proxy_resp.close() @@ -760,29 +1160,103 @@ class UnixConnector(BaseConnector): limit - The total number of simultaneous connections. limit_per_host - Number of simultaneous connections to one host. loop - Optional event loop. + """ - Usage: + def __init__( + self, + path: str, + force_close: bool = False, + keepalive_timeout: Union[object, float, None] = sentinel, + limit: int = 100, + limit_per_host: int = 0, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + super().__init__( + force_close=force_close, + keepalive_timeout=keepalive_timeout, + limit=limit, + limit_per_host=limit_per_host, + loop=loop, + ) + self._path = path - >>> conn = UnixConnector(path='/path/to/socket') - >>> session = ClientSession(connector=conn) - >>> resp = yield from session.get('http://python.org') + @property + def path(self) -> str: + """Path to unix socket.""" + return self._path + async def _create_connection( + self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + ) -> ResponseHandler: + try: + with CeilTimeout(timeout.sock_connect): + _, proto = await self._loop.create_unix_connection( + self._factory, self._path + ) + except OSError as exc: + raise ClientConnectorError(req.connection_key, exc) from exc + + return cast(ResponseHandler, proto) + + +class NamedPipeConnector(BaseConnector): + """Named pipe connector. + + Only supported by the proactor event loop. + See also: https://docs.python.org/3.7/library/asyncio-eventloop.html + + path - Windows named pipe path. + keepalive_timeout - (optional) Keep-alive timeout. + force_close - Set to True to force close and do reconnect + after each request (and between redirects). + limit - The total number of simultaneous connections. + limit_per_host - Number of simultaneous connections to one host. + loop - Optional event loop. """ - def __init__(self, path, force_close=False, keepalive_timeout=sentinel, - limit=100, limit_per_host=0, loop=None): - super().__init__(force_close=force_close, - keepalive_timeout=keepalive_timeout, - limit=limit, limit_per_host=limit_per_host, loop=loop) + def __init__( + self, + path: str, + force_close: bool = False, + keepalive_timeout: Union[object, float, None] = sentinel, + limit: int = 100, + limit_per_host: int = 0, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + super().__init__( + force_close=force_close, + keepalive_timeout=keepalive_timeout, + limit=limit, + limit_per_host=limit_per_host, + loop=loop, + ) + if not isinstance(self._loop, asyncio.ProactorEventLoop): # type: ignore + raise RuntimeError( + "Named Pipes only available in proactor " "loop under windows" + ) self._path = path @property - def path(self): - """Path to unix socket.""" + def path(self) -> str: + """Path to the named pipe.""" return self._path - @asyncio.coroutine - def _create_connection(self, req): - _, proto = yield from self._loop.create_unix_connection( - self._factory, self._path) - return proto + async def _create_connection( + self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + ) -> ResponseHandler: + try: + with CeilTimeout(timeout.sock_connect): + _, proto = await self._loop.create_pipe_connection( # type: ignore + self._factory, self._path + ) + # the drain is required so that the connection_made is called + # and transport is set otherwise it is not set before the + # `assert conn.transport is not None` + # in client.py's _request method + await asyncio.sleep(0) + # other option is to manually set transport like + # `proto.transport = trans` + except OSError as exc: + raise ClientConnectorError(req.connection_key, exc) from exc + + return cast(ResponseHandler, proto) diff --git a/aiohttp/cookiejar.py b/aiohttp/cookiejar.py index 27c2baede94..b6b59d62894 100644 --- a/aiohttp/cookiejar.py +++ b/aiohttp/cookiejar.py @@ -1,15 +1,34 @@ +import asyncio import datetime +import os # noqa import pathlib import pickle import re from collections import defaultdict -from collections.abc import Mapping -from http.cookies import Morsel -from math import ceil +from http.cookies import BaseCookie, Morsel, SimpleCookie +from typing import ( # noqa + DefaultDict, + Dict, + Iterable, + Iterator, + Mapping, + Optional, + Set, + Tuple, + Union, + cast, +) + from yarl import URL from .abc import AbstractCookieJar -from .helpers import SimpleCookie, is_ip_address +from .helpers import is_ip_address, next_whole_second +from .typedefs import LooseCookies, PathLike + +__all__ = ("CookieJar", "DummyCookieJar") + + +CookieItem = Union[str, "Morsel[str]"] class CookieJar(AbstractCookieJar): @@ -17,63 +36,83 @@ class CookieJar(AbstractCookieJar): DATE_TOKENS_RE = re.compile( r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*" - r"(?P[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)") + r"(?P[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)" + ) DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})") DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})") - DATE_MONTH_RE = re.compile("(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|" - "(aug)|(sep)|(oct)|(nov)|(dec)", re.I) + DATE_MONTH_RE = re.compile( + "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|" "(aug)|(sep)|(oct)|(nov)|(dec)", + re.I, + ) DATE_YEAR_RE = re.compile(r"(\d{2,4})") - MAX_TIME = 2051215261.0 # so far in future (2035-01-01) + MAX_TIME = datetime.datetime.max.replace(tzinfo=datetime.timezone.utc) + + MAX_32BIT_TIME = datetime.datetime.utcfromtimestamp(2 ** 31 - 1) - def __init__(self, *, unsafe=False, loop=None): + def __init__( + self, + *, + unsafe: bool = False, + quote_cookie: bool = True, + loop: Optional[asyncio.AbstractEventLoop] = None + ) -> None: super().__init__(loop=loop) - self._cookies = defaultdict(SimpleCookie) - self._host_only_cookies = set() + self._cookies = defaultdict( + SimpleCookie + ) # type: DefaultDict[str, SimpleCookie[str]] + self._host_only_cookies = set() # type: Set[Tuple[str, str]] self._unsafe = unsafe - self._next_expiration = ceil(self._loop.time()) - self._expirations = {} - - def save(self, file_path): + self._quote_cookie = quote_cookie + self._next_expiration = next_whole_second() + self._expirations = {} # type: Dict[Tuple[str, str], datetime.datetime] + # #4515: datetime.max may not be representable on 32-bit platforms + self._max_time = self.MAX_TIME + try: + self._max_time.timestamp() + except OverflowError: + self._max_time = self.MAX_32BIT_TIME + + def save(self, file_path: PathLike) -> None: file_path = pathlib.Path(file_path) - with file_path.open(mode='wb') as f: + with file_path.open(mode="wb") as f: pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL) - def load(self, file_path): + def load(self, file_path: PathLike) -> None: file_path = pathlib.Path(file_path) - with file_path.open(mode='rb') as f: + with file_path.open(mode="rb") as f: self._cookies = pickle.load(f) - def clear(self): + def clear(self) -> None: self._cookies.clear() self._host_only_cookies.clear() - self._next_expiration = ceil(self._loop.time()) + self._next_expiration = next_whole_second() self._expirations.clear() - def __iter__(self): + def __iter__(self) -> "Iterator[Morsel[str]]": self._do_expiration() for val in self._cookies.values(): yield from val.values() - def __len__(self): + def __len__(self) -> int: return sum(1 for i in self) - def _do_expiration(self): - now = self._loop.time() + def _do_expiration(self) -> None: + now = datetime.datetime.now(datetime.timezone.utc) if self._next_expiration > now: return if not self._expirations: return - next_expiration = self.MAX_TIME + next_expiration = self._max_time to_del = [] cookies = self._cookies expirations = self._expirations for (domain, name), when in expirations.items(): - if when < now: + if when <= now: cookies[domain].pop(name, None) to_del.append((domain, name)) self._host_only_cookies.discard((domain, name)) @@ -82,13 +121,18 @@ def _do_expiration(self): for key in to_del: del expirations[key] - self._next_expiration = ceil(next_expiration) + try: + self._next_expiration = next_expiration.replace( + microsecond=0 + ) + datetime.timedelta(seconds=1) + except OverflowError: + self._next_expiration = self._max_time - def _expire_cookie(self, when, domain, name): + def _expire_cookie(self, when: datetime.datetime, domain: str, name: str) -> None: self._next_expiration = min(self._next_expiration, when) self._expirations[(domain, name)] = when - def update_cookies(self, cookies, response_url=URL()): + def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: """Update cookies.""" hostname = response_url.raw_host @@ -101,14 +145,14 @@ def update_cookies(self, cookies, response_url=URL()): for name, cookie in cookies: if not isinstance(cookie, Morsel): - tmp = SimpleCookie() - tmp[name] = cookie + tmp = SimpleCookie() # type: SimpleCookie[str] + tmp[name] = cookie # type: ignore cookie = tmp[name] domain = cookie["domain"] # ignore domains with trailing dots - if domain.endswith('.'): + if domain.endswith("."): domain = "" del cookie["domain"] @@ -135,15 +179,20 @@ def update_cookies(self, cookies, response_url=URL()): path = "/" else: # Cut everything from the last slash to the end - path = "/" + path[1:path.rfind("/")] + path = "/" + path[1 : path.rfind("/")] cookie["path"] = path max_age = cookie["max-age"] if max_age: try: delta_seconds = int(max_age) - self._expire_cookie(self._loop.time() + delta_seconds, - domain, name) + try: + max_age_expiration = datetime.datetime.now( + datetime.timezone.utc + ) + datetime.timedelta(seconds=delta_seconds) + except OverflowError: + max_age_expiration = self._max_time + self._expire_cookie(max_age_expiration, domain, name) except ValueError: cookie["max-age"] = "" @@ -152,22 +201,23 @@ def update_cookies(self, cookies, response_url=URL()): if expires: expire_time = self._parse_date(expires) if expire_time: - self._expire_cookie(expire_time.timestamp(), - domain, name) + self._expire_cookie(expire_time, domain, name) else: cookie["expires"] = "" - # use dict method because SimpleCookie class modifies value - # before Python 3.4.3 - dict.__setitem__(self._cookies[domain], name, cookie) + self._cookies[domain][name] = cookie self._do_expiration() - def filter_cookies(self, request_url=URL()): + def filter_cookies( + self, request_url: URL = URL() + ) -> Union["BaseCookie[str]", "SimpleCookie[str]"]: """Returns this jar's cookies filtered by their attributes.""" self._do_expiration() request_url = URL(request_url) - filtered = SimpleCookie() + filtered: Union["SimpleCookie[str]", "BaseCookie[str]"] = ( + SimpleCookie() if self._quote_cookie else BaseCookie() + ) hostname = request_url.raw_host or "" is_not_secure = request_url.scheme not in ("https", "wss") @@ -197,14 +247,14 @@ def filter_cookies(self, request_url=URL()): # It's critical we use the Morsel so the coded_value # (based on cookie version) is preserved - mrsl_val = cookie.get(cookie.key, Morsel()) + mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) filtered[name] = mrsl_val return filtered @staticmethod - def _is_domain_match(domain, hostname): + def _is_domain_match(domain: str, hostname: str) -> bool: """Implements domain matching adhering to RFC 6265.""" if hostname == domain: return True @@ -212,7 +262,7 @@ def _is_domain_match(domain, hostname): if not hostname.endswith(domain): return False - non_matching = hostname[:-len(domain)] + non_matching = hostname[: -len(domain)] if not non_matching.endswith("."): return False @@ -220,7 +270,7 @@ def _is_domain_match(domain, hostname): return not is_ip_address(hostname) @staticmethod - def _is_path_match(req_path, cookie_path): + def _is_path_match(req_path: str, cookie_path: str) -> bool: """Implements path matching adhering to RFC 6265.""" if not req_path.startswith("/"): req_path = "/" @@ -234,15 +284,15 @@ def _is_path_match(req_path, cookie_path): if cookie_path.endswith("/"): return True - non_matching = req_path[len(cookie_path):] + non_matching = req_path[len(cookie_path) :] return non_matching.startswith("/") @classmethod - def _parse_date(cls, date_str): + def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]: """Implements date string parsing adhering to RFC 6265.""" if not date_str: - return + return None found_time = False found_day = False @@ -262,8 +312,7 @@ def _parse_date(cls, date_str): time_match = cls.DATE_HMS_TIME_RE.match(token) if time_match: found_time = True - hour, minute, second = [ - int(s) for s in time_match.groups()] + hour, minute, second = [int(s) for s in time_match.groups()] continue if not found_day: @@ -277,6 +326,7 @@ def _parse_date(cls, date_str): month_match = cls.DATE_MONTH_RE.match(token) if month_match: found_month = True + assert month_match.lastindex is not None month = month_match.lastindex continue @@ -292,14 +342,41 @@ def _parse_date(cls, date_str): year += 2000 if False in (found_day, found_month, found_year, found_time): - return + return None if not 1 <= day <= 31: - return + return None if year < 1601 or hour > 23 or minute > 59 or second > 59: - return + return None + + return datetime.datetime( + year, month, day, hour, minute, second, tzinfo=datetime.timezone.utc + ) + + +class DummyCookieJar(AbstractCookieJar): + """Implements a dummy cookie storage. + + It can be used with the ClientSession when no cookie processing is needed. + + """ + + def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + super().__init__(loop=loop) + + def __iter__(self) -> "Iterator[Morsel[str]]": + while False: + yield None + + def __len__(self) -> int: + return 0 + + def clear(self) -> None: + pass + + def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: + pass - return datetime.datetime(year, month, day, - hour, minute, second, - tzinfo=datetime.timezone.utc) + def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": + return SimpleCookie() diff --git a/aiohttp/formdata.py b/aiohttp/formdata.py index b3a845cd873..900716b72a6 100644 --- a/aiohttp/formdata.py +++ b/aiohttp/formdata.py @@ -1,22 +1,30 @@ import io +from typing import Any, Iterable, List, Optional from urllib.parse import urlencode from multidict import MultiDict, MultiDictProxy from . import hdrs, multipart, payload from .helpers import guess_filename +from .payload import Payload -__all__ = ('FormData',) +__all__ = ("FormData",) class FormData: """Helper class for multipart/form-data and application/x-www-form-urlencoded body generation.""" - def __init__(self, fields=(), quote_fields=True, charset=None): - self._writer = multipart.MultipartWriter('form-data') - self._fields = [] + def __init__( + self, + fields: Iterable[Any] = (), + quote_fields: bool = True, + charset: Optional[str] = None, + ) -> None: + self._writer = multipart.MultipartWriter("form-data") + self._fields = [] # type: List[Any] self._is_multipart = False + self._is_processed = False self._quote_fields = quote_fields self._charset = charset @@ -27,11 +35,18 @@ def __init__(self, fields=(), quote_fields=True, charset=None): self.add_fields(*fields) @property - def is_multipart(self): + def is_multipart(self) -> bool: return self._is_multipart - def add_field(self, name, value, *, content_type=None, filename=None, - content_transfer_encoding=None): + def add_field( + self, + name: str, + value: Any, + *, + content_type: Optional[str] = None, + filename: Optional[str] = None, + content_transfer_encoding: Optional[str] = None + ) -> None: if isinstance(value, io.IOBase): self._is_multipart = True @@ -39,103 +54,116 @@ def add_field(self, name, value, *, content_type=None, filename=None, if filename is None and content_transfer_encoding is None: filename = name - type_options = MultiDict({'name': name}) + type_options = MultiDict({"name": name}) # type: MultiDict[str] if filename is not None and not isinstance(filename, str): - raise TypeError('filename must be an instance of str. ' - 'Got: %s' % filename) + raise TypeError( + "filename must be an instance of str. " "Got: %s" % filename + ) if filename is None and isinstance(value, io.IOBase): filename = guess_filename(value, name) if filename is not None: - type_options['filename'] = filename + type_options["filename"] = filename self._is_multipart = True headers = {} if content_type is not None: if not isinstance(content_type, str): - raise TypeError('content_type must be an instance of str. ' - 'Got: %s' % content_type) + raise TypeError( + "content_type must be an instance of str. " "Got: %s" % content_type + ) headers[hdrs.CONTENT_TYPE] = content_type self._is_multipart = True if content_transfer_encoding is not None: if not isinstance(content_transfer_encoding, str): - raise TypeError('content_transfer_encoding must be an instance' - ' of str. Got: %s' % content_transfer_encoding) + raise TypeError( + "content_transfer_encoding must be an instance" + " of str. Got: %s" % content_transfer_encoding + ) headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding self._is_multipart = True self._fields.append((type_options, headers, value)) - def add_fields(self, *fields): + def add_fields(self, *fields: Any) -> None: to_add = list(fields) while to_add: rec = to_add.pop(0) if isinstance(rec, io.IOBase): - k = guess_filename(rec, 'unknown') - self.add_field(k, rec) + k = guess_filename(rec, "unknown") + self.add_field(k, rec) # type: ignore elif isinstance(rec, (MultiDictProxy, MultiDict)): to_add.extend(rec.items()) elif isinstance(rec, (list, tuple)) and len(rec) == 2: k, fp = rec - self.add_field(k, fp) + self.add_field(k, fp) # type: ignore else: - raise TypeError('Only io.IOBase, multidict and (name, file) ' - 'pairs allowed, use .add_field() for passing ' - 'more complex parameters, got {!r}' - .format(rec)) + raise TypeError( + "Only io.IOBase, multidict and (name, file) " + "pairs allowed, use .add_field() for passing " + "more complex parameters, got {!r}".format(rec) + ) - def _gen_form_urlencoded(self): + def _gen_form_urlencoded(self) -> payload.BytesPayload: # form data (x-www-form-urlencoded) data = [] for type_options, _, value in self._fields: - data.append((type_options['name'], value)) + data.append((type_options["name"], value)) - charset = self._charset if self._charset is not None else 'utf-8' + charset = self._charset if self._charset is not None else "utf-8" - if charset == 'utf-8': - content_type = 'application/x-www-form-urlencoded' + if charset == "utf-8": + content_type = "application/x-www-form-urlencoded" else: - content_type = ('application/x-www-form-urlencoded; ' - 'charset=%s' % charset) + content_type = "application/x-www-form-urlencoded; " "charset=%s" % charset return payload.BytesPayload( urlencode(data, doseq=True, encoding=charset).encode(), - content_type=content_type) + content_type=content_type, + ) - def _gen_form_data(self): + def _gen_form_data(self) -> multipart.MultipartWriter: """Encode a list of fields using the multipart/form-data MIME format""" + if self._is_processed: + raise RuntimeError("Form data has been processed already") for dispparams, headers, value in self._fields: try: if hdrs.CONTENT_TYPE in headers: part = payload.get_payload( - value, content_type=headers[hdrs.CONTENT_TYPE], - headers=headers, encoding=self._charset) + value, + content_type=headers[hdrs.CONTENT_TYPE], + headers=headers, + encoding=self._charset, + ) else: part = payload.get_payload( - value, headers=headers, encoding=self._charset) + value, headers=headers, encoding=self._charset + ) except Exception as exc: raise TypeError( - 'Can not serialize value type: %r\n ' - 'headers: %r\n value: %r' % ( - type(value), headers, value)) from exc + "Can not serialize value type: %r\n " + "headers: %r\n value: %r" % (type(value), headers, value) + ) from exc if dispparams: part.set_content_disposition( - 'form-data', quote_fields=self._quote_fields, **dispparams + "form-data", quote_fields=self._quote_fields, **dispparams ) # FIXME cgi.FieldStorage doesn't likes body parts with # Content-Length which were sent via chunked transfer encoding - part.headers.pop(hdrs.CONTENT_LENGTH, None) + assert part.headers is not None + part.headers.popall(hdrs.CONTENT_LENGTH, None) self._writer.append_payload(part) + self._is_processed = True return self._writer - def __call__(self): + def __call__(self) -> Payload: if self._is_multipart: return self._gen_form_data() else: diff --git a/aiohttp/frozenlist.py b/aiohttp/frozenlist.py new file mode 100644 index 00000000000..46b26108cfa --- /dev/null +++ b/aiohttp/frozenlist.py @@ -0,0 +1,72 @@ +from collections.abc import MutableSequence +from functools import total_ordering + +from .helpers import NO_EXTENSIONS + + +@total_ordering +class FrozenList(MutableSequence): + + __slots__ = ("_frozen", "_items") + + def __init__(self, items=None): + self._frozen = False + if items is not None: + items = list(items) + else: + items = [] + self._items = items + + @property + def frozen(self): + return self._frozen + + def freeze(self): + self._frozen = True + + def __getitem__(self, index): + return self._items[index] + + def __setitem__(self, index, value): + if self._frozen: + raise RuntimeError("Cannot modify frozen list.") + self._items[index] = value + + def __delitem__(self, index): + if self._frozen: + raise RuntimeError("Cannot modify frozen list.") + del self._items[index] + + def __len__(self): + return self._items.__len__() + + def __iter__(self): + return self._items.__iter__() + + def __reversed__(self): + return self._items.__reversed__() + + def __eq__(self, other): + return list(self) == other + + def __le__(self, other): + return list(self) <= other + + def insert(self, pos, item): + if self._frozen: + raise RuntimeError("Cannot modify frozen list.") + self._items.insert(pos, item) + + def __repr__(self): + return f"" + + +PyFrozenList = FrozenList + +try: + from aiohttp._frozenlist import FrozenList as CFrozenList # type: ignore + + if not NO_EXTENSIONS: + FrozenList = CFrozenList # type: ignore +except ImportError: # pragma: no cover + pass diff --git a/aiohttp/frozenlist.pyi b/aiohttp/frozenlist.pyi new file mode 100644 index 00000000000..72ab086715b --- /dev/null +++ b/aiohttp/frozenlist.pyi @@ -0,0 +1,46 @@ +from typing import ( + Generic, + Iterable, + Iterator, + List, + MutableSequence, + Optional, + TypeVar, + Union, + overload, +) + +_T = TypeVar("_T") +_Arg = Union[List[_T], Iterable[_T]] + +class FrozenList(MutableSequence[_T], Generic[_T]): + def __init__(self, items: Optional[_Arg[_T]] = ...) -> None: ... + @property + def frozen(self) -> bool: ... + def freeze(self) -> None: ... + @overload + def __getitem__(self, i: int) -> _T: ... + @overload + def __getitem__(self, s: slice) -> FrozenList[_T]: ... + @overload + def __setitem__(self, i: int, o: _T) -> None: ... + @overload + def __setitem__(self, s: slice, o: Iterable[_T]) -> None: ... + @overload + def __delitem__(self, i: int) -> None: ... + @overload + def __delitem__(self, i: slice) -> None: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[_T]: ... + def __reversed__(self) -> Iterator[_T]: ... + def __eq__(self, other: object) -> bool: ... + def __le__(self, other: FrozenList[_T]) -> bool: ... + def __ne__(self, other: object) -> bool: ... + def __lt__(self, other: FrozenList[_T]) -> bool: ... + def __ge__(self, other: FrozenList[_T]) -> bool: ... + def __gt__(self, other: FrozenList[_T]) -> bool: ... + def insert(self, pos: int, item: _T) -> None: ... + def __repr__(self) -> str: ... + +# types for C accelerators are the same +CFrozenList = PyFrozenList = FrozenList diff --git a/aiohttp/hdrs.py b/aiohttp/hdrs.py index f994319e48d..f04a5457f9f 100644 --- a/aiohttp/hdrs.py +++ b/aiohttp/hdrs.py @@ -1,91 +1,108 @@ """HTTP Headers constants.""" + +# After changing the file content call ./tools/gen.py +# to regenerate the headers parser + from multidict import istr -METH_ANY = '*' -METH_CONNECT = 'CONNECT' -METH_HEAD = 'HEAD' -METH_GET = 'GET' -METH_DELETE = 'DELETE' -METH_OPTIONS = 'OPTIONS' -METH_PATCH = 'PATCH' -METH_POST = 'POST' -METH_PUT = 'PUT' -METH_TRACE = 'TRACE' +METH_ANY = "*" +METH_CONNECT = "CONNECT" +METH_HEAD = "HEAD" +METH_GET = "GET" +METH_DELETE = "DELETE" +METH_OPTIONS = "OPTIONS" +METH_PATCH = "PATCH" +METH_POST = "POST" +METH_PUT = "PUT" +METH_TRACE = "TRACE" -METH_ALL = {METH_CONNECT, METH_HEAD, METH_GET, METH_DELETE, - METH_OPTIONS, METH_PATCH, METH_POST, METH_PUT, METH_TRACE} +METH_ALL = { + METH_CONNECT, + METH_HEAD, + METH_GET, + METH_DELETE, + METH_OPTIONS, + METH_PATCH, + METH_POST, + METH_PUT, + METH_TRACE, +} -ACCEPT = istr('ACCEPT') -ACCEPT_CHARSET = istr('ACCEPT-CHARSET') -ACCEPT_ENCODING = istr('ACCEPT-ENCODING') -ACCEPT_LANGUAGE = istr('ACCEPT-LANGUAGE') -ACCEPT_RANGES = istr('ACCEPT-RANGES') -ACCESS_CONTROL_MAX_AGE = istr('ACCESS-CONTROL-MAX-AGE') -ACCESS_CONTROL_ALLOW_CREDENTIALS = istr('ACCESS-CONTROL-ALLOW-CREDENTIALS') -ACCESS_CONTROL_ALLOW_HEADERS = istr('ACCESS-CONTROL-ALLOW-HEADERS') -ACCESS_CONTROL_ALLOW_METHODS = istr('ACCESS-CONTROL-ALLOW-METHODS') -ACCESS_CONTROL_ALLOW_ORIGIN = istr('ACCESS-CONTROL-ALLOW-ORIGIN') -ACCESS_CONTROL_EXPOSE_HEADERS = istr('ACCESS-CONTROL-EXPOSE-HEADERS') -ACCESS_CONTROL_REQUEST_HEADERS = istr('ACCESS-CONTROL-REQUEST-HEADERS') -ACCESS_CONTROL_REQUEST_METHOD = istr('ACCESS-CONTROL-REQUEST-METHOD') -AGE = istr('AGE') -ALLOW = istr('ALLOW') -AUTHORIZATION = istr('AUTHORIZATION') -CACHE_CONTROL = istr('CACHE-CONTROL') -CONNECTION = istr('CONNECTION') -CONTENT_DISPOSITION = istr('CONTENT-DISPOSITION') -CONTENT_ENCODING = istr('CONTENT-ENCODING') -CONTENT_LANGUAGE = istr('CONTENT-LANGUAGE') -CONTENT_LENGTH = istr('CONTENT-LENGTH') -CONTENT_LOCATION = istr('CONTENT-LOCATION') -CONTENT_MD5 = istr('CONTENT-MD5') -CONTENT_RANGE = istr('CONTENT-RANGE') -CONTENT_TRANSFER_ENCODING = istr('CONTENT-TRANSFER-ENCODING') -CONTENT_TYPE = istr('CONTENT-TYPE') -COOKIE = istr('COOKIE') -DATE = istr('DATE') -DESTINATION = istr('DESTINATION') -DIGEST = istr('DIGEST') -ETAG = istr('ETAG') -EXPECT = istr('EXPECT') -EXPIRES = istr('EXPIRES') -FROM = istr('FROM') -HOST = istr('HOST') -IF_MATCH = istr('IF-MATCH') -IF_MODIFIED_SINCE = istr('IF-MODIFIED-SINCE') -IF_NONE_MATCH = istr('IF-NONE-MATCH') -IF_RANGE = istr('IF-RANGE') -IF_UNMODIFIED_SINCE = istr('IF-UNMODIFIED-SINCE') -KEEP_ALIVE = istr('KEEP-ALIVE') -LAST_EVENT_ID = istr('LAST-EVENT-ID') -LAST_MODIFIED = istr('LAST-MODIFIED') -LINK = istr('LINK') -LOCATION = istr('LOCATION') -MAX_FORWARDS = istr('MAX-FORWARDS') -ORIGIN = istr('ORIGIN') -PRAGMA = istr('PRAGMA') -PROXY_AUTHENTICATE = istr('PROXY_AUTHENTICATE') -PROXY_AUTHORIZATION = istr('PROXY-AUTHORIZATION') -RANGE = istr('RANGE') -REFERER = istr('REFERER') -RETRY_AFTER = istr('RETRY-AFTER') -SEC_WEBSOCKET_ACCEPT = istr('SEC-WEBSOCKET-ACCEPT') -SEC_WEBSOCKET_VERSION = istr('SEC-WEBSOCKET-VERSION') -SEC_WEBSOCKET_PROTOCOL = istr('SEC-WEBSOCKET-PROTOCOL') -SEC_WEBSOCKET_KEY = istr('SEC-WEBSOCKET-KEY') -SEC_WEBSOCKET_KEY1 = istr('SEC-WEBSOCKET-KEY1') -SERVER = istr('SERVER') -SET_COOKIE = istr('SET-COOKIE') -TE = istr('TE') -TRAILER = istr('TRAILER') -TRANSFER_ENCODING = istr('TRANSFER-ENCODING') -UPGRADE = istr('UPGRADE') -WEBSOCKET = istr('WEBSOCKET') -URI = istr('URI') -USER_AGENT = istr('USER-AGENT') -VARY = istr('VARY') -VIA = istr('VIA') -WANT_DIGEST = istr('WANT-DIGEST') -WARNING = istr('WARNING') -WWW_AUTHENTICATE = istr('WWW-AUTHENTICATE') +ACCEPT = istr("Accept") +ACCEPT_CHARSET = istr("Accept-Charset") +ACCEPT_ENCODING = istr("Accept-Encoding") +ACCEPT_LANGUAGE = istr("Accept-Language") +ACCEPT_RANGES = istr("Accept-Ranges") +ACCESS_CONTROL_MAX_AGE = istr("Access-Control-Max-Age") +ACCESS_CONTROL_ALLOW_CREDENTIALS = istr("Access-Control-Allow-Credentials") +ACCESS_CONTROL_ALLOW_HEADERS = istr("Access-Control-Allow-Headers") +ACCESS_CONTROL_ALLOW_METHODS = istr("Access-Control-Allow-Methods") +ACCESS_CONTROL_ALLOW_ORIGIN = istr("Access-Control-Allow-Origin") +ACCESS_CONTROL_EXPOSE_HEADERS = istr("Access-Control-Expose-Headers") +ACCESS_CONTROL_REQUEST_HEADERS = istr("Access-Control-Request-Headers") +ACCESS_CONTROL_REQUEST_METHOD = istr("Access-Control-Request-Method") +AGE = istr("Age") +ALLOW = istr("Allow") +AUTHORIZATION = istr("Authorization") +CACHE_CONTROL = istr("Cache-Control") +CONNECTION = istr("Connection") +CONTENT_DISPOSITION = istr("Content-Disposition") +CONTENT_ENCODING = istr("Content-Encoding") +CONTENT_LANGUAGE = istr("Content-Language") +CONTENT_LENGTH = istr("Content-Length") +CONTENT_LOCATION = istr("Content-Location") +CONTENT_MD5 = istr("Content-MD5") +CONTENT_RANGE = istr("Content-Range") +CONTENT_TRANSFER_ENCODING = istr("Content-Transfer-Encoding") +CONTENT_TYPE = istr("Content-Type") +COOKIE = istr("Cookie") +DATE = istr("Date") +DESTINATION = istr("Destination") +DIGEST = istr("Digest") +ETAG = istr("Etag") +EXPECT = istr("Expect") +EXPIRES = istr("Expires") +FORWARDED = istr("Forwarded") +FROM = istr("From") +HOST = istr("Host") +IF_MATCH = istr("If-Match") +IF_MODIFIED_SINCE = istr("If-Modified-Since") +IF_NONE_MATCH = istr("If-None-Match") +IF_RANGE = istr("If-Range") +IF_UNMODIFIED_SINCE = istr("If-Unmodified-Since") +KEEP_ALIVE = istr("Keep-Alive") +LAST_EVENT_ID = istr("Last-Event-ID") +LAST_MODIFIED = istr("Last-Modified") +LINK = istr("Link") +LOCATION = istr("Location") +MAX_FORWARDS = istr("Max-Forwards") +ORIGIN = istr("Origin") +PRAGMA = istr("Pragma") +PROXY_AUTHENTICATE = istr("Proxy-Authenticate") +PROXY_AUTHORIZATION = istr("Proxy-Authorization") +RANGE = istr("Range") +REFERER = istr("Referer") +RETRY_AFTER = istr("Retry-After") +SEC_WEBSOCKET_ACCEPT = istr("Sec-WebSocket-Accept") +SEC_WEBSOCKET_VERSION = istr("Sec-WebSocket-Version") +SEC_WEBSOCKET_PROTOCOL = istr("Sec-WebSocket-Protocol") +SEC_WEBSOCKET_EXTENSIONS = istr("Sec-WebSocket-Extensions") +SEC_WEBSOCKET_KEY = istr("Sec-WebSocket-Key") +SEC_WEBSOCKET_KEY1 = istr("Sec-WebSocket-Key1") +SERVER = istr("Server") +SET_COOKIE = istr("Set-Cookie") +TE = istr("TE") +TRAILER = istr("Trailer") +TRANSFER_ENCODING = istr("Transfer-Encoding") +UPGRADE = istr("Upgrade") +URI = istr("URI") +USER_AGENT = istr("User-Agent") +VARY = istr("Vary") +VIA = istr("Via") +WANT_DIGEST = istr("Want-Digest") +WARNING = istr("Warning") +WWW_AUTHENTICATE = istr("WWW-Authenticate") +X_FORWARDED_FOR = istr("X-Forwarded-For") +X_FORWARDED_HOST = istr("X-Forwarded-Host") +X_FORWARDED_PROTO = istr("X-Forwarded-Proto") diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 85b3d9ffd5f..bbf5f1298fb 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -6,420 +6,400 @@ import cgi import datetime import functools +import inspect +import netrc import os +import platform import re import sys import time import warnings import weakref -from collections import MutableSequence, namedtuple -from functools import total_ordering +from collections import namedtuple +from contextlib import suppress from math import ceil from pathlib import Path -from time import gmtime +from types import TracebackType +from typing import ( + Any, + Callable, + Dict, + Generator, + Generic, + Iterable, + Iterator, + List, + Mapping, + Optional, + Pattern, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) from urllib.parse import quote +from urllib.request import getproxies -from async_timeout import timeout +import async_timeout +import attr +from multidict import MultiDict, MultiDictProxy +from typing_extensions import Protocol +from yarl import URL from . import hdrs +from .log import client_logger, internal_logger +from .typedefs import PathLike # noqa -try: - from asyncio import ensure_future -except ImportError: - ensure_future = asyncio.async +__all__ = ("BasicAuth", "ChainMapProxy") -PY_34 = sys.version_info < (3, 5) -PY_35 = sys.version_info >= (3, 5) -PY_352 = sys.version_info >= (3, 5, 2) +PY_36 = sys.version_info >= (3, 6) +PY_37 = sys.version_info >= (3, 7) +PY_38 = sys.version_info >= (3, 8) -if sys.version_info >= (3, 4, 3): - from http.cookies import SimpleCookie # noqa -else: - from .backport_cookies import SimpleCookie # noqa +if not PY_37: + import idna_ssl + idna_ssl.patch_match_hostname() -__all__ = ('BasicAuth', 'create_future', 'parse_mimetype', - 'Timeout', 'ensure_future', 'noop') +try: + from typing import ContextManager +except ImportError: + from typing_extensions import ContextManager + + +def all_tasks( + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> Set["asyncio.Task[Any]"]: + tasks = list(asyncio.Task.all_tasks(loop)) + return {t for t in tasks if not t.done()} + + +if PY_37: + all_tasks = getattr(asyncio, "all_tasks") + + +_T = TypeVar("_T") +_S = TypeVar("_S") + + +sentinel = object() # type: Any +NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) # type: bool + +# N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr +# for compatibility with older versions +DEBUG = getattr(sys.flags, "dev_mode", False) or ( + not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) +) # type: bool + + +CHAR = {chr(i) for i in range(0, 128)} +CTL = {chr(i) for i in range(0, 32)} | { + chr(127), +} +SEPARATORS = { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", + " ", + chr(9), +} +TOKEN = CHAR ^ CTL ^ SEPARATORS -sentinel = object() -Timeout = timeout -NO_EXTENSIONS = bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')) +class noop: + def __await__(self) -> Generator[None, None, None]: + yield -CHAR = set(chr(i) for i in range(0, 128)) -CTL = set(chr(i) for i in range(0, 32)) | {chr(127), } -SEPARATORS = {'(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', - '?', '=', '{', '}', ' ', chr(9)} -TOKEN = CHAR ^ CTL ^ SEPARATORS +class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])): + """Http basic authentication helper.""" -if sys.version_info < (3, 5): - noop = tuple + def __new__( + cls, login: str, password: str = "", encoding: str = "latin1" + ) -> "BasicAuth": + if login is None: + raise ValueError("None is not allowed as login value") - coroutines = asyncio.coroutines - old_debug = coroutines._DEBUG - coroutines._DEBUG = False + if password is None: + raise ValueError("None is not allowed as password value") - @asyncio.coroutine - def deprecated_noop(message): - warnings.warn(message, DeprecationWarning, stacklevel=3) + if ":" in login: + raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)') - coroutines._DEBUG = old_debug + return super().__new__(cls, login, password, encoding) -else: - coroutines = asyncio.coroutines - old_debug = coroutines._DEBUG - coroutines._DEBUG = False + @classmethod + def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": + """Create a BasicAuth object from an Authorization HTTP header.""" + try: + auth_type, encoded_credentials = auth_header.split(" ", 1) + except ValueError: + raise ValueError("Could not parse authorization header.") - @asyncio.coroutine - def noop(*args, **kwargs): - return + if auth_type.lower() != "basic": + raise ValueError("Unknown authorization method %s" % auth_type) - @asyncio.coroutine - def deprecated_noop(message): - warnings.warn(message, DeprecationWarning, stacklevel=3) + try: + decoded = base64.b64decode( + encoded_credentials.encode("ascii"), validate=True + ).decode(encoding) + except binascii.Error: + raise ValueError("Invalid base64 encoding.") - coroutines._DEBUG = old_debug + try: + # RFC 2617 HTTP Authentication + # https://www.ietf.org/rfc/rfc2617.txt + # the colon must be present, but the username and password may be + # otherwise blank. + username, password = decoded.split(":", 1) + except ValueError: + raise ValueError("Invalid credentials.") + return cls(username, password, encoding=encoding) -class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])): - """Http basic authentication helper. + @classmethod + def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: + """Create BasicAuth from url.""" + if not isinstance(url, URL): + raise TypeError("url should be yarl.URL instance") + if url.user is None: + return None + return cls(url.user, url.password or "", encoding=encoding) - :param str login: Login - :param str password: Password - :param str encoding: (optional) encoding ('latin1' by default) - """ + def encode(self) -> str: + """Encode credentials.""" + creds = (f"{self.login}:{self.password}").encode(self.encoding) + return "Basic %s" % base64.b64encode(creds).decode(self.encoding) - def __new__(cls, login, password='', encoding='latin1'): - if login is None: - raise ValueError('None is not allowed as login value') - if password is None: - raise ValueError('None is not allowed as password value') +def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: + auth = BasicAuth.from_url(url) + if auth is None: + return url, None + else: + return url.with_user(None), auth - if ':' in login: - raise ValueError( - 'A ":" is not allowed in login (RFC 1945#section-11.1)') - return super().__new__(cls, login, password, encoding) +def netrc_from_env() -> Optional[netrc.netrc]: + """Attempt to load the netrc file from the path specified by the env-var + NETRC or in the default location in the user's home directory. - @classmethod - def decode(cls, auth_header, encoding='latin1'): - """Create a :class:`BasicAuth` object from an ``Authorization`` HTTP - header.""" - split = auth_header.strip().split(' ') - if len(split) == 2: - if split[0].strip().lower() != 'basic': - raise ValueError('Unknown authorization method %s' % split[0]) - to_decode = split[1] - else: - raise ValueError('Could not parse authorization header.') + Returns None if it couldn't be found or fails to parse. + """ + netrc_env = os.environ.get("NETRC") + if netrc_env is not None: + netrc_path = Path(netrc_env) + else: try: - username, _, password = base64.b64decode( - to_decode.encode('ascii') - ).decode(encoding).partition(':') - except binascii.Error: - raise ValueError('Invalid base64 encoding.') + home_dir = Path.home() + except RuntimeError as e: # pragma: no cover + # if pathlib can't resolve home, it may raise a RuntimeError + client_logger.debug( + "Could not resolve home directory when " + "trying to look for .netrc file: %s", + e, + ) + return None - return cls(username, password, encoding=encoding) + netrc_path = home_dir / ( + "_netrc" if platform.system() == "Windows" else ".netrc" + ) - def encode(self): - """Encode credentials.""" - creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding) - return 'Basic %s' % base64.b64encode(creds).decode(self.encoding) + try: + return netrc.netrc(str(netrc_path)) + except netrc.NetrcParseError as e: + client_logger.warning("Could not parse .netrc file: %s", e) + except OSError as e: + # we couldn't read the file (doesn't exist, permissions, etc.) + if netrc_env or netrc_path.is_file(): + # only warn if the environment wanted us to load it, + # or it appears like the default file does actually exist + client_logger.warning("Could not read .netrc file: %s", e) + + return None + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class ProxyInfo: + proxy: URL + proxy_auth: Optional[BasicAuth] + + +def proxies_from_env() -> Dict[str, ProxyInfo]: + proxy_urls = {k: URL(v) for k, v in getproxies().items() if k in ("http", "https")} + netrc_obj = netrc_from_env() + stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} + ret = {} + for proto, val in stripped.items(): + proxy, auth = val + if proxy.scheme == "https": + client_logger.warning("HTTPS proxies %s are not supported, ignoring", proxy) + continue + if netrc_obj and auth is None: + auth_from_netrc = None + if proxy.host is not None: + auth_from_netrc = netrc_obj.authenticators(proxy.host) + if auth_from_netrc is not None: + # auth_from_netrc is a (`user`, `account`, `password`) tuple, + # `user` and `account` both can be username, + # if `user` is None, use `account` + *logins, password = auth_from_netrc + login = logins[0] if logins[0] else logins[-1] + auth = BasicAuth(cast(str, login), cast(str, password)) + ret[proto] = ProxyInfo(proxy, auth) + return ret + + +def current_task( + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> "Optional[asyncio.Task[Any]]": + if PY_37: + return asyncio.current_task(loop=loop) + else: + return asyncio.Task.current_task(loop=loop) + + +def get_running_loop( + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> asyncio.AbstractEventLoop: + if loop is None: + loop = asyncio.get_event_loop() + if not loop.is_running(): + warnings.warn( + "The object should be created within an async function", + DeprecationWarning, + stacklevel=3, + ) + if loop.get_debug(): + internal_logger.warning( + "The object should be created within an async function", stack_info=True + ) + return loop -if PY_352: - def create_future(loop): - return loop.create_future() -else: - def create_future(loop): - """Compatibility wrapper for the loop.create_future() call introduced in - 3.5.2.""" - return asyncio.Future(loop=loop) +def isasyncgenfunction(obj: Any) -> bool: + func = getattr(inspect, "isasyncgenfunction", None) + if func is not None: + return func(obj) + else: + return False -def parse_mimetype(mimetype): +@attr.s(auto_attribs=True, frozen=True, slots=True) +class MimeType: + type: str + subtype: str + suffix: str + parameters: "MultiDictProxy[str]" + + +@functools.lru_cache(maxsize=56) +def parse_mimetype(mimetype: str) -> MimeType: """Parses a MIME type into its components. - :param str mimetype: MIME type + mimetype is a MIME type string. - :returns: 4 element tuple for MIME type, subtype, suffix and parameters - :rtype: tuple + Returns a MimeType object. Example: >>> parse_mimetype('text/html; charset=utf-8') - ('text', 'html', '', {'charset': 'utf-8'}) + MimeType(type='text', subtype='html', suffix='', + parameters={'charset': 'utf-8'}) """ if not mimetype: - return '', '', '', {} + return MimeType( + type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) + ) - parts = mimetype.split(';') - params = [] + parts = mimetype.split(";") + params = MultiDict() # type: MultiDict[str] for item in parts[1:]: if not item: continue - key, value = item.split('=', 1) if '=' in item else (item, '') - params.append((key.lower().strip(), value.strip(' "'))) - params = dict(params) + key, value = cast( + Tuple[str, str], item.split("=", 1) if "=" in item else (item, "") + ) + params.add(key.lower().strip(), value.strip(' "')) fulltype = parts[0].strip().lower() - if fulltype == '*': - fulltype = '*/*' - - mtype, stype = fulltype.split('/', 1) \ - if '/' in fulltype else (fulltype, '') - stype, suffix = stype.split('+', 1) if '+' in stype else (stype, '') - - return mtype, stype, suffix, params - - -def guess_filename(obj, default=None): - name = getattr(obj, 'name', None) - if name and name[0] != '<' and name[-1] != '>': + if fulltype == "*": + fulltype = "*/*" + + mtype, stype = ( + cast(Tuple[str, str], fulltype.split("/", 1)) + if "/" in fulltype + else (fulltype, "") + ) + stype, suffix = ( + cast(Tuple[str, str], stype.split("+", 1)) if "+" in stype else (stype, "") + ) + + return MimeType( + type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) + ) + + +def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]: + name = getattr(obj, "name", None) + if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": return Path(name).name return default -def content_disposition_header(disptype, quote_fields=True, **params): +def content_disposition_header( + disptype: str, quote_fields: bool = True, **params: str +) -> str: """Sets ``Content-Disposition`` header. - :param str disptype: Disposition type: inline, attachment, form-data. - Should be valid extension token (see RFC 2183) - :param dict params: Disposition params + disptype is a disposition type: inline, attachment, form-data. + Should be valid extension token (see RFC 2183) + + params is a dict with disposition params. """ if not disptype or not (TOKEN > set(disptype)): - raise ValueError('bad content disposition type {!r}' - ''.format(disptype)) + raise ValueError("bad content disposition type {!r}" "".format(disptype)) value = disptype if params: lparams = [] for key, val in params.items(): if not key or not (TOKEN > set(key)): - raise ValueError('bad content disposition parameter' - ' {!r}={!r}'.format(key, val)) - qval = quote(val, '') if quote_fields else val + raise ValueError( + "bad content disposition parameter" " {!r}={!r}".format(key, val) + ) + qval = quote(val, "") if quote_fields else val lparams.append((key, '"%s"' % qval)) - if key == 'filename': - lparams.append(('filename*', "utf-8''" + qval)) - sparams = '; '.join('='.join(pair) for pair in lparams) - value = '; '.join((value, sparams)) + if key == "filename": + lparams.append(("filename*", "utf-8''" + qval)) + sparams = "; ".join("=".join(pair) for pair in lparams) + value = "; ".join((value, sparams)) return value -class AccessLogger: - """Helper object to log access. - - Usage: - log = logging.getLogger("spam") - log_format = "%a %{User-Agent}i" - access_logger = AccessLogger(log, log_format) - access_logger.log(message, environ, response, transport, time) - - Format: - %% The percent sign - %a Remote IP-address (IP-address of proxy if using reverse proxy) - %t Time when the request was started to process - %P The process ID of the child that serviced the request - %r First line of request - %s Response status code - %b Size of response in bytes, excluding HTTP headers - %O Bytes sent, including headers - %T Time taken to serve the request, in seconds - %Tf Time taken to serve the request, in seconds with floating fraction - in .06f format - %D Time taken to serve the request, in microseconds - %{FOO}i request.headers['FOO'] - %{FOO}o response.headers['FOO'] - %{FOO}e os.environ['FOO'] - - """ - LOG_FORMAT_MAP = { - 'a': 'remote_address', - 't': 'request_time', - 'P': 'process_id', - 'r': 'first_request_line', - 's': 'response_status', - 'b': 'response_size', - 'O': 'bytes_sent', - 'T': 'request_time', - 'Tf': 'request_time_frac', - 'D': 'request_time_micro', - 'i': 'request_header', - 'o': 'response_header', - 'e': 'environ' - } - - LOG_FORMAT = '%a %l %u %t "%r" %s %b "%{Referrer}i" "%{User-Agent}i"' - FORMAT_RE = re.compile(r'%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)') - CLEANUP_RE = re.compile(r'(%[^s])') - _FORMAT_CACHE = {} - - KeyMethod = namedtuple('KeyMethod', 'key method') - - def __init__(self, logger, log_format=LOG_FORMAT): - """Initialise the logger. - - :param logger: logger object to be used for logging - :param log_format: apache compatible log format - - """ - self.logger = logger - - _compiled_format = AccessLogger._FORMAT_CACHE.get(log_format) - if not _compiled_format: - _compiled_format = self.compile_format(log_format) - AccessLogger._FORMAT_CACHE[log_format] = _compiled_format - - self._log_format, self._methods = _compiled_format - - def compile_format(self, log_format): - """Translate log_format into form usable by modulo formatting - - All known atoms will be replaced with %s - Also methods for formatting of those atoms will be added to - _methods in apropriate order - - For example we have log_format = "%a %t" - This format will be translated to "%s %s" - Also contents of _methods will be - [self._format_a, self._format_t] - These method will be called and results will be passed - to translated string format. - - Each _format_* method receive 'args' which is list of arguments - given to self.log - - Exceptions are _format_e, _format_i and _format_o methods which - also receive key name (by functools.partial) - - """ - - log_format = log_format.replace("%l", "-") - log_format = log_format.replace("%u", "-") - - # list of (key, method) tuples, we don't use an OrderedDict as users - # can repeat the same key more than once - methods = list() - - for atom in self.FORMAT_RE.findall(log_format): - if atom[1] == '': - format_key = self.LOG_FORMAT_MAP[atom[0]] - m = getattr(AccessLogger, '_format_%s' % atom[0]) - else: - format_key = (self.LOG_FORMAT_MAP[atom[2]], atom[1]) - m = getattr(AccessLogger, '_format_%s' % atom[2]) - m = functools.partial(m, atom[1]) - - methods.append(self.KeyMethod(format_key, m)) - - log_format = self.FORMAT_RE.sub(r'%s', log_format) - log_format = self.CLEANUP_RE.sub(r'%\1', log_format) - return log_format, methods - - @staticmethod - def _format_e(key, args): - return (args[1] or {}).get(key, '-') - - @staticmethod - def _format_i(key, args): - if not args[0]: - return '(no headers)' - - # suboptimal, make istr(key) once - return args[0].headers.get(key, '-') - - @staticmethod - def _format_o(key, args): - # suboptimal, make istr(key) once - return args[2].headers.get(key, '-') - - @staticmethod - def _format_a(args): - if args[3] is None: - return '-' - peername = args[3].get_extra_info('peername') - if isinstance(peername, (list, tuple)): - return peername[0] - else: - return peername - - @staticmethod - def _format_t(args): - return datetime.datetime.utcnow().strftime('[%d/%b/%Y:%H:%M:%S +0000]') - - @staticmethod - def _format_P(args): - return "<%s>" % os.getpid() - - @staticmethod - def _format_r(args): - msg = args[0] - if not msg: - return '-' - return '%s %s HTTP/%s.%s' % tuple((msg.method, - msg.path) + msg.version) - - @staticmethod - def _format_s(args): - return args[2].status - - @staticmethod - def _format_b(args): - return args[2].body_length - - @staticmethod - def _format_O(args): - return args[2].body_length - - @staticmethod - def _format_T(args): - return round(args[4]) - - @staticmethod - def _format_Tf(args): - return '%06f' % args[4] - - @staticmethod - def _format_D(args): - return round(args[4] * 1000000) - - def _format_line(self, args): - return ((key, method(args)) for key, method in self._methods) - - def log(self, message, environ, response, transport, time): - """Log access. - - :param message: Request object. May be None. - :param environ: Environment dict. May be None. - :param response: Response object. - :param transport: Tansport object. May be None - :param float time: Time taken to serve the request. - """ - try: - fmt_info = self._format_line( - [message, environ, response, transport, time]) - - values = list() - extra = dict() - for key, value in fmt_info: - values.append(value) - - if key.__class__ is str: - extra[key] = value - else: - extra[key[0]] = {key[1]: value} +class _TSelf(Protocol): + _cache: Dict[str, Any] - self.logger.info(self._log_format % tuple(values), extra=extra) - except Exception: - self.logger.exception("Error in logging") - -class reify: +class reify(Generic[_T]): """Use as a class method decorator. It operates almost exactly like the Python `@property` decorator, but it puts the result of the method it decorates into the instance dict after the first call, @@ -428,15 +408,12 @@ class reify: """ - def __init__(self, wrapped): + def __init__(self, wrapped: Callable[..., _T]) -> None: self.wrapped = wrapped - try: - self.__doc__ = wrapped.__doc__ - except: # pragma: no cover - self.__doc__ = "" + self.__doc__ = wrapped.__doc__ self.name = wrapped.__name__ - def __get__(self, inst, owner, _sentinel=sentinel): + def __get__(self, inst: _TSelf, owner: Optional[Type[Any]] = None) -> _T: try: try: return inst._cache[self.name] @@ -449,236 +426,215 @@ def __get__(self, inst, owner, _sentinel=sentinel): return self raise - def __set__(self, inst, value): + def __set__(self, inst: _TSelf, value: _T) -> None: raise AttributeError("reified property is read-only") -_ipv4_pattern = (r'^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}' - r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$') +reify_py = reify + +try: + from ._helpers import reify as reify_c + + if not NO_EXTENSIONS: + reify = reify_c # type: ignore +except ImportError: + pass + +_ipv4_pattern = ( + r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}" + r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$" +) _ipv6_pattern = ( - r'^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}' - r'(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)' - r'((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})' - r'(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}' - r'(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}' - r'[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)' - r'(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}' - r':|:(:[A-F0-9]{1,4}){7})$') + r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}" + r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)" + r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})" + r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}" + r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}" + r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)" + r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}" + r":|:(:[A-F0-9]{1,4}){7})$" +) _ipv4_regex = re.compile(_ipv4_pattern) _ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE) -_ipv4_regexb = re.compile(_ipv4_pattern.encode('ascii')) -_ipv6_regexb = re.compile(_ipv6_pattern.encode('ascii'), flags=re.IGNORECASE) +_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii")) +_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE) -def is_ip_address(host): +def _is_ip_address( + regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]] +) -> bool: if host is None: return False if isinstance(host, str): - if _ipv4_regex.match(host) or _ipv6_regex.match(host): - return True - else: - return False + return bool(regex.match(host)) elif isinstance(host, (bytes, bytearray, memoryview)): - if _ipv4_regexb.match(host) or _ipv6_regexb.match(host): - return True - else: - return False + return bool(regexb.match(host)) else: - raise TypeError("{} [{}] is not a str or bytes" - .format(host, type(host))) - - -@total_ordering -class FrozenList(MutableSequence): - - __slots__ = ('_frozen', '_items') - - def __init__(self, items=None): - self._frozen = False - if items is not None: - items = list(items) - else: - items = [] - self._items = items - - def freeze(self): - self._frozen = True - self._items = tuple(self._items) + raise TypeError("{} [{}] is not a str or bytes".format(host, type(host))) - def __getitem__(self, index): - return self._items[index] - def __setitem__(self, index, value): - if self._frozen: - raise RuntimeError("Cannot modify frozen list.") - self._items[index] = value +is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb) +is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb) - def __delitem__(self, index): - if self._frozen: - raise RuntimeError("Cannot modify frozen list.") - del self._items[index] - def __len__(self): - return self._items.__len__() +def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool: + return is_ipv4_address(host) or is_ipv6_address(host) - def __iter__(self): - return self._items.__iter__() - def __reversed__(self): - return self._items.__reversed__() +def next_whole_second() -> datetime.datetime: + """Return current time rounded up to the next whole second.""" + return datetime.datetime.now(datetime.timezone.utc).replace( + microsecond=0 + ) + datetime.timedelta(seconds=0) - def __eq__(self, other): - return list(self) == other - def __le__(self, other): - return list(self) <= other +_cached_current_datetime = None # type: Optional[int] +_cached_formatted_datetime = "" - def insert(self, pos, item): - if self._frozen: - raise RuntimeError("Cannot modify frozen list.") - self._items.insert(pos, item) +def rfc822_formatted_time() -> str: + global _cached_current_datetime + global _cached_formatted_datetime -class TimeService: - - def __init__(self, loop, *, interval=1.0): - self._loop = loop - self._interval = interval - self._time = time.time() - self._loop_time = loop.time() - self._count = 0 - self._strtime = None - self._cb = loop.call_at(self._loop_time + self._interval, self._on_cb) - - def close(self): - if self._cb: - self._cb.cancel() - - self._cb = None - self._loop = None - - def _on_cb(self, reset_count=10*60): - if self._count >= reset_count: - # reset timer every 10 minutes - self._count = 0 - self._time = time.time() - else: - self._time += self._interval - - self._strtime = None - self._loop_time = ceil(self._loop.time()) - self._cb = self._loop.call_at( - self._loop_time + self._interval, self._on_cb) - - def _format_date_time(self): + now = int(time.time()) + if now != _cached_current_datetime: # Weekday and month names for HTTP date/time formatting; # always English! - # Tuples are contants stored in codeobject! + # Tuples are constants stored in codeobject! _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") - _monthname = (None, # Dummy so we can use 1-based month numbers - "Jan", "Feb", "Mar", "Apr", "May", "Jun", - "Jul", "Aug", "Sep", "Oct", "Nov", "Dec") - - year, month, day, hh, mm, ss, wd, y, z = gmtime(self._time) - return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - _weekdayname[wd], day, _monthname[month], year, hh, mm, ss + _monthname = ( + "", # Dummy so we can use 1-based month numbers + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", ) - def time(self): - return self._time - - def strtime(self): - s = self._strtime - if s is None: - self._strtime = s = self._format_date_time() - return self._strtime - - @property - def loop_time(self): - return self._loop_time + year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now) + _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( + _weekdayname[wd], + day, + _monthname[month], + year, + hh, + mm, + ss, + ) + _cached_current_datetime = now + return _cached_formatted_datetime -def _weakref_handle(info): +def _weakref_handle(info): # type: ignore ref, name = info ob = ref() if ob is not None: - try: + with suppress(Exception): getattr(ob, name)() - except: - pass -def weakref_handle(ob, name, timeout, loop, ceil_timeout=True): +def weakref_handle(ob, name, timeout, loop): # type: ignore if timeout is not None and timeout > 0: when = loop.time() + timeout - if ceil_timeout: + if timeout >= 5: when = ceil(when) return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name)) -def call_later(cb, timeout, loop): +def call_later(cb, timeout, loop): # type: ignore if timeout is not None and timeout > 0: - when = ceil(loop.time() + timeout) + when = loop.time() + timeout + if timeout > 5: + when = ceil(when) return loop.call_at(when, cb) class TimeoutHandle: """ Timeout handle """ - def __init__(self, loop, timeout): + def __init__( + self, loop: asyncio.AbstractEventLoop, timeout: Optional[float] + ) -> None: self._timeout = timeout self._loop = loop - self._callbacks = [] + self._callbacks = ( + [] + ) # type: List[Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]] - def register(self, callback, *args, **kwargs): + def register( + self, callback: Callable[..., None], *args: Any, **kwargs: Any + ) -> None: self._callbacks.append((callback, args, kwargs)) - def close(self): + def close(self) -> None: self._callbacks.clear() - def start(self): - if self._timeout is not None and self._timeout > 0: - at = ceil(self._loop.time() + self._timeout) - return self._loop.call_at(at, self.__call__) + def start(self) -> Optional[asyncio.Handle]: + timeout = self._timeout + if timeout is not None and timeout > 0: + when = self._loop.time() + timeout + if timeout >= 5: + when = ceil(when) + return self._loop.call_at(when, self.__call__) + else: + return None - def timer(self): - timer = TimerContext(self._loop) - self.register(timer.timeout) - return timer + def timer(self) -> "BaseTimerContext": + if self._timeout is not None and self._timeout > 0: + timer = TimerContext(self._loop) + self.register(timer.timeout) + return timer + else: + return TimerNoop() - def __call__(self): + def __call__(self) -> None: for cb, args, kwargs in self._callbacks: - try: + with suppress(Exception): cb(*args, **kwargs) - except: - pass self._callbacks.clear() -class TimerNoop: +class BaseTimerContext(ContextManager["BaseTimerContext"]): + pass + - def __enter__(self): +class TimerNoop(BaseTimerContext): + def __enter__(self) -> BaseTimerContext: return self - def __exit__(self, exc_type, exc_val, exc_tb): - return False + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + return -class TimerContext: +class TimerContext(BaseTimerContext): """ Low resolution timeout context manager """ - def __init__(self, loop): + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop = loop - self._tasks = [] + self._tasks = [] # type: List[asyncio.Task[Any]] self._cancelled = False - def __enter__(self): - task = asyncio.Task.current_task(loop=self._loop) + def __enter__(self) -> BaseTimerContext: + task = current_task(loop=self._loop) + if task is None: - raise RuntimeError('Timeout context manager should be used ' - 'inside a task') + raise RuntimeError( + "Timeout context manager should be used " "inside a task" + ) if self._cancelled: task.cancel() @@ -687,77 +643,138 @@ def __enter__(self): self._tasks.append(task) return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: if self._tasks: - task = self._tasks.pop() - else: - task = None + self._tasks.pop() if exc_type is asyncio.CancelledError and self._cancelled: - for task in self._tasks: - task.cancel() raise asyncio.TimeoutError from None + return None - if exc_type is None and self._cancelled and task is not None: - task.cancel() - - def timeout(self): + def timeout(self) -> None: if not self._cancelled: - for task in self._tasks: + for task in set(self._tasks): task.cancel() self._cancelled = True -class CeilTimeout(Timeout): - - def __enter__(self): +class CeilTimeout(async_timeout.timeout): + def __enter__(self) -> async_timeout.timeout: if self._timeout is not None: - self._task = asyncio.Task.current_task(loop=self._loop) + self._task = current_task(loop=self._loop) if self._task is None: raise RuntimeError( - 'Timeout context manager should be used inside a task') - self._cancel_handler = self._loop.call_at( - ceil(self._loop.time() + self._timeout), self._cancel_task) + "Timeout context manager should be used inside a task" + ) + now = self._loop.time() + delay = self._timeout + when = now + delay + if delay > 5: + when = ceil(when) + self._cancel_handler = self._loop.call_at(when, self._cancel_task) return self class HeadersMixin: - _content_type = None - _content_dict = None + ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"]) + + _content_type = None # type: Optional[str] + _content_dict = None # type: Optional[Dict[str, str]] _stored_content_type = sentinel - def _parse_content_type(self, raw): + def _parse_content_type(self, raw: str) -> None: self._stored_content_type = raw if raw is None: # default value according to RFC 2616 - self._content_type = 'application/octet-stream' + self._content_type = "application/octet-stream" self._content_dict = {} else: self._content_type, self._content_dict = cgi.parse_header(raw) @property - def content_type(self, *, _CONTENT_TYPE=hdrs.CONTENT_TYPE): + def content_type(self) -> str: """The value of content part for Content-Type HTTP header.""" - raw = self._headers.get(_CONTENT_TYPE) + raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore if self._stored_content_type != raw: self._parse_content_type(raw) - return self._content_type + return self._content_type # type: ignore @property - def charset(self, *, _CONTENT_TYPE=hdrs.CONTENT_TYPE): + def charset(self) -> Optional[str]: """The value of charset part for Content-Type HTTP header.""" - raw = self._headers.get(_CONTENT_TYPE) + raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore if self._stored_content_type != raw: self._parse_content_type(raw) - return self._content_dict.get('charset') + return self._content_dict.get("charset") # type: ignore @property - def content_length(self, *, _CONTENT_LENGTH=hdrs.CONTENT_LENGTH): + def content_length(self) -> Optional[int]: """The value of Content-Length HTTP header.""" - l = self._headers.get(_CONTENT_LENGTH) - if l is None: - return None + content_length = self._headers.get(hdrs.CONTENT_LENGTH) # type: ignore + + if content_length is not None: + return int(content_length) else: - return int(l) + return None + + +def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: + if not fut.done(): + fut.set_result(result) + + +def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None: + if not fut.done(): + fut.set_exception(exc) + + +class ChainMapProxy(Mapping[str, Any]): + __slots__ = ("_maps",) + + def __init__(self, maps: Iterable[Mapping[str, Any]]) -> None: + self._maps = tuple(maps) + + def __init_subclass__(cls) -> None: + raise TypeError( + "Inheritance class {} from ChainMapProxy " + "is forbidden".format(cls.__name__) + ) + + def __getitem__(self, key: str) -> Any: + for mapping in self._maps: + try: + return mapping[key] + except KeyError: + pass + raise KeyError(key) + + def get(self, key: str, default: Any = None) -> Any: + return self[key] if key in self else default + + def __len__(self) -> int: + # reuses stored hash values if possible + return len(set().union(*self._maps)) # type: ignore + + def __iter__(self) -> Iterator[str]: + d = {} # type: Dict[str, Any] + for mapping in reversed(self._maps): + # reuses stored hash values if possible + d.update(mapping) + return iter(d) + + def __contains__(self, key: object) -> bool: + return any(key in m for m in self._maps) + + def __bool__(self) -> bool: + return any(self._maps) + + def __repr__(self) -> str: + content = ", ".join(map(repr, self._maps)) + return f"ChainMapProxy({content})" diff --git a/aiohttp/http.py b/aiohttp/http.py index 7ee7e76795a..415ffbf563b 100644 --- a/aiohttp/http.py +++ b/aiohttp/http.py @@ -1,37 +1,72 @@ import http.server import sys - -from yarl import URL # noqa +from typing import Mapping, Tuple from . import __version__ -from .http_exceptions import HttpProcessingError -from .http_parser import (HttpParser, HttpRequestParser, HttpResponseParser, - RawRequestMessage, RawResponseMessage) -from .http_websocket import (WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, WS_KEY, - WebSocketError, WebSocketReader, WebSocketWriter, - WSCloseCode, WSMessage, WSMsgType, do_handshake) -from .http_writer import (HttpVersion, HttpVersion10, HttpVersion11, - PayloadWriter, StreamWriter) +from .http_exceptions import HttpProcessingError as HttpProcessingError +from .http_parser import ( + HeadersParser as HeadersParser, + HttpParser as HttpParser, + HttpRequestParser as HttpRequestParser, + HttpResponseParser as HttpResponseParser, + RawRequestMessage as RawRequestMessage, + RawResponseMessage as RawResponseMessage, +) +from .http_websocket import ( + WS_CLOSED_MESSAGE as WS_CLOSED_MESSAGE, + WS_CLOSING_MESSAGE as WS_CLOSING_MESSAGE, + WS_KEY as WS_KEY, + WebSocketError as WebSocketError, + WebSocketReader as WebSocketReader, + WebSocketWriter as WebSocketWriter, + WSCloseCode as WSCloseCode, + WSMessage as WSMessage, + WSMsgType as WSMsgType, + ws_ext_gen as ws_ext_gen, + ws_ext_parse as ws_ext_parse, +) +from .http_writer import ( + HttpVersion as HttpVersion, + HttpVersion10 as HttpVersion10, + HttpVersion11 as HttpVersion11, + StreamWriter as StreamWriter, +) __all__ = ( - 'HttpProcessingError', 'RESPONSES', 'SERVER_SOFTWARE', - + "HttpProcessingError", + "RESPONSES", + "SERVER_SOFTWARE", # .http_writer - 'PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11', - 'StreamWriter', - + "StreamWriter", + "HttpVersion", + "HttpVersion10", + "HttpVersion11", # .http_parser - 'HttpParser', 'HttpRequestParser', 'HttpResponseParser', - 'RawRequestMessage', 'RawResponseMessage', - + "HeadersParser", + "HttpParser", + "HttpRequestParser", + "HttpResponseParser", + "RawRequestMessage", + "RawResponseMessage", # .http_websocket - 'WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY', - 'WebSocketReader', 'WebSocketWriter', 'do_handshake', - 'WSMessage', 'WebSocketError', 'WSMsgType', 'WSCloseCode', + "WS_CLOSED_MESSAGE", + "WS_CLOSING_MESSAGE", + "WS_KEY", + "WebSocketReader", + "WebSocketWriter", + "ws_ext_gen", + "ws_ext_parse", + "WSMessage", + "WebSocketError", + "WSMsgType", + "WSCloseCode", ) -SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} aiohttp/{1}'.format( - sys.version_info, __version__) +SERVER_SOFTWARE = "Python/{0[0]}.{0[1]} aiohttp/{1}".format( + sys.version_info, __version__ +) # type: str -RESPONSES = http.server.BaseHTTPRequestHandler.responses +RESPONSES = ( + http.server.BaseHTTPRequestHandler.responses +) # type: Mapping[int, Tuple[str, str]] diff --git a/aiohttp/http_exceptions.py b/aiohttp/http_exceptions.py index dc2d1095f93..c885f80f322 100644 --- a/aiohttp/http_exceptions.py +++ b/aiohttp/http_exceptions.py @@ -1,6 +1,11 @@ """Low-level http related exceptions.""" -__all__ = ('HttpProcessingError',) + +from typing import Optional, Union + +from .typedefs import _CIMultiDict + +__all__ = ("HttpProcessingError",) class HttpProcessingError(Exception): @@ -8,37 +13,48 @@ class HttpProcessingError(Exception): Shortcut for raising HTTP errors with custom code, message and headers. - :param int code: HTTP Error code. - :param str message: (optional) Error message. - :param list of [tuple] headers: (optional) Headers to be sent in response. + code: HTTP Error code. + message: (optional) Error message. + headers: (optional) Headers to be sent in response, a list of pairs """ code = 0 - message = '' + message = "" headers = None - def __init__(self, *, code=None, message='', headers=None): + def __init__( + self, + *, + code: Optional[int] = None, + message: str = "", + headers: Optional[_CIMultiDict] = None, + ) -> None: if code is not None: self.code = code self.headers = headers self.message = message - super().__init__("%s, message='%s'" % (self.code, message)) + def __str__(self) -> str: + return f"{self.code}, message={self.message!r}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: {self}>" class BadHttpMessage(HttpProcessingError): code = 400 - message = 'Bad Request' + message = "Bad Request" - def __init__(self, message, *, headers=None): + def __init__(self, message: str, *, headers: Optional[_CIMultiDict] = None) -> None: super().__init__(message=message, headers=headers) + self.args = (message,) class HttpBadRequest(BadHttpMessage): code = 400 - message = 'Bad Request' + message = "Bad Request" class PayloadEncodingError(BadHttpMessage): @@ -58,27 +74,30 @@ class ContentLengthError(PayloadEncodingError): class LineTooLong(BadHttpMessage): - - def __init__(self, line, limit='Unknown'): + def __init__( + self, line: str, limit: str = "Unknown", actual_size: str = "Unknown" + ) -> None: super().__init__( - "Got more than %s bytes when reading %s." % (limit, line)) + f"Got more than {limit} bytes ({actual_size}) when reading {line}." + ) + self.args = (line, limit, actual_size) class InvalidHeader(BadHttpMessage): - - def __init__(self, hdr): + def __init__(self, hdr: Union[bytes, str]) -> None: if isinstance(hdr, bytes): - hdr = hdr.decode('utf-8', 'surrogateescape') - super().__init__('Invalid HTTP Header: {}'.format(hdr)) + hdr = hdr.decode("utf-8", "surrogateescape") + super().__init__(f"Invalid HTTP Header: {hdr}") self.hdr = hdr + self.args = (hdr,) class BadStatusLine(BadHttpMessage): - - def __init__(self, line=''): - if not line: + def __init__(self, line: str = "") -> None: + if not isinstance(line, str): line = repr(line) - self.args = line, + super().__init__(f"Bad status line {line!r}") + self.args = (line,) self.line = line diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 5c7000c4215..71ba815ae67 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -1,39 +1,91 @@ +import abc +import asyncio import collections import re import string import zlib from enum import IntEnum +from typing import Any, List, Optional, Tuple, Type, Union -import yarl -from multidict import CIMultiDict, istr +from multidict import CIMultiDict, CIMultiDictProxy, istr +from yarl import URL from . import hdrs -from .helpers import NO_EXTENSIONS -from .http_exceptions import (BadStatusLine, ContentEncodingError, - ContentLengthError, InvalidHeader, LineTooLong, - TransferEncodingError) +from .base_protocol import BaseProtocol +from .helpers import NO_EXTENSIONS, BaseTimerContext +from .http_exceptions import ( + BadStatusLine, + ContentEncodingError, + ContentLengthError, + InvalidHeader, + LineTooLong, + TransferEncodingError, +) from .http_writer import HttpVersion, HttpVersion10 from .log import internal_logger -from .streams import EMPTY_PAYLOAD, FlowControlStreamReader +from .streams import EMPTY_PAYLOAD, StreamReader +from .typedefs import RawHeaders + +try: + import brotli + + HAS_BROTLI = True +except ImportError: # pragma: no cover + HAS_BROTLI = False + __all__ = ( - 'HttpParser', 'HttpRequestParser', 'HttpResponseParser', - 'RawRequestMessage', 'RawResponseMessage') + "HeadersParser", + "HttpParser", + "HttpRequestParser", + "HttpResponseParser", + "RawRequestMessage", + "RawResponseMessage", +) ASCIISET = set(string.printable) -METHRE = re.compile('[A-Z0-9$-_.]+') -VERSRE = re.compile(r'HTTP/(\d+).(\d+)') -HDRRE = re.compile(rb'[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]') + +# See https://tools.ietf.org/html/rfc7230#section-3.1.1 +# and https://tools.ietf.org/html/rfc7230#appendix-B +# +# method = token +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +# token = 1*tchar +METHRE = re.compile(r"[!#$%&'*+\-.^_`|~0-9A-Za-z]+") +VERSRE = re.compile(r"HTTP/(\d+).(\d+)") +HDRRE = re.compile(rb"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") RawRequestMessage = collections.namedtuple( - 'RawRequestMessage', - ['method', 'path', 'version', 'headers', 'raw_headers', - 'should_close', 'compression', 'upgrade', 'chunked', 'url']) + "RawRequestMessage", + [ + "method", + "path", + "version", + "headers", + "raw_headers", + "should_close", + "compression", + "upgrade", + "chunked", + "url", + ], +) RawResponseMessage = collections.namedtuple( - 'RawResponseMessage', - ['version', 'code', 'reason', 'headers', 'raw_headers', - 'should_close', 'compression', 'upgrade', 'chunked']) + "RawResponseMessage", + [ + "version", + "code", + "reason", + "headers", + "raw_headers", + "should_close", + "compression", + "upgrade", + "chunked", + ], +) class ParseState(IntEnum): @@ -52,13 +104,118 @@ class ChunkState(IntEnum): PARSE_TRAILERS = 4 -class HttpParser: +class HeadersParser: + def __init__( + self, + max_line_size: int = 8190, + max_headers: int = 32768, + max_field_size: int = 8190, + ) -> None: + self.max_line_size = max_line_size + self.max_headers = max_headers + self.max_field_size = max_field_size + + def parse_headers( + self, lines: List[bytes] + ) -> Tuple["CIMultiDictProxy[str]", RawHeaders]: + headers = CIMultiDict() # type: CIMultiDict[str] + raw_headers = [] + + lines_idx = 1 + line = lines[1] + line_count = len(lines) + + while line: + # Parse initial header name : value pair. + try: + bname, bvalue = line.split(b":", 1) + except ValueError: + raise InvalidHeader(line) from None + + bname = bname.strip(b" \t") + bvalue = bvalue.lstrip() + if HDRRE.search(bname): + raise InvalidHeader(bname) + if len(bname) > self.max_field_size: + raise LineTooLong( + "request header name {}".format( + bname.decode("utf8", "xmlcharrefreplace") + ), + str(self.max_field_size), + str(len(bname)), + ) + + header_length = len(bvalue) + + # next line + lines_idx += 1 + line = lines[lines_idx] + + # consume continuation lines + continuation = line and line[0] in (32, 9) # (' ', '\t') + + if continuation: + bvalue_lst = [bvalue] + while continuation: + header_length += len(line) + if header_length > self.max_field_size: + raise LineTooLong( + "request header field {}".format( + bname.decode("utf8", "xmlcharrefreplace") + ), + str(self.max_field_size), + str(header_length), + ) + bvalue_lst.append(line) + + # next line + lines_idx += 1 + if lines_idx < line_count: + line = lines[lines_idx] + if line: + continuation = line[0] in (32, 9) # (' ', '\t') + else: + line = b"" + break + bvalue = b"".join(bvalue_lst) + else: + if header_length > self.max_field_size: + raise LineTooLong( + "request header field {}".format( + bname.decode("utf8", "xmlcharrefreplace") + ), + str(self.max_field_size), + str(header_length), + ) + + bvalue = bvalue.strip() + name = bname.decode("utf-8", "surrogateescape") + value = bvalue.decode("utf-8", "surrogateescape") + + headers.add(name, value) + raw_headers.append((bname, bvalue)) - def __init__(self, protocol=None, loop=None, - max_line_size=8190, max_headers=32768, max_field_size=8190, - timer=None, code=None, method=None, readall=False, - payload_exception=None, - response_with_body=True, read_until_eof=False): + return (CIMultiDictProxy(headers), tuple(raw_headers)) + + +class HttpParser(abc.ABC): + def __init__( + self, + protocol: Optional[BaseProtocol] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + limit: int = 2 ** 16, + max_line_size: int = 8190, + max_headers: int = 32768, + max_field_size: int = 8190, + timer: Optional[BaseTimerContext] = None, + code: Optional[int] = None, + method: Optional[str] = None, + readall: bool = False, + payload_exception: Optional[Type[BaseException]] = None, + response_with_body: bool = True, + read_until_eof: bool = False, + auto_decompress: bool = True, + ) -> None: self.protocol = protocol self.loop = loop self.max_line_size = max_line_size @@ -72,27 +229,50 @@ def __init__(self, protocol=None, loop=None, self.response_with_body = response_with_body self.read_until_eof = read_until_eof - self._lines = [] - self._tail = b'' + self._lines = [] # type: List[bytes] + self._tail = b"" self._upgraded = False self._payload = None - self._payload_parser = None + self._payload_parser = None # type: Optional[HttpPayloadParser] + self._auto_decompress = auto_decompress + self._limit = limit + self._headers_parser = HeadersParser(max_line_size, max_headers, max_field_size) + + @abc.abstractmethod + def parse_message(self, lines: List[bytes]) -> Any: + pass - def feed_eof(self): + def feed_eof(self) -> Any: if self._payload_parser is not None: self._payload_parser.feed_eof() self._payload_parser = None + else: + # try to extract partial message + if self._tail: + self._lines.append(self._tail) - def feed_data(self, data, - SEP=b'\r\n', EMPTY=b'', - CONTENT_LENGTH=hdrs.CONTENT_LENGTH, - METH_CONNECT=hdrs.METH_CONNECT, - SEC_WEBSOCKET_KEY1=hdrs.SEC_WEBSOCKET_KEY1): + if self._lines: + if self._lines[-1] != "\r\n": + self._lines.append(b"") + try: + return self.parse_message(self._lines) + except Exception: + return None + + def feed_data( + self, + data: bytes, + SEP: bytes = b"\r\n", + EMPTY: bytes = b"", + CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH, + METH_CONNECT: str = hdrs.METH_CONNECT, + SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1, + ) -> Tuple[List[Any], bool, bytes]: messages = [] if self._tail: - data, self._tail = self._tail + data, b'' + data, self._tail = self._tail + data, b"" data_len = len(data) start_pos = 0 @@ -104,6 +284,11 @@ def feed_data(self, data, # and split by lines if self._payload_parser is None and not self._upgraded: pos = data.find(SEP, start_pos) + # consume \r\n + if pos == start_pos and not self._lines: + start_pos = pos + 2 + continue + if pos >= start_pos: # line found self._lines.append(data[start_pos:pos]) @@ -132,43 +317,76 @@ def feed_data(self, data, self._upgraded = msg.upgrade - method = getattr(msg, 'method', self.method) + method = getattr(msg, "method", self.method) + assert self.protocol is not None # calculate payload - if ((length is not None and length > 0) or - msg.chunked and not msg.upgrade): - payload = FlowControlStreamReader( - self.protocol, timer=self.timer, loop=loop) + if ( + (length is not None and length > 0) + or msg.chunked + and not msg.upgrade + ): + payload = StreamReader( + self.protocol, + timer=self.timer, + loop=loop, + limit=self._limit, + ) payload_parser = HttpPayloadParser( - payload, length=length, - chunked=msg.chunked, method=method, + payload, + length=length, + chunked=msg.chunked, + method=method, compression=msg.compression, - code=self.code, readall=self.readall, - response_with_body=self.response_with_body) + code=self.code, + readall=self.readall, + response_with_body=self.response_with_body, + auto_decompress=self._auto_decompress, + ) if not payload_parser.done: self._payload_parser = payload_parser elif method == METH_CONNECT: - payload = FlowControlStreamReader( - self.protocol, timer=self.timer, loop=loop) + payload = StreamReader( + self.protocol, + timer=self.timer, + loop=loop, + limit=self._limit, + ) self._upgraded = True self._payload_parser = HttpPayloadParser( - payload, method=msg.method, - compression=msg.compression, readall=True) + payload, + method=msg.method, + compression=msg.compression, + readall=True, + auto_decompress=self._auto_decompress, + ) else: - if (getattr(msg, 'code', 100) >= 199 and - length is None and self.read_until_eof): - payload = FlowControlStreamReader( - self.protocol, timer=self.timer, loop=loop) + if ( + getattr(msg, "code", 100) >= 199 + and length is None + and self.read_until_eof + ): + payload = StreamReader( + self.protocol, + timer=self.timer, + loop=loop, + limit=self._limit, + ) payload_parser = HttpPayloadParser( - payload, length=length, - chunked=msg.chunked, method=method, + payload, + length=length, + chunked=msg.chunked, + method=method, compression=msg.compression, - code=self.code, readall=True, - response_with_body=self.response_with_body) + code=self.code, + readall=True, + response_with_body=self.response_with_body, + auto_decompress=self._auto_decompress, + ) if not payload_parser.done: self._payload_parser = payload_parser else: - payload = EMPTY_PAYLOAD + payload = EMPTY_PAYLOAD # type: ignore messages.append((msg, payload)) else: @@ -184,18 +402,19 @@ def feed_data(self, data, # feed payload elif data and start_pos < data_len: assert not self._lines + assert self._payload_parser is not None try: - eof, data = self._payload_parser.feed_data( - data[start_pos:]) + eof, data = self._payload_parser.feed_data(data[start_pos:]) except BaseException as exc: if self.payload_exception is not None: self._payload_parser.payload.set_exception( - self.payload_exception(str(exc))) + self.payload_exception(str(exc)) + ) else: self._payload_parser.payload.set_exception(exc) eof = True - data = b'' + data = b"" if eof: start_pos = 0 @@ -212,227 +431,227 @@ def feed_data(self, data, return messages, self._upgraded, data - def parse_headers(self, lines): + def parse_headers( + self, lines: List[bytes] + ) -> Tuple[ + "CIMultiDictProxy[str]", RawHeaders, Optional[bool], Optional[str], bool, bool + ]: """Parses RFC 5322 headers from a stream. Line continuations are supported. Returns list of header name and value pairs. Header name is in upper case. """ - headers = CIMultiDict() - raw_headers = [] - - lines_idx = 1 - line = lines[1] - line_count = len(lines) - - while line: - header_length = len(line) - - # Parse initial header name : value pair. - try: - bname, bvalue = line.split(b':', 1) - except ValueError: - raise InvalidHeader(line) from None - - bname = bname.strip(b' \t') - if HDRRE.search(bname): - raise InvalidHeader(bname) - - # next line - lines_idx += 1 - line = lines[lines_idx] - - # consume continuation lines - continuation = line and line[0] in (32, 9) # (' ', '\t') - - if continuation: - bvalue = [bvalue] - while continuation: - header_length += len(line) - if header_length > self.max_field_size: - raise LineTooLong( - 'request header field {}'.format( - bname.decode("utf8", "xmlcharrefreplace")), - self.max_field_size) - bvalue.append(line) - - # next line - lines_idx += 1 - if lines_idx < line_count: - line = lines[lines_idx] - if line: - continuation = line[0] in (32, 9) # (' ', '\t') - else: - line = b'' - break - bvalue = b''.join(bvalue) - else: - if header_length > self.max_field_size: - raise LineTooLong( - 'request header field {}'.format( - bname.decode("utf8", "xmlcharrefreplace")), - self.max_field_size) - - bvalue = bvalue.strip() - name = istr(bname.decode('utf-8', 'surrogateescape')) - value = bvalue.decode('utf-8', 'surrogateescape') - - headers.add(name, value) - raw_headers.append((bname, bvalue)) - + headers, raw_headers = self._headers_parser.parse_headers(lines) close_conn = None encoding = None upgrade = False chunked = False - raw_headers = tuple(raw_headers) # keep-alive conn = headers.get(hdrs.CONNECTION) if conn: v = conn.lower() - if v == 'close': + if v == "close": close_conn = True - elif v == 'keep-alive': + elif v == "keep-alive": close_conn = False - elif v == 'upgrade': + elif v == "upgrade": upgrade = True # encoding enc = headers.get(hdrs.CONTENT_ENCODING) if enc: enc = enc.lower() - if enc in ('gzip', 'deflate'): + if enc in ("gzip", "deflate", "br"): encoding = enc # chunking te = headers.get(hdrs.TRANSFER_ENCODING) - if te and 'chunked' in te.lower(): + if te and "chunked" in te.lower(): chunked = True - return headers, raw_headers, close_conn, encoding, upgrade, chunked + return (headers, raw_headers, close_conn, encoding, upgrade, chunked) + + def set_upgraded(self, val: bool) -> None: + """Set connection upgraded (to websocket) mode. + :param bool val: new state. + """ + self._upgraded = val -class HttpRequestParserPy(HttpParser): +class HttpRequestParser(HttpParser): """Read request status line. Exception .http_exceptions.BadStatusLine could be raised in case of any errors in status line. Returns RawRequestMessage. """ - def parse_message(self, lines): - if len(lines[0]) > self.max_line_size: - raise LineTooLong( - 'Status line is too long', self.max_line_size) - + def parse_message(self, lines: List[bytes]) -> Any: # request line - line = lines[0].decode('utf-8', 'surrogateescape') + line = lines[0].decode("utf-8", "surrogateescape") try: method, path, version = line.split(None, 2) except ValueError: raise BadStatusLine(line) from None + if len(path) > self.max_line_size: + raise LineTooLong( + "Status line is too long", str(self.max_line_size), str(len(path)) + ) + + path_part, _hash_separator, url_fragment = path.partition("#") + path_part, _question_mark_separator, qs_part = path_part.partition("?") + # method - method = method.upper() if not METHRE.match(method): raise BadStatusLine(method) # version try: - if version.startswith('HTTP/'): - n1, n2 = version[5:].split('.', 1) - version = HttpVersion(int(n1), int(n2)) + if version.startswith("HTTP/"): + n1, n2 = version[5:].split(".", 1) + version_o = HttpVersion(int(n1), int(n2)) else: raise BadStatusLine(version) - except: + except Exception: raise BadStatusLine(version) # read headers - headers, raw_headers, \ - close, compression, upgrade, chunked = self.parse_headers(lines) + ( + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + ) = self.parse_headers(lines) if close is None: # then the headers weren't set in the request - if version <= HttpVersion10: # HTTP 1.0 must asks to not close + if version_o <= HttpVersion10: # HTTP 1.0 must asks to not close close = True else: # HTTP 1.1 must ask to close. close = False return RawRequestMessage( - method, path, version, headers, raw_headers, - close, compression, upgrade, chunked, yarl.URL(path)) - - -class HttpResponseParserPy(HttpParser): + method, + path, + version_o, + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + # NOTE: `yarl.URL.build()` is used to mimic what the Cython-based + # NOTE: parser does, otherwise it results into the same + # NOTE: HTTP Request-Line input producing different + # NOTE: `yarl.URL()` objects + URL.build( + path=path_part, + query_string=qs_part, + fragment=url_fragment, + encoded=True, + ), + ) + + +class HttpResponseParser(HttpParser): """Read response status line and headers. BadStatusLine could be raised in case of any errors in status line. Returns RawResponseMessage""" - def parse_message(self, lines): - if len(lines[0]) > self.max_line_size: - raise LineTooLong( - 'Status line is too long', self.max_line_size) - - line = lines[0].decode('utf-8', 'surrogateescape') + def parse_message(self, lines: List[bytes]) -> Any: + line = lines[0].decode("utf-8", "surrogateescape") try: version, status = line.split(None, 1) except ValueError: raise BadStatusLine(line) from None - else: - try: - status, reason = status.split(None, 1) - except ValueError: - reason = '' + + try: + status, reason = status.split(None, 1) + except ValueError: + reason = "" + + if len(reason) > self.max_line_size: + raise LineTooLong( + "Status line is too long", str(self.max_line_size), str(len(reason)) + ) # version match = VERSRE.match(version) if match is None: raise BadStatusLine(line) - version = HttpVersion(int(match.group(1)), int(match.group(2))) + version_o = HttpVersion(int(match.group(1)), int(match.group(2))) # The status code is a three-digit number try: - status = int(status) + status_i = int(status) except ValueError: raise BadStatusLine(line) from None - if status > 999: + if status_i > 999: raise BadStatusLine(line) # read headers - headers, raw_headers, \ - close, compression, upgrade, chunked = self.parse_headers(lines) + ( + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + ) = self.parse_headers(lines) if close is None: - close = version <= HttpVersion10 + close = version_o <= HttpVersion10 return RawResponseMessage( - version, status, reason.strip(), - headers, raw_headers, close, compression, upgrade, chunked) + version_o, + status_i, + reason.strip(), + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + ) class HttpPayloadParser: - - def __init__(self, payload, - length=None, chunked=False, compression=None, - code=None, method=None, - readall=False, response_with_body=True): - self.payload = payload - + def __init__( + self, + payload: StreamReader, + length: Optional[int] = None, + chunked: bool = False, + compression: Optional[str] = None, + code: Optional[int] = None, + method: Optional[str] = None, + readall: bool = False, + response_with_body: bool = True, + auto_decompress: bool = True, + ) -> None: self._length = 0 self._type = ParseState.PARSE_NONE self._chunk = ChunkState.PARSE_CHUNKED_SIZE self._chunk_size = 0 - self._chunk_tail = b'' + self._chunk_tail = b"" + self._auto_decompress = auto_decompress self.done = False # payload decompression wrapper - if (response_with_body and compression): - payload = DeflateBuffer(payload, compression) + if response_with_body and compression and self._auto_decompress: + real_payload = DeflateBuffer( + payload, compression + ) # type: Union[StreamReader, DeflateBuffer] + else: + real_payload = payload # payload parser if not response_with_body: # don't parse payload if it's not expected to be received self._type = ParseState.PARSE_NONE - payload.feed_eof() + real_payload.feed_eof() self.done = True elif chunked: @@ -441,31 +660,36 @@ def __init__(self, payload, self._type = ParseState.PARSE_LENGTH self._length = length if self._length == 0: - payload.feed_eof() + real_payload.feed_eof() self.done = True else: if readall and code != 204: self._type = ParseState.PARSE_UNTIL_EOF - elif method in ('PUT', 'POST'): + elif method in ("PUT", "POST"): internal_logger.warning( # pragma: no cover - 'Content-Length or Transfer-Encoding header is required') + "Content-Length or Transfer-Encoding header is required" + ) self._type = ParseState.PARSE_NONE - payload.feed_eof() + real_payload.feed_eof() self.done = True - self.payload = payload + self.payload = real_payload - def feed_eof(self): + def feed_eof(self) -> None: if self._type == ParseState.PARSE_UNTIL_EOF: self.payload.feed_eof() elif self._type == ParseState.PARSE_LENGTH: raise ContentLengthError( - "Not enough data for satisfy content length header.") + "Not enough data for satisfy content length header." + ) elif self._type == ParseState.PARSE_CHUNKED: raise TransferEncodingError( - "Not enough data for satisfy transfer length header.") + "Not enough data for satisfy transfer length header." + ) - def feed_data(self, chunk, SEP=b'\r\n', CHUNK_EXT=b';'): + def feed_data( + self, chunk: bytes, SEP: bytes = b"\r\n", CHUNK_EXT: bytes = b";" + ) -> Tuple[bool, bytes]: # Read specified amount of bytes if self._type == ParseState.PARSE_LENGTH: required = self._length @@ -476,7 +700,7 @@ def feed_data(self, chunk, SEP=b'\r\n', CHUNK_EXT=b';'): self.payload.feed_data(chunk, chunk_len) if self._length == 0: self.payload.feed_eof() - return True, b'' + return True, b"" else: self._length = 0 self.payload.feed_data(chunk[:required], required) @@ -487,7 +711,7 @@ def feed_data(self, chunk, SEP=b'\r\n', CHUNK_EXT=b';'): elif self._type == ParseState.PARSE_CHUNKED: if self._chunk_tail: chunk = self._chunk_tail + chunk - self._chunk_tail = b'' + self._chunk_tail = b"" while chunk: @@ -497,44 +721,45 @@ def feed_data(self, chunk, SEP=b'\r\n', CHUNK_EXT=b';'): if pos >= 0: i = chunk.find(CHUNK_EXT, 0, pos) if i >= 0: - size = chunk[:i] # strip chunk-extensions + size_b = chunk[:i] # strip chunk-extensions else: - size = chunk[:pos] + size_b = chunk[:pos] try: - size = int(size, 16) + size = int(bytes(size_b), 16) except ValueError: - exc = TransferEncodingError(chunk[:pos]) + exc = TransferEncodingError( + chunk[:pos].decode("ascii", "surrogateescape") + ) self.payload.set_exception(exc) raise exc from None - chunk = chunk[pos+2:] + chunk = chunk[pos + 2 :] if size == 0: # eof marker self._chunk = ChunkState.PARSE_MAYBE_TRAILERS else: self._chunk = ChunkState.PARSE_CHUNKED_CHUNK self._chunk_size = size + self.payload.begin_http_chunk_receiving() else: self._chunk_tail = chunk - return False, None + return False, b"" # read chunk and feed buffer if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK: required = self._chunk_size chunk_len = len(chunk) - if required >= chunk_len: + if required > chunk_len: self._chunk_size = required - chunk_len - if self._chunk_size == 0: - self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF - self.payload.feed_data(chunk, chunk_len) - return False, None + return False, b"" else: self._chunk_size = 0 self.payload.feed_data(chunk[:required], required) chunk = chunk[required:] self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF + self.payload.end_http_chunk_receiving() # toss the CRLF at the end of the chunk if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF: @@ -543,80 +768,134 @@ def feed_data(self, chunk, SEP=b'\r\n', CHUNK_EXT=b';'): self._chunk = ChunkState.PARSE_CHUNKED_SIZE else: self._chunk_tail = chunk - return False, None + return False, b"" # if stream does not contain trailer, after 0\r\n # we should get another \r\n otherwise # trailers needs to be skiped until \r\n\r\n if self._chunk == ChunkState.PARSE_MAYBE_TRAILERS: - if chunk[:2] == SEP: + head = chunk[:2] + if head == SEP: # end of stream self.payload.feed_eof() return True, chunk[2:] - else: - self._chunk = ChunkState.PARSE_TRAILERS + # Both CR and LF, or only LF may not be received yet. It is + # expected that CRLF or LF will be shown at the very first + # byte next time, otherwise trailers should come. The last + # CRLF which marks the end of response might not be + # contained in the same TCP segment which delivered the + # size indicator. + if not head: + return False, b"" + if head == SEP[:1]: + self._chunk_tail = head + return False, b"" + self._chunk = ChunkState.PARSE_TRAILERS # read and discard trailer up to the CRLF terminator if self._chunk == ChunkState.PARSE_TRAILERS: pos = chunk.find(SEP) if pos >= 0: - chunk = chunk[pos+2:] + chunk = chunk[pos + 2 :] self._chunk = ChunkState.PARSE_MAYBE_TRAILERS else: self._chunk_tail = chunk - return False, None + return False, b"" # Read all bytes until eof elif self._type == ParseState.PARSE_UNTIL_EOF: self.payload.feed_data(chunk, len(chunk)) - return False, None + return False, b"" class DeflateBuffer: """DeflateStream decompress stream and feed data into specified stream.""" - def __init__(self, out, encoding): + def __init__(self, out: StreamReader, encoding: Optional[str]) -> None: self.out = out self.size = 0 self.encoding = encoding + self._started_decoding = False + + if encoding == "br": + if not HAS_BROTLI: # pragma: no cover + raise ContentEncodingError( + "Can not decode content-encoding: brotli (br). " + "Please install `brotlipy`" + ) + self.decompressor = brotli.Decompressor() + else: + zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS + self.decompressor = zlib.decompressobj(wbits=zlib_mode) - zlib_mode = (16 + zlib.MAX_WBITS - if encoding == 'gzip' else -zlib.MAX_WBITS) - - self.zlib = zlib.decompressobj(wbits=zlib_mode) - - def set_exception(self, exc): + def set_exception(self, exc: BaseException) -> None: self.out.set_exception(exc) - def feed_data(self, chunk, size): + def feed_data(self, chunk: bytes, size: int) -> None: + if not size: + return + self.size += size + + # RFC1950 + # bits 0..3 = CM = 0b1000 = 8 = "deflate" + # bits 4..7 = CINFO = 1..7 = windows size. + if ( + not self._started_decoding + and self.encoding == "deflate" + and chunk[0] & 0xF != 8 + ): + # Change the decoder to decompress incorrectly compressed data + # Actually we should issue a warning about non-RFC-compliant data. + self.decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS) + try: - chunk = self.zlib.decompress(chunk) + chunk = self.decompressor.decompress(chunk) except Exception: raise ContentEncodingError( - 'Can not decode content-encoding: %s' % self.encoding) + "Can not decode content-encoding: %s" % self.encoding + ) + + self._started_decoding = True if chunk: self.out.feed_data(chunk, len(chunk)) - def feed_eof(self): - chunk = self.zlib.flush() + def feed_eof(self) -> None: + chunk = self.decompressor.flush() if chunk or self.size > 0: self.out.feed_data(chunk, len(chunk)) - if not self.zlib.eof: - raise ContentEncodingError('deflate') + if self.encoding == "deflate" and not self.decompressor.eof: + raise ContentEncodingError("deflate") self.out.feed_eof() + def begin_http_chunk_receiving(self) -> None: + self.out.begin_http_chunk_receiving() + + def end_http_chunk_receiving(self) -> None: + self.out.end_http_chunk_receiving() + + +HttpRequestParserPy = HttpRequestParser +HttpResponseParserPy = HttpResponseParser +RawRequestMessagePy = RawRequestMessage +RawResponseMessagePy = RawResponseMessage -HttpRequestParser = HttpRequestParserPy -HttpResponseParser = HttpResponseParserPy try: - from ._http_parser import HttpRequestParserC, HttpResponseParserC - if not NO_EXTENSIONS: # pragma: no cover - HttpRequestParser = HttpRequestParserC - HttpResponseParser = HttpResponseParserC + if not NO_EXTENSIONS: + from ._http_parser import ( # type: ignore + HttpRequestParser, + HttpResponseParser, + RawRequestMessage, + RawResponseMessage, + ) + + HttpRequestParserC = HttpRequestParser + HttpResponseParserC = HttpResponseParser + RawRequestMessageC = RawRequestMessage + RawResponseMessageC = RawResponseMessage except ImportError: # pragma: no cover pass diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index 6e0caf4137a..5cdaeea43c0 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -1,23 +1,31 @@ """WebSocket protocol versions 13 and 8.""" -import base64 -import binascii +import asyncio import collections -import hashlib import json import random +import re import sys +import zlib from enum import IntEnum from struct import Struct +from typing import Any, Callable, List, Optional, Tuple, Union -from . import hdrs -from .helpers import NO_EXTENSIONS, noop -from .http_exceptions import HttpBadRequest, HttpProcessingError -from .log import ws_logger +from .base_protocol import BaseProtocol +from .helpers import NO_EXTENSIONS +from .streams import DataQueue -__all__ = ('WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY', - 'WebSocketReader', 'WebSocketWriter', 'do_handshake', - 'WSMessage', 'WebSocketError', 'WSMsgType', 'WSCloseCode') +__all__ = ( + "WS_CLOSED_MESSAGE", + "WS_CLOSING_MESSAGE", + "WS_KEY", + "WebSocketReader", + "WebSocketWriter", + "WSMessage", + "WebSocketError", + "WSMsgType", + "WSCloseCode", +) class WSCloseCode(IntEnum): @@ -43,7 +51,7 @@ class WSMsgType(IntEnum): TEXT = 0x1 BINARY = 0x2 PING = 0x9 - PONG = 0xa + PONG = 0xA CLOSE = 0x8 # aiohttp specific types @@ -61,37 +69,31 @@ class WSMsgType(IntEnum): error = ERROR -WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" -UNPACK_LEN2 = Struct('!H').unpack_from -UNPACK_LEN3 = Struct('!Q').unpack_from -UNPACK_CLOSE_CODE = Struct('!H').unpack -PACK_LEN1 = Struct('!BB').pack -PACK_LEN2 = Struct('!BBH').pack -PACK_LEN3 = Struct('!BBQ').pack -PACK_CLOSE_CODE = Struct('!H').pack +UNPACK_LEN2 = Struct("!H").unpack_from +UNPACK_LEN3 = Struct("!Q").unpack_from +UNPACK_CLOSE_CODE = Struct("!H").unpack +PACK_LEN1 = Struct("!BB").pack +PACK_LEN2 = Struct("!BBH").pack +PACK_LEN3 = Struct("!BBQ").pack +PACK_CLOSE_CODE = Struct("!H").pack MSG_SIZE = 2 ** 14 DEFAULT_LIMIT = 2 ** 16 -_WSMessageBase = collections.namedtuple('_WSMessageBase', - ['type', 'data', 'extra']) +_WSMessageBase = collections.namedtuple("_WSMessageBase", ["type", "data", "extra"]) class WSMessage(_WSMessageBase): - - def json(self, *, loads=json.loads): + def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any: """Return parsed JSON data. .. versionadded:: 0.22 """ return loads(self.data) - @property - def tp(self): - return self.type - WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None) WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None) @@ -100,21 +102,33 @@ def tp(self): class WebSocketError(Exception): """WebSocket protocol parser error.""" - def __init__(self, code, message): + def __init__(self, code: int, message: str) -> None: self.code = code - super().__init__(message) + super().__init__(code, message) + + def __str__(self) -> str: + return self.args[1] + + +class WSHandshakeError(Exception): + """WebSocket protocol handshake error.""" native_byteorder = sys.byteorder -def _websocket_mask_python(mask, data): +# Used by _websocket_mask_python +_XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)] + + +def _websocket_mask_python(mask: bytes, data: bytearray) -> None: """Websocket masking function. - `mask` is a `bytes` object of length 4; `data` is a `bytes` object - of any length. Returns a `bytes` object of the same length as - `data` with the mask applied as specified in section 5.3 of RFC - 6455. + `mask` is a `bytes` object of length 4; `data` is a `bytearray` + object of any length. The contents of `data` are masked with `mask`, + as specified in section 5.3 of RFC 6455. + + Note that this function mutates the `data` argument. This pure-python implementation may be replaced by an optimized version when available. @@ -122,25 +136,111 @@ def _websocket_mask_python(mask, data): """ assert isinstance(data, bytearray), data assert len(mask) == 4, mask - datalen = len(data) - if datalen == 0: - # everything work without this, but may be changed later in Python. - return bytearray() - data = int.from_bytes(data, native_byteorder) - mask = int.from_bytes(mask * (datalen // 4) + mask[: datalen % 4], - native_byteorder) - return (data ^ mask).to_bytes(datalen, native_byteorder) + + if data: + a, b, c, d = (_XOR_TABLE[n] for n in mask) + data[::4] = data[::4].translate(a) + data[1::4] = data[1::4].translate(b) + data[2::4] = data[2::4].translate(c) + data[3::4] = data[3::4].translate(d) -if NO_EXTENSIONS: +if NO_EXTENSIONS: # pragma: no cover _websocket_mask = _websocket_mask_python else: try: - from ._websocket import _websocket_mask_cython + from ._websocket import _websocket_mask_cython # type: ignore + _websocket_mask = _websocket_mask_cython except ImportError: # pragma: no cover _websocket_mask = _websocket_mask_python +_WS_DEFLATE_TRAILING = bytes([0x00, 0x00, 0xFF, 0xFF]) + + +_WS_EXT_RE = re.compile( + r"^(?:;\s*(?:" + r"(server_no_context_takeover)|" + r"(client_no_context_takeover)|" + r"(server_max_window_bits(?:=(\d+))?)|" + r"(client_max_window_bits(?:=(\d+))?)))*$" +) + +_WS_EXT_RE_SPLIT = re.compile(r"permessage-deflate([^,]+)?") + + +def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]: + if not extstr: + return 0, False + + compress = 0 + notakeover = False + for ext in _WS_EXT_RE_SPLIT.finditer(extstr): + defext = ext.group(1) + # Return compress = 15 when get `permessage-deflate` + if not defext: + compress = 15 + break + match = _WS_EXT_RE.match(defext) + if match: + compress = 15 + if isserver: + # Server never fail to detect compress handshake. + # Server does not need to send max wbit to client + if match.group(4): + compress = int(match.group(4)) + # Group3 must match if group4 matches + # Compress wbit 8 does not support in zlib + # If compress level not support, + # CONTINUE to next extension + if compress > 15 or compress < 9: + compress = 0 + continue + if match.group(1): + notakeover = True + # Ignore regex group 5 & 6 for client_max_window_bits + break + else: + if match.group(6): + compress = int(match.group(6)) + # Group5 must match if group6 matches + # Compress wbit 8 does not support in zlib + # If compress level not support, + # FAIL the parse progress + if compress > 15 or compress < 9: + raise WSHandshakeError("Invalid window size") + if match.group(2): + notakeover = True + # Ignore regex group 5 & 6 for client_max_window_bits + break + # Return Fail if client side and not match + elif not isserver: + raise WSHandshakeError("Extension for deflate not supported" + ext.group(1)) + + return compress, notakeover + + +def ws_ext_gen( + compress: int = 15, isserver: bool = False, server_notakeover: bool = False +) -> str: + # client_notakeover=False not used for server + # compress wbit 8 does not support in zlib + if compress < 9 or compress > 15: + raise ValueError( + "Compress wbits must between 9 and 15, " "zlib does not support wbits=8" + ) + enabledext = ["permessage-deflate"] + if not isserver: + enabledext.append("client_max_window_bits") + + if compress < 15: + enabledext.append("server_max_window_bits=" + str(compress)) + if server_notakeover: + enabledext.append("server_no_context_takeover") + # if client_notakeover: + # enabledext.append('client_no_context_takeover') + return "; ".join(enabledext) + class WSParserState(IntEnum): READ_HEADER = 1 @@ -150,29 +250,34 @@ class WSParserState(IntEnum): class WebSocketReader: - - def __init__(self, queue): + def __init__( + self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True + ) -> None: self.queue = queue + self._max_msg_size = max_msg_size - self._exc = None - self._partial = [] + self._exc = None # type: Optional[BaseException] + self._partial = bytearray() self._state = WSParserState.READ_HEADER - self._opcode = None + self._opcode = None # type: Optional[int] self._frame_fin = False - self._frame_opcode = None + self._frame_opcode = None # type: Optional[int] self._frame_payload = bytearray() - self._tail = b'' + self._tail = b"" self._has_mask = False - self._frame_mask = None + self._frame_mask = None # type: Optional[bytes] self._payload_length = 0 self._payload_length_flag = 0 + self._compressed = None # type: Optional[bool] + self._decompressobj = None # type: Any # zlib.decompressobj actually + self._compress = compress - def feed_eof(self): + def feed_eof(self) -> None: self.queue.feed_eof() - def feed_data(self, data): + def feed_data(self, data: bytes) -> Tuple[bool, bytes]: if self._exc: return True, data @@ -181,56 +286,68 @@ def feed_data(self, data): except Exception as exc: self._exc = exc self.queue.set_exception(exc) - return True, b'' + return True, b"" - def _feed_data(self, data): - for fin, opcode, payload in self.parse_frame(data): + def _feed_data(self, data: bytes) -> Tuple[bool, bytes]: + for fin, opcode, payload, compressed in self.parse_frame(data): + if compressed and not self._decompressobj: + self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS) if opcode == WSMsgType.CLOSE: if len(payload) >= 2: close_code = UNPACK_CLOSE_CODE(payload[:2])[0] - if (close_code < 3000 and - close_code not in ALLOWED_CLOSE_CODES): + if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, - 'Invalid close code: {}'.format(close_code)) + f"Invalid close code: {close_code}", + ) try: - close_message = payload[2:].decode('utf-8') + close_message = payload[2:].decode("utf-8") except UnicodeDecodeError as exc: raise WebSocketError( - WSCloseCode.INVALID_TEXT, - 'Invalid UTF-8 text message') from exc + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc msg = WSMessage(WSMsgType.CLOSE, close_code, close_message) elif payload: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, - 'Invalid close frame: {} {} {!r}'.format( - fin, opcode, payload)) + f"Invalid close frame: {fin} {opcode} {payload!r}", + ) else: - msg = WSMessage(WSMsgType.CLOSE, 0, '') + msg = WSMessage(WSMsgType.CLOSE, 0, "") self.queue.feed_data(msg, 0) elif opcode == WSMsgType.PING: self.queue.feed_data( - WSMessage(WSMsgType.PING, payload, ''), len(payload)) + WSMessage(WSMsgType.PING, payload, ""), len(payload) + ) elif opcode == WSMsgType.PONG: self.queue.feed_data( - WSMessage(WSMsgType.PONG, payload, ''), len(payload)) + WSMessage(WSMsgType.PONG, payload, ""), len(payload) + ) - elif opcode not in ( - WSMsgType.TEXT, WSMsgType.BINARY) and not self._opcode: + elif ( + opcode not in (WSMsgType.TEXT, WSMsgType.BINARY) + and self._opcode is None + ): raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Unexpected opcode={!r}".format(opcode)) + WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" + ) else: # load text/binary - if not fin: # got partial frame payload if opcode != WSMsgType.CONTINUATION: self._opcode = opcode - self._partial.append(payload) + self._partial.extend(payload) + if self._max_msg_size and len(self._partial) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(self._partial), self._max_msg_size + ), + ) else: # previous frame was non finished # we should get continuation opcode @@ -238,48 +355,78 @@ def _feed_data(self, data): if opcode != WSMsgType.CONTINUATION: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, - 'The opcode in non-fin frame is expected ' - 'to be zero, got {!r}'.format(opcode)) + "The opcode in non-fin frame is expected " + "to be zero, got {!r}".format(opcode), + ) if opcode == WSMsgType.CONTINUATION: + assert self._opcode is not None opcode = self._opcode + self._opcode = None + + self._partial.extend(payload) + if self._max_msg_size and len(self._partial) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(self._partial), self._max_msg_size + ), + ) + + # Decompress process must to be done after all packets + # received. + if compressed: + self._partial.extend(_WS_DEFLATE_TRAILING) + payload_merged = self._decompressobj.decompress( + self._partial, self._max_msg_size + ) + if self._decompressobj.unconsumed_tail: + left = len(self._decompressobj.unconsumed_tail) + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Decompressed message size {} exceeds limit {}".format( + self._max_msg_size + left, self._max_msg_size + ), + ) + else: + payload_merged = bytes(self._partial) - self._partial.append(payload) + self._partial.clear() if opcode == WSMsgType.TEXT: try: - text = b''.join(self._partial).decode('utf-8') + text = payload_merged.decode("utf-8") self.queue.feed_data( - WSMessage(WSMsgType.TEXT, text, ''), len(text)) + WSMessage(WSMsgType.TEXT, text, ""), len(text) + ) except UnicodeDecodeError as exc: raise WebSocketError( - WSCloseCode.INVALID_TEXT, - 'Invalid UTF-8 text message') from exc + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc else: - data = b''.join(self._partial) self.queue.feed_data( - WSMessage(WSMsgType.BINARY, data, ''), len(data)) - - self._start_opcode = None - self._partial.clear() + WSMessage(WSMsgType.BINARY, payload_merged, ""), + len(payload_merged), + ) - return False, b'' + return False, b"" - def parse_frame(self, buf, continuation=False, EMPTY=b''): + def parse_frame( + self, buf: bytes + ) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]: """Return the next frame from the socket.""" frames = [] if self._tail: - buf, self._tail = self._tail + buf, EMPTY + buf, self._tail = self._tail + buf, b"" start_pos = 0 buf_length = len(buf) while True: - # read header if self._state == WSParserState.READ_HEADER: if buf_length - start_pos >= 2: - data = buf[start_pos:start_pos+2] + data = buf[start_pos : start_pos + 2] start_pos += 2 first_byte, second_byte = data @@ -287,7 +434,7 @@ def parse_frame(self, buf, continuation=False, EMPTY=b''): rsv1 = (first_byte >> 6) & 1 rsv2 = (first_byte >> 5) & 1 rsv3 = (first_byte >> 4) & 1 - opcode = first_byte & 0xf + opcode = first_byte & 0xF # frame-fin = %x0 ; more frames of this message follow # / %x1 ; final frame of this message @@ -297,39 +444,45 @@ def parse_frame(self, buf, continuation=False, EMPTY=b''): # 1 bit, MUST be 0 unless negotiated otherwise # frame-rsv3 = %x0 ; # 1 bit, MUST be 0 unless negotiated otherwise - if rsv1 or rsv2 or rsv3: + # + # Remove rsv1 from this test for deflate development + if rsv2 or rsv3 or (rsv1 and not self._compress): raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, - 'Received frame with non-zero reserved bits') + "Received frame with non-zero reserved bits", + ) if opcode > 0x7 and fin == 0: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, - 'Received fragmented control frame') - - continuation = not self._frame_fin - if (fin == 0 and - opcode == WSMsgType.CONTINUATION and - not continuation): - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - 'Received new fragment frame with non-zero ' - 'opcode {!r}'.format(opcode)) + "Received fragmented control frame", + ) has_mask = (second_byte >> 7) & 1 - length = (second_byte) & 0x7f + length = second_byte & 0x7F # Control frames MUST have a payload # length of 125 bytes or less if opcode > 0x7 and length > 125: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, - 'Control frame payload cannot be ' - 'larger than 125 bytes') + "Control frame payload cannot be " "larger than 125 bytes", + ) + + # Set compress status if last package is FIN + # OR set compress status if this is first fragment + # Raise error if not first fragment with rsv1 = 0x1 + if self._frame_fin or self._compressed is None: + self._compressed = True if rsv1 else False + elif rsv1: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received frame with non-zero reserved bits", + ) - self._frame_fin = fin + self._frame_fin = bool(fin) self._frame_opcode = opcode - self._has_mask = has_mask + self._has_mask = bool(has_mask) self._payload_length_flag = length self._state = WSParserState.READ_PAYLOAD_LENGTH else: @@ -340,26 +493,28 @@ def parse_frame(self, buf, continuation=False, EMPTY=b''): length = self._payload_length_flag if length == 126: if buf_length - start_pos >= 2: - data = buf[start_pos:start_pos+2] + data = buf[start_pos : start_pos + 2] start_pos += 2 length = UNPACK_LEN2(data)[0] self._payload_length = length self._state = ( WSParserState.READ_PAYLOAD_MASK if self._has_mask - else WSParserState.READ_PAYLOAD) + else WSParserState.READ_PAYLOAD + ) else: break elif length > 126: if buf_length - start_pos >= 8: - data = buf[start_pos:start_pos+8] + data = buf[start_pos : start_pos + 8] start_pos += 8 length = UNPACK_LEN3(data)[0] self._payload_length = length self._state = ( WSParserState.READ_PAYLOAD_MASK if self._has_mask - else WSParserState.READ_PAYLOAD) + else WSParserState.READ_PAYLOAD + ) else: break else: @@ -367,12 +522,13 @@ def parse_frame(self, buf, continuation=False, EMPTY=b''): self._state = ( WSParserState.READ_PAYLOAD_MASK if self._has_mask - else WSParserState.READ_PAYLOAD) + else WSParserState.READ_PAYLOAD + ) # read payload mask if self._state == WSParserState.READ_PAYLOAD_MASK: if buf_length - start_pos >= 4: - self._frame_mask = buf[start_pos:start_pos+4] + self._frame_mask = buf[start_pos : start_pos + 4] start_pos += 4 self._state = WSParserState.READ_PAYLOAD else: @@ -389,41 +545,81 @@ def parse_frame(self, buf, continuation=False, EMPTY=b''): start_pos = buf_length else: self._payload_length = 0 - payload.extend(buf[start_pos:start_pos+length]) + payload.extend(buf[start_pos : start_pos + length]) start_pos = start_pos + length if self._payload_length == 0: if self._has_mask: - payload = _websocket_mask( - self._frame_mask, payload) + assert self._frame_mask is not None + _websocket_mask(self._frame_mask, payload) frames.append( - (self._frame_fin, self._frame_opcode, payload)) + (self._frame_fin, self._frame_opcode, payload, self._compressed) + ) self._frame_payload = bytearray() self._state = WSParserState.READ_HEADER else: break + self._tail = buf[start_pos:] + return frames class WebSocketWriter: - - def __init__(self, stream, *, - use_mask=False, limit=DEFAULT_LIMIT, random=random.Random()): - self.stream = stream - self.writer = stream.transport + def __init__( + self, + protocol: BaseProtocol, + transport: asyncio.Transport, + *, + use_mask: bool = False, + limit: int = DEFAULT_LIMIT, + random: Any = random.Random(), + compress: int = 0, + notakeover: bool = False, + ) -> None: + self.protocol = protocol + self.transport = transport self.use_mask = use_mask self.randrange = random.randrange + self.compress = compress + self.notakeover = notakeover self._closing = False self._limit = limit self._output_size = 0 + self._compressobj = None # type: Any # actually compressobj - def _send_frame(self, message, opcode): + async def _send_frame( + self, message: bytes, opcode: int, compress: Optional[int] = None + ) -> None: """Send a frame over the websocket with message as its payload.""" - if self._closing: - ws_logger.warning('websocket connection is closing.') + if self._closing and not (opcode & WSMsgType.CLOSE): + raise ConnectionResetError("Cannot write to closing transport") + + rsv = 0 + + # Only compress larger packets (disabled) + # Does small packet needs to be compressed? + # if self.compress and opcode < 8 and len(message) > 124: + if (compress or self.compress) and opcode < 8: + if compress: + # Do not set self._compress if compressing is for this frame + compressobj = zlib.compressobj(level=zlib.Z_BEST_SPEED, wbits=-compress) + else: # self.compress + if not self._compressobj: + self._compressobj = zlib.compressobj( + level=zlib.Z_BEST_SPEED, wbits=-self.compress + ) + compressobj = self._compressobj + + message = compressobj.compress(message) + message = message + compressobj.flush( + zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH + ) + if message.endswith(_WS_DEFLATE_TRAILING): + message = message[:-4] + rsv = rsv | 0x40 msg_length = len(message) @@ -434,138 +630,69 @@ def _send_frame(self, message, opcode): mask_bit = 0 if msg_length < 126: - header = PACK_LEN1(0x80 | opcode, msg_length | mask_bit) + header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit) elif msg_length < (1 << 16): - header = PACK_LEN2(0x80 | opcode, 126 | mask_bit, msg_length) + header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length) else: - header = PACK_LEN3(0x80 | opcode, 127 | mask_bit, msg_length) + header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length) if use_mask: - mask = self.randrange(0, 0xffffffff) - mask = mask.to_bytes(4, 'big') - message = _websocket_mask(mask, bytearray(message)) - self.writer.write(header + mask + message) + mask = self.randrange(0, 0xFFFFFFFF) + mask = mask.to_bytes(4, "big") + message = bytearray(message) + _websocket_mask(mask, message) + self._write(header + mask + message) self._output_size += len(header) + len(mask) + len(message) else: if len(message) > MSG_SIZE: - self.writer.write(header) - self.writer.write(message) + self._write(header) + self._write(message) else: - self.writer.write(header + message) + self._write(header + message) self._output_size += len(header) + len(message) if self._output_size > self._limit: self._output_size = 0 - return self.stream.drain() + await self.protocol._drain_helper() - return noop() + def _write(self, data: bytes) -> None: + if self.transport is None or self.transport.is_closing(): + raise ConnectionResetError("Cannot write to closing transport") + self.transport.write(data) - def pong(self, message=b''): + async def pong(self, message: bytes = b"") -> None: """Send pong message.""" if isinstance(message, str): - message = message.encode('utf-8') - return self._send_frame(message, WSMsgType.PONG) + message = message.encode("utf-8") + await self._send_frame(message, WSMsgType.PONG) - def ping(self, message=b''): + async def ping(self, message: bytes = b"") -> None: """Send ping message.""" if isinstance(message, str): - message = message.encode('utf-8') - return self._send_frame(message, WSMsgType.PING) - - def send(self, message, binary=False): + message = message.encode("utf-8") + await self._send_frame(message, WSMsgType.PING) + + async def send( + self, + message: Union[str, bytes], + binary: bool = False, + compress: Optional[int] = None, + ) -> None: """Send a frame over the websocket with message as its payload.""" if isinstance(message, str): - message = message.encode('utf-8') + message = message.encode("utf-8") if binary: - return self._send_frame(message, WSMsgType.BINARY) + await self._send_frame(message, WSMsgType.BINARY, compress) else: - return self._send_frame(message, WSMsgType.TEXT) + await self._send_frame(message, WSMsgType.TEXT, compress) - def close(self, code=1000, message=b''): + async def close(self, code: int = 1000, message: bytes = b"") -> None: """Close the websocket, sending the specified code and message.""" if isinstance(message, str): - message = message.encode('utf-8') + message = message.encode("utf-8") try: - return self._send_frame( - PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE) + await self._send_frame( + PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE + ) finally: self._closing = True - - -def do_handshake(method, headers, stream, - protocols=(), write_buffer_size=DEFAULT_LIMIT): - """Prepare WebSocket handshake. - - It return HTTP response code, response headers, websocket parser, - websocket writer. It does not perform any IO. - - `protocols` is a sequence of known protocols. On successful handshake, - the returned response headers contain the first protocol in this list - which the server also knows. - - `write_buffer_size` max size of write buffer before `drain()` get called. - """ - # WebSocket accepts only GET - if method.upper() != hdrs.METH_GET: - raise HttpProcessingError( - code=405, headers=((hdrs.ALLOW, hdrs.METH_GET),)) - - if 'websocket' != headers.get(hdrs.UPGRADE, '').lower().strip(): - raise HttpBadRequest( - message='No WebSocket UPGRADE hdr: {}\n Can ' - '"Upgrade" only to "WebSocket".'.format(headers.get(hdrs.UPGRADE))) - - if 'upgrade' not in headers.get(hdrs.CONNECTION, '').lower(): - raise HttpBadRequest( - message='No CONNECTION upgrade hdr: {}'.format( - headers.get(hdrs.CONNECTION))) - - # find common sub-protocol between client and server - protocol = None - if hdrs.SEC_WEBSOCKET_PROTOCOL in headers: - req_protocols = [str(proto.strip()) for proto in - headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')] - - for proto in req_protocols: - if proto in protocols: - protocol = proto - break - else: - # No overlap found: Return no protocol as per spec - ws_logger.warning( - 'Client protocols %r don’t overlap server-known ones %r', - req_protocols, protocols) - - # check supported version - version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, '') - if version not in ('13', '8', '7'): - raise HttpBadRequest( - message='Unsupported version: {}'.format(version), - headers=((hdrs.SEC_WEBSOCKET_VERSION, '13'),)) - - # check client handshake for validity - key = headers.get(hdrs.SEC_WEBSOCKET_KEY) - try: - if not key or len(base64.b64decode(key)) != 16: - raise HttpBadRequest( - message='Handshake error: {!r}'.format(key)) - except binascii.Error: - raise HttpBadRequest( - message='Handshake error: {!r}'.format(key)) from None - - response_headers = [ - (hdrs.UPGRADE, 'websocket'), - (hdrs.CONNECTION, 'upgrade'), - (hdrs.TRANSFER_ENCODING, 'chunked'), - (hdrs.SEC_WEBSOCKET_ACCEPT, base64.b64encode( - hashlib.sha1(key.encode() + WS_KEY).digest()).decode())] - - if protocol: - response_headers.append((hdrs.SEC_WEBSOCKET_PROTOCOL, protocol)) - - # response code, headers, None, writer, protocol - return (101, - response_headers, - None, - WebSocketWriter(stream, limit=write_buffer_size), - protocol) diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index 8e6756054bf..d261fc4e8d1 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -2,136 +2,34 @@ import asyncio import collections -import socket import zlib -from urllib.parse import SplitResult +from typing import Any, Awaitable, Callable, Optional, Union # noqa -import yarl +from multidict import CIMultiDict -from .abc import AbstractPayloadWriter -from .helpers import create_future, noop +from .abc import AbstractStreamWriter +from .base_protocol import BaseProtocol +from .helpers import NO_EXTENSIONS -__all__ = ('PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11', - 'StreamWriter') +__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11") -HttpVersion = collections.namedtuple('HttpVersion', ['major', 'minor']) +HttpVersion = collections.namedtuple("HttpVersion", ["major", "minor"]) HttpVersion10 = HttpVersion(1, 0) HttpVersion11 = HttpVersion(1, 1) -if hasattr(socket, 'TCP_CORK'): # pragma: no cover - CORK = socket.TCP_CORK -elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover - CORK = socket.TCP_NOPUSH -else: # pragma: no cover - CORK = None +_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]] -class StreamWriter: - - def __init__(self, protocol, transport, loop): +class StreamWriter(AbstractStreamWriter): + def __init__( + self, + protocol: BaseProtocol, + loop: asyncio.AbstractEventLoop, + on_chunk_sent: _T_OnChunkSent = None, + ) -> None: self._protocol = protocol - self._loop = loop - self._tcp_nodelay = False - self._tcp_cork = False - self._socket = transport.get_extra_info('socket') - self._waiters = [] - self.available = True - self.transport = transport - - def acquire(self, writer): - if self.available: - self.available = False - writer.set_transport(self.transport) - else: - self._waiters.append(writer) - - def release(self): - if self._waiters: - self.available = False - writer = self._waiters.pop(0) - writer.set_transport(self.transport) - else: - self.available = True - - def replace(self, writer, factory): - try: - idx = self._waiters.index(writer) - writer = factory(self, self._loop, False) - self._waiters[idx] = writer - return writer - except ValueError: - self.available = True - return factory(self, self._loop) - - @property - def tcp_nodelay(self): - return self._tcp_nodelay - - def set_tcp_nodelay(self, value): - value = bool(value) - if self._tcp_nodelay == value: - return - if self._socket is None: - return - if self._socket.family not in (socket.AF_INET, socket.AF_INET6): - return - - # socket may be closed already, on windows OSError get raised - try: - if self._tcp_cork: - if CORK is not None: # pragma: no branch - self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False) - self._tcp_cork = False - - self._socket.setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, value) - self._tcp_nodelay = value - except OSError: - pass - - @property - def tcp_cork(self): - return self._tcp_cork - - def set_tcp_cork(self, value): - value = bool(value) - if self._tcp_cork == value: - return - if self._socket is None: - return - if self._socket.family not in (socket.AF_INET, socket.AF_INET6): - return - - try: - if self._tcp_nodelay: - self._socket.setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, False) - self._tcp_nodelay = False - if CORK is not None: # pragma: no branch - self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value) - self._tcp_cork = value - except OSError: - pass - - @asyncio.coroutine - def drain(self): - """Flush the write buffer. - - The intended use is to write - - w.write(data) - yield from w.drain() - """ - if self._protocol.transport is not None: - yield from self._protocol._drain_helper() - - -class PayloadWriter(AbstractPayloadWriter): - - def __init__(self, stream, loop, acquire=True): - self._stream = stream - self._transport = None + self._transport = protocol.transport self.loop = loop self.length = None @@ -140,190 +38,145 @@ def __init__(self, stream, loop, acquire=True): self.output_size = 0 self._eof = False - self._buffer = [] - self._compress = None + self._compress = None # type: Any self._drain_waiter = None - if self._stream.available: - self._transport = self._stream.transport - self._stream.available = False - elif acquire: - self._stream.acquire(self) - - def set_transport(self, transport): - self._transport = transport - - chunk = b''.join(self._buffer) - if chunk: - transport.write(chunk) - self._buffer.clear() - - if self._drain_waiter is not None: - waiter, self._drain_waiter = self._drain_waiter, None - if not waiter.done(): - waiter.set_result(None) + self._on_chunk_sent = on_chunk_sent # type: _T_OnChunkSent @property - def tcp_nodelay(self): - return self._stream.tcp_nodelay - - def set_tcp_nodelay(self, value): - self._stream.set_tcp_nodelay(value) + def transport(self) -> Optional[asyncio.Transport]: + return self._transport @property - def tcp_cork(self): - return self._stream.tcp_cork - - def set_tcp_cork(self, value): - self._stream.set_tcp_cork(value) + def protocol(self) -> BaseProtocol: + return self._protocol - def enable_chunking(self): + def enable_chunking(self) -> None: self.chunked = True - def enable_compression(self, encoding='deflate'): - zlib_mode = (16 + zlib.MAX_WBITS - if encoding == 'gzip' else -zlib.MAX_WBITS) + def enable_compression(self, encoding: str = "deflate") -> None: + zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS self._compress = zlib.compressobj(wbits=zlib_mode) - def buffer_data(self, chunk): - if chunk: - size = len(chunk) - self.buffer_size += size - self.output_size += size - self._buffer.append(chunk) - - def _write(self, chunk): + def _write(self, chunk: bytes) -> None: size = len(chunk) self.buffer_size += size self.output_size += size - if self._transport is not None: - if self._buffer: - self._buffer.append(chunk) - self._transport.write(b''.join(self._buffer)) - self._buffer.clear() - else: - self._transport.write(chunk) - else: - self._buffer.append(chunk) + if self._transport is None or self._transport.is_closing(): + raise ConnectionResetError("Cannot write to closing transport") + self._transport.write(chunk) - def write(self, chunk, *, drain=True, LIMIT=64*1024): + async def write( + self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000 + ) -> None: """Writes chunk of data to a stream. write_eof() indicates end of stream. writer can't be used after write_eof() method being called. write() return drain future. """ + if self._on_chunk_sent is not None: + await self._on_chunk_sent(chunk) + + if isinstance(chunk, memoryview): + if chunk.nbytes != len(chunk): + # just reshape it + chunk = chunk.cast("c") + if self._compress is not None: chunk = self._compress.compress(chunk) if not chunk: - return noop() + return if self.length is not None: chunk_len = len(chunk) if self.length >= chunk_len: self.length = self.length - chunk_len else: - chunk = chunk[:self.length] + chunk = chunk[: self.length] self.length = 0 if not chunk: - return noop() + return if chunk: if self.chunked: - chunk_len = ('%x\r\n' % len(chunk)).encode('ascii') - chunk = chunk_len + chunk + b'\r\n' + chunk_len_pre = ("%x\r\n" % len(chunk)).encode("ascii") + chunk = chunk_len_pre + chunk + b"\r\n" self._write(chunk) if self.buffer_size > LIMIT and drain: self.buffer_size = 0 - return self.drain() + await self.drain() - return noop() - - def write_headers(self, status_line, headers, SEP=': ', END='\r\n'): + async def write_headers( + self, status_line: str, headers: "CIMultiDict[str]" + ) -> None: """Write request/response status and headers.""" # status + headers - headers = status_line + ''.join( - [k + SEP + v + END for k, v in headers.items()]) - headers = headers.encode('utf-8') + b'\r\n' + buf = _serialize_headers(status_line, headers) + self._write(buf) - size = len(headers) - self.buffer_size += size - self.output_size += size - self._buffer.append(headers) - - @asyncio.coroutine - def write_eof(self, chunk=b''): + async def write_eof(self, chunk: bytes = b"") -> None: if self._eof: return + if chunk and self._on_chunk_sent is not None: + await self._on_chunk_sent(chunk) + if self._compress: if chunk: chunk = self._compress.compress(chunk) chunk = chunk + self._compress.flush() if chunk and self.chunked: - chunk_len = ('%x\r\n' % len(chunk)).encode('ascii') - chunk = chunk_len + chunk + b'\r\n0\r\n\r\n' + chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") + chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" else: if self.chunked: if chunk: - chunk_len = ('%x\r\n' % len(chunk)).encode('ascii') - chunk = chunk_len + chunk + b'\r\n0\r\n\r\n' + chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") + chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" else: - chunk = b'0\r\n\r\n' + chunk = b"0\r\n\r\n" if chunk: - self.buffer_data(chunk) + self._write(chunk) - yield from self.drain(True) + await self.drain() self._eof = True self._transport = None - self._stream.release() - - @asyncio.coroutine - def drain(self, last=False): - if self._transport is not None: - if self._buffer: - self._transport.write(b''.join(self._buffer)) - if not last: - self._buffer.clear() - yield from self._stream.drain() - else: - # wait for transport - if self._drain_waiter is None: - self._drain_waiter = create_future(self.loop) - yield from self._drain_waiter + async def drain(self) -> None: + """Flush the write buffer. + The intended use is to write -class URL(yarl.URL): + await w.write(data) + await w.drain() + """ + if self._protocol.transport is not None: + await self._protocol._drain_helper() - def __init__(self, schema, netloc, port, path, query, fragment, userinfo): - self._strict = False - if port: - netloc += ':{}'.format(port) - if userinfo: - netloc = yarl.quote( - userinfo, safe='@:', - protected=':', strict=False) + '@' + netloc +def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes: + line = ( + status_line + + "\r\n" + + "".join([k + ": " + v + "\r\n" for k, v in headers.items()]) + ) + return line.encode("utf-8") + b"\r\n" - if path: - path = yarl.quote(path, safe='@:', protected='/', strict=False) - if query: - query = yarl.quote( - query, safe='=+&?/:@', - protected=yarl.PROTECT_CHARS, qs=True, strict=False) +_serialize_headers = _py_serialize_headers - if fragment: - fragment = yarl.quote(fragment, safe='?/:@', strict=False) +try: + import aiohttp._http_writer as _http_writer # type: ignore - self._val = SplitResult( - schema or '', # scheme - netloc=netloc, path=path, query=query, fragment=fragment) - self._cache = {} + _c_serialize_headers = _http_writer._serialize_headers + if not NO_EXTENSIONS: + _serialize_headers = _c_serialize_headers +except ImportError: + pass diff --git a/aiohttp/locks.py b/aiohttp/locks.py new file mode 100644 index 00000000000..ce5b9c6f731 --- /dev/null +++ b/aiohttp/locks.py @@ -0,0 +1,45 @@ +import asyncio +import collections +from typing import Any, Optional + +try: + from typing import Deque +except ImportError: + from typing_extensions import Deque + + +class EventResultOrError: + """ + This class wrappers the Event asyncio lock allowing either awake the + locked Tasks without any error or raising an exception. + + thanks to @vorpalsmith for the simple design. + """ + + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._exc = None # type: Optional[BaseException] + self._event = asyncio.Event() + self._waiters = collections.deque() # type: Deque[asyncio.Future[Any]] + + def set(self, exc: Optional[BaseException] = None) -> None: + self._exc = exc + self._event.set() + + async def wait(self) -> Any: + waiter = self._loop.create_task(self._event.wait()) + self._waiters.append(waiter) + try: + val = await waiter + finally: + self._waiters.remove(waiter) + + if self._exc is not None: + raise self._exc + + return val + + def cancel(self) -> None: + """ Cancel all waiters """ + for waiter in self._waiters: + waiter.cancel() diff --git a/aiohttp/log.py b/aiohttp/log.py index cfda0e5f070..3cecea2bac1 100644 --- a/aiohttp/log.py +++ b/aiohttp/log.py @@ -1,8 +1,8 @@ import logging -access_logger = logging.getLogger('aiohttp.access') -client_logger = logging.getLogger('aiohttp.client') -internal_logger = logging.getLogger('aiohttp.internal') -server_logger = logging.getLogger('aiohttp.server') -web_logger = logging.getLogger('aiohttp.web') -ws_logger = logging.getLogger('aiohttp.websocket') +access_logger = logging.getLogger("aiohttp.access") +client_logger = logging.getLogger("aiohttp.client") +internal_logger = logging.getLogger("aiohttp.internal") +server_logger = logging.getLogger("aiohttp.server") +web_logger = logging.getLogger("aiohttp.web") +ws_logger = logging.getLogger("aiohttp.websocket") diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 0e6e2fd1167..9e1ca92d23e 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -1,4 +1,3 @@ -import asyncio import base64 import binascii import json @@ -6,21 +5,59 @@ import uuid import warnings import zlib -from collections import Mapping, Sequence, deque +from collections import deque +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) from urllib.parse import parse_qsl, unquote, urlencode -from multidict import CIMultiDict - -from .hdrs import (CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LENGTH, - CONTENT_TRANSFER_ENCODING, CONTENT_TYPE) -from .helpers import CHAR, PY_35, PY_352, TOKEN, parse_mimetype, reify -from .http import HttpParser -from .payload import (BytesPayload, LookupError, Payload, StringPayload, - get_payload, payload_type) - -__all__ = ('MultipartReader', 'MultipartWriter', 'BodyPartReader', - 'BadContentDispositionHeader', 'BadContentDispositionParam', - 'parse_content_disposition', 'content_disposition_filename') +from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping + +from .hdrs import ( + CONTENT_DISPOSITION, + CONTENT_ENCODING, + CONTENT_LENGTH, + CONTENT_TRANSFER_ENCODING, + CONTENT_TYPE, +) +from .helpers import CHAR, TOKEN, parse_mimetype, reify +from .http import HeadersParser +from .payload import ( + JsonPayload, + LookupError, + Order, + Payload, + StringPayload, + get_payload, + payload_type, +) +from .streams import StreamReader + +__all__ = ( + "MultipartReader", + "MultipartWriter", + "BodyPartReader", + "BadContentDispositionHeader", + "BadContentDispositionParam", + "parse_content_disposition", + "content_disposition_filename", +) + + +if TYPE_CHECKING: # pragma: no cover + from .client_reqrep import ClientResponse class BadContentDispositionHeader(RuntimeWarning): @@ -31,44 +68,48 @@ class BadContentDispositionParam(RuntimeWarning): pass -def parse_content_disposition(header): - def is_token(string): - return string and TOKEN >= set(string) +def parse_content_disposition( + header: Optional[str], +) -> Tuple[Optional[str], Dict[str, str]]: + def is_token(string: str) -> bool: + return bool(string) and TOKEN >= set(string) - def is_quoted(string): + def is_quoted(string: str) -> bool: return string[0] == string[-1] == '"' - def is_rfc5987(string): + def is_rfc5987(string: str) -> bool: return is_token(string) and string.count("'") == 2 - def is_extended_param(string): - return string.endswith('*') + def is_extended_param(string: str) -> bool: + return string.endswith("*") - def is_continuous_param(string): - pos = string.find('*') + 1 + def is_continuous_param(string: str) -> bool: + pos = string.find("*") + 1 if not pos: return False - substring = string[pos:-1] if string.endswith('*') else string[pos:] + substring = string[pos:-1] if string.endswith("*") else string[pos:] return substring.isdigit() - def unescape(text, *, chars=''.join(map(re.escape, CHAR))): - return re.sub('\\\\([{}])'.format(chars), '\\1', text) + def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str: + return re.sub(f"\\\\([{chars}])", "\\1", text) if not header: return None, {} - disptype, *parts = header.split(';') + disptype, *parts = header.split(";") if not is_token(disptype): warnings.warn(BadContentDispositionHeader(header)) return None, {} - params = {} - for item in parts: - if '=' not in item: + params = {} # type: Dict[str, str] + while parts: + item = parts.pop(0) + + if "=" not in item: warnings.warn(BadContentDispositionHeader(header)) return None, {} - key, value = item.split('=', 1) + key, value = item.split("=", 1) key = key.lower().strip() value = value.lstrip() @@ -90,21 +131,34 @@ def unescape(text, *, chars=''.join(map(re.escape, CHAR))): elif is_extended_param(key): if is_rfc5987(value): encoding, _, value = value.split("'", 2) - encoding = encoding or 'utf-8' + encoding = encoding or "utf-8" else: warnings.warn(BadContentDispositionParam(item)) continue try: - value = unquote(value, encoding, 'strict') + value = unquote(value, encoding, "strict") except UnicodeDecodeError: # pragma: nocover warnings.warn(BadContentDispositionParam(item)) continue else: + failed = True if is_quoted(value): - value = unescape(value[1:-1].lstrip('\\/')) - elif not is_token(value): + failed = False + value = unescape(value[1:-1].lstrip("\\/")) + elif is_token(value): + failed = False + elif parts: + # maybe just ; in filename, in any case this is just + # one case fix, for proper fix we need to redesign parser + _value = "{};{}".format(value, parts[0]) + if is_quoted(_value): + parts.pop(0) + value = unescape(_value[1:-1].lstrip("\\/")) + failed = False + + if failed: warnings.warn(BadContentDispositionHeader(header)) return None, {} @@ -113,8 +167,10 @@ def unescape(text, *, chars=''.join(map(re.escape, CHAR))): return disptype.lower(), params -def content_disposition_filename(params, name='filename'): - name_suf = '%s*' % name +def content_disposition_filename( + params: Mapping[str, str], name: str = "filename" +) -> Optional[str]: + name_suf = "%s*" % name if not params: return None elif name_suf in params: @@ -123,12 +179,12 @@ def content_disposition_filename(params, name='filename'): return params[name] else: parts = [] - fnparams = sorted((key, value) - for key, value in params.items() - if key.startswith(name_suf)) + fnparams = sorted( + (key, value) for key, value in params.items() if key.startswith(name_suf) + ) for num, (key, value) in enumerate(fnparams): - _, tail = key.split('*', 1) - if tail.endswith('*'): + _, tail = key.split("*", 1) + if tail.endswith("*"): tail = tail[:-1] if tail == str(num): parts.append(value) @@ -136,64 +192,67 @@ def content_disposition_filename(params, name='filename'): break if not parts: return None - value = ''.join(parts) + value = "".join(parts) if "'" in value: encoding, _, value = value.split("'", 2) - encoding = encoding or 'utf-8' - return unquote(value, encoding, 'strict') + encoding = encoding or "utf-8" + return unquote(value, encoding, "strict") return value -class MultipartResponseWrapper(object): - """Wrapper around the :class:`MultipartBodyReader` to take care about - underlying connection and close it when it needs in.""" +class MultipartResponseWrapper: + """Wrapper around the MultipartReader. + + It takes care about + underlying connection and close it when it needs in. + """ - def __init__(self, resp, stream): + def __init__( + self, + resp: "ClientResponse", + stream: "MultipartReader", + ) -> None: self.resp = resp self.stream = stream - if PY_35: - def __aiter__(self): - return self - - if not PY_352: # pragma: no cover - __aiter__ = asyncio.coroutine(__aiter__) - - @asyncio.coroutine - def __anext__(self): - part = yield from self.next() - if part is None: - raise StopAsyncIteration # NOQA - return part + def __aiter__(self) -> "MultipartResponseWrapper": + return self - def at_eof(self): - """Returns ``True`` when all response data had been read. + async def __anext__( + self, + ) -> Union["MultipartReader", "BodyPartReader"]: + part = await self.next() + if part is None: + raise StopAsyncIteration + return part - :rtype: bool - """ + def at_eof(self) -> bool: + """Returns True when all response data had been read.""" return self.resp.content.at_eof() - @asyncio.coroutine - def next(self): + async def next( + self, + ) -> Optional[Union["MultipartReader", "BodyPartReader"]]: """Emits next multipart reader object.""" - item = yield from self.stream.next() + item = await self.stream.next() if self.stream.at_eof(): - yield from self.release() + await self.release() return item - @asyncio.coroutine - def release(self): + async def release(self) -> None: """Releases the connection gracefully, reading all the content to the void.""" - yield from self.resp.release() + await self.resp.release() -class BodyPartReader(object): +class BodyPartReader: """Multipart reader for single body part.""" chunk_size = 8192 - def __init__(self, boundary, headers, content): + def __init__( + self, boundary: bytes, headers: "CIMultiDictProxy[str]", content: StreamReader + ) -> None: self.headers = headers self._boundary = boundary self._content = content @@ -201,236 +260,183 @@ def __init__(self, boundary, headers, content): length = self.headers.get(CONTENT_LENGTH, None) self._length = int(length) if length is not None else None self._read_bytes = 0 - self._unread = deque() - self._prev_chunk = None + # TODO: typeing.Deque is not supported by Python 3.5 + self._unread = deque() # type: Any + self._prev_chunk = None # type: Optional[bytes] self._content_eof = 0 - self._cache = {} - - if PY_35: - def __aiter__(self): - return self + self._cache = {} # type: Dict[str, Any] - if not PY_352: # pragma: no cover - __aiter__ = asyncio.coroutine(__aiter__) + def __aiter__(self) -> AsyncIterator["BodyPartReader"]: + return self # type: ignore - @asyncio.coroutine - def __anext__(self): - part = yield from self.next() - if part is None: - raise StopAsyncIteration # NOQA - return part + async def __anext__(self) -> bytes: + part = await self.next() + if part is None: + raise StopAsyncIteration + return part - @asyncio.coroutine - def next(self): - item = yield from self.read() + async def next(self) -> Optional[bytes]: + item = await self.read() if not item: return None return item - @asyncio.coroutine - def read(self, *, decode=False): + async def read(self, *, decode: bool = False) -> bytes: """Reads body part data. - :param bool decode: Decodes data following by encoding - method from `Content-Encoding` header. If it missed - data remains untouched - - :rtype: bytearray + decode: Decodes data following by encoding + method from Content-Encoding header. If it missed + data remains untouched """ if self._at_eof: - return b'' + return b"" data = bytearray() while not self._at_eof: - data.extend((yield from self.read_chunk(self.chunk_size))) + data.extend(await self.read_chunk(self.chunk_size)) if decode: return self.decode(data) return data - @asyncio.coroutine - def read_chunk(self, size=chunk_size): + async def read_chunk(self, size: int = chunk_size) -> bytes: """Reads body part content chunk of the specified size. - :param int size: chunk size - - :rtype: bytearray + size: chunk size """ if self._at_eof: - return b'' + return b"" if self._length: - chunk = yield from self._read_chunk_from_length(size) + chunk = await self._read_chunk_from_length(size) else: - chunk = yield from self._read_chunk_from_stream(size) + chunk = await self._read_chunk_from_stream(size) self._read_bytes += len(chunk) if self._read_bytes == self._length: self._at_eof = True if self._at_eof: - assert b'\r\n' == (yield from self._content.readline()), \ - 'reader did not read all the data or it is malformed' + clrf = await self._content.readline() + assert ( + b"\r\n" == clrf + ), "reader did not read all the data or it is malformed" return chunk - @asyncio.coroutine - def _read_chunk_from_length(self, size): - """Reads body part content chunk of the specified size. - The body part must has `Content-Length` header with proper value. - - :param int size: chunk size - - :rtype: bytearray - """ - assert self._length is not None, \ - 'Content-Length required for chunked read' + async def _read_chunk_from_length(self, size: int) -> bytes: + # Reads body part content chunk of the specified size. + # The body part must has Content-Length header with proper value. + assert self._length is not None, "Content-Length required for chunked read" chunk_size = min(size, self._length - self._read_bytes) - chunk = yield from self._content.read(chunk_size) + chunk = await self._content.read(chunk_size) return chunk - @asyncio.coroutine - def _read_chunk_from_stream(self, size): - """Reads content chunk of body part with unknown length. - The `Content-Length` header for body part is not necessary. - - :param int size: chunk size - - :rtype: bytearray - """ - assert size >= len(self._boundary) + 2, \ - 'Chunk size must be greater or equal than boundary length + 2' + async def _read_chunk_from_stream(self, size: int) -> bytes: + # Reads content chunk of body part with unknown length. + # The Content-Length header for body part is not necessary. + assert ( + size >= len(self._boundary) + 2 + ), "Chunk size must be greater or equal than boundary length + 2" first_chunk = self._prev_chunk is None if first_chunk: - self._prev_chunk = yield from self._content.read(size) + self._prev_chunk = await self._content.read(size) - chunk = yield from self._content.read(size) + chunk = await self._content.read(size) self._content_eof += int(self._content.at_eof()) assert self._content_eof < 3, "Reading after EOF" + assert self._prev_chunk is not None window = self._prev_chunk + chunk - sub = b'\r\n' + self._boundary + sub = b"\r\n" + self._boundary if first_chunk: idx = window.find(sub) else: idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub))) if idx >= 0: # pushing boundary back to content - self._content.unread_data(window[idx:]) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + self._content.unread_data(window[idx:]) if size > idx: self._prev_chunk = self._prev_chunk[:idx] - chunk = window[len(self._prev_chunk):idx] + chunk = window[len(self._prev_chunk) : idx] if not chunk: self._at_eof = True result = self._prev_chunk self._prev_chunk = chunk return result - @asyncio.coroutine - def readline(self): - """Reads body part by line by line. - - :rtype: bytearray - """ + async def readline(self) -> bytes: + """Reads body part by line by line.""" if self._at_eof: - return b'' + return b"" if self._unread: line = self._unread.popleft() else: - line = yield from self._content.readline() + line = await self._content.readline() if line.startswith(self._boundary): # the very last boundary may not come with \r\n, # so set single rules for everyone - sline = line.rstrip(b'\r\n') + sline = line.rstrip(b"\r\n") boundary = self._boundary - last_boundary = self._boundary + b'--' + last_boundary = self._boundary + b"--" # ensure that we read exactly the boundary, not something alike if sline == boundary or sline == last_boundary: self._at_eof = True self._unread.append(line) - return b'' + return b"" else: - next_line = yield from self._content.readline() + next_line = await self._content.readline() if next_line.startswith(self._boundary): line = line[:-2] # strip CRLF but only once self._unread.append(next_line) return line - @asyncio.coroutine - def release(self): - """Like :meth:`read`, but reads all the data to the void. - - :rtype: None - """ + async def release(self) -> None: + """Like read(), but reads all the data to the void.""" if self._at_eof: return while not self._at_eof: - yield from self.read_chunk(self.chunk_size) - - @asyncio.coroutine - def text(self, *, encoding=None): - """Like :meth:`read`, but assumes that body part contains text data. + await self.read_chunk(self.chunk_size) - :param str encoding: Custom text encoding. Overrides specified - in charset param of `Content-Type` header - - :rtype: str - """ - data = yield from self.read(decode=True) + async def text(self, *, encoding: Optional[str] = None) -> str: + """Like read(), but assumes that body part contains text data.""" + data = await self.read(decode=True) # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm # NOQA # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send # NOQA - encoding = encoding or self.get_charset(default='utf-8') + encoding = encoding or self.get_charset(default="utf-8") return data.decode(encoding) - @asyncio.coroutine - def json(self, *, encoding=None): - """Like :meth:`read`, but assumes that body parts contains JSON data. - - :param str encoding: Custom JSON encoding. Overrides specified - in charset param of `Content-Type` header - """ - data = yield from self.read(decode=True) + async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Like read(), but assumes that body parts contains JSON data.""" + data = await self.read(decode=True) if not data: return None - encoding = encoding or self.get_charset(default='utf-8') + encoding = encoding or self.get_charset(default="utf-8") return json.loads(data.decode(encoding)) - @asyncio.coroutine - def form(self, *, encoding=None): - """Like :meth:`read`, but assumes that body parts contains form + async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]: + """Like read(), but assumes that body parts contains form urlencoded data. - - :param str encoding: Custom form encoding. Overrides specified - in charset param of `Content-Type` header """ - data = yield from self.read(decode=True) + data = await self.read(decode=True) if not data: - return None - encoding = encoding or self.get_charset(default='utf-8') - return parse_qsl(data.rstrip().decode(encoding), - keep_blank_values=True, - encoding=encoding) - - def at_eof(self): - """Returns ``True`` if the boundary was reached or - ``False`` otherwise. - - :rtype: bool - """ + return [] + if encoding is not None: + real_encoding = encoding + else: + real_encoding = self.get_charset(default="utf-8") + return parse_qsl( + data.rstrip().decode(real_encoding), + keep_blank_values=True, + encoding=real_encoding, + ) + + def at_eof(self) -> bool: + """Returns True if the boundary was reached or False otherwise.""" return self._at_eof - def decode(self, data): - """Decodes data according the specified `Content-Encoding` - or `Content-Transfer-Encoding` headers value. - - Supports ``gzip``, ``deflate`` and ``identity`` encodings for - `Content-Encoding` header. - - Supports ``base64``, ``quoted-printable``, ``binary`` encodings for - `Content-Transfer-Encoding` header. - - :param bytearray data: Data to decode. - - :raises: :exc:`RuntimeError` - if encoding is unknown. - - :rtype: bytes + def decode(self, data: bytes) -> bytes: + """Decodes data according the specified Content-Encoding + or Content-Transfer-Encoding headers value. """ if CONTENT_TRANSFER_ENCODING in self.headers: data = self._decode_content_transfer(data) @@ -438,80 +444,79 @@ def decode(self, data): return self._decode_content(data) return data - def _decode_content(self, data): - encoding = self.headers[CONTENT_ENCODING].lower() + def _decode_content(self, data: bytes) -> bytes: + encoding = self.headers.get(CONTENT_ENCODING, "").lower() - if encoding == 'deflate': + if encoding == "deflate": return zlib.decompress(data, -zlib.MAX_WBITS) - elif encoding == 'gzip': + elif encoding == "gzip": return zlib.decompress(data, 16 + zlib.MAX_WBITS) - elif encoding == 'identity': + elif encoding == "identity": return data else: - raise RuntimeError('unknown content encoding: {}'.format(encoding)) + raise RuntimeError(f"unknown content encoding: {encoding}") - def _decode_content_transfer(self, data): - encoding = self.headers[CONTENT_TRANSFER_ENCODING].lower() + def _decode_content_transfer(self, data: bytes) -> bytes: + encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower() - if encoding == 'base64': + if encoding == "base64": return base64.b64decode(data) - elif encoding == 'quoted-printable': + elif encoding == "quoted-printable": return binascii.a2b_qp(data) - elif encoding in ('binary', '8bit', '7bit'): + elif encoding in ("binary", "8bit", "7bit"): return data else: - raise RuntimeError('unknown content transfer encoding: {}' - ''.format(encoding)) + raise RuntimeError( + "unknown content transfer encoding: {}" "".format(encoding) + ) - def get_charset(self, default=None): - """Returns charset parameter from ``Content-Type`` header or default. - """ - ctype = self.headers.get(CONTENT_TYPE, '') - *_, params = parse_mimetype(ctype) - return params.get('charset', default) + def get_charset(self, default: str) -> str: + """Returns charset parameter from Content-Type header or default.""" + ctype = self.headers.get(CONTENT_TYPE, "") + mimetype = parse_mimetype(ctype) + return mimetype.parameters.get("charset", default) @reify - def name(self): - """Returns filename specified in Content-Disposition header or ``None`` - if missed or header is malformed.""" - _, params = parse_content_disposition( - self.headers.get(CONTENT_DISPOSITION)) - return content_disposition_filename(params, 'name') + def name(self) -> Optional[str]: + """Returns name specified in Content-Disposition header or None + if missed or header is malformed. + """ + + _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION)) + return content_disposition_filename(params, "name") @reify - def filename(self): - """Returns filename specified in Content-Disposition header or ``None`` - if missed or header is malformed.""" - _, params = parse_content_disposition( - self.headers.get(CONTENT_DISPOSITION)) - return content_disposition_filename(params, 'filename') + def filename(self) -> Optional[str]: + """Returns filename specified in Content-Disposition header or None + if missed or header is malformed. + """ + _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION)) + return content_disposition_filename(params, "filename") -@payload_type(BodyPartReader) +@payload_type(BodyPartReader, order=Order.try_first) class BodyPartReaderPayload(Payload): - - def __init__(self, value, *args, **kwargs): + def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None: super().__init__(value, *args, **kwargs) - params = {} + params = {} # type: Dict[str, str] if value.name is not None: - params['name'] = value.name + params["name"] = value.name if value.filename is not None: - params['filename'] = value.name + params["filename"] = value.filename if params: - self.set_content_disposition('attachment', **params) + self.set_content_disposition("attachment", True, **params) - @asyncio.coroutine - def write(self, writer): + async def write(self, writer: Any) -> None: field = self._value - chunk = yield from field.read_chunk(size=2**16) + chunk = await field.read_chunk(size=2 ** 16) while chunk: - writer.write(field.decode(chunk)) - chunk = yield from field.read_chunk(size=2**16) + await writer.write(field.decode(chunk)) + chunk = await field.read_chunk(size=2 ** 16) -class MultipartReader(object): +class MultipartReader: """Multipart body reader.""" #: Response wrapper, used when multipart readers constructs from response. @@ -522,145 +527,151 @@ class MultipartReader(object): #: Body part reader class for non multipart/* content types. part_reader_cls = BodyPartReader - def __init__(self, headers, content): + def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None: self.headers = headers - self._boundary = ('--' + self._get_boundary()).encode() + self._boundary = ("--" + self._get_boundary()).encode() self._content = content - self._last_part = None + self._last_part = ( + None + ) # type: Optional[Union['MultipartReader', BodyPartReader]] self._at_eof = False self._at_bof = True - self._unread = [] - - if PY_35: - def __aiter__(self): - return self + self._unread = [] # type: List[bytes] - if not PY_352: # pragma: no cover - __aiter__ = asyncio.coroutine(__aiter__) + def __aiter__( + self, + ) -> AsyncIterator["BodyPartReader"]: + return self # type: ignore - @asyncio.coroutine - def __anext__(self): - part = yield from self.next() - if part is None: - raise StopAsyncIteration # NOQA - return part + async def __anext__( + self, + ) -> Optional[Union["MultipartReader", BodyPartReader]]: + part = await self.next() + if part is None: + raise StopAsyncIteration + return part @classmethod - def from_response(cls, response): + def from_response( + cls, + response: "ClientResponse", + ) -> MultipartResponseWrapper: """Constructs reader instance from HTTP response. :param response: :class:`~aiohttp.client.ClientResponse` instance """ - obj = cls.response_wrapper_cls(response, cls(response.headers, - response.content)) + obj = cls.response_wrapper_cls( + response, cls(response.headers, response.content) + ) return obj - def at_eof(self): - """Returns ``True`` if the final boundary was reached or - ``False`` otherwise. - - :rtype: bool + def at_eof(self) -> bool: + """Returns True if the final boundary was reached or + False otherwise. """ return self._at_eof - @asyncio.coroutine - def next(self): + async def next( + self, + ) -> Optional[Union["MultipartReader", BodyPartReader]]: """Emits the next multipart body part.""" # So, if we're at BOF, we need to skip till the boundary. if self._at_eof: - return - yield from self._maybe_release_last_part() + return None + await self._maybe_release_last_part() if self._at_bof: - yield from self._read_until_first_boundary() + await self._read_until_first_boundary() self._at_bof = False else: - yield from self._read_boundary() + await self._read_boundary() if self._at_eof: # we just read the last boundary, nothing to do there - return - self._last_part = yield from self.fetch_next_part() + return None + self._last_part = await self.fetch_next_part() return self._last_part - @asyncio.coroutine - def release(self): + async def release(self) -> None: """Reads all the body parts to the void till the final boundary.""" while not self._at_eof: - item = yield from self.next() + item = await self.next() if item is None: break - yield from item.release() + await item.release() - @asyncio.coroutine - def fetch_next_part(self): + async def fetch_next_part( + self, + ) -> Union["MultipartReader", BodyPartReader]: """Returns the next body part reader.""" - headers = yield from self._read_headers() + headers = await self._read_headers() return self._get_part_reader(headers) - def _get_part_reader(self, headers): + def _get_part_reader( + self, + headers: "CIMultiDictProxy[str]", + ) -> Union["MultipartReader", BodyPartReader]: """Dispatches the response by the `Content-Type` header, returning suitable reader instance. :param dict headers: Response headers """ - ctype = headers.get(CONTENT_TYPE, '') - mtype, *_ = parse_mimetype(ctype) - if mtype == 'multipart': + ctype = headers.get(CONTENT_TYPE, "") + mimetype = parse_mimetype(ctype) + + if mimetype.type == "multipart": if self.multipart_reader_cls is None: return type(self)(headers, self._content) return self.multipart_reader_cls(headers, self._content) else: return self.part_reader_cls(self._boundary, headers, self._content) - def _get_boundary(self): - mtype, *_, params = parse_mimetype(self.headers[CONTENT_TYPE]) + def _get_boundary(self) -> str: + mimetype = parse_mimetype(self.headers[CONTENT_TYPE]) - assert mtype == 'multipart', 'multipart/* content type expected' + assert mimetype.type == "multipart", "multipart/* content type expected" - if 'boundary' not in params: - raise ValueError('boundary missed for Content-Type: %s' - % self.headers[CONTENT_TYPE]) + if "boundary" not in mimetype.parameters: + raise ValueError( + "boundary missed for Content-Type: %s" % self.headers[CONTENT_TYPE] + ) - boundary = params['boundary'] + boundary = mimetype.parameters["boundary"] if len(boundary) > 70: - raise ValueError('boundary %r is too long (70 chars max)' - % boundary) + raise ValueError("boundary %r is too long (70 chars max)" % boundary) return boundary - @asyncio.coroutine - def _readline(self): + async def _readline(self) -> bytes: if self._unread: return self._unread.pop() - return (yield from self._content.readline()) + return await self._content.readline() - @asyncio.coroutine - def _read_until_first_boundary(self): + async def _read_until_first_boundary(self) -> None: while True: - chunk = yield from self._readline() - if chunk == b'': - raise ValueError("Could not find starting boundary %r" - % (self._boundary)) + chunk = await self._readline() + if chunk == b"": + raise ValueError( + "Could not find starting boundary %r" % (self._boundary) + ) chunk = chunk.rstrip() if chunk == self._boundary: return - elif chunk == self._boundary + b'--': + elif chunk == self._boundary + b"--": self._at_eof = True return - @asyncio.coroutine - def _read_boundary(self): - chunk = (yield from self._readline()).rstrip() + async def _read_boundary(self) -> None: + chunk = (await self._readline()).rstrip() if chunk == self._boundary: pass - elif chunk == self._boundary + b'--': + elif chunk == self._boundary + b"--": self._at_eof = True - epilogue = yield from self._readline() - next_line = yield from self._readline() + epilogue = await self._readline() + next_line = await self._readline() # the epilogue is expected and then either the end of input or the # parent multipart boundary, if the parent boundary is found then # it should be marked as unread and handed to the parent for # processing - if next_line[:2] == b'--': + if next_line[:2] == b"--": self._unread.append(next_line) # otherwise the request is likely missing an epilogue and both # lines should be passed to the parent for processing @@ -668,101 +679,148 @@ def _read_boundary(self): else: self._unread.extend([next_line, epilogue]) else: - raise ValueError('Invalid boundary %r, expected %r' - % (chunk, self._boundary)) + raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}") - @asyncio.coroutine - def _read_headers(self): - lines = [b''] + async def _read_headers(self) -> "CIMultiDictProxy[str]": + lines = [b""] while True: - chunk = yield from self._content.readline() + chunk = await self._content.readline() chunk = chunk.strip() lines.append(chunk) if not chunk: break - parser = HttpParser() - headers, *_ = parser.parse_headers(lines) + parser = HeadersParser() + headers, raw_headers = parser.parse_headers(lines) return headers - @asyncio.coroutine - def _maybe_release_last_part(self): + async def _maybe_release_last_part(self) -> None: """Ensures that the last read body part is read completely.""" if self._last_part is not None: if not self._last_part.at_eof(): - yield from self._last_part.release() + await self._last_part.release() self._unread.extend(self._last_part._unread) self._last_part = None +_Part = Tuple[Payload, str, str] + + class MultipartWriter(Payload): """Multipart body writer.""" - def __init__(self, subtype='mixed', boundary=None): + def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None: boundary = boundary if boundary is not None else uuid.uuid4().hex + # The underlying Payload API demands a str (utf-8), not bytes, + # so we need to ensure we don't lose anything during conversion. + # As a result, require the boundary to be ASCII only. + # In both situations. + try: - self._boundary = boundary.encode('us-ascii') + self._boundary = boundary.encode("ascii") except UnicodeEncodeError: - raise ValueError('boundary should contains ASCII only chars') - ctype = 'multipart/{}; boundary="{}"'.format(subtype, boundary) + raise ValueError("boundary should contain ASCII only chars") from None + ctype = f"multipart/{subtype}; boundary={self._boundary_value}" super().__init__(None, content_type=ctype) - self._parts = [] - self._headers = CIMultiDict() - self._headers[CONTENT_TYPE] = self.content_type + self._parts = [] # type: List[_Part] - def __enter__(self): + def __enter__(self) -> "MultipartWriter": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: pass - def __iter__(self): + def __iter__(self) -> Iterator[_Part]: return iter(self._parts) - def __len__(self): + def __len__(self) -> int: return len(self._parts) + def __bool__(self) -> bool: + return True + + _valid_tchar_regex = re.compile(br"\A[!#$%&'*+\-.^_`|~\w]+\Z") + _invalid_qdtext_char_regex = re.compile(br"[\x00-\x08\x0A-\x1F\x7F]") + @property - def boundary(self): - return self._boundary + def _boundary_value(self) -> str: + """Wrap boundary parameter value in quotes, if necessary. + + Reads self.boundary and returns a unicode sting. + """ + # Refer to RFCs 7231, 7230, 5234. + # + # parameter = token "=" ( token / quoted-string ) + # token = 1*tchar + # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE + # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text + # obs-text = %x80-FF + # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text ) + # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" + # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" + # / DIGIT / ALPHA + # ; any VCHAR, except delimiters + # VCHAR = %x21-7E + value = self._boundary + if re.match(self._valid_tchar_regex, value): + return value.decode("ascii") # cannot fail + + if re.search(self._invalid_qdtext_char_regex, value): + raise ValueError("boundary value contains invalid characters") + + # escape %x5C and %x22 + quoted_value_content = value.replace(b"\\", b"\\\\") + quoted_value_content = quoted_value_content.replace(b'"', b'\\"') + + return '"' + quoted_value_content.decode("ascii") + '"' - def append(self, obj, headers=None): + @property + def boundary(self) -> str: + return self._boundary.decode("ascii") + + def append(self, obj: Any, headers: Optional[MultiMapping[str]] = None) -> Payload: if headers is None: headers = CIMultiDict() if isinstance(obj, Payload): - if obj.headers is not None: - obj.headers.update(headers) - else: - obj._headers = headers - self.append_payload(obj) + obj.headers.update(headers) + return self.append_payload(obj) else: try: - self.append_payload(get_payload(obj, headers=headers)) + payload = get_payload(obj, headers=headers) except LookupError: - raise TypeError + raise TypeError("Cannot create payload from %r" % obj) + else: + return self.append_payload(payload) - def append_payload(self, payload): + def append_payload(self, payload: Payload) -> Payload: """Adds a new body part to multipart writer.""" - # content-type - if CONTENT_TYPE not in payload.headers: - payload.headers[CONTENT_TYPE] = payload.content_type - # compression - encoding = payload.headers.get(CONTENT_ENCODING, '').lower() - if encoding and encoding not in ('deflate', 'gzip', 'identity'): - raise RuntimeError('unknown content encoding: {}'.format(encoding)) - if encoding == 'identity': + encoding = payload.headers.get( + CONTENT_ENCODING, + "", + ).lower() # type: Optional[str] + if encoding and encoding not in ("deflate", "gzip", "identity"): + raise RuntimeError(f"unknown content encoding: {encoding}") + if encoding == "identity": encoding = None # te encoding te_encoding = payload.headers.get( - CONTENT_TRANSFER_ENCODING, '').lower() - if te_encoding not in ('', 'base64', 'quoted-printable', 'binary'): - raise RuntimeError('unknown content transfer encoding: {}' - ''.format(te_encoding)) - if te_encoding == 'binary': + CONTENT_TRANSFER_ENCODING, + "", + ).lower() # type: Optional[str] + if te_encoding not in ("", "base64", "quoted-printable", "binary"): + raise RuntimeError( + "unknown content transfer encoding: {}" "".format(te_encoding) + ) + if te_encoding == "binary": te_encoding = None # size @@ -770,24 +828,23 @@ def append_payload(self, payload): if size is not None and not (encoding or te_encoding): payload.headers[CONTENT_LENGTH] = str(size) - # render headers - headers = ''.join( - [k + ': ' + v + '\r\n' for k, v in payload.headers.items()] - ).encode('utf-8') + b'\r\n' - - self._parts.append((payload, headers, encoding, te_encoding)) + self._parts.append((payload, encoding, te_encoding)) # type: ignore + return payload - def append_json(self, obj, headers=None): + def append_json( + self, obj: Any, headers: Optional[MultiMapping[str]] = None + ) -> Payload: """Helper to append JSON part.""" if headers is None: headers = CIMultiDict() - data = json.dumps(obj).encode('utf-8') - self.append_payload( - BytesPayload( - data, headers=headers, content_type='application/json')) + return self.append_payload(JsonPayload(obj, headers=headers)) - def append_form(self, obj, headers=None): + def append_form( + self, + obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]], + headers: Optional[MultiMapping[str]] = None, + ) -> Payload: """Helper to append form urlencoded part.""" assert isinstance(obj, (Sequence, Mapping)) @@ -799,38 +856,36 @@ def append_form(self, obj, headers=None): data = urlencode(obj, doseq=True) return self.append_payload( - StringPayload(data, headers=headers, - content_type='application/x-www-form-urlencoded')) + StringPayload( + data, headers=headers, content_type="application/x-www-form-urlencoded" + ) + ) @property - def size(self): + def size(self) -> Optional[int]: """Size of the payload.""" - if not self._parts: - return 0 - total = 0 - for part, headers, encoding, te_encoding in self._parts: + for part, encoding, te_encoding in self._parts: if encoding or te_encoding or part.size is None: return None - total += ( - 2 + len(self._boundary) + 2 + # b'--'+self._boundary+b'\r\n' - part.size + len(headers) + - 2 # b'\r\n' + total += int( + 2 + + len(self._boundary) + + 2 + + part.size # b'--'+self._boundary+b'\r\n' + + len(part._binary_headers) + + 2 # b'\r\n' ) total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n' return total - @asyncio.coroutine - def write(self, writer): + async def write(self, writer: Any, close_boundary: bool = True) -> None: """Write body.""" - if not self._parts: - return - - for part, headers, encoding, te_encoding in self._parts: - yield from writer.write(b'--' + self._boundary + b'\r\n') - yield from writer.write(headers) + for part, encoding, te_encoding in self._parts: + await writer.write(b"--" + self._boundary + b"\r\n") + await writer.write(part._binary_headers) if encoding or te_encoding: w = MultipartPayloadWriter(writer) @@ -838,68 +893,65 @@ def write(self, writer): w.enable_compression(encoding) if te_encoding: w.enable_encoding(te_encoding) - yield from part.write(w) - yield from w.write_eof() + await part.write(w) # type: ignore + await w.write_eof() else: - yield from part.write(writer) + await part.write(writer) - yield from writer.write(b'\r\n') + await writer.write(b"\r\n") - yield from writer.write(b'--' + self._boundary + b'--\r\n') + if close_boundary: + await writer.write(b"--" + self._boundary + b"--\r\n") class MultipartPayloadWriter: - - def __init__(self, writer): + def __init__(self, writer: Any) -> None: self._writer = writer - self._encoding = None - self._compress = None + self._encoding = None # type: Optional[str] + self._compress = None # type: Any + self._encoding_buffer = None # type: Optional[bytearray] - def enable_encoding(self, encoding): - if encoding == 'base64': + def enable_encoding(self, encoding: str) -> None: + if encoding == "base64": self._encoding = encoding self._encoding_buffer = bytearray() - elif encoding == 'quoted-printable': - self._encoding = 'quoted-printable' + elif encoding == "quoted-printable": + self._encoding = "quoted-printable" - def enable_compression(self, encoding='deflate'): - zlib_mode = (16 + zlib.MAX_WBITS - if encoding == 'gzip' else -zlib.MAX_WBITS) + def enable_compression(self, encoding: str = "deflate") -> None: + zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else -zlib.MAX_WBITS self._compress = zlib.compressobj(wbits=zlib_mode) - @asyncio.coroutine - def write_eof(self): + async def write_eof(self) -> None: if self._compress is not None: chunk = self._compress.flush() if chunk: self._compress = None - yield from self.write(chunk) + await self.write(chunk) - if self._encoding == 'base64': + if self._encoding == "base64": if self._encoding_buffer: - yield from self._writer.write(base64.b64encode( - self._encoding_buffer)) + await self._writer.write(base64.b64encode(self._encoding_buffer)) - @asyncio.coroutine - def write(self, chunk): + async def write(self, chunk: bytes) -> None: if self._compress is not None: if chunk: chunk = self._compress.compress(chunk) if not chunk: return - if self._encoding == 'base64': - self._encoding_buffer.extend(chunk) + if self._encoding == "base64": + buf = self._encoding_buffer + assert buf is not None + buf.extend(chunk) - if self._encoding_buffer: - buffer = self._encoding_buffer - div, mod = divmod(len(buffer), 3) - enc_chunk, self._encoding_buffer = ( - buffer[:div * 3], buffer[div * 3:]) + if buf: + div, mod = divmod(len(buf), 3) + enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :]) if enc_chunk: - enc_chunk = base64.b64encode(enc_chunk) - yield from self._writer.write(enc_chunk) - elif self._encoding == 'quoted-printable': - yield from self._writer.write(binascii.b2a_qp(chunk)) + b64chunk = base64.b64encode(enc_chunk) + await self._writer.write(b64chunk) + elif self._encoding == "quoted-printable": + await self._writer.write(binascii.b2a_qp(chunk)) else: - yield from self._writer.write(chunk) + await self._writer.write(chunk) diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 7ee7876d412..c63dd2204c0 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -1,42 +1,91 @@ import asyncio +import enum import io import json import mimetypes import os +import warnings from abc import ABC, abstractmethod +from itertools import chain +from typing import ( + IO, + TYPE_CHECKING, + Any, + ByteString, + Dict, + Iterable, + Optional, + Text, + TextIO, + Tuple, + Type, + Union, +) from multidict import CIMultiDict from . import hdrs -from .helpers import (content_disposition_header, guess_filename, - parse_mimetype, sentinel) -from .streams import DEFAULT_LIMIT, DataQueue, EofStream, StreamReader - -__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'payload_type', 'Payload', - 'BytesPayload', 'StringPayload', 'StreamReaderPayload', - 'IOBasePayload', 'BytesIOPayload', 'BufferedReaderPayload', - 'TextIOPayload', 'StringIOPayload', 'JsonPayload') +from .abc import AbstractStreamWriter +from .helpers import ( + PY_36, + content_disposition_header, + guess_filename, + parse_mimetype, + sentinel, +) +from .streams import StreamReader +from .typedefs import JSONEncoder, _CIMultiDict + +__all__ = ( + "PAYLOAD_REGISTRY", + "get_payload", + "payload_type", + "Payload", + "BytesPayload", + "StringPayload", + "IOBasePayload", + "BytesIOPayload", + "BufferedReaderPayload", + "TextIOPayload", + "StringIOPayload", + "JsonPayload", + "AsyncIterablePayload", +) + +TOO_LARGE_BYTES_BODY = 2 ** 20 # 1 MB + + +if TYPE_CHECKING: # pragma: no cover + from typing import List class LookupError(Exception): pass -def get_payload(data, *args, **kwargs): +class Order(str, enum.Enum): + normal = "normal" + try_first = "try_first" + try_last = "try_last" + + +def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload": return PAYLOAD_REGISTRY.get(data, *args, **kwargs) -def register_payload(factory, type): - PAYLOAD_REGISTRY.register(factory, type) +def register_payload( + factory: Type["Payload"], type: Any, *, order: Order = Order.normal +) -> None: + PAYLOAD_REGISTRY.register(factory, type, order=order) class payload_type: - - def __init__(self, type): + def __init__(self, type: Any, *, order: Order = Order.normal) -> None: self.type = type + self.order = order - def __call__(self, factory): - register_payload(factory, self.type) + def __call__(self, factory: Type["Payload"]) -> Type["Payload"]: + register_payload(factory, self.type, order=self.order) return factory @@ -46,213 +95,269 @@ class PayloadRegistry: note: we need zope.interface for more efficient adapter search """ - def __init__(self): - self._registry = [] + def __init__(self) -> None: + self._first = [] # type: List[Tuple[Type[Payload], Any]] + self._normal = [] # type: List[Tuple[Type[Payload], Any]] + self._last = [] # type: List[Tuple[Type[Payload], Any]] - def get(self, data, *args, **kwargs): + def get( + self, data: Any, *args: Any, _CHAIN: Any = chain, **kwargs: Any + ) -> "Payload": if isinstance(data, Payload): return data - for factory, type in self._registry: + for factory, type in _CHAIN(self._first, self._normal, self._last): if isinstance(data, type): return factory(data, *args, **kwargs) raise LookupError() - def register(self, factory, type): - self._registry.append((factory, type)) + def register( + self, factory: Type["Payload"], type: Any, *, order: Order = Order.normal + ) -> None: + if order is Order.try_first: + self._first.append((factory, type)) + elif order is Order.normal: + self._normal.append((factory, type)) + elif order is Order.try_last: + self._last.append((factory, type)) + else: + raise ValueError(f"Unsupported order {order!r}") class Payload(ABC): - _size = None - _headers = None - _content_type = 'application/octet-stream' - - def __init__(self, value, *, headers=None, content_type=sentinel, - filename=None, encoding=None, **kwargs): - self._value = value + _default_content_type = "application/octet-stream" # type: str + _size = None # type: Optional[int] + + def __init__( + self, + value: Any, + headers: Optional[ + Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]] + ] = None, + content_type: Optional[str] = sentinel, + filename: Optional[str] = None, + encoding: Optional[str] = None, + **kwargs: Any, + ) -> None: self._encoding = encoding self._filename = filename - if headers is not None: - self._headers = CIMultiDict(headers) - if content_type is sentinel and hdrs.CONTENT_TYPE in self._headers: - content_type = self._headers[hdrs.CONTENT_TYPE] - - if content_type is sentinel: - content_type = None - - self._content_type = content_type + self._headers = CIMultiDict() # type: _CIMultiDict + self._value = value + if content_type is not sentinel and content_type is not None: + self._headers[hdrs.CONTENT_TYPE] = content_type + elif self._filename is not None: + content_type = mimetypes.guess_type(self._filename)[0] + if content_type is None: + content_type = self._default_content_type + self._headers[hdrs.CONTENT_TYPE] = content_type + else: + self._headers[hdrs.CONTENT_TYPE] = self._default_content_type + self._headers.update(headers or {}) @property - def size(self): + def size(self) -> Optional[int]: """Size of the payload.""" return self._size @property - def filename(self): + def filename(self) -> Optional[str]: """Filename of the payload.""" return self._filename @property - def headers(self): + def headers(self) -> _CIMultiDict: """Custom item headers""" return self._headers @property - def encoding(self): + def _binary_headers(self) -> bytes: + return ( + "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode( + "utf-8" + ) + + b"\r\n" + ) + + @property + def encoding(self) -> Optional[str]: """Payload encoding""" return self._encoding @property - def content_type(self): + def content_type(self) -> str: """Content type""" - if self._content_type is not None: - return self._content_type - elif self._filename is not None: - mime = mimetypes.guess_type(self._filename)[0] - return 'application/octet-stream' if mime is None else mime - else: - return Payload._content_type - - def set_content_disposition(self, disptype, quote_fields=True, **params): - """Sets ``Content-Disposition`` header. - - :param str disptype: Disposition type: inline, attachment, form-data. - Should be valid extension token (see RFC 2183) - :param dict params: Disposition params - """ - if self._headers is None: - self._headers = CIMultiDict() + return self._headers[hdrs.CONTENT_TYPE] + def set_content_disposition( + self, disptype: str, quote_fields: bool = True, **params: Any + ) -> None: + """Sets ``Content-Disposition`` header.""" self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header( - disptype, quote_fields=quote_fields, **params) + disptype, quote_fields=quote_fields, **params + ) - @asyncio.coroutine # pragma: no branch @abstractmethod - def write(self, writer): - """Write payload + async def write(self, writer: AbstractStreamWriter) -> None: + """Write payload. - :param AbstractPayloadWriter writer: + writer is an AbstractStreamWriter instance: """ class BytesPayload(Payload): + def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None: + if not isinstance(value, (bytes, bytearray, memoryview)): + raise TypeError( + "value argument must be byte-ish, not {!r}".format(type(value)) + ) - def __init__(self, value, *args, **kwargs): - assert isinstance(value, (bytes, bytearray, memoryview)), \ - "value argument must be byte-ish (%r)" % type(value) - - if 'content_type' not in kwargs: - kwargs['content_type'] = 'application/octet-stream' + if "content_type" not in kwargs: + kwargs["content_type"] = "application/octet-stream" super().__init__(value, *args, **kwargs) - self._size = len(value) + if isinstance(value, memoryview): + self._size = value.nbytes + else: + self._size = len(value) + + if self._size > TOO_LARGE_BYTES_BODY: + if PY_36: + kwargs = {"source": self} + else: + kwargs = {} + warnings.warn( + "Sending a large body directly with raw bytes might" + " lock the event loop. You should probably pass an " + "io.BytesIO object instead", + ResourceWarning, + **kwargs, + ) - @asyncio.coroutine - def write(self, writer): - yield from writer.write(self._value) + async def write(self, writer: AbstractStreamWriter) -> None: + await writer.write(self._value) class StringPayload(BytesPayload): - - def __init__(self, value, *args, - encoding=None, content_type=None, **kwargs): + def __init__( + self, + value: Text, + *args: Any, + encoding: Optional[str] = None, + content_type: Optional[str] = None, + **kwargs: Any, + ) -> None: if encoding is None: if content_type is None: - encoding = 'utf-8' - content_type = 'text/plain; charset=utf-8' + real_encoding = "utf-8" + content_type = "text/plain; charset=utf-8" else: - *_, params = parse_mimetype(content_type) - encoding = params.get('charset', 'utf-8') + mimetype = parse_mimetype(content_type) + real_encoding = mimetype.parameters.get("charset", "utf-8") else: if content_type is None: - content_type = 'text/plain; charset=%s' % encoding + content_type = "text/plain; charset=%s" % encoding + real_encoding = encoding super().__init__( - value.encode(encoding), - encoding=encoding, content_type=content_type, *args, **kwargs) + value.encode(real_encoding), + encoding=real_encoding, + content_type=content_type, + *args, + **kwargs, + ) -class IOBasePayload(Payload): +class StringIOPayload(StringPayload): + def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None: + super().__init__(value.read(), *args, **kwargs) - def __init__(self, value, disposition='attachment', *args, **kwargs): - if 'filename' not in kwargs: - kwargs['filename'] = guess_filename(value) + +class IOBasePayload(Payload): + def __init__( + self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any + ) -> None: + if "filename" not in kwargs: + kwargs["filename"] = guess_filename(value) super().__init__(value, *args, **kwargs) if self._filename is not None and disposition is not None: - self.set_content_disposition(disposition, filename=self._filename) + if hdrs.CONTENT_DISPOSITION not in self.headers: + self.set_content_disposition(disposition, filename=self._filename) - @asyncio.coroutine - def write(self, writer): + async def write(self, writer: AbstractStreamWriter) -> None: + loop = asyncio.get_event_loop() try: - chunk = self._value.read(DEFAULT_LIMIT) + chunk = await loop.run_in_executor(None, self._value.read, 2 ** 16) while chunk: - yield from writer.write(chunk) - chunk = self._value.read(DEFAULT_LIMIT) + await writer.write(chunk) + chunk = await loop.run_in_executor(None, self._value.read, 2 ** 16) finally: - self._value.close() + await loop.run_in_executor(None, self._value.close) class TextIOPayload(IOBasePayload): - - def __init__(self, value, *args, - encoding=None, content_type=None, **kwargs): + def __init__( + self, + value: TextIO, + *args: Any, + encoding: Optional[str] = None, + content_type: Optional[str] = None, + **kwargs: Any, + ) -> None: if encoding is None: if content_type is None: - encoding = 'utf-8' - content_type = 'text/plain; charset=utf-8' + encoding = "utf-8" + content_type = "text/plain; charset=utf-8" else: - *_, params = parse_mimetype(content_type) - encoding = params.get('charset', 'utf-8') + mimetype = parse_mimetype(content_type) + encoding = mimetype.parameters.get("charset", "utf-8") else: if content_type is None: - content_type = 'text/plain; charset=%s' % encoding + content_type = "text/plain; charset=%s" % encoding super().__init__( value, - content_type=content_type, encoding=encoding, *args, **kwargs) + content_type=content_type, + encoding=encoding, + *args, + **kwargs, + ) @property - def size(self): + def size(self) -> Optional[int]: try: return os.fstat(self._value.fileno()).st_size - self._value.tell() except OSError: return None - @asyncio.coroutine - def write(self, writer): + async def write(self, writer: AbstractStreamWriter) -> None: + loop = asyncio.get_event_loop() try: - chunk = self._value.read(DEFAULT_LIMIT) + chunk = await loop.run_in_executor(None, self._value.read, 2 ** 16) while chunk: - yield from writer.write(chunk.encode(self._encoding)) - chunk = self._value.read(DEFAULT_LIMIT) + await writer.write(chunk.encode(self._encoding)) + chunk = await loop.run_in_executor(None, self._value.read, 2 ** 16) finally: - self._value.close() - - -class StringIOPayload(TextIOPayload): - - @property - def size(self): - return len(self._value.getvalue()) - self._value.tell() + await loop.run_in_executor(None, self._value.close) class BytesIOPayload(IOBasePayload): - @property - def size(self): - return len(self._value.getbuffer()) - self._value.tell() + def size(self) -> int: + position = self._value.tell() + end = self._value.seek(0, os.SEEK_END) + self._value.seek(position) + return end - position class BufferedReaderPayload(IOBasePayload): - @property - def size(self): + def size(self) -> Optional[int]: try: return os.fstat(self._value.fileno()).st_size - self._value.tell() except OSError: @@ -261,39 +366,72 @@ def size(self): return None -class StreamReaderPayload(Payload): +class JsonPayload(BytesPayload): + def __init__( + self, + value: Any, + encoding: str = "utf-8", + content_type: str = "application/json", + dumps: JSONEncoder = json.dumps, + *args: Any, + **kwargs: Any, + ) -> None: - @asyncio.coroutine - def write(self, writer): - chunk = yield from self._value.read(DEFAULT_LIMIT) - while chunk: - yield from writer.write(chunk) - chunk = yield from self._value.read(DEFAULT_LIMIT) + super().__init__( + dumps(value).encode(encoding), + content_type=content_type, + encoding=encoding, + *args, + **kwargs, + ) -class DataQueuePayload(Payload): +if TYPE_CHECKING: # pragma: no cover + from typing import AsyncIterable, AsyncIterator - @asyncio.coroutine - def write(self, writer): - while True: - try: - chunk = yield from self._value.read() - if not chunk: - break - yield from writer.write(chunk) - except EofStream: - break + _AsyncIterator = AsyncIterator[bytes] + _AsyncIterable = AsyncIterable[bytes] +else: + from collections.abc import AsyncIterable, AsyncIterator + _AsyncIterator = AsyncIterator + _AsyncIterable = AsyncIterable -class JsonPayload(BytesPayload): - def __init__(self, value, - encoding='utf-8', content_type='application/json', - dumps=json.dumps, *args, **kwargs): +class AsyncIterablePayload(Payload): - super().__init__( - dumps(value).encode(encoding), - content_type=content_type, encoding=encoding, *args, **kwargs) + _iter = None # type: Optional[_AsyncIterator] + + def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None: + if not isinstance(value, AsyncIterable): + raise TypeError( + "value argument must support " + "collections.abc.AsyncIterablebe interface, " + "got {!r}".format(type(value)) + ) + + if "content_type" not in kwargs: + kwargs["content_type"] = "application/octet-stream" + + super().__init__(value, *args, **kwargs) + + self._iter = value.__aiter__() + + async def write(self, writer: AbstractStreamWriter) -> None: + if self._iter: + try: + # iter is not None check prevents rare cases + # when the case iterable is used twice + while True: + chunk = await self._iter.__anext__() + await writer.write(chunk) + except StopAsyncIteration: + self._iter = None + + +class StreamReaderPayload(AsyncIterablePayload): + def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None: + super().__init__(value.iter_any(), *args, **kwargs) PAYLOAD_REGISTRY = PayloadRegistry() @@ -302,9 +440,9 @@ def __init__(self, value, PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO) PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase) PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO) -PAYLOAD_REGISTRY.register( - BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom)) +PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom)) PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase) -PAYLOAD_REGISTRY.register( - StreamReaderPayload, (asyncio.StreamReader, StreamReader)) -PAYLOAD_REGISTRY.register(DataQueuePayload, DataQueue) +PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader) +# try_last for giving a chance to more specialized async interables like +# multidict.BodyPartReaderPayload override the default +PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last) diff --git a/aiohttp/payload_streamer.py b/aiohttp/payload_streamer.py index 2813469964b..3b2de151640 100644 --- a/aiohttp/payload_streamer.py +++ b/aiohttp/payload_streamer.py @@ -3,11 +3,11 @@ As a simple case, you can upload data from file:: @aiohttp.streamer - def file_sender(writer, file_name=None): + async def file_sender(writer, file_name=None): with open(file_name, 'rb') as f: chunk = f.read(2**16) while chunk: - yield from writer.write(chunk) + await writer.write(chunk) chunk = f.read(2**16) @@ -21,48 +21,54 @@ def file_sender(writer, file_name=None): """ -import asyncio +import types +import warnings +from typing import Any, Awaitable, Callable, Dict, Tuple +from .abc import AbstractStreamWriter from .payload import Payload, payload_type -__all__ = ('streamer',) +__all__ = ("streamer",) class _stream_wrapper: - - def __init__(self, coro, args, kwargs): - self.coro = asyncio.coroutine(coro) + def __init__( + self, + coro: Callable[..., Awaitable[None]], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> None: + self.coro = types.coroutine(coro) self.args = args self.kwargs = kwargs - @asyncio.coroutine - def __call__(self, writer): - yield from self.coro(writer, *self.args, **self.kwargs) + async def __call__(self, writer: AbstractStreamWriter) -> None: + await self.coro(writer, *self.args, **self.kwargs) # type: ignore class streamer: - - def __init__(self, coro): + def __init__(self, coro: Callable[..., Awaitable[None]]) -> None: + warnings.warn( + "@streamer is deprecated, use async generators instead", + DeprecationWarning, + stacklevel=2, + ) self.coro = coro - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> _stream_wrapper: return _stream_wrapper(self.coro, args, kwargs) @payload_type(_stream_wrapper) class StreamWrapperPayload(Payload): - - @asyncio.coroutine - def write(self, writer): - yield from self._value(writer) + async def write(self, writer: AbstractStreamWriter) -> None: + await self._value(writer) @payload_type(streamer) class StreamPayload(StreamWrapperPayload): - - def __init__(self, value, *args, **kwargs): + def __init__(self, value: Any, *args: Any, **kwargs: Any) -> None: super().__init__(value(), *args, **kwargs) - @asyncio.coroutine - def write(self, writer): - yield from self._value(writer) + async def write(self, writer: AbstractStreamWriter) -> None: + await self._value(writer) diff --git a/aiohttp/py.typed b/aiohttp/py.typed new file mode 100644 index 00000000000..f5642f79f21 --- /dev/null +++ b/aiohttp/py.typed @@ -0,0 +1 @@ +Marker diff --git a/aiohttp/pytest_plugin.py b/aiohttp/pytest_plugin.py index 939923f256e..5204293410b 100644 --- a/aiohttp/pytest_plugin.py +++ b/aiohttp/pytest_plugin.py @@ -1,41 +1,153 @@ import asyncio import contextlib -import tempfile +import warnings +from collections.abc import Callable import pytest -from py import path +from aiohttp.helpers import PY_37, isasyncgenfunction from aiohttp.web import Application -from .test_utils import unused_port as _unused_port -from .test_utils import (RawTestServer, TestClient, TestServer, - loop_context, setup_test_loop, teardown_test_loop) +from .test_utils import ( + BaseTestServer, + RawTestServer, + TestClient, + TestServer, + loop_context, + setup_test_loop, + teardown_test_loop, + unused_port as _unused_port, +) try: import uvloop -except: +except ImportError: # pragma: no cover uvloop = None +try: + import tokio +except ImportError: # pragma: no cover + tokio = None + + +def pytest_addoption(parser): # type: ignore + parser.addoption( + "--aiohttp-fast", + action="store_true", + default=False, + help="run tests faster by disabling extra checks", + ) + parser.addoption( + "--aiohttp-loop", + action="store", + default="pyloop", + help="run tests with specific loop: pyloop, uvloop, tokio or all", + ) + parser.addoption( + "--aiohttp-enable-loop-debug", + action="store_true", + default=False, + help="enable event loop debug mode", + ) + + +def pytest_fixture_setup(fixturedef): # type: ignore + """ + Allow fixtures to be coroutines. Run coroutine fixtures in an event loop. + """ + func = fixturedef.func + + if isasyncgenfunction(func): + # async generator fixture + is_async_gen = True + elif asyncio.iscoroutinefunction(func): + # regular async fixture + is_async_gen = False + else: + # not an async fixture, nothing to do + return + + strip_request = False + if "request" not in fixturedef.argnames: + fixturedef.argnames += ("request",) + strip_request = True + + def wrapper(*args, **kwargs): # type: ignore + request = kwargs["request"] + if strip_request: + del kwargs["request"] + + # if neither the fixture nor the test use the 'loop' fixture, + # 'getfixturevalue' will fail because the test is not parameterized + # (this can be removed someday if 'loop' is no longer parameterized) + if "loop" not in request.fixturenames: + raise Exception( + "Asynchronous fixtures must depend on the 'loop' fixture or " + "be used in tests depending from it." + ) + + _loop = request.getfixturevalue("loop") + + if is_async_gen: + # for async generators, we need to advance the generator once, + # then advance it again in a finalizer + gen = func(*args, **kwargs) + + def finalizer(): # type: ignore + try: + return _loop.run_until_complete(gen.__anext__()) + except StopAsyncIteration: + pass + + request.addfinalizer(finalizer) + return _loop.run_until_complete(gen.__anext__()) + else: + return _loop.run_until_complete(func(*args, **kwargs)) -def pytest_addoption(parser): - parser.addoption('--fast', action='store_true', default=False, - help='run tests faster by disabling extra checks') - parser.addoption('--with-uvloop-only', action='store_true', default=False, - help='run tests with uvloop only if available') - parser.addoption('--without-uvloop', action='store_true', default=False, - help='run tests without uvloop') - parser.addoption('--enable-loop-debug', action='store_true', default=False, - help='enable event loop debug mode') + fixturedef.func = wrapper @pytest.fixture -def fast(request): - """ --fast config option """ - return request.config.getoption('--fast') +def fast(request): # type: ignore + """--fast config option""" + return request.config.getoption("--aiohttp-fast") + + +@pytest.fixture +def loop_debug(request): # type: ignore + """--enable-loop-debug config option""" + return request.config.getoption("--aiohttp-enable-loop-debug") @contextlib.contextmanager -def _passthrough_loop_context(loop, fast=False): +def _runtime_warning_context(): # type: ignore + """ + Context manager which checks for RuntimeWarnings, specifically to + avoid "coroutine 'X' was never awaited" warnings being missed. + + If RuntimeWarnings occur in the context a RuntimeError is raised. + """ + with warnings.catch_warnings(record=True) as _warnings: + yield + rw = [ + "{w.filename}:{w.lineno}:{w.message}".format(w=w) + for w in _warnings + if w.category == RuntimeWarning + ] + if rw: + raise RuntimeError( + "{} Runtime Warning{},\n{}".format( + len(rw), "" if len(rw) == 1 else "s", "\n".join(rw) + ) + ) + + +@contextlib.contextmanager +def _passthrough_loop_context(loop, fast=False): # type: ignore + """ + setups and tears down a loop unless one is passed in via the loop + argument when it's passed straight through. + """ if loop: # loop already exists, pass it straight through yield loop @@ -46,7 +158,7 @@ def _passthrough_loop_context(loop, fast=False): teardown_test_loop(loop, fast=fast) -def pytest_pycollect_makeitem(collector, name, obj): +def pytest_pycollect_makeitem(collector, name, obj): # type: ignore """ Fix pytest collecting for coroutines. """ @@ -54,162 +166,215 @@ def pytest_pycollect_makeitem(collector, name, obj): return list(collector._genfunctions(name, obj)) -def pytest_pyfunc_call(pyfuncitem): +def pytest_pyfunc_call(pyfuncitem): # type: ignore """ Run coroutines in an event loop instead of a normal function call. """ - fast = pyfuncitem.config.getoption("--fast") + fast = pyfuncitem.config.getoption("--aiohttp-fast") if asyncio.iscoroutinefunction(pyfuncitem.function): - existing_loop = pyfuncitem.funcargs.get('loop', None) - with _passthrough_loop_context(existing_loop, fast=fast) as _loop: - testargs = {arg: pyfuncitem.funcargs[arg] - for arg in pyfuncitem._fixtureinfo.argnames} - - task = _loop.create_task(pyfuncitem.obj(**testargs)) - _loop.run_until_complete(task) + existing_loop = pyfuncitem.funcargs.get( + "proactor_loop" + ) or pyfuncitem.funcargs.get("loop", None) + with _runtime_warning_context(): + with _passthrough_loop_context(existing_loop, fast=fast) as _loop: + testargs = { + arg: pyfuncitem.funcargs[arg] + for arg in pyfuncitem._fixtureinfo.argnames + } + _loop.run_until_complete(pyfuncitem.obj(**testargs)) return True -def pytest_configure(config): - fast = config.getoption('--fast') - uvloop_only = config.getoption('--with-uvloop-only') - - without_uvloop = False - if fast: - without_uvloop = True +def pytest_generate_tests(metafunc): # type: ignore + if "loop_factory" not in metafunc.fixturenames: + return - if config.getoption('--without-uvloop'): - without_uvloop = True + loops = metafunc.config.option.aiohttp_loop + avail_factories = {"pyloop": asyncio.DefaultEventLoopPolicy} - LOOP_FACTORIES.clear() - if uvloop_only and uvloop is not None: - LOOP_FACTORIES.append(uvloop.new_event_loop) - elif without_uvloop: - LOOP_FACTORIES.append(asyncio.new_event_loop) - else: - LOOP_FACTORIES.append(asyncio.new_event_loop) - if uvloop is not None: - LOOP_FACTORIES.append(uvloop.new_event_loop) + if uvloop is not None: # pragma: no cover + avail_factories["uvloop"] = uvloop.EventLoopPolicy - asyncio.set_event_loop(None) + if tokio is not None: # pragma: no cover + avail_factories["tokio"] = tokio.EventLoopPolicy + if loops == "all": + loops = "pyloop,uvloop?,tokio?" -LOOP_FACTORIES = [] + factories = {} # type: ignore + for name in loops.split(","): + required = not name.endswith("?") + name = name.strip(" ?") + if name not in avail_factories: # pragma: no cover + if required: + raise ValueError( + "Unknown loop '%s', available loops: %s" + % (name, list(factories.keys())) + ) + else: + continue + factories[name] = avail_factories[name] + metafunc.parametrize( + "loop_factory", list(factories.values()), ids=list(factories.keys()) + ) -@pytest.yield_fixture(params=LOOP_FACTORIES) -def loop(request): +@pytest.fixture +def loop(loop_factory, fast, loop_debug): # type: ignore """Return an instance of the event loop.""" - fast = request.config.getoption('--fast') - debug = request.config.getoption('--enable-loop-debug') + policy = loop_factory() + asyncio.set_event_loop_policy(policy) + with loop_context(fast=fast) as _loop: + if loop_debug: + _loop.set_debug(True) # pragma: no cover + asyncio.set_event_loop(_loop) + yield _loop + + +@pytest.fixture +def proactor_loop(): # type: ignore + if not PY_37: + policy = asyncio.get_event_loop_policy() + policy._loop_factory = asyncio.ProactorEventLoop # type: ignore + else: + policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore + asyncio.set_event_loop_policy(policy) - with loop_context(request.param, fast=fast) as _loop: - if debug: - _loop.set_debug(True) + with loop_context(policy.new_event_loop) as _loop: + asyncio.set_event_loop(_loop) yield _loop @pytest.fixture -def unused_port(): +def unused_port(aiohttp_unused_port): # type: ignore # pragma: no cover + warnings.warn( + "Deprecated, use aiohttp_unused_port fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_unused_port + + +@pytest.fixture +def aiohttp_unused_port(): # type: ignore """Return a port that is unused on the current host.""" return _unused_port -@pytest.yield_fixture -def test_server(loop): +@pytest.fixture +def aiohttp_server(loop): # type: ignore """Factory to create a TestServer instance, given an app. - test_server(app, **kwargs) + aiohttp_server(app, **kwargs) """ servers = [] - @asyncio.coroutine - def go(app, **kwargs): - server = TestServer(app) - yield from server.start_server(loop=loop, **kwargs) + async def go(app, *, port=None, **kwargs): # type: ignore + server = TestServer(app, port=port) + await server.start_server(loop=loop, **kwargs) servers.append(server) return server yield go - @asyncio.coroutine - def finalize(): + async def finalize(): # type: ignore while servers: - yield from servers.pop().close() + await servers.pop().close() loop.run_until_complete(finalize()) -@pytest.yield_fixture -def raw_test_server(loop): +@pytest.fixture +def test_server(aiohttp_server): # type: ignore # pragma: no cover + warnings.warn( + "Deprecated, use aiohttp_server fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_server + + +@pytest.fixture +def aiohttp_raw_server(loop): # type: ignore """Factory to create a RawTestServer instance, given a web handler. - raw_test_server(handler, **kwargs) + aiohttp_raw_server(handler, **kwargs) """ servers = [] - @asyncio.coroutine - def go(handler, **kwargs): - server = RawTestServer(handler) - yield from server.start_server(loop=loop, **kwargs) + async def go(handler, *, port=None, **kwargs): # type: ignore + server = RawTestServer(handler, port=port) + await server.start_server(loop=loop, **kwargs) servers.append(server) return server yield go - @asyncio.coroutine - def finalize(): + async def finalize(): # type: ignore while servers: - yield from servers.pop().close() + await servers.pop().close() loop.run_until_complete(finalize()) -@pytest.yield_fixture -def test_client(loop): +@pytest.fixture +def raw_test_server(aiohttp_raw_server): # type: ignore # pragma: no cover + warnings.warn( + "Deprecated, use aiohttp_raw_server fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_raw_server + + +@pytest.fixture +def aiohttp_client(loop): # type: ignore """Factory to create a TestClient instance. - test_client(app, **kwargs) - test_client(server, **kwargs) - test_client(raw_server, **kwargs) + aiohttp_client(app, **kwargs) + aiohttp_client(server, **kwargs) + aiohttp_client(raw_server, **kwargs) """ clients = [] - @asyncio.coroutine - def go(__param, *args, **kwargs): - if isinstance(__param, Application): - assert not args, "args should be empty" - client = TestClient(__param, loop=loop, **kwargs) - elif isinstance(__param, TestServer): - assert not args, "args should be empty" - client = TestClient(__param, loop=loop, **kwargs) - elif isinstance(__param, RawTestServer): + async def go(__param, *args, server_kwargs=None, **kwargs): # type: ignore + + if isinstance(__param, Callable) and not isinstance( # type: ignore + __param, (Application, BaseTestServer) + ): + __param = __param(loop, *args, **kwargs) + kwargs = {} + else: assert not args, "args should be empty" + + if isinstance(__param, Application): + server_kwargs = server_kwargs or {} + server = TestServer(__param, loop=loop, **server_kwargs) + client = TestClient(server, loop=loop, **kwargs) + elif isinstance(__param, BaseTestServer): client = TestClient(__param, loop=loop, **kwargs) else: - __param = __param(loop, *args, **kwargs) - client = TestClient(__param, loop=loop) + raise ValueError("Unknown argument type: %r" % type(__param)) - yield from client.start_server() + await client.start_server() clients.append(client) return client yield go - @asyncio.coroutine - def finalize(): + async def finalize(): # type: ignore while clients: - yield from clients.pop().close() + await clients.pop().close() loop.run_until_complete(finalize()) @pytest.fixture -def shorttmpdir(): - """Provides a temporary directory with a shorter file system path than the - tmpdir fixture. - """ - tmpdir = path.local(tempfile.mkdtemp()) - yield tmpdir - tmpdir.remove(rec=1) +def test_client(aiohttp_client): # type: ignore # pragma: no cover + warnings.warn( + "Deprecated, use aiohttp_client fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_client diff --git a/aiohttp/resolver.py b/aiohttp/resolver.py index 102a79b3731..2974bcad7af 100644 --- a/aiohttp/resolver.py +++ b/aiohttp/resolver.py @@ -1,12 +1,15 @@ import asyncio import socket +from typing import Any, Dict, List, Optional from .abc import AbstractResolver +from .helpers import get_running_loop -__all__ = ('ThreadedResolver', 'AsyncResolver', 'DefaultResolver') +__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver") try: import aiodns + # aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname') except ImportError: # pragma: no cover aiodns = None @@ -19,82 +22,127 @@ class ThreadedResolver(AbstractResolver): concurrent.futures.ThreadPoolExecutor. """ - def __init__(self, loop=None): - if loop is None: - loop = asyncio.get_event_loop() - self._loop = loop + def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + self._loop = get_running_loop(loop) - @asyncio.coroutine - def resolve(self, host, port=0, family=socket.AF_INET): - infos = yield from self._loop.getaddrinfo( - host, port, type=socket.SOCK_STREAM, family=family) + async def resolve( + self, hostname: str, port: int = 0, family: int = socket.AF_INET + ) -> List[Dict[str, Any]]: + infos = await self._loop.getaddrinfo( + hostname, + port, + type=socket.SOCK_STREAM, + family=family, + flags=socket.AI_ADDRCONFIG, + ) hosts = [] for family, _, proto, _, address in infos: + if family == socket.AF_INET6 and address[3]: # type: ignore + # This is essential for link-local IPv6 addresses. + # LL IPv6 is a VERY rare case. Strictly speaking, we should use + # getnameinfo() unconditionally, but performance makes sense. + host, _port = socket.getnameinfo( + address, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV + ) + port = int(_port) + else: + host, port = address[:2] hosts.append( - {'hostname': host, - 'host': address[0], 'port': address[1], - 'family': family, 'proto': proto, - 'flags': socket.AI_NUMERICHOST}) + { + "hostname": hostname, + "host": host, + "port": port, + "family": family, + "proto": proto, + "flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV, + } + ) return hosts - @asyncio.coroutine - def close(self): + async def close(self) -> None: pass class AsyncResolver(AbstractResolver): """Use the `aiodns` package to make asynchronous DNS lookups""" - def __init__(self, loop=None, *args, **kwargs): - if loop is None: - loop = asyncio.get_event_loop() - + def __init__( + self, + loop: Optional[asyncio.AbstractEventLoop] = None, + *args: Any, + **kwargs: Any + ) -> None: if aiodns is None: raise RuntimeError("Resolver requires aiodns library") - self._loop = loop + self._loop = get_running_loop(loop) self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs) - if not hasattr(self._resolver, 'gethostbyname'): + if not hasattr(self._resolver, "gethostbyname"): # aiodns 1.1 is not available, fallback to DNSResolver.query - self.resolve = self.resolve_with_query - - @asyncio.coroutine - def resolve(self, host, port=0, family=socket.AF_INET): + self.resolve = self._resolve_with_query # type: ignore + + async def resolve( + self, host: str, port: int = 0, family: int = socket.AF_INET + ) -> List[Dict[str, Any]]: + try: + resp = await self._resolver.gethostbyname(host, family) + except aiodns.error.DNSError as exc: + msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" + raise OSError(msg) from exc hosts = [] - resp = yield from self._resolver.gethostbyname(host, family) - for address in resp.addresses: hosts.append( - {'hostname': host, - 'host': address, 'port': port, - 'family': family, 'proto': 0, - 'flags': socket.AI_NUMERICHOST}) + { + "hostname": host, + "host": address, + "port": port, + "family": family, + "proto": 0, + "flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV, + } + ) + + if not hosts: + raise OSError("DNS lookup failed") + return hosts - @asyncio.coroutine - def resolve_with_query(self, host, port=0, family=socket.AF_INET): + async def _resolve_with_query( + self, host: str, port: int = 0, family: int = socket.AF_INET + ) -> List[Dict[str, Any]]: if family == socket.AF_INET6: - qtype = 'AAAA' + qtype = "AAAA" else: - qtype = 'A' + qtype = "A" - hosts = [] - resp = yield from self._resolver.query(host, qtype) + try: + resp = await self._resolver.query(host, qtype) + except aiodns.error.DNSError as exc: + msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" + raise OSError(msg) from exc + hosts = [] for rr in resp: hosts.append( - {'hostname': host, - 'host': rr.host, 'port': port, - 'family': family, 'proto': 0, - 'flags': socket.AI_NUMERICHOST}) + { + "hostname": host, + "host": rr.host, + "port": port, + "family": family, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + ) + + if not hosts: + raise OSError("DNS lookup failed") return hosts - @asyncio.coroutine - def close(self): + async def close(self) -> None: return self._resolver.cancel() diff --git a/aiohttp/signals.py b/aiohttp/signals.py index 003671024eb..d406c02423b 100644 --- a/aiohttp/signals.py +++ b/aiohttp/signals.py @@ -1,98 +1,34 @@ -import asyncio -from itertools import count +from aiohttp.frozenlist import FrozenList -from aiohttp.helpers import FrozenList +__all__ = ("Signal",) -class BaseSignal(FrozenList): - - __slots__ = () - - @asyncio.coroutine - def _send(self, *args, **kwargs): - for receiver in self._items: - res = receiver(*args, **kwargs) - if asyncio.iscoroutine(res) or isinstance(res, asyncio.Future): - yield from res - - -class Signal(BaseSignal): +class Signal(FrozenList): """Coroutine-based signal implementation. To connect a callback to a signal, use any list method. - Signals are fired using the :meth:`send` coroutine, which takes named + Signals are fired using the send() coroutine, which takes named arguments. """ - __slots__ = ('_app', '_name', '_pre', '_post') + __slots__ = ("_owner",) - def __init__(self, app): + def __init__(self, owner): super().__init__() - self._app = app - klass = self.__class__ - self._name = klass.__module__ + ':' + klass.__qualname__ - self._pre = app.on_pre_signal - self._post = app.on_post_signal - - @asyncio.coroutine - def send(self, *args, **kwargs): - """ - Sends data to all registered receivers. - """ - if self._items: - ordinal = None - debug = self._app._debug - if debug: - ordinal = self._pre.ordinal() - yield from self._pre.send( - ordinal, self._name, *args, **kwargs) - yield from self._send(*args, **kwargs) - if debug: - yield from self._post.send( - ordinal, self._name, *args, **kwargs) - + self._owner = owner -class FuncSignal(BaseSignal): - """Callback-based signal implementation. - - To connect a callback to a signal, use any list method. - - Signals are fired using the :meth:`send` method, which takes named - arguments. - """ + def __repr__(self): + return "".format( + self._owner, self.frozen, list(self) + ) - __slots__ = () - - def send(self, *args, **kwargs): + async def send(self, *args, **kwargs): """ Sends data to all registered receivers. """ - for receiver in self._items: - receiver(*args, **kwargs) - - -class DebugSignal(BaseSignal): - - __slots__ = () - - @asyncio.coroutine - def send(self, ordinal, name, *args, **kwargs): - yield from self._send(ordinal, name, *args, **kwargs) - - -class PreSignal(DebugSignal): - - __slots__ = ('_counter',) - - def __init__(self): - super().__init__() - self._counter = count(1) - - def ordinal(self): - return next(self._counter) - - -class PostSignal(DebugSignal): + if not self.frozen: + raise RuntimeError("Cannot send non-frozen signal.") - __slots__ = () + for receiver in self: + await receiver(*args, **kwargs) # type: ignore diff --git a/aiohttp/signals.pyi b/aiohttp/signals.pyi new file mode 100644 index 00000000000..455f8e2f227 --- /dev/null +++ b/aiohttp/signals.pyi @@ -0,0 +1,12 @@ +from typing import Any, Generic, TypeVar + +from aiohttp.frozenlist import FrozenList + +__all__ = ("Signal",) + +_T = TypeVar("_T") + +class Signal(FrozenList[_T], Generic[_T]): + def __init__(self, owner: Any) -> None: ... + def __repr__(self) -> str: ... + async def send(self, *args: Any, **kwargs: Any) -> None: ... diff --git a/aiohttp/streams.py b/aiohttp/streams.py index f6d6b8ecf26..42970b531d0 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -1,68 +1,90 @@ import asyncio import collections -import traceback +import warnings +from typing import Awaitable, Callable, Generic, List, Optional, Tuple, TypeVar -from . import helpers +from .base_protocol import BaseProtocol +from .helpers import BaseTimerContext, set_exception, set_result from .log import internal_logger +try: # pragma: no cover + from typing import Deque +except ImportError: + from typing_extensions import Deque + __all__ = ( - 'EMPTY_PAYLOAD', 'EofStream', 'StreamReader', 'DataQueue', 'ChunksQueue', - 'FlowControlStreamReader', - 'FlowControlDataQueue', 'FlowControlChunksQueue') + "EMPTY_PAYLOAD", + "EofStream", + "StreamReader", + "DataQueue", + "FlowControlDataQueue", +) -DEFAULT_LIMIT = 2 ** 16 +_T = TypeVar("_T") class EofStream(Exception): """eof stream indication.""" -if helpers.PY_35: - class AsyncStreamIterator: +class AsyncStreamIterator(Generic[_T]): + def __init__(self, read_func: Callable[[], Awaitable[_T]]) -> None: + self.read_func = read_func + + def __aiter__(self) -> "AsyncStreamIterator[_T]": + return self + + async def __anext__(self) -> _T: + try: + rv = await self.read_func() + except EofStream: + raise StopAsyncIteration + if rv == b"": + raise StopAsyncIteration + return rv - def __init__(self, read_func): - self.read_func = read_func - def __aiter__(self): - return self +class ChunkTupleAsyncStreamIterator: + def __init__(self, stream: "StreamReader") -> None: + self._stream = stream - if not helpers.PY_352: # pragma: no cover - __aiter__ = asyncio.coroutine(__aiter__) + def __aiter__(self) -> "ChunkTupleAsyncStreamIterator": + return self - @asyncio.coroutine - def __anext__(self): - try: - rv = yield from self.read_func() - except EofStream: - raise StopAsyncIteration # NOQA - if rv == b'': - raise StopAsyncIteration # NOQA - return rv + async def __anext__(self) -> Tuple[bytes, bool]: + rv = await self._stream.readchunk() + if rv == (b"", False): + raise StopAsyncIteration + return rv class AsyncStreamReaderMixin: + def __aiter__(self) -> AsyncStreamIterator[bytes]: + return AsyncStreamIterator(self.readline) # type: ignore - if helpers.PY_35: - def __aiter__(self): - return AsyncStreamIterator(self.readline) + def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]: + """Returns an asynchronous iterator that yields chunks of size n. - if not helpers.PY_352: # pragma: no cover - __aiter__ = asyncio.coroutine(__aiter__) + Python-3.5 available for Python 3.5+ only + """ + return AsyncStreamIterator(lambda: self.read(n)) # type: ignore - def iter_chunked(self, n): - """Returns an asynchronous iterator that yields chunks of size n. + def iter_any(self) -> AsyncStreamIterator[bytes]: + """Returns an asynchronous iterator that yields all the available + data as soon as it is received - Python-3.5 available for Python 3.5+ only - """ - return AsyncStreamIterator(lambda: self.read(n)) + Python-3.5 available for Python 3.5+ only + """ + return AsyncStreamIterator(self.readany) # type: ignore - def iter_any(self): - """Returns an asynchronous iterator that yields slices of data - as they come. + def iter_chunks(self) -> ChunkTupleAsyncStreamIterator: + """Returns an asynchronous iterator that yields chunks of data + as they are received by the server. The yielded objects are tuples + of (bytes, bool) as returned by the StreamReader.readchunk method. - Python-3.5 available for Python 3.5+ only - """ - return AsyncStreamIterator(self.readany) + Python-3.5 available for Python 3.5+ only + """ + return ChunkTupleAsyncStreamIterator(self) # type: ignore class StreamReader(AsyncStreamReaderMixin): @@ -81,120 +103,137 @@ class StreamReader(AsyncStreamReaderMixin): total_bytes = 0 - def __init__(self, limit=DEFAULT_LIMIT, timer=None, loop=None): - self._limit = limit + def __init__( + self, + protocol: BaseProtocol, + limit: int, + *, + timer: Optional[BaseTimerContext] = None, + loop: Optional[asyncio.AbstractEventLoop] = None + ) -> None: + self._protocol = protocol + self._low_water = limit + self._high_water = limit * 2 if loop is None: loop = asyncio.get_event_loop() self._loop = loop self._size = 0 - self._buffer = collections.deque() + self._cursor = 0 + self._http_chunk_splits = None # type: Optional[List[int]] + self._buffer = collections.deque() # type: Deque[bytes] self._buffer_offset = 0 self._eof = False - self._waiter = None - self._eof_waiter = None - self._exception = None + self._waiter = None # type: Optional[asyncio.Future[None]] + self._eof_waiter = None # type: Optional[asyncio.Future[None]] + self._exception = None # type: Optional[BaseException] self._timer = timer - self._eof_callbacks = [] + self._eof_callbacks = [] # type: List[Callable[[], None]] - def __repr__(self): + def __repr__(self) -> str: info = [self.__class__.__name__] if self._size: - info.append('%d bytes' % self._size) + info.append("%d bytes" % self._size) if self._eof: - info.append('eof') - if self._limit != DEFAULT_LIMIT: - info.append('l=%d' % self._limit) + info.append("eof") + if self._low_water != 2 ** 16: # default limit + info.append("low=%d high=%d" % (self._low_water, self._high_water)) if self._waiter: - info.append('w=%r' % self._waiter) + info.append("w=%r" % self._waiter) if self._exception: - info.append('e=%r' % self._exception) - return '<%s>' % ' '.join(info) + info.append("e=%r" % self._exception) + return "<%s>" % " ".join(info) + + def get_read_buffer_limits(self) -> Tuple[int, int]: + return (self._low_water, self._high_water) - def exception(self): + def exception(self) -> Optional[BaseException]: return self._exception - def set_exception(self, exc): + def set_exception(self, exc: BaseException) -> None: self._exception = exc self._eof_callbacks.clear() waiter = self._waiter if waiter is not None: self._waiter = None - if not waiter.done(): - waiter.set_exception(exc) + set_exception(waiter, exc) waiter = self._eof_waiter if waiter is not None: self._eof_waiter = None - if not waiter.done(): - waiter.set_exception(exc) + set_exception(waiter, exc) - def on_eof(self, callback): + def on_eof(self, callback: Callable[[], None]) -> None: if self._eof: try: callback() except Exception: - internal_logger.exception('Exception in eof callback') + internal_logger.exception("Exception in eof callback") else: self._eof_callbacks.append(callback) - def feed_eof(self): + def feed_eof(self) -> None: self._eof = True waiter = self._waiter if waiter is not None: self._waiter = None - if not waiter.done(): - waiter.set_result(True) + set_result(waiter, None) waiter = self._eof_waiter if waiter is not None: self._eof_waiter = None - if not waiter.done(): - waiter.set_result(True) + set_result(waiter, None) for cb in self._eof_callbacks: try: cb() except Exception: - internal_logger.exception('Exception in eof callback') + internal_logger.exception("Exception in eof callback") self._eof_callbacks.clear() - def is_eof(self): + def is_eof(self) -> bool: """Return True if 'feed_eof' was called.""" return self._eof - def at_eof(self): + def at_eof(self) -> bool: """Return True if the buffer is empty and 'feed_eof' was called.""" return self._eof and not self._buffer - @asyncio.coroutine - def wait_eof(self): + async def wait_eof(self) -> None: if self._eof: return assert self._eof_waiter is None - self._eof_waiter = helpers.create_future(self._loop) + self._eof_waiter = self._loop.create_future() try: - yield from self._eof_waiter + await self._eof_waiter finally: self._eof_waiter = None - def unread_data(self, data): - """ rollback reading some data from stream, inserting it to buffer head. - """ + def unread_data(self, data: bytes) -> None: + """rollback reading some data from stream, inserting it to buffer head.""" + warnings.warn( + "unread_data() is deprecated " + "and will be removed in future releases (#3260)", + DeprecationWarning, + stacklevel=2, + ) if not data: return if self._buffer_offset: - self._buffer[0] = self._buffer[0][self._buffer_offset:] + self._buffer[0] = self._buffer[0][self._buffer_offset :] self._buffer_offset = 0 self._size += len(data) + self._cursor -= len(data) self._buffer.appendleft(data) + self._eof_counter = 0 - def feed_data(self, data): - assert not self._eof, 'feed_data after feed_eof' + # TODO: size is ignored, remove the param later + def feed_data(self, data: bytes, size: int = 0) -> None: + assert not self._eof, "feed_data after feed_eof" if not data: return @@ -206,31 +245,71 @@ def feed_data(self, data): waiter = self._waiter if waiter is not None: self._waiter = None - if not waiter.done(): - waiter.set_result(False) + set_result(waiter, None) + + if self._size > self._high_water and not self._protocol._reading_paused: + self._protocol.pause_reading() + + def begin_http_chunk_receiving(self) -> None: + if self._http_chunk_splits is None: + if self.total_bytes: + raise RuntimeError( + "Called begin_http_chunk_receiving when" "some data was already fed" + ) + self._http_chunk_splits = [] + + def end_http_chunk_receiving(self) -> None: + if self._http_chunk_splits is None: + raise RuntimeError( + "Called end_chunk_receiving without calling " + "begin_chunk_receiving first" + ) + + # self._http_chunk_splits contains logical byte offsets from start of + # the body transfer. Each offset is the offset of the end of a chunk. + # "Logical" means bytes, accessible for a user. + # If no chunks containig logical data were received, current position + # is difinitely zero. + pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0 + + if self.total_bytes == pos: + # We should not add empty chunks here. So we check for that. + # Note, when chunked + gzip is used, we can receive a chunk + # of compressed data, but that data may not be enough for gzip FSM + # to yield any uncompressed data. That's why current position may + # not change after receiving a chunk. + return + + self._http_chunk_splits.append(self.total_bytes) - @asyncio.coroutine - def _wait(self, func_name): + # wake up readchunk when end of http chunk received + waiter = self._waiter + if waiter is not None: + self._waiter = None + set_result(waiter, None) + + async def _wait(self, func_name: str) -> None: # StreamReader uses a future to link the protocol feed_data() method # to a read coroutine. Running two read coroutines at the same time # would have an unexpected behaviour. It would not possible to know # which coroutine would get the next data. if self._waiter is not None: - raise RuntimeError('%s() called while another coroutine is ' - 'already waiting for incoming data' % func_name) + raise RuntimeError( + "%s() called while another coroutine is " + "already waiting for incoming data" % func_name + ) - waiter = self._waiter = helpers.create_future(self._loop) + waiter = self._waiter = self._loop.create_future() try: if self._timer: with self._timer: - yield from waiter + await waiter else: - yield from waiter + await waiter finally: self._waiter = None - @asyncio.coroutine - def readline(self): + async def readline(self) -> bytes: if self._exception is not None: raise self._exception @@ -241,7 +320,7 @@ def readline(self): while not_enough: while self._buffer and not_enough: offset = self._buffer_offset - ichar = self._buffer[0].find(b'\n', offset) + 1 + ichar = self._buffer[0].find(b"\n", offset) + 1 # Read from current offset to found b'\n' or to the end. data = self._read_nowait_chunk(ichar - offset if ichar else -1) line.append(data) @@ -249,19 +328,18 @@ def readline(self): if ichar: not_enough = False - if line_size > self._limit: - raise ValueError('Line is too long') + if line_size > self._high_water: + raise ValueError("Line is too long") if self._eof: break if not_enough: - yield from self._wait('readline') + await self._wait("readline") - return b''.join(line) + return b"".join(line) - @asyncio.coroutine - def read(self, n=-1): + async def read(self, n: int = -1) -> bytes: if self._exception is not None: raise self._exception @@ -271,15 +349,16 @@ def read(self, n=-1): # lets keep this code one major release. if __debug__: if self._eof and not self._buffer: - self._eof_counter = getattr(self, '_eof_counter', 0) + 1 + self._eof_counter = getattr(self, "_eof_counter", 0) + 1 if self._eof_counter > 5: - stack = traceback.format_stack() internal_logger.warning( - 'Multiple access to StreamReader in eof state, ' - 'might be infinite loop: \n%s', stack) + "Multiple access to StreamReader in eof state, " + "might be infinite loop.", + stack_info=True, + ) if not n: - return b'' + return b"" if n < 0: # This used to just loop creating a new waiter hoping to @@ -288,45 +367,80 @@ def read(self, n=-1): # bytes. So just call self.readany() until EOF. blocks = [] while True: - block = yield from self.readany() + block = await self.readany() if not block: break blocks.append(block) - return b''.join(blocks) + return b"".join(blocks) - if not self._buffer and not self._eof: - yield from self._wait('read') + # TODO: should be `if` instead of `while` + # because waiter maybe triggered on chunk end, + # without feeding any data + while not self._buffer and not self._eof: + await self._wait("read") return self._read_nowait(n) - @asyncio.coroutine - def readany(self): + async def readany(self) -> bytes: if self._exception is not None: raise self._exception - if not self._buffer and not self._eof: - yield from self._wait('readany') + # TODO: should be `if` instead of `while` + # because waiter maybe triggered on chunk end, + # without feeding any data + while not self._buffer and not self._eof: + await self._wait("readany") return self._read_nowait(-1) - @asyncio.coroutine - def readexactly(self, n): + async def readchunk(self) -> Tuple[bytes, bool]: + """Returns a tuple of (data, end_of_http_chunk). When chunked transfer + encoding is used, end_of_http_chunk is a boolean indicating if the end + of the data corresponds to the end of a HTTP chunk , otherwise it is + always False. + """ + while True: + if self._exception is not None: + raise self._exception + + while self._http_chunk_splits: + pos = self._http_chunk_splits.pop(0) + if pos == self._cursor: + return (b"", True) + if pos > self._cursor: + return (self._read_nowait(pos - self._cursor), True) + internal_logger.warning( + "Skipping HTTP chunk end due to data " + "consumption beyond chunk boundary" + ) + + if self._buffer: + return (self._read_nowait_chunk(-1), False) + # return (self._read_nowait(-1), False) + + if self._eof: + # Special case for signifying EOF. + # (b'', True) is not a final return value actually. + return (b"", False) + + await self._wait("readchunk") + + async def readexactly(self, n: int) -> bytes: if self._exception is not None: raise self._exception - blocks = [] + blocks = [] # type: List[bytes] while n > 0: - block = yield from self.read(n) + block = await self.read(n) if not block: - partial = b''.join(blocks) - raise asyncio.streams.IncompleteReadError( - partial, len(partial) + n) + partial = b"".join(blocks) + raise asyncio.IncompleteReadError(partial, len(partial) + n) blocks.append(block) n -= len(block) - return b''.join(blocks) + return b"".join(blocks) - def read_nowait(self, n=-1): + def read_nowait(self, n: int = -1) -> bytes: # default was changed to be consistent with .read(-1) # # I believe the most users don't know about the method and @@ -336,15 +450,16 @@ def read_nowait(self, n=-1): if self._waiter and not self._waiter.done(): raise RuntimeError( - 'Called while some coroutine is waiting for incoming data.') + "Called while some coroutine is waiting for incoming data." + ) return self._read_nowait(n) - def _read_nowait_chunk(self, n): + def _read_nowait_chunk(self, n: int) -> bytes: first_buffer = self._buffer[0] offset = self._buffer_offset if n != -1 and len(first_buffer) - offset > n: - data = first_buffer[offset:offset + n] + data = first_buffer[offset : offset + n] self._buffer_offset += n elif offset: @@ -356,9 +471,19 @@ def _read_nowait_chunk(self, n): data = self._buffer.popleft() self._size -= len(data) + self._cursor += len(data) + + chunk_splits = self._http_chunk_splits + # Prevent memory leak: drop useless chunk splits + while chunk_splits and chunk_splits[0] < self._cursor: + chunk_splits.pop(0) + + if self._size < self._low_water and self._protocol._reading_paused: + self._protocol.resume_reading() return data - def _read_nowait(self, n): + def _read_nowait(self, n: int) -> bytes: + """ Read not more than n bytes, or whole buffer if n == -1 """ chunks = [] while self._buffer: @@ -369,124 +494,114 @@ def _read_nowait(self, n): if n == 0: break - return b''.join(chunks) if chunks else b'' + return b"".join(chunks) if chunks else b"" class EmptyStreamReader(AsyncStreamReaderMixin): - - def exception(self): + def exception(self) -> Optional[BaseException]: return None - def set_exception(self, exc): + def set_exception(self, exc: BaseException) -> None: pass - def on_eof(self, callback): + def on_eof(self, callback: Callable[[], None]) -> None: try: callback() except Exception: - internal_logger.exception('Exception in eof callback') + internal_logger.exception("Exception in eof callback") - def feed_eof(self): + def feed_eof(self) -> None: pass - def is_eof(self): + def is_eof(self) -> bool: return True - def at_eof(self): + def at_eof(self) -> bool: return True - @asyncio.coroutine - def wait_eof(self): + async def wait_eof(self) -> None: return - def feed_data(self, data): + def feed_data(self, data: bytes, n: int = 0) -> None: pass - @asyncio.coroutine - def readline(self): - return b'' + async def readline(self) -> bytes: + return b"" + + async def read(self, n: int = -1) -> bytes: + return b"" - @asyncio.coroutine - def read(self, n=-1): - return b'' + async def readany(self) -> bytes: + return b"" - @asyncio.coroutine - def readany(self): - return b'' + async def readchunk(self) -> Tuple[bytes, bool]: + return (b"", True) - @asyncio.coroutine - def readexactly(self, n): - raise asyncio.streams.IncompleteReadError(b'', n) + async def readexactly(self, n: int) -> bytes: + raise asyncio.IncompleteReadError(b"", n) - def read_nowait(self): - return b'' + def read_nowait(self) -> bytes: + return b"" EMPTY_PAYLOAD = EmptyStreamReader() -class DataQueue: +class DataQueue(Generic[_T]): """DataQueue is a general-purpose blocking queue with one reader.""" - def __init__(self, *, loop=None): + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop = loop self._eof = False - self._waiter = None - self._exception = None + self._waiter = None # type: Optional[asyncio.Future[None]] + self._exception = None # type: Optional[BaseException] self._size = 0 - self._buffer = collections.deque() + self._buffer = collections.deque() # type: Deque[Tuple[_T, int]] - def __len__(self): + def __len__(self) -> int: return len(self._buffer) - def is_eof(self): + def is_eof(self) -> bool: return self._eof - def at_eof(self): + def at_eof(self) -> bool: return self._eof and not self._buffer - def exception(self): + def exception(self) -> Optional[BaseException]: return self._exception - def set_exception(self, exc): + def set_exception(self, exc: BaseException) -> None: self._eof = True self._exception = exc waiter = self._waiter if waiter is not None: self._waiter = None - if not waiter.done(): - waiter.set_exception(exc) + set_exception(waiter, exc) - def feed_data(self, data, size=0): + def feed_data(self, data: _T, size: int = 0) -> None: self._size += size self._buffer.append((data, size)) waiter = self._waiter if waiter is not None: self._waiter = None - if not waiter.cancelled(): - waiter.set_result(True) + set_result(waiter, None) - def feed_eof(self): + def feed_eof(self) -> None: self._eof = True waiter = self._waiter if waiter is not None: self._waiter = None - if not waiter.cancelled(): - waiter.set_result(False) + set_result(waiter, None) - @asyncio.coroutine - def read(self): + async def read(self) -> _T: if not self._buffer and not self._eof: - if self._exception is not None: - raise self._exception - assert not self._waiter - self._waiter = helpers.create_future(self._loop) + self._waiter = self._loop.create_future() try: - yield from self._waiter + await self._waiter except (asyncio.CancelledError, asyncio.TimeoutError): self._waiter = None raise @@ -501,114 +616,32 @@ def read(self): else: raise EofStream - if helpers.PY_35: - def __aiter__(self): - return AsyncStreamIterator(self.read) + def __aiter__(self) -> AsyncStreamIterator[_T]: + return AsyncStreamIterator(self.read) - if not helpers.PY_352: # pragma: no cover - __aiter__ = asyncio.coroutine(__aiter__) - -class ChunksQueue(DataQueue): - """Like a :class:`DataQueue`, but for binary chunked data transfer.""" - - @asyncio.coroutine - def read(self): - try: - return (yield from super().read()) - except EofStream: - return b'' - - readany = read - - -class FlowControlStreamReader(StreamReader): - - def __init__(self, protocol, buffer_limit=DEFAULT_LIMIT, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._protocol = protocol - self._b_limit = buffer_limit * 2 - - def feed_data(self, data, size=0): - super().feed_data(data) - - if self._size > self._b_limit and not self._protocol._reading_paused: - self._protocol.pause_reading() - - @asyncio.coroutine - def read(self, n=-1): - try: - return (yield from super().read(n)) - finally: - if self._size < self._b_limit and self._protocol._reading_paused: - self._protocol.resume_reading() - - @asyncio.coroutine - def readline(self): - try: - return (yield from super().readline()) - finally: - if self._size < self._b_limit and self._protocol._reading_paused: - self._protocol.resume_reading() - - @asyncio.coroutine - def readany(self): - try: - return (yield from super().readany()) - finally: - if self._size < self._b_limit and self._protocol._reading_paused: - self._protocol.resume_reading() - - @asyncio.coroutine - def readexactly(self, n): - try: - return (yield from super().readexactly(n)) - finally: - if self._size < self._b_limit and self._protocol._reading_paused: - self._protocol.resume_reading() - - def read_nowait(self, n=-1): - try: - return super().read_nowait(n) - finally: - if self._size < self._b_limit and self._protocol._reading_paused: - self._protocol.resume_reading() - - -class FlowControlDataQueue(DataQueue): +class FlowControlDataQueue(DataQueue[_T]): """FlowControlDataQueue resumes and pauses an underlying stream. It is a destination for parsed data.""" - def __init__(self, protocol, *, limit=DEFAULT_LIMIT, loop=None): + def __init__( + self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop + ) -> None: super().__init__(loop=loop) self._protocol = protocol self._limit = limit * 2 - def feed_data(self, data, size): + def feed_data(self, data: _T, size: int = 0) -> None: super().feed_data(data, size) if self._size > self._limit and not self._protocol._reading_paused: self._protocol.pause_reading() - @asyncio.coroutine - def read(self): + async def read(self) -> _T: try: - return (yield from super().read()) + return await super().read() finally: if self._size < self._limit and self._protocol._reading_paused: self._protocol.resume_reading() - - -class FlowControlChunksQueue(FlowControlDataQueue): - - @asyncio.coroutine - def read(self): - try: - return (yield from super().read()) - except EofStream: - return b'' - - readany = read diff --git a/aiohttp/tcp_helpers.py b/aiohttp/tcp_helpers.py new file mode 100644 index 00000000000..0e1dbf16552 --- /dev/null +++ b/aiohttp/tcp_helpers.py @@ -0,0 +1,38 @@ +"""Helper methods to tune a TCP connection""" + +import asyncio +import socket +from contextlib import suppress +from typing import Optional # noqa + +__all__ = ("tcp_keepalive", "tcp_nodelay") + + +if hasattr(socket, "SO_KEEPALIVE"): + + def tcp_keepalive(transport: asyncio.Transport) -> None: + sock = transport.get_extra_info("socket") + if sock is not None: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + + +else: + + def tcp_keepalive(transport: asyncio.Transport) -> None: # pragma: no cover + pass + + +def tcp_nodelay(transport: asyncio.Transport, value: bool) -> None: + sock = transport.get_extra_info("socket") + + if sock is None: + return + + if sock.family not in (socket.AF_INET, socket.AF_INET6): + return + + value = bool(value) + + # socket may be closed already, on windows OSError get raised + with suppress(OSError): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value) diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index d15a43242da..7a9ca7ddf3e 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -4,80 +4,132 @@ import contextlib import functools import gc +import inspect +import os import socket +import sys import unittest from abc import ABC, abstractmethod -from contextlib import contextmanager +from types import TracebackType +from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Type, Union from unittest import mock -from multidict import CIMultiDict +from multidict import CIMultiDict, CIMultiDictProxy from yarl import URL import aiohttp -from aiohttp.client import _RequestContextManager +from aiohttp.client import ( + ClientResponse, + _RequestContextManager, + _WSRequestContextManager, +) from . import ClientSession, hdrs -from .helpers import PY_35, noop, sentinel +from .abc import AbstractCookieJar +from .client_reqrep import ClientResponse +from .client_ws import ClientWebSocketResponse +from .helpers import sentinel from .http import HttpVersion, RawRequestMessage from .signals import Signal -from .web import Application, Request, Server, UrlMappingMatchInfo +from .web import ( + Application, + AppRunner, + BaseRunner, + Request, + Server, + ServerRunner, + SockSite, + UrlMappingMatchInfo, +) +from .web_protocol import _RequestHandler +if TYPE_CHECKING: # pragma: no cover + from ssl import SSLContext +else: + SSLContext = None -def run_briefly(loop): - @asyncio.coroutine - def once(): - pass - t = asyncio.Task(once(), loop=loop) - loop.run_until_complete(t) + +REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" + + +def get_unused_port_socket(host: str) -> socket.socket: + return get_port_socket(host, 0) -def unused_port(): +def get_port_socket(host: str, port: int) -> socket.socket: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if REUSE_ADDRESS: + # Windows has different semantics for SO_REUSEADDR, + # so don't set it. Ref: + # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind((host, port)) + return s + + +def unused_port() -> int: """Return a port that is unused on the current host.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('127.0.0.1', 0)) + s.bind(("127.0.0.1", 0)) return s.getsockname()[1] class BaseTestServer(ABC): - def __init__(self, *, scheme=sentinel, loop=None, - host='127.0.0.1', skip_url_asserts=False, **kwargs): + __test__ = False + + def __init__( + self, + *, + scheme: Union[str, object] = sentinel, + loop: Optional[asyncio.AbstractEventLoop] = None, + host: str = "127.0.0.1", + port: Optional[int] = None, + skip_url_asserts: bool = False, + **kwargs: Any, + ) -> None: self._loop = loop - self.port = None - self.server = None - self.handler = None - self._root = None + self.runner = None # type: Optional[BaseRunner] + self._root = None # type: Optional[URL] self.host = host + self.port = port self._closed = False self.scheme = scheme self.skip_url_asserts = skip_url_asserts - @asyncio.coroutine - def start_server(self, loop=None, **kwargs): - if self.server: + async def start_server( + self, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any + ) -> None: + if self.runner: return self._loop = loop - self.port = unused_port() - self._ssl = kwargs.pop('ssl', None) + self._ssl = kwargs.pop("ssl", None) + self.runner = await self._make_runner(**kwargs) + await self.runner.setup() + if not self.port: + self.port = 0 + _sock = get_port_socket(self.host, self.port) + self.host, self.port = _sock.getsockname()[:2] + site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl) + await site.start() + server = site._server + assert server is not None + sockets = server.sockets + assert sockets is not None + self.port = sockets[0].getsockname()[1] if self.scheme is sentinel: if self._ssl: - scheme = 'https' + scheme = "https" else: - scheme = 'http' + scheme = "http" self.scheme = scheme - self._root = URL('{}://{}:{}'.format(self.scheme, - self.host, - self.port)) - - handler = yield from self._make_factory(**kwargs) - self.server = yield from self._loop.create_server( - handler, self.host, self.port, ssl=self._ssl) + self._root = URL(f"{self.scheme}://{self.host}:{self.port}") @abstractmethod # pragma: no cover - @asyncio.coroutine - def _make_factory(self, **kwargs): + async def _make_runner(self, **kwargs: Any) -> BaseRunner: pass - def make_url(self, path): + def make_url(self, path: str) -> URL: + assert self._root is not None url = URL(path) if not self.skip_url_asserts: assert not url.is_absolute() @@ -86,15 +138,23 @@ def make_url(self, path): return URL(str(self._root) + path) @property - def started(self): - return self.server is not None + def started(self) -> bool: + return self.runner is not None @property - def closed(self): + def closed(self) -> bool: return self._closed - @asyncio.coroutine - def close(self): + @property + def handler(self) -> Server: + # for backward compatibility + # web.Server instance + runner = self.runner + assert runner is not None + assert runner.server is not None + return runner.server + + async def close(self) -> None: """Close all fixtures created by the test client. After that point, the TestClient is no longer usable. @@ -107,72 +167,70 @@ def close(self): """ if self.started and not self.closed: - self.server.close() - yield from self.server.wait_closed() + assert self.runner is not None + await self.runner.cleanup() self._root = None self.port = None - yield from self._close_hook() self._closed = True - @abstractmethod - @asyncio.coroutine - def _close_hook(self): + def __enter__(self) -> None: + raise TypeError("Use async with instead") + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + # __exit__ should exist in pair with __enter__ but never executed pass # pragma: no cover - def __enter__(self): - self._loop.run_until_complete(self.start_server(loop=self._loop)) + async def __aenter__(self) -> "BaseTestServer": + await self.start_server(loop=self._loop) return self - def __exit__(self, exc_type, exc_value, traceback): - self._loop.run_until_complete(self.close()) - - if PY_35: - @asyncio.coroutine - def __aenter__(self): - yield from self.start_server(loop=self._loop) - return self - - @asyncio.coroutine - def __aexit__(self, exc_type, exc_value, traceback): - yield from self.close() + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + await self.close() class TestServer(BaseTestServer): - - def __init__(self, app, *, - scheme=sentinel, host='127.0.0.1', **kwargs): + def __init__( + self, + app: Application, + *, + scheme: Union[str, object] = sentinel, + host: str = "127.0.0.1", + port: Optional[int] = None, + **kwargs: Any, + ): self.app = app - super().__init__(scheme=scheme, host=host, **kwargs) + super().__init__(scheme=scheme, host=host, port=port, **kwargs) - @asyncio.coroutine - def _make_factory(self, **kwargs): - self.handler = self.app.make_handler(loop=self._loop, **kwargs) - yield from self.app.startup() - return self.handler - - @asyncio.coroutine - def _close_hook(self): - yield from self.app.shutdown() - yield from self.handler.shutdown() - yield from self.app.cleanup() + async def _make_runner(self, **kwargs: Any) -> BaseRunner: + return AppRunner(self.app, **kwargs) class RawTestServer(BaseTestServer): - - def __init__(self, handler, *, - scheme=sentinel, host='127.0.0.1', **kwargs): + def __init__( + self, + handler: _RequestHandler, + *, + scheme: Union[str, object] = sentinel, + host: str = "127.0.0.1", + port: Optional[int] = None, + **kwargs: Any, + ) -> None: self._handler = handler - super().__init__(scheme=scheme, host=host, **kwargs) + super().__init__(scheme=scheme, host=host, port=port, **kwargs) - @asyncio.coroutine - def _make_factory(self, debug=True, **kwargs): - self.handler = Server( - self._handler, loop=self._loop, debug=True, **kwargs) - return self.handler - - @asyncio.coroutine - def _close_hook(self): - return + async def _make_runner(self, debug: bool = True, **kwargs: Any) -> ServerRunner: + srv = Server(self._handler, loop=self._loop, debug=debug, **kwargs) + return ServerRunner(srv, debug=debug, **kwargs) class TestClient: @@ -183,51 +241,50 @@ class TestClient: """ - def __init__(self, app_or_server, *, scheme=sentinel, host=sentinel, - cookie_jar=None, server_kwargs=None, loop=None, **kwargs): - if isinstance(app_or_server, BaseTestServer): - if scheme is not sentinel or host is not sentinel: - raise ValueError("scheme and host are mutable exclusive " - "with TestServer parameter") - self._server = app_or_server - elif isinstance(app_or_server, Application): - scheme = "http" if scheme is sentinel else scheme - host = '127.0.0.1' if host is sentinel else host - server_kwargs = server_kwargs or {} - self._server = TestServer( - app_or_server, - scheme=scheme, host=host, **server_kwargs) - else: - raise TypeError("app_or_server should be either web.Application " - "or TestServer instance") + __test__ = False + + def __init__( + self, + server: BaseTestServer, + *, + cookie_jar: Optional[AbstractCookieJar] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + **kwargs: Any, + ) -> None: + if not isinstance(server, BaseTestServer): + raise TypeError( + "server must be TestServer " "instance, found type: %r" % type(server) + ) + self._server = server self._loop = loop if cookie_jar is None: cookie_jar = aiohttp.CookieJar(unsafe=True, loop=loop) - self._session = ClientSession(loop=loop, - cookie_jar=cookie_jar, - **kwargs) + self._session = ClientSession(loop=loop, cookie_jar=cookie_jar, **kwargs) self._closed = False - self._responses = [] - self._websockets = [] + self._responses = [] # type: List[ClientResponse] + self._websockets = [] # type: List[ClientWebSocketResponse] - @asyncio.coroutine - def start_server(self): - yield from self._server.start_server(loop=self._loop) + async def start_server(self) -> None: + await self._server.start_server(loop=self._loop) @property - def host(self): + def host(self) -> str: return self._server.host @property - def port(self): + def port(self) -> Optional[int]: return self._server.port @property - def server(self): + def server(self) -> BaseTestServer: return self._server @property - def session(self): + def app(self) -> Application: + return getattr(self._server, "app", None) + + @property + def session(self) -> ClientSession: """An internal aiohttp.ClientSession. Unlike the methods on the TestClient, client session requests @@ -237,81 +294,67 @@ def session(self): """ return self._session - def make_url(self, path): + def make_url(self, path: str) -> URL: return self._server.make_url(path) - @asyncio.coroutine - def request(self, method, path, *args, **kwargs): + async def _request(self, method: str, path: str, **kwargs: Any) -> ClientResponse: + resp = await self._session.request(method, self.make_url(path), **kwargs) + # save it to close later + self._responses.append(resp) + return resp + + def request(self, method: str, path: str, **kwargs: Any) -> _RequestContextManager: """Routes a request to tested http server. - The interface is identical to asyncio.ClientSession.request, + The interface is identical to aiohttp.ClientSession.request, except the loop kwarg is overridden by the instance used by the test server. """ - resp = yield from self._session.request( - method, self.make_url(path), *args, **kwargs - ) - # save it to close later - self._responses.append(resp) - return resp + return _RequestContextManager(self._request(method, path, **kwargs)) - def get(self, path, *args, **kwargs): + def get(self, path: str, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP GET request.""" - return _RequestContextManager( - self.request(hdrs.METH_GET, path, *args, **kwargs) - ) + return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) - def post(self, path, *args, **kwargs): + def post(self, path: str, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP POST request.""" - return _RequestContextManager( - self.request(hdrs.METH_POST, path, *args, **kwargs) - ) + return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) - def options(self, path, *args, **kwargs): + def options(self, path: str, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP OPTIONS request.""" - return _RequestContextManager( - self.request(hdrs.METH_OPTIONS, path, *args, **kwargs) - ) + return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs)) - def head(self, path, *args, **kwargs): + def head(self, path: str, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP HEAD request.""" - return _RequestContextManager( - self.request(hdrs.METH_HEAD, path, *args, **kwargs) - ) + return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) - def put(self, path, *args, **kwargs): + def put(self, path: str, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP PUT request.""" - return _RequestContextManager( - self.request(hdrs.METH_PUT, path, *args, **kwargs) - ) + return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) - def patch(self, path, *args, **kwargs): + def patch(self, path: str, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP PATCH request.""" - return _RequestContextManager( - self.request(hdrs.METH_PATCH, path, *args, **kwargs) - ) + return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs)) - def delete(self, path, *args, **kwargs): + def delete(self, path: str, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP PATCH request.""" - return _RequestContextManager( - self.request(hdrs.METH_DELETE, path, *args, **kwargs) - ) + return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs)) - @asyncio.coroutine - def ws_connect(self, path, *args, **kwargs): + def ws_connect(self, path: str, **kwargs: Any) -> _WSRequestContextManager: """Initiate websocket connection. The api corresponds to aiohttp.ClientSession.ws_connect. """ - ws = yield from self._session.ws_connect( - self.make_url(path), *args, **kwargs) + return _WSRequestContextManager(self._ws_connect(path, **kwargs)) + + async def _ws_connect(self, path: str, **kwargs: Any) -> ClientWebSocketResponse: + ws = await self._session.ws_connect(self.make_url(path), **kwargs) self._websockets.append(ws) return ws - @asyncio.coroutine - def close(self): + async def close(self) -> None: """Close all fixtures created by the test client. After that point, the TestClient is no longer usable. @@ -327,27 +370,34 @@ def close(self): for resp in self._responses: resp.close() for ws in self._websockets: - yield from ws.close() - self._session.close() - yield from self._server.close() + await ws.close() + await self._session.close() + await self._server.close() self._closed = True - def __enter__(self): - self._loop.run_until_complete(self.start_server()) - return self + def __enter__(self) -> None: + raise TypeError("Use async with instead") - def __exit__(self, exc_type, exc_value, traceback): - self._loop.run_until_complete(self.close()) + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + # __exit__ should exist in pair with __enter__ but never executed + pass # pragma: no cover - if PY_35: - @asyncio.coroutine - def __aenter__(self): - yield from self.start_server() - return self + async def __aenter__(self) -> "TestClient": + await self.start_server() + return self - @asyncio.coroutine - def __aexit__(self, exc_type, exc_value, traceback): - yield from self.close() + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + await self.close() class AioHTTPTestCase(unittest.TestCase): @@ -366,8 +416,7 @@ class AioHTTPTestCase(unittest.TestCase): execute function on the test client using asynchronous methods. """ - @asyncio.coroutine - def get_application(self): + async def get_application(self) -> Application: """ This method should be overridden to return the aiohttp.web.Application @@ -376,33 +425,46 @@ def get_application(self): """ return self.get_app() - def get_app(self): + def get_app(self) -> Application: """Obsolete method used to constructing web application. Use .get_application() coroutine instead """ - pass # pragma: no cover + raise RuntimeError("Did you forget to define get_application()?") - def setUp(self): + def setUp(self) -> None: self.loop = setup_test_loop() self.app = self.loop.run_until_complete(self.get_application()) - self.client = self.loop.run_until_complete(self._get_client(self.app)) + self.server = self.loop.run_until_complete(self.get_server(self.app)) + self.client = self.loop.run_until_complete(self.get_client(self.server)) self.loop.run_until_complete(self.client.start_server()) - def tearDown(self): + self.loop.run_until_complete(self.setUpAsync()) + + async def setUpAsync(self) -> None: + pass + + def tearDown(self) -> None: + self.loop.run_until_complete(self.tearDownAsync()) self.loop.run_until_complete(self.client.close()) teardown_test_loop(self.loop) - @asyncio.coroutine - def _get_client(self, app): + async def tearDownAsync(self) -> None: + pass + + async def get_server(self, app: Application) -> TestServer: + """Return a TestServer instance.""" + return TestServer(app, loop=self.loop) + + async def get_client(self, server: TestServer) -> TestClient: """Return a TestClient instance.""" - return TestClient(self.app, loop=self.loop) + return TestClient(server, loop=self.loop) -def unittest_run_loop(func): +def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any: """A decorator dedicated to use with asynchronous methods of an AioHTTPTestCase. @@ -410,15 +472,20 @@ def unittest_run_loop(func): the self.loop of the AioHTTPTestCase. """ - @functools.wraps(func) - def new_func(self): - return self.loop.run_until_complete(func(self)) + @functools.wraps(func, *args, **kwargs) + def new_func(self: Any, *inner_args: Any, **inner_kwargs: Any) -> Any: + return self.loop.run_until_complete(func(self, *inner_args, **inner_kwargs)) return new_func +_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop] + + @contextlib.contextmanager -def loop_context(loop_factory=asyncio.new_event_loop, fast=False): +def loop_context( + loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False +) -> Iterator[asyncio.AbstractEventLoop]: """A contextmanager that creates an event_loop, for test purposes. Handles the creation and cleanup of a test loop. @@ -428,7 +495,9 @@ def loop_context(loop_factory=asyncio.new_event_loop, fast=False): teardown_test_loop(loop, fast=fast) -def setup_test_loop(loop_factory=asyncio.new_event_loop): +def setup_test_loop( + loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, +) -> asyncio.AbstractEventLoop: """Create and return an asyncio.BaseEventLoop instance. @@ -436,11 +505,23 @@ def setup_test_loop(loop_factory=asyncio.new_event_loop): once they are done with the loop. """ loop = loop_factory() - asyncio.set_event_loop(None) + try: + module = loop.__class__.__module__ + skip_watcher = "uvloop" in module + except AttributeError: # pragma: no cover + # Just in case + skip_watcher = True + asyncio.set_event_loop(loop) + if sys.platform != "win32" and not skip_watcher: + policy = asyncio.get_event_loop_policy() + watcher = asyncio.SafeChildWatcher() + watcher.attach_loop(loop) + with contextlib.suppress(NotImplementedError): + policy.set_child_watcher(watcher) return loop -def teardown_test_loop(loop, fast=False): +def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None: """Teardown and cleanup an event_loop created by setup_test_loop. @@ -457,18 +538,29 @@ def teardown_test_loop(loop, fast=False): asyncio.set_event_loop(None) -def _create_app_mock(): - app = mock.Mock() +def _create_app_mock() -> mock.MagicMock: + def get_dict(app: Any, key: str) -> Any: + return app.__app_dict[key] + + def set_dict(app: Any, key: str, value: Any) -> None: + app.__app_dict[key] = value + + app = mock.MagicMock() + app.__app_dict = {} + app.__getitem__ = get_dict + app.__setitem__ = set_dict + app._debug = False app.on_response_prepare = Signal(app) + app.on_response_prepare.freeze() return app -def _create_transport(sslcontext=None): +def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock: transport = mock.Mock() - def get_extra_info(key): - if key == 'sslcontext': + def get_extra_info(key: str) -> Optional[SSLContext]: + if key == "sslcontext": return sslcontext else: return None @@ -477,17 +569,23 @@ def get_extra_info(key): return transport -def make_mocked_request(method, path, headers=None, *, - version=HttpVersion(1, 1), closing=False, - app=None, - writer=sentinel, - payload_writer=sentinel, - protocol=sentinel, - transport=sentinel, - payload=sentinel, - sslcontext=None, - secure_proxy_ssl_header=None, - client_max_size=1024**2): +def make_mocked_request( + method: str, + path: str, + headers: Any = None, + *, + match_info: Any = sentinel, + version: HttpVersion = HttpVersion(1, 1), + closing: bool = False, + app: Any = None, + writer: Any = sentinel, + protocol: Any = sentinel, + transport: Any = sentinel, + payload: Any = sentinel, + sslcontext: Optional[SSLContext] = None, + client_max_size: int = 1024 ** 2, + loop: Any = ..., +) -> Request: """Creates mocked web.Request testing purposes. Useful in unit tests, when spinning full web server is overkill or @@ -496,78 +594,83 @@ def make_mocked_request(method, path, headers=None, *, """ task = mock.Mock() - loop = mock.Mock() - loop.create_future.return_value = () + if loop is ...: + loop = mock.Mock() + loop.create_future.return_value = () if version < HttpVersion(1, 1): closing = True if headers: - headers = CIMultiDict(headers) + headers = CIMultiDictProxy(CIMultiDict(headers)) raw_hdrs = tuple( - (k.encode('utf-8'), v.encode('utf-8')) for k, v in headers.items()) + (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() + ) else: - headers = CIMultiDict() + headers = CIMultiDictProxy(CIMultiDict()) raw_hdrs = () - chunked = 'chunked' in headers.get(hdrs.TRANSFER_ENCODING, '').lower() + chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower() message = RawRequestMessage( - method, path, version, headers, - raw_hdrs, closing, False, False, chunked, URL(path)) + method, + path, + version, + headers, + raw_hdrs, + closing, + False, + False, + chunked, + URL(path), + ) if app is None: app = _create_app_mock() - if protocol is sentinel: - protocol = mock.Mock() - if transport is sentinel: transport = _create_transport(sslcontext) + if protocol is sentinel: + protocol = mock.Mock() + protocol.transport = transport + if writer is sentinel: writer = mock.Mock() + writer.write_headers = make_mocked_coro(None) + writer.write = make_mocked_coro(None) + writer.write_eof = make_mocked_coro(None) + writer.drain = make_mocked_coro(None) writer.transport = transport - if payload_writer is sentinel: - payload_writer = mock.Mock() - payload_writer.write_eof.side_effect = noop - payload_writer.drain.side_effect = noop - protocol.transport = transport protocol.writer = writer if payload is sentinel: payload = mock.Mock() - time_service = mock.Mock() - time_service.time.return_value = 12345 - time_service.strtime.return_value = "Tue, 15 Nov 1994 08:12:31 GMT" - - @contextmanager - def timeout(*args, **kw): - yield + req = Request( + message, payload, protocol, writer, task, loop, client_max_size=client_max_size + ) - time_service.timeout = mock.Mock() - time_service.timeout.side_effect = timeout - - req = Request(message, payload, - protocol, payload_writer, time_service, task, - secure_proxy_ssl_header=secure_proxy_ssl_header, - client_max_size=client_max_size) - - match_info = UrlMappingMatchInfo({}, mock.Mock()) + match_info = UrlMappingMatchInfo( + {} if match_info is sentinel else match_info, mock.Mock() + ) match_info.add_app(app) req._match_info = match_info return req -def make_mocked_coro(return_value=sentinel, raise_exception=sentinel): +def make_mocked_coro( + return_value: Any = sentinel, raise_exception: Any = sentinel +) -> Any: """Creates a coroutine mock.""" - @asyncio.coroutine - def mock_coro(*args, **kwargs): + + async def mock_coro(*args: Any, **kwargs: Any) -> Any: if raise_exception is not sentinel: raise raise_exception - return return_value + if not inspect.isawaitable(return_value): + return return_value + await return_value return mock.Mock(wraps=mock_coro) diff --git a/aiohttp/tracing.py b/aiohttp/tracing.py new file mode 100644 index 00000000000..7ae7948f9ac --- /dev/null +++ b/aiohttp/tracing.py @@ -0,0 +1,442 @@ +from types import SimpleNamespace +from typing import TYPE_CHECKING, Awaitable, Optional, Type, TypeVar + +import attr +from multidict import CIMultiDict +from yarl import URL + +from .client_reqrep import ClientResponse +from .signals import Signal + +if TYPE_CHECKING: # pragma: no cover + from typing_extensions import Protocol + + from .client import ClientSession + + _ParamT_contra = TypeVar("_ParamT_contra", contravariant=True) + + class _SignalCallback(Protocol[_ParamT_contra]): + def __call__( + self, + __client_session: ClientSession, + __trace_config_ctx: SimpleNamespace, + __params: _ParamT_contra, + ) -> Awaitable[None]: + ... + + +__all__ = ( + "TraceConfig", + "TraceRequestStartParams", + "TraceRequestEndParams", + "TraceRequestExceptionParams", + "TraceConnectionQueuedStartParams", + "TraceConnectionQueuedEndParams", + "TraceConnectionCreateStartParams", + "TraceConnectionCreateEndParams", + "TraceConnectionReuseconnParams", + "TraceDnsResolveHostStartParams", + "TraceDnsResolveHostEndParams", + "TraceDnsCacheHitParams", + "TraceDnsCacheMissParams", + "TraceRequestRedirectParams", + "TraceRequestChunkSentParams", + "TraceResponseChunkReceivedParams", +) + + +class TraceConfig: + """First-class used to trace requests launched via ClientSession + objects.""" + + def __init__( + self, trace_config_ctx_factory: Type[SimpleNamespace] = SimpleNamespace + ) -> None: + self._on_request_start = Signal( + self + ) # type: Signal[_SignalCallback[TraceRequestStartParams]] + self._on_request_chunk_sent = Signal( + self + ) # type: Signal[_SignalCallback[TraceRequestChunkSentParams]] + self._on_response_chunk_received = Signal( + self + ) # type: Signal[_SignalCallback[TraceResponseChunkReceivedParams]] + self._on_request_end = Signal( + self + ) # type: Signal[_SignalCallback[TraceRequestEndParams]] + self._on_request_exception = Signal( + self + ) # type: Signal[_SignalCallback[TraceRequestExceptionParams]] + self._on_request_redirect = Signal( + self + ) # type: Signal[_SignalCallback[TraceRequestRedirectParams]] + self._on_connection_queued_start = Signal( + self + ) # type: Signal[_SignalCallback[TraceConnectionQueuedStartParams]] + self._on_connection_queued_end = Signal( + self + ) # type: Signal[_SignalCallback[TraceConnectionQueuedEndParams]] + self._on_connection_create_start = Signal( + self + ) # type: Signal[_SignalCallback[TraceConnectionCreateStartParams]] + self._on_connection_create_end = Signal( + self + ) # type: Signal[_SignalCallback[TraceConnectionCreateEndParams]] + self._on_connection_reuseconn = Signal( + self + ) # type: Signal[_SignalCallback[TraceConnectionReuseconnParams]] + self._on_dns_resolvehost_start = Signal( + self + ) # type: Signal[_SignalCallback[TraceDnsResolveHostStartParams]] + self._on_dns_resolvehost_end = Signal( + self + ) # type: Signal[_SignalCallback[TraceDnsResolveHostEndParams]] + self._on_dns_cache_hit = Signal( + self + ) # type: Signal[_SignalCallback[TraceDnsCacheHitParams]] + self._on_dns_cache_miss = Signal( + self + ) # type: Signal[_SignalCallback[TraceDnsCacheMissParams]] + + self._trace_config_ctx_factory = trace_config_ctx_factory + + def trace_config_ctx( + self, trace_request_ctx: Optional[SimpleNamespace] = None + ) -> SimpleNamespace: + """ Return a new trace_config_ctx instance """ + return self._trace_config_ctx_factory(trace_request_ctx=trace_request_ctx) + + def freeze(self) -> None: + self._on_request_start.freeze() + self._on_request_chunk_sent.freeze() + self._on_response_chunk_received.freeze() + self._on_request_end.freeze() + self._on_request_exception.freeze() + self._on_request_redirect.freeze() + self._on_connection_queued_start.freeze() + self._on_connection_queued_end.freeze() + self._on_connection_create_start.freeze() + self._on_connection_create_end.freeze() + self._on_connection_reuseconn.freeze() + self._on_dns_resolvehost_start.freeze() + self._on_dns_resolvehost_end.freeze() + self._on_dns_cache_hit.freeze() + self._on_dns_cache_miss.freeze() + + @property + def on_request_start(self) -> "Signal[_SignalCallback[TraceRequestStartParams]]": + return self._on_request_start + + @property + def on_request_chunk_sent( + self, + ) -> "Signal[_SignalCallback[TraceRequestChunkSentParams]]": + return self._on_request_chunk_sent + + @property + def on_response_chunk_received( + self, + ) -> "Signal[_SignalCallback[TraceResponseChunkReceivedParams]]": + return self._on_response_chunk_received + + @property + def on_request_end(self) -> "Signal[_SignalCallback[TraceRequestEndParams]]": + return self._on_request_end + + @property + def on_request_exception( + self, + ) -> "Signal[_SignalCallback[TraceRequestExceptionParams]]": + return self._on_request_exception + + @property + def on_request_redirect( + self, + ) -> "Signal[_SignalCallback[TraceRequestRedirectParams]]": + return self._on_request_redirect + + @property + def on_connection_queued_start( + self, + ) -> "Signal[_SignalCallback[TraceConnectionQueuedStartParams]]": + return self._on_connection_queued_start + + @property + def on_connection_queued_end( + self, + ) -> "Signal[_SignalCallback[TraceConnectionQueuedEndParams]]": + return self._on_connection_queued_end + + @property + def on_connection_create_start( + self, + ) -> "Signal[_SignalCallback[TraceConnectionCreateStartParams]]": + return self._on_connection_create_start + + @property + def on_connection_create_end( + self, + ) -> "Signal[_SignalCallback[TraceConnectionCreateEndParams]]": + return self._on_connection_create_end + + @property + def on_connection_reuseconn( + self, + ) -> "Signal[_SignalCallback[TraceConnectionReuseconnParams]]": + return self._on_connection_reuseconn + + @property + def on_dns_resolvehost_start( + self, + ) -> "Signal[_SignalCallback[TraceDnsResolveHostStartParams]]": + return self._on_dns_resolvehost_start + + @property + def on_dns_resolvehost_end( + self, + ) -> "Signal[_SignalCallback[TraceDnsResolveHostEndParams]]": + return self._on_dns_resolvehost_end + + @property + def on_dns_cache_hit(self) -> "Signal[_SignalCallback[TraceDnsCacheHitParams]]": + return self._on_dns_cache_hit + + @property + def on_dns_cache_miss(self) -> "Signal[_SignalCallback[TraceDnsCacheMissParams]]": + return self._on_dns_cache_miss + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceRequestStartParams: + """ Parameters sent by the `on_request_start` signal""" + + method: str + url: URL + headers: "CIMultiDict[str]" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceRequestChunkSentParams: + """ Parameters sent by the `on_request_chunk_sent` signal""" + + method: str + url: URL + chunk: bytes + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceResponseChunkReceivedParams: + """ Parameters sent by the `on_response_chunk_received` signal""" + + method: str + url: URL + chunk: bytes + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceRequestEndParams: + """ Parameters sent by the `on_request_end` signal""" + + method: str + url: URL + headers: "CIMultiDict[str]" + response: ClientResponse + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceRequestExceptionParams: + """ Parameters sent by the `on_request_exception` signal""" + + method: str + url: URL + headers: "CIMultiDict[str]" + exception: BaseException + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceRequestRedirectParams: + """ Parameters sent by the `on_request_redirect` signal""" + + method: str + url: URL + headers: "CIMultiDict[str]" + response: ClientResponse + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceConnectionQueuedStartParams: + """ Parameters sent by the `on_connection_queued_start` signal""" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceConnectionQueuedEndParams: + """ Parameters sent by the `on_connection_queued_end` signal""" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceConnectionCreateStartParams: + """ Parameters sent by the `on_connection_create_start` signal""" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceConnectionCreateEndParams: + """ Parameters sent by the `on_connection_create_end` signal""" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceConnectionReuseconnParams: + """ Parameters sent by the `on_connection_reuseconn` signal""" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceDnsResolveHostStartParams: + """ Parameters sent by the `on_dns_resolvehost_start` signal""" + + host: str + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceDnsResolveHostEndParams: + """ Parameters sent by the `on_dns_resolvehost_end` signal""" + + host: str + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceDnsCacheHitParams: + """ Parameters sent by the `on_dns_cache_hit` signal""" + + host: str + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceDnsCacheMissParams: + """ Parameters sent by the `on_dns_cache_miss` signal""" + + host: str + + +class Trace: + """Internal class used to keep together the main dependencies used + at the moment of send a signal.""" + + def __init__( + self, + session: "ClientSession", + trace_config: TraceConfig, + trace_config_ctx: SimpleNamespace, + ) -> None: + self._trace_config = trace_config + self._trace_config_ctx = trace_config_ctx + self._session = session + + async def send_request_start( + self, method: str, url: URL, headers: "CIMultiDict[str]" + ) -> None: + return await self._trace_config.on_request_start.send( + self._session, + self._trace_config_ctx, + TraceRequestStartParams(method, url, headers), + ) + + async def send_request_chunk_sent( + self, method: str, url: URL, chunk: bytes + ) -> None: + return await self._trace_config.on_request_chunk_sent.send( + self._session, + self._trace_config_ctx, + TraceRequestChunkSentParams(method, url, chunk), + ) + + async def send_response_chunk_received( + self, method: str, url: URL, chunk: bytes + ) -> None: + return await self._trace_config.on_response_chunk_received.send( + self._session, + self._trace_config_ctx, + TraceResponseChunkReceivedParams(method, url, chunk), + ) + + async def send_request_end( + self, + method: str, + url: URL, + headers: "CIMultiDict[str]", + response: ClientResponse, + ) -> None: + return await self._trace_config.on_request_end.send( + self._session, + self._trace_config_ctx, + TraceRequestEndParams(method, url, headers, response), + ) + + async def send_request_exception( + self, + method: str, + url: URL, + headers: "CIMultiDict[str]", + exception: BaseException, + ) -> None: + return await self._trace_config.on_request_exception.send( + self._session, + self._trace_config_ctx, + TraceRequestExceptionParams(method, url, headers, exception), + ) + + async def send_request_redirect( + self, + method: str, + url: URL, + headers: "CIMultiDict[str]", + response: ClientResponse, + ) -> None: + return await self._trace_config._on_request_redirect.send( + self._session, + self._trace_config_ctx, + TraceRequestRedirectParams(method, url, headers, response), + ) + + async def send_connection_queued_start(self) -> None: + return await self._trace_config.on_connection_queued_start.send( + self._session, self._trace_config_ctx, TraceConnectionQueuedStartParams() + ) + + async def send_connection_queued_end(self) -> None: + return await self._trace_config.on_connection_queued_end.send( + self._session, self._trace_config_ctx, TraceConnectionQueuedEndParams() + ) + + async def send_connection_create_start(self) -> None: + return await self._trace_config.on_connection_create_start.send( + self._session, self._trace_config_ctx, TraceConnectionCreateStartParams() + ) + + async def send_connection_create_end(self) -> None: + return await self._trace_config.on_connection_create_end.send( + self._session, self._trace_config_ctx, TraceConnectionCreateEndParams() + ) + + async def send_connection_reuseconn(self) -> None: + return await self._trace_config.on_connection_reuseconn.send( + self._session, self._trace_config_ctx, TraceConnectionReuseconnParams() + ) + + async def send_dns_resolvehost_start(self, host: str) -> None: + return await self._trace_config.on_dns_resolvehost_start.send( + self._session, self._trace_config_ctx, TraceDnsResolveHostStartParams(host) + ) + + async def send_dns_resolvehost_end(self, host: str) -> None: + return await self._trace_config.on_dns_resolvehost_end.send( + self._session, self._trace_config_ctx, TraceDnsResolveHostEndParams(host) + ) + + async def send_dns_cache_hit(self, host: str) -> None: + return await self._trace_config.on_dns_cache_hit.send( + self._session, self._trace_config_ctx, TraceDnsCacheHitParams(host) + ) + + async def send_dns_cache_miss(self, host: str) -> None: + return await self._trace_config.on_dns_cache_miss.send( + self._session, self._trace_config_ctx, TraceDnsCacheMissParams(host) + ) diff --git a/aiohttp/typedefs.py b/aiohttp/typedefs.py new file mode 100644 index 00000000000..1b68a242af5 --- /dev/null +++ b/aiohttp/typedefs.py @@ -0,0 +1,46 @@ +import json +import os +import pathlib +import sys +from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Tuple, Union + +from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy, istr +from yarl import URL + +DEFAULT_JSON_ENCODER = json.dumps +DEFAULT_JSON_DECODER = json.loads + +if TYPE_CHECKING: # pragma: no cover + _CIMultiDict = CIMultiDict[str] + _CIMultiDictProxy = CIMultiDictProxy[str] + _MultiDict = MultiDict[str] + _MultiDictProxy = MultiDictProxy[str] + from http.cookies import BaseCookie, Morsel +else: + _CIMultiDict = CIMultiDict + _CIMultiDictProxy = CIMultiDictProxy + _MultiDict = MultiDict + _MultiDictProxy = MultiDictProxy + +Byteish = Union[bytes, bytearray, memoryview] +JSONEncoder = Callable[[Any], str] +JSONDecoder = Callable[[str], Any] +LooseHeaders = Union[Mapping[Union[str, istr], str], _CIMultiDict, _CIMultiDictProxy] +RawHeaders = Tuple[Tuple[bytes, bytes], ...] +StrOrURL = Union[str, URL] + +LooseCookiesMappings = Mapping[str, Union[str, "BaseCookie[str]", "Morsel[Any]"]] +LooseCookiesIterables = Iterable[ + Tuple[str, Union[str, "BaseCookie[str]", "Morsel[Any]"]] +] +LooseCookies = Union[ + LooseCookiesMappings, + LooseCookiesIterables, + "BaseCookie[str]", +] + + +if sys.version_info >= (3, 6): + PathLike = Union[str, "os.PathLike[str]"] +else: + PathLike = Union[str, pathlib.PurePath] diff --git a/aiohttp/web.py b/aiohttp/web.py index 1d7167dcbbf..557e3c3b4d0 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -1,473 +1,576 @@ import asyncio -import os +import logging import socket -import stat import sys -import warnings from argparse import ArgumentParser -from collections import Iterable, MutableMapping +from collections.abc import Iterable from importlib import import_module +from typing import ( + Any as Any, + Awaitable as Awaitable, + Callable as Callable, + Iterable as TypingIterable, + List as List, + Optional as Optional, + Set as Set, + Type as Type, + Union as Union, + cast as cast, +) + +from .abc import AbstractAccessLogger +from .helpers import all_tasks +from .log import access_logger +from .web_app import Application as Application, CleanupError as CleanupError +from .web_exceptions import ( + HTTPAccepted as HTTPAccepted, + HTTPBadGateway as HTTPBadGateway, + HTTPBadRequest as HTTPBadRequest, + HTTPClientError as HTTPClientError, + HTTPConflict as HTTPConflict, + HTTPCreated as HTTPCreated, + HTTPError as HTTPError, + HTTPException as HTTPException, + HTTPExpectationFailed as HTTPExpectationFailed, + HTTPFailedDependency as HTTPFailedDependency, + HTTPForbidden as HTTPForbidden, + HTTPFound as HTTPFound, + HTTPGatewayTimeout as HTTPGatewayTimeout, + HTTPGone as HTTPGone, + HTTPInsufficientStorage as HTTPInsufficientStorage, + HTTPInternalServerError as HTTPInternalServerError, + HTTPLengthRequired as HTTPLengthRequired, + HTTPMethodNotAllowed as HTTPMethodNotAllowed, + HTTPMisdirectedRequest as HTTPMisdirectedRequest, + HTTPMovedPermanently as HTTPMovedPermanently, + HTTPMultipleChoices as HTTPMultipleChoices, + HTTPNetworkAuthenticationRequired as HTTPNetworkAuthenticationRequired, + HTTPNoContent as HTTPNoContent, + HTTPNonAuthoritativeInformation as HTTPNonAuthoritativeInformation, + HTTPNotAcceptable as HTTPNotAcceptable, + HTTPNotExtended as HTTPNotExtended, + HTTPNotFound as HTTPNotFound, + HTTPNotImplemented as HTTPNotImplemented, + HTTPNotModified as HTTPNotModified, + HTTPOk as HTTPOk, + HTTPPartialContent as HTTPPartialContent, + HTTPPaymentRequired as HTTPPaymentRequired, + HTTPPermanentRedirect as HTTPPermanentRedirect, + HTTPPreconditionFailed as HTTPPreconditionFailed, + HTTPPreconditionRequired as HTTPPreconditionRequired, + HTTPProxyAuthenticationRequired as HTTPProxyAuthenticationRequired, + HTTPRedirection as HTTPRedirection, + HTTPRequestEntityTooLarge as HTTPRequestEntityTooLarge, + HTTPRequestHeaderFieldsTooLarge as HTTPRequestHeaderFieldsTooLarge, + HTTPRequestRangeNotSatisfiable as HTTPRequestRangeNotSatisfiable, + HTTPRequestTimeout as HTTPRequestTimeout, + HTTPRequestURITooLong as HTTPRequestURITooLong, + HTTPResetContent as HTTPResetContent, + HTTPSeeOther as HTTPSeeOther, + HTTPServerError as HTTPServerError, + HTTPServiceUnavailable as HTTPServiceUnavailable, + HTTPSuccessful as HTTPSuccessful, + HTTPTemporaryRedirect as HTTPTemporaryRedirect, + HTTPTooManyRequests as HTTPTooManyRequests, + HTTPUnauthorized as HTTPUnauthorized, + HTTPUnavailableForLegalReasons as HTTPUnavailableForLegalReasons, + HTTPUnprocessableEntity as HTTPUnprocessableEntity, + HTTPUnsupportedMediaType as HTTPUnsupportedMediaType, + HTTPUpgradeRequired as HTTPUpgradeRequired, + HTTPUseProxy as HTTPUseProxy, + HTTPVariantAlsoNegotiates as HTTPVariantAlsoNegotiates, + HTTPVersionNotSupported as HTTPVersionNotSupported, +) +from .web_fileresponse import FileResponse as FileResponse +from .web_log import AccessLogger +from .web_middlewares import ( + middleware as middleware, + normalize_path_middleware as normalize_path_middleware, +) +from .web_protocol import ( + PayloadAccessError as PayloadAccessError, + RequestHandler as RequestHandler, + RequestPayloadError as RequestPayloadError, +) +from .web_request import ( + BaseRequest as BaseRequest, + FileField as FileField, + Request as Request, +) +from .web_response import ( + ContentCoding as ContentCoding, + Response as Response, + StreamResponse as StreamResponse, + json_response as json_response, +) +from .web_routedef import ( + AbstractRouteDef as AbstractRouteDef, + RouteDef as RouteDef, + RouteTableDef as RouteTableDef, + StaticDef as StaticDef, + delete as delete, + get as get, + head as head, + options as options, + patch as patch, + post as post, + put as put, + route as route, + static as static, + view as view, +) +from .web_runner import ( + AppRunner as AppRunner, + BaseRunner as BaseRunner, + BaseSite as BaseSite, + GracefulExit as GracefulExit, + NamedPipeSite as NamedPipeSite, + ServerRunner as ServerRunner, + SockSite as SockSite, + TCPSite as TCPSite, + UnixSite as UnixSite, +) +from .web_server import Server as Server +from .web_urldispatcher import ( + AbstractResource as AbstractResource, + AbstractRoute as AbstractRoute, + DynamicResource as DynamicResource, + PlainResource as PlainResource, + Resource as Resource, + ResourceRoute as ResourceRoute, + StaticResource as StaticResource, + UrlDispatcher as UrlDispatcher, + UrlMappingMatchInfo as UrlMappingMatchInfo, + View as View, +) +from .web_ws import ( + WebSocketReady as WebSocketReady, + WebSocketResponse as WebSocketResponse, + WSMsgType as WSMsgType, +) + +__all__ = ( + # web_app + "Application", + "CleanupError", + # web_exceptions + "HTTPAccepted", + "HTTPBadGateway", + "HTTPBadRequest", + "HTTPClientError", + "HTTPConflict", + "HTTPCreated", + "HTTPError", + "HTTPException", + "HTTPExpectationFailed", + "HTTPFailedDependency", + "HTTPForbidden", + "HTTPFound", + "HTTPGatewayTimeout", + "HTTPGone", + "HTTPInsufficientStorage", + "HTTPInternalServerError", + "HTTPLengthRequired", + "HTTPMethodNotAllowed", + "HTTPMisdirectedRequest", + "HTTPMovedPermanently", + "HTTPMultipleChoices", + "HTTPNetworkAuthenticationRequired", + "HTTPNoContent", + "HTTPNonAuthoritativeInformation", + "HTTPNotAcceptable", + "HTTPNotExtended", + "HTTPNotFound", + "HTTPNotImplemented", + "HTTPNotModified", + "HTTPOk", + "HTTPPartialContent", + "HTTPPaymentRequired", + "HTTPPermanentRedirect", + "HTTPPreconditionFailed", + "HTTPPreconditionRequired", + "HTTPProxyAuthenticationRequired", + "HTTPRedirection", + "HTTPRequestEntityTooLarge", + "HTTPRequestHeaderFieldsTooLarge", + "HTTPRequestRangeNotSatisfiable", + "HTTPRequestTimeout", + "HTTPRequestURITooLong", + "HTTPResetContent", + "HTTPSeeOther", + "HTTPServerError", + "HTTPServiceUnavailable", + "HTTPSuccessful", + "HTTPTemporaryRedirect", + "HTTPTooManyRequests", + "HTTPUnauthorized", + "HTTPUnavailableForLegalReasons", + "HTTPUnprocessableEntity", + "HTTPUnsupportedMediaType", + "HTTPUpgradeRequired", + "HTTPUseProxy", + "HTTPVariantAlsoNegotiates", + "HTTPVersionNotSupported", + # web_fileresponse + "FileResponse", + # web_middlewares + "middleware", + "normalize_path_middleware", + # web_protocol + "PayloadAccessError", + "RequestHandler", + "RequestPayloadError", + # web_request + "BaseRequest", + "FileField", + "Request", + # web_response + "ContentCoding", + "Response", + "StreamResponse", + "json_response", + # web_routedef + "AbstractRouteDef", + "RouteDef", + "RouteTableDef", + "StaticDef", + "delete", + "get", + "head", + "options", + "patch", + "post", + "put", + "route", + "static", + "view", + # web_runner + "AppRunner", + "BaseRunner", + "BaseSite", + "GracefulExit", + "ServerRunner", + "SockSite", + "TCPSite", + "UnixSite", + "NamedPipeSite", + # web_server + "Server", + # web_urldispatcher + "AbstractResource", + "AbstractRoute", + "DynamicResource", + "PlainResource", + "Resource", + "ResourceRoute", + "StaticResource", + "UrlDispatcher", + "UrlMappingMatchInfo", + "View", + # web_ws + "WebSocketReady", + "WebSocketResponse", + "WSMsgType", + # web + "run_app", +) + + +try: + from ssl import SSLContext +except ImportError: # pragma: no cover + SSLContext = Any # type: ignore + +HostSequence = TypingIterable[str] + + +async def _run_app( + app: Union[Application, Awaitable[Application]], + *, + host: Optional[Union[str, HostSequence]] = None, + port: Optional[int] = None, + path: Optional[str] = None, + sock: Optional[socket.socket] = None, + shutdown_timeout: float = 60.0, + ssl_context: Optional[SSLContext] = None, + print: Callable[..., None] = print, + backlog: int = 128, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log_format: str = AccessLogger.LOG_FORMAT, + access_log: Optional[logging.Logger] = access_logger, + handle_signals: bool = True, + reuse_address: Optional[bool] = None, + reuse_port: Optional[bool] = None, +) -> None: + # A internal functio to actually do all dirty job for application running + if asyncio.iscoroutine(app): + app = await app # type: ignore + + app = cast(Application, app) + + runner = AppRunner( + app, + handle_signals=handle_signals, + access_log_class=access_log_class, + access_log_format=access_log_format, + access_log=access_log, + ) -from yarl import URL - -from . import (hdrs, web_exceptions, web_fileresponse, web_middlewares, - web_protocol, web_request, web_response, web_server, - web_urldispatcher, web_ws) -from .abc import AbstractMatchInfo, AbstractRouter -from .helpers import FrozenList -from .http import HttpVersion # noqa -from .log import access_logger, web_logger -from .signals import FuncSignal, PostSignal, PreSignal, Signal -from .web_exceptions import * # noqa -from .web_fileresponse import * # noqa -from .web_middlewares import * # noqa -from .web_protocol import * # noqa -from .web_request import * # noqa -from .web_response import * # noqa -from .web_server import Server -from .web_urldispatcher import * # noqa -from .web_urldispatcher import PrefixedSubAppResource -from .web_ws import * # noqa - -__all__ = (web_protocol.__all__ + - web_fileresponse.__all__ + - web_request.__all__ + - web_response.__all__ + - web_exceptions.__all__ + - web_urldispatcher.__all__ + - web_ws.__all__ + - web_server.__all__ + - web_middlewares.__all__ + - ('Application', 'HttpVersion', 'MsgType')) - - -class Application(MutableMapping): - - def __init__(self, *, logger=web_logger, router=None, middlewares=(), - handler_args=None, client_max_size=1024**2, - loop=None, debug=...): - if router is None: - router = web_urldispatcher.UrlDispatcher() - assert isinstance(router, AbstractRouter), router - - if loop is not None: - warnings.warn("loop argument is deprecated", ResourceWarning) - - self._debug = debug - self._router = router - self._secure_proxy_ssl_header = None - self._loop = loop - self._handler_args = handler_args - self.logger = logger - - self._middlewares = FrozenList(middlewares) - self._state = {} - self._frozen = False - self._subapps = [] - - self._on_pre_signal = PreSignal() - self._on_post_signal = PostSignal() - self._on_loop_available = FuncSignal(self) - self._on_response_prepare = Signal(self) - self._on_startup = Signal(self) - self._on_shutdown = Signal(self) - self._on_cleanup = Signal(self) - self._client_max_size = client_max_size - - # MutableMapping API - - def __getitem__(self, key): - return self._state[key] - - def _check_frozen(self): - if self._frozen: - warnings.warn("Changing state of started or joined " - "application is deprecated", - DeprecationWarning, - stacklevel=3) - - def __setitem__(self, key, value): - self._check_frozen() - self._state[key] = value - - def __delitem__(self, key): - self._check_frozen() - del self._state[key] - - def __len__(self): - return len(self._state) - - def __iter__(self): - return iter(self._state) - - ######## - @property - def loop(self): - return self._loop - - def _set_loop(self, loop): - if loop is None: - loop = asyncio.get_event_loop() - if self._loop is not None and self._loop is not loop: - raise RuntimeError( - "web.Application instance initialized with different loop") - - self._loop = loop - self._on_loop_available.send(self) - - # set loop debug - if self._debug is ...: - self._debug = loop.get_debug() - - # set loop to sub applications - for subapp in self._subapps: - subapp._set_loop(loop) - - @property - def frozen(self): - return self._frozen - - def freeze(self): - if self._frozen: - return - - self._frozen = True - self._middlewares = tuple(reversed(self._middlewares)) - self._router.freeze() - self._on_loop_available.freeze() - self._on_pre_signal.freeze() - self._on_post_signal.freeze() - self._on_response_prepare.freeze() - self._on_startup.freeze() - self._on_shutdown.freeze() - self._on_cleanup.freeze() - - for subapp in self._subapps: - subapp.freeze() - - @property - def debug(self): - return self._debug - - def _reg_subapp_signals(self, subapp): - - def reg_handler(signame): - subsig = getattr(subapp, signame) - - @asyncio.coroutine - def handler(app): - yield from subsig.send(subapp) - appsig = getattr(self, signame) - appsig.append(handler) - - reg_handler('on_startup') - reg_handler('on_shutdown') - reg_handler('on_cleanup') - - def add_subapp(self, prefix, subapp): - if self.frozen: - raise RuntimeError( - "Cannot add sub application to frozen application") - if subapp.frozen: - raise RuntimeError("Cannot add frozen application") - if prefix.endswith('/'): - prefix = prefix[:-1] - if prefix in ('', '/'): - raise ValueError("Prefix cannot be empty") - - resource = PrefixedSubAppResource(prefix, subapp) - self.router.register_resource(resource) - self._reg_subapp_signals(subapp) - self._subapps.append(subapp) - if self._loop is not None: - subapp._set_loop(self._loop) - return resource - - @property - def on_loop_available(self): - return self._on_loop_available - - @property - def on_response_prepare(self): - return self._on_response_prepare - - @property - def on_pre_signal(self): - return self._on_pre_signal - - @property - def on_post_signal(self): - return self._on_post_signal - - @property - def on_startup(self): - return self._on_startup - - @property - def on_shutdown(self): - return self._on_shutdown - - @property - def on_cleanup(self): - return self._on_cleanup - - @property - def router(self): - return self._router - - @property - def middlewares(self): - return self._middlewares - - def make_handler(self, *, loop=None, - secure_proxy_ssl_header=None, **kwargs): - self._set_loop(loop) - self.freeze() - - kwargs['debug'] = self.debug - if self._handler_args: - for k, v in self._handler_args.items(): - kwargs[k] = v - - self._secure_proxy_ssl_header = secure_proxy_ssl_header - return Server(self._handle, request_factory=self._make_request, - loop=self.loop, **kwargs) - - @asyncio.coroutine - def startup(self): - """Causes on_startup signal - - Should be called in the event loop along with the request handler. - """ - yield from self.on_startup.send(self) - - @asyncio.coroutine - def shutdown(self): - """Causes on_shutdown signal - - Should be called before cleanup() - """ - yield from self.on_shutdown.send(self) - - @asyncio.coroutine - def cleanup(self): - """Causes on_cleanup signal - - Should be called after shutdown() - """ - yield from self.on_cleanup.send(self) - - def _make_request(self, message, payload, protocol, writer, task, - _cls=web_request.Request): - return _cls( - message, payload, protocol, writer, protocol._time_service, task, - secure_proxy_ssl_header=self._secure_proxy_ssl_header, - client_max_size=self._client_max_size) - - @asyncio.coroutine - def _handle(self, request): - match_info = yield from self._router.resolve(request) - assert isinstance(match_info, AbstractMatchInfo), match_info - match_info.add_app(self) - - if __debug__: - match_info.freeze() - - resp = None - request._match_info = match_info - expect = request.headers.get(hdrs.EXPECT) - if expect: - resp = ( - yield from match_info.expect_handler(request)) - - if resp is None: - handler = match_info.handler - for app in match_info.apps: - for factory in app._middlewares: - handler = yield from factory(app, handler) - - resp = yield from handler(request) - - assert isinstance(resp, web_response.StreamResponse), \ - ("Handler {!r} should return response instance, " - "got {!r} [middlewares {!r}]").format( - match_info.handler, type(resp), - [middleware for middleware in app.middlewares - for app in match_info.apps]) - return resp - - def __call__(self): - """gunicorn compatibility""" - return self - - def __repr__(self): - return "".format(id(self)) - - -def run_app(app, *, host=None, port=None, path=None, sock=None, - shutdown_timeout=60.0, ssl_context=None, - print=print, backlog=128, access_log_format=None, - access_log=access_logger, loop=None): - """Run an app locally""" - if loop is None: - loop = asyncio.get_event_loop() - - make_handler_kwargs = dict() - if access_log_format is not None: - make_handler_kwargs['access_log_format'] = access_log_format - handler = app.make_handler(loop=loop, access_log=access_log, - **make_handler_kwargs) - - loop.run_until_complete(app.startup()) - - scheme = 'https' if ssl_context else 'http' - base_url = URL('{}://localhost'.format(scheme)).with_port(port) - - if path is None: - paths = () - elif isinstance(path, (str, bytes, bytearray, memoryview))\ - or not isinstance(path, Iterable): - paths = (path,) - else: - paths = path - - if sock is None: - socks = () - elif not isinstance(sock, Iterable): - socks = (sock,) - else: - socks = sock - - if host is None: - if (paths or socks) and not port: - hosts = () - else: - hosts = ("0.0.0.0",) - elif isinstance(host, (str, bytes, bytearray, memoryview))\ - or not isinstance(host, Iterable): - hosts = (host,) - else: - hosts = host - - if hosts and port is None: - port = 8443 if ssl_context else 8080 - - server_creations = [] - uris = [str(base_url.with_host(host)) for host in hosts] - if hosts: - # Multiple hosts bound to same server is available in most loop - # implementations, but only send multiple if we have multiple. - host_binding = hosts[0] if len(hosts) == 1 else hosts - server_creations.append( - loop.create_server( - handler, host_binding, port, ssl=ssl_context, backlog=backlog - ) - ) - for path in paths: - # Most loop implementations don't support multiple paths bound in same - # server, so create a server for each. - server_creations.append( - loop.create_unix_server( - handler, path, ssl=ssl_context, backlog=backlog + await runner.setup() + + sites = [] # type: List[BaseSite] + + try: + if host is not None: + if isinstance(host, (str, bytes, bytearray, memoryview)): + sites.append( + TCPSite( + runner, + host, + port, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + reuse_address=reuse_address, + reuse_port=reuse_port, + ) + ) + else: + for h in host: + sites.append( + TCPSite( + runner, + h, + port, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + reuse_address=reuse_address, + reuse_port=reuse_port, + ) + ) + elif path is None and sock is None or port is not None: + sites.append( + TCPSite( + runner, + port=port, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + reuse_address=reuse_address, + reuse_port=reuse_port, + ) ) - ) - uris.append('{}://unix:{}:'.format(scheme, path)) - - # Clean up prior socket path if stale and not abstract. - # CPython 3.5.3+'s event loop already does this. See - # https://github.com/python/asyncio/issues/425 - if path[0] not in (0, '\x00'): # pragma: no branch - try: - if stat.S_ISSOCK(os.stat(path).st_mode): - os.remove(path) - except FileNotFoundError: - pass - for sock in socks: - server_creations.append( - loop.create_server( - handler, sock=sock, ssl=ssl_context, backlog=backlog + + if path is not None: + if isinstance(path, (str, bytes, bytearray, memoryview)): + sites.append( + UnixSite( + runner, + path, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + ) + ) + else: + for p in path: + sites.append( + UnixSite( + runner, + p, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + ) + ) + + if sock is not None: + if not isinstance(sock, Iterable): + sites.append( + SockSite( + runner, + sock, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + ) + ) + else: + for s in sock: + sites.append( + SockSite( + runner, + s, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + ) + ) + for site in sites: + await site.start() + + if print: # pragma: no branch + names = sorted(str(s.name) for s in runner.sites) + print( + "======== Running on {} ========\n" + "(Press CTRL+C to quit)".format(", ".join(names)) ) - ) - if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX: - uris.append('{}://unix:{}:'.format(scheme, sock.getsockname())) + # sleep forever by 1 hour intervals, + # on Windows before Python 3.8 wake up every 1 second to handle + # Ctrl+C smoothly + if sys.platform == "win32" and sys.version_info < (3, 8): + delay = 1 else: - host, port = sock.getsockname() - uris.append(str(base_url.with_host(host).with_port(port))) + delay = 3600 + + while True: + await asyncio.sleep(delay) + finally: + await runner.cleanup() + + +def _cancel_tasks( + to_cancel: Set["asyncio.Task[Any]"], loop: asyncio.AbstractEventLoop +) -> None: + if not to_cancel: + return - servers = loop.run_until_complete( - asyncio.gather(*server_creations, loop=loop) + for task in to_cancel: + task.cancel() + + loop.run_until_complete( + asyncio.gather(*to_cancel, loop=loop, return_exceptions=True) ) - print("======== Running on {} ========\n" - "(Press CTRL+C to quit)".format(', '.join(uris))) + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) + + +def run_app( + app: Union[Application, Awaitable[Application]], + *, + host: Optional[Union[str, HostSequence]] = None, + port: Optional[int] = None, + path: Optional[str] = None, + sock: Optional[socket.socket] = None, + shutdown_timeout: float = 60.0, + ssl_context: Optional[SSLContext] = None, + print: Callable[..., None] = print, + backlog: int = 128, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log_format: str = AccessLogger.LOG_FORMAT, + access_log: Optional[logging.Logger] = access_logger, + handle_signals: bool = True, + reuse_address: Optional[bool] = None, + reuse_port: Optional[bool] = None, +) -> None: + """Run an app locally""" + loop = asyncio.get_event_loop() + + # Configure if and only if in debugging mode and using the default logger + if loop.get_debug() and access_log and access_log.name == "aiohttp.access": + if access_log.level == logging.NOTSET: + access_log.setLevel(logging.DEBUG) + if not access_log.hasHandlers(): + access_log.addHandler(logging.StreamHandler()) try: - loop.run_forever() - except KeyboardInterrupt: # pragma: no cover + main_task = loop.create_task( + _run_app( + app, + host=host, + port=port, + path=path, + sock=sock, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + print=print, + backlog=backlog, + access_log_class=access_log_class, + access_log_format=access_log_format, + access_log=access_log, + handle_signals=handle_signals, + reuse_address=reuse_address, + reuse_port=reuse_port, + ) + ) + loop.run_until_complete(main_task) + except (GracefulExit, KeyboardInterrupt): # pragma: no cover pass finally: - server_closures = [] - for srv in servers: - srv.close() - server_closures.append(srv.wait_closed()) - loop.run_until_complete(asyncio.gather(*server_closures, loop=loop)) - loop.run_until_complete(app.shutdown()) - loop.run_until_complete(handler.shutdown(shutdown_timeout)) - loop.run_until_complete(app.cleanup()) - loop.close() - - -def main(argv): + _cancel_tasks({main_task}, loop) + _cancel_tasks(all_tasks(loop), loop) + if sys.version_info >= (3, 6): # don't use PY_36 to pass mypy + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + + +def main(argv: List[str]) -> None: arg_parser = ArgumentParser( - description="aiohttp.web Application server", - prog="aiohttp.web" + description="aiohttp.web Application server", prog="aiohttp.web" ) arg_parser.add_argument( "entry_func", - help=("Callable returning the `aiohttp.web.Application` instance to " - "run. Should be specified in the 'module:function' syntax."), - metavar="entry-func" + help=( + "Callable returning the `aiohttp.web.Application` instance to " + "run. Should be specified in the 'module:function' syntax." + ), + metavar="entry-func", ) arg_parser.add_argument( - "-H", "--hostname", + "-H", + "--hostname", help="TCP/IP hostname to serve on (default: %(default)r)", - default="localhost" + default="localhost", ) arg_parser.add_argument( - "-P", "--port", + "-P", + "--port", help="TCP/IP port to serve on (default: %(default)r)", type=int, - default="8080" + default="8080", ) arg_parser.add_argument( - "-U", "--path", + "-U", + "--path", help="Unix file system path to serve on. Specifying a path will cause " - "hostname and port arguments to be ignored.", + "hostname and port arguments to be ignored.", ) args, extra_argv = arg_parser.parse_known_args(argv) # Import logic mod_str, _, func_str = args.entry_func.partition(":") if not func_str or not mod_str: - arg_parser.error( - "'entry-func' not in 'module:function' syntax" - ) + arg_parser.error("'entry-func' not in 'module:function' syntax") if mod_str.startswith("."): arg_parser.error("relative module names not supported") try: module = import_module(mod_str) except ImportError as ex: - arg_parser.error("unable to import %s: %s" % (mod_str, ex)) + arg_parser.error(f"unable to import {mod_str}: {ex}") try: func = getattr(module, func_str) except AttributeError: - arg_parser.error("module %r has no attribute %r" % (mod_str, func_str)) + arg_parser.error(f"module {mod_str!r} has no attribute {func_str!r}") # Compatibility logic - if args.path is not None and not hasattr(socket, 'AF_UNIX'): - arg_parser.error("file system paths not supported by your operating" - " environment") + if args.path is not None and not hasattr(socket, "AF_UNIX"): + arg_parser.error( + "file system paths not supported by your operating" " environment" + ) + + logging.basicConfig(level=logging.DEBUG) app = func(extra_argv) run_app(app, host=args.hostname, port=args.port, path=args.path) diff --git a/aiohttp/web_app.py b/aiohttp/web_app.py new file mode 100644 index 00000000000..14f2937ae55 --- /dev/null +++ b/aiohttp/web_app.py @@ -0,0 +1,552 @@ +import asyncio +import logging +import warnings +from functools import partial, update_wrapper +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +from . import hdrs +from .abc import ( + AbstractAccessLogger, + AbstractMatchInfo, + AbstractRouter, + AbstractStreamWriter, +) +from .frozenlist import FrozenList +from .helpers import DEBUG +from .http_parser import RawRequestMessage +from .log import web_logger +from .signals import Signal +from .streams import StreamReader +from .web_log import AccessLogger +from .web_middlewares import _fix_request_current_app +from .web_protocol import RequestHandler +from .web_request import Request +from .web_response import StreamResponse +from .web_routedef import AbstractRouteDef +from .web_server import Server +from .web_urldispatcher import ( + AbstractResource, + AbstractRoute, + Domain, + MaskDomain, + MatchedSubAppResource, + PrefixedSubAppResource, + UrlDispatcher, +) + +__all__ = ("Application", "CleanupError") + + +if TYPE_CHECKING: # pragma: no cover + _AppSignal = Signal[Callable[["Application"], Awaitable[None]]] + _RespPrepareSignal = Signal[Callable[[Request, StreamResponse], Awaitable[None]]] + _Handler = Callable[[Request], Awaitable[StreamResponse]] + _Middleware = Union[ + Callable[[Request, _Handler], Awaitable[StreamResponse]], + Callable[["Application", _Handler], Awaitable[_Handler]], # old-style + ] + _Middlewares = FrozenList[_Middleware] + _MiddlewaresHandlers = Optional[Sequence[Tuple[_Middleware, bool]]] + _Subapps = List["Application"] +else: + # No type checker mode, skip types + _AppSignal = Signal + _RespPrepareSignal = Signal + _Handler = Callable + _Middleware = Callable + _Middlewares = FrozenList + _MiddlewaresHandlers = Optional[Sequence] + _Subapps = List + + +class Application(MutableMapping[str, Any]): + ATTRS = frozenset( + [ + "logger", + "_debug", + "_router", + "_loop", + "_handler_args", + "_middlewares", + "_middlewares_handlers", + "_run_middlewares", + "_state", + "_frozen", + "_pre_frozen", + "_subapps", + "_on_response_prepare", + "_on_startup", + "_on_shutdown", + "_on_cleanup", + "_client_max_size", + "_cleanup_ctx", + ] + ) + + def __init__( + self, + *, + logger: logging.Logger = web_logger, + router: Optional[UrlDispatcher] = None, + middlewares: Iterable[_Middleware] = (), + handler_args: Optional[Mapping[str, Any]] = None, + client_max_size: int = 1024 ** 2, + loop: Optional[asyncio.AbstractEventLoop] = None, + debug: Any = ..., # mypy doesn't support ellipsis + ) -> None: + if router is None: + router = UrlDispatcher() + else: + warnings.warn( + "router argument is deprecated", DeprecationWarning, stacklevel=2 + ) + assert isinstance(router, AbstractRouter), router + + if loop is not None: + warnings.warn( + "loop argument is deprecated", DeprecationWarning, stacklevel=2 + ) + + if debug is not ...: + warnings.warn( + "debug argument is deprecated", DeprecationWarning, stacklevel=2 + ) + self._debug = debug + self._router = router # type: UrlDispatcher + self._loop = loop + self._handler_args = handler_args + self.logger = logger + + self._middlewares = FrozenList(middlewares) # type: _Middlewares + + # initialized on freezing + self._middlewares_handlers = None # type: _MiddlewaresHandlers + # initialized on freezing + self._run_middlewares = None # type: Optional[bool] + + self._state = {} # type: Dict[str, Any] + self._frozen = False + self._pre_frozen = False + self._subapps = [] # type: _Subapps + + self._on_response_prepare = Signal(self) # type: _RespPrepareSignal + self._on_startup = Signal(self) # type: _AppSignal + self._on_shutdown = Signal(self) # type: _AppSignal + self._on_cleanup = Signal(self) # type: _AppSignal + self._cleanup_ctx = CleanupContext() + self._on_startup.append(self._cleanup_ctx._on_startup) + self._on_cleanup.append(self._cleanup_ctx._on_cleanup) + self._client_max_size = client_max_size + + def __init_subclass__(cls: Type["Application"]) -> None: + warnings.warn( + "Inheritance class {} from web.Application " + "is discouraged".format(cls.__name__), + DeprecationWarning, + stacklevel=2, + ) + + if DEBUG: # pragma: no cover + + def __setattr__(self, name: str, val: Any) -> None: + if name not in self.ATTRS: + warnings.warn( + "Setting custom web.Application.{} attribute " + "is discouraged".format(name), + DeprecationWarning, + stacklevel=2, + ) + super().__setattr__(name, val) + + # MutableMapping API + + def __eq__(self, other: object) -> bool: + return self is other + + def __getitem__(self, key: str) -> Any: + return self._state[key] + + def _check_frozen(self) -> None: + if self._frozen: + warnings.warn( + "Changing state of started or joined " "application is deprecated", + DeprecationWarning, + stacklevel=3, + ) + + def __setitem__(self, key: str, value: Any) -> None: + self._check_frozen() + self._state[key] = value + + def __delitem__(self, key: str) -> None: + self._check_frozen() + del self._state[key] + + def __len__(self) -> int: + return len(self._state) + + def __iter__(self) -> Iterator[str]: + return iter(self._state) + + ######## + @property + def loop(self) -> asyncio.AbstractEventLoop: + # Technically the loop can be None + # but we mask it by explicit type cast + # to provide more convinient type annotation + warnings.warn("loop property is deprecated", DeprecationWarning, stacklevel=2) + return cast(asyncio.AbstractEventLoop, self._loop) + + def _set_loop(self, loop: Optional[asyncio.AbstractEventLoop]) -> None: + if loop is None: + loop = asyncio.get_event_loop() + if self._loop is not None and self._loop is not loop: + raise RuntimeError( + "web.Application instance initialized with different loop" + ) + + self._loop = loop + + # set loop debug + if self._debug is ...: + self._debug = loop.get_debug() + + # set loop to sub applications + for subapp in self._subapps: + subapp._set_loop(loop) + + @property + def pre_frozen(self) -> bool: + return self._pre_frozen + + def pre_freeze(self) -> None: + if self._pre_frozen: + return + + self._pre_frozen = True + self._middlewares.freeze() + self._router.freeze() + self._on_response_prepare.freeze() + self._cleanup_ctx.freeze() + self._on_startup.freeze() + self._on_shutdown.freeze() + self._on_cleanup.freeze() + self._middlewares_handlers = tuple(self._prepare_middleware()) + + # If current app and any subapp do not have middlewares avoid run all + # of the code footprint that it implies, which have a middleware + # hardcoded per app that sets up the current_app attribute. If no + # middlewares are configured the handler will receive the proper + # current_app without needing all of this code. + self._run_middlewares = True if self.middlewares else False + + for subapp in self._subapps: + subapp.pre_freeze() + self._run_middlewares = self._run_middlewares or subapp._run_middlewares + + @property + def frozen(self) -> bool: + return self._frozen + + def freeze(self) -> None: + if self._frozen: + return + + self.pre_freeze() + self._frozen = True + for subapp in self._subapps: + subapp.freeze() + + @property + def debug(self) -> bool: + warnings.warn("debug property is deprecated", DeprecationWarning, stacklevel=2) + return self._debug + + def _reg_subapp_signals(self, subapp: "Application") -> None: + def reg_handler(signame: str) -> None: + subsig = getattr(subapp, signame) + + async def handler(app: "Application") -> None: + await subsig.send(subapp) + + appsig = getattr(self, signame) + appsig.append(handler) + + reg_handler("on_startup") + reg_handler("on_shutdown") + reg_handler("on_cleanup") + + def add_subapp(self, prefix: str, subapp: "Application") -> AbstractResource: + if not isinstance(prefix, str): + raise TypeError("Prefix must be str") + prefix = prefix.rstrip("/") + if not prefix: + raise ValueError("Prefix cannot be empty") + factory = partial(PrefixedSubAppResource, prefix, subapp) + return self._add_subapp(factory, subapp) + + def _add_subapp( + self, resource_factory: Callable[[], AbstractResource], subapp: "Application" + ) -> AbstractResource: + if self.frozen: + raise RuntimeError("Cannot add sub application to frozen application") + if subapp.frozen: + raise RuntimeError("Cannot add frozen application") + resource = resource_factory() + self.router.register_resource(resource) + self._reg_subapp_signals(subapp) + self._subapps.append(subapp) + subapp.pre_freeze() + if self._loop is not None: + subapp._set_loop(self._loop) + return resource + + def add_domain(self, domain: str, subapp: "Application") -> AbstractResource: + if not isinstance(domain, str): + raise TypeError("Domain must be str") + elif "*" in domain: + rule = MaskDomain(domain) # type: Domain + else: + rule = Domain(domain) + factory = partial(MatchedSubAppResource, rule, subapp) + return self._add_subapp(factory, subapp) + + def add_routes(self, routes: Iterable[AbstractRouteDef]) -> List[AbstractRoute]: + return self.router.add_routes(routes) + + @property + def on_response_prepare(self) -> _RespPrepareSignal: + return self._on_response_prepare + + @property + def on_startup(self) -> _AppSignal: + return self._on_startup + + @property + def on_shutdown(self) -> _AppSignal: + return self._on_shutdown + + @property + def on_cleanup(self) -> _AppSignal: + return self._on_cleanup + + @property + def cleanup_ctx(self) -> "CleanupContext": + return self._cleanup_ctx + + @property + def router(self) -> UrlDispatcher: + return self._router + + @property + def middlewares(self) -> _Middlewares: + return self._middlewares + + def _make_handler( + self, + *, + loop: Optional[asyncio.AbstractEventLoop] = None, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + **kwargs: Any, + ) -> Server: + + if not issubclass(access_log_class, AbstractAccessLogger): + raise TypeError( + "access_log_class must be subclass of " + "aiohttp.abc.AbstractAccessLogger, got {}".format(access_log_class) + ) + + self._set_loop(loop) + self.freeze() + + kwargs["debug"] = self._debug + kwargs["access_log_class"] = access_log_class + if self._handler_args: + for k, v in self._handler_args.items(): + kwargs[k] = v + + return Server( + self._handle, # type: ignore + request_factory=self._make_request, + loop=self._loop, + **kwargs, + ) + + def make_handler( + self, + *, + loop: Optional[asyncio.AbstractEventLoop] = None, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + **kwargs: Any, + ) -> Server: + + warnings.warn( + "Application.make_handler(...) is deprecated, " "use AppRunner API instead", + DeprecationWarning, + stacklevel=2, + ) + + return self._make_handler( + loop=loop, access_log_class=access_log_class, **kwargs + ) + + async def startup(self) -> None: + """Causes on_startup signal + + Should be called in the event loop along with the request handler. + """ + await self.on_startup.send(self) + + async def shutdown(self) -> None: + """Causes on_shutdown signal + + Should be called before cleanup() + """ + await self.on_shutdown.send(self) + + async def cleanup(self) -> None: + """Causes on_cleanup signal + + Should be called after shutdown() + """ + await self.on_cleanup.send(self) + + def _make_request( + self, + message: RawRequestMessage, + payload: StreamReader, + protocol: RequestHandler, + writer: AbstractStreamWriter, + task: "asyncio.Task[None]", + _cls: Type[Request] = Request, + ) -> Request: + return _cls( + message, + payload, + protocol, + writer, + task, + self._loop, + client_max_size=self._client_max_size, + ) + + def _prepare_middleware(self) -> Iterator[Tuple[_Middleware, bool]]: + for m in reversed(self._middlewares): + if getattr(m, "__middleware_version__", None) == 1: + yield m, True + else: + warnings.warn( + 'old-style middleware "{!r}" deprecated, ' "see #2252".format(m), + DeprecationWarning, + stacklevel=2, + ) + yield m, False + + yield _fix_request_current_app(self), True + + async def _handle(self, request: Request) -> StreamResponse: + loop = asyncio.get_event_loop() + debug = loop.get_debug() + match_info = await self._router.resolve(request) + if debug: # pragma: no cover + if not isinstance(match_info, AbstractMatchInfo): + raise TypeError( + "match_info should be AbstractMatchInfo " + "instance, not {!r}".format(match_info) + ) + match_info.add_app(self) + + match_info.freeze() + + resp = None + request._match_info = match_info # type: ignore + expect = request.headers.get(hdrs.EXPECT) + if expect: + resp = await match_info.expect_handler(request) + await request.writer.drain() + + if resp is None: + handler = match_info.handler + + if self._run_middlewares: + for app in match_info.apps[::-1]: + for m, new_style in app._middlewares_handlers: # type: ignore + if new_style: + handler = update_wrapper( + partial(m, handler=handler), handler + ) + else: + handler = await m(app, handler) # type: ignore + + resp = await handler(request) + + return resp + + def __call__(self) -> "Application": + """gunicorn compatibility""" + return self + + def __repr__(self) -> str: + return "".format(id(self)) + + def __bool__(self) -> bool: + return True + + +class CleanupError(RuntimeError): + @property + def exceptions(self) -> List[BaseException]: + return self.args[1] + + +if TYPE_CHECKING: # pragma: no cover + _CleanupContextBase = FrozenList[Callable[[Application], AsyncIterator[None]]] +else: + _CleanupContextBase = FrozenList + + +class CleanupContext(_CleanupContextBase): + def __init__(self) -> None: + super().__init__() + self._exits = [] # type: List[AsyncIterator[None]] + + async def _on_startup(self, app: Application) -> None: + for cb in self: + it = cb(app).__aiter__() + await it.__anext__() + self._exits.append(it) + + async def _on_cleanup(self, app: Application) -> None: + errors = [] + for it in reversed(self._exits): + try: + await it.__anext__() + except StopAsyncIteration: + pass + except Exception as exc: + errors.append(exc) + else: + errors.append(RuntimeError(f"{it!r} has more than one 'yield'")) + if errors: + if len(errors) == 1: + raise errors[0] + else: + raise CleanupError("Multiple errors on cleanup stage", errors) diff --git a/aiohttp/web_exceptions.py b/aiohttp/web_exceptions.py index 0a10d9d1f9a..2eadca0386a 100644 --- a/aiohttp/web_exceptions.py +++ b/aiohttp/web_exceptions.py @@ -1,60 +1,69 @@ +import warnings +from typing import Any, Dict, Iterable, List, Optional, Set # noqa + +from yarl import URL + +from .typedefs import LooseHeaders, StrOrURL from .web_response import Response __all__ = ( - 'HTTPException', - 'HTTPError', - 'HTTPRedirection', - 'HTTPSuccessful', - 'HTTPOk', - 'HTTPCreated', - 'HTTPAccepted', - 'HTTPNonAuthoritativeInformation', - 'HTTPNoContent', - 'HTTPResetContent', - 'HTTPPartialContent', - 'HTTPMultipleChoices', - 'HTTPMovedPermanently', - 'HTTPFound', - 'HTTPSeeOther', - 'HTTPNotModified', - 'HTTPUseProxy', - 'HTTPTemporaryRedirect', - 'HTTPPermanentRedirect', - 'HTTPClientError', - 'HTTPBadRequest', - 'HTTPUnauthorized', - 'HTTPPaymentRequired', - 'HTTPForbidden', - 'HTTPNotFound', - 'HTTPMethodNotAllowed', - 'HTTPNotAcceptable', - 'HTTPProxyAuthenticationRequired', - 'HTTPRequestTimeout', - 'HTTPConflict', - 'HTTPGone', - 'HTTPLengthRequired', - 'HTTPPreconditionFailed', - 'HTTPRequestEntityTooLarge', - 'HTTPRequestURITooLong', - 'HTTPUnsupportedMediaType', - 'HTTPRequestRangeNotSatisfiable', - 'HTTPExpectationFailed', - 'HTTPMisdirectedRequest', - 'HTTPUpgradeRequired', - 'HTTPPreconditionRequired', - 'HTTPTooManyRequests', - 'HTTPRequestHeaderFieldsTooLarge', - 'HTTPUnavailableForLegalReasons', - 'HTTPServerError', - 'HTTPInternalServerError', - 'HTTPNotImplemented', - 'HTTPBadGateway', - 'HTTPServiceUnavailable', - 'HTTPGatewayTimeout', - 'HTTPVersionNotSupported', - 'HTTPVariantAlsoNegotiates', - 'HTTPNotExtended', - 'HTTPNetworkAuthenticationRequired', + "HTTPException", + "HTTPError", + "HTTPRedirection", + "HTTPSuccessful", + "HTTPOk", + "HTTPCreated", + "HTTPAccepted", + "HTTPNonAuthoritativeInformation", + "HTTPNoContent", + "HTTPResetContent", + "HTTPPartialContent", + "HTTPMultipleChoices", + "HTTPMovedPermanently", + "HTTPFound", + "HTTPSeeOther", + "HTTPNotModified", + "HTTPUseProxy", + "HTTPTemporaryRedirect", + "HTTPPermanentRedirect", + "HTTPClientError", + "HTTPBadRequest", + "HTTPUnauthorized", + "HTTPPaymentRequired", + "HTTPForbidden", + "HTTPNotFound", + "HTTPMethodNotAllowed", + "HTTPNotAcceptable", + "HTTPProxyAuthenticationRequired", + "HTTPRequestTimeout", + "HTTPConflict", + "HTTPGone", + "HTTPLengthRequired", + "HTTPPreconditionFailed", + "HTTPRequestEntityTooLarge", + "HTTPRequestURITooLong", + "HTTPUnsupportedMediaType", + "HTTPRequestRangeNotSatisfiable", + "HTTPExpectationFailed", + "HTTPMisdirectedRequest", + "HTTPUnprocessableEntity", + "HTTPFailedDependency", + "HTTPUpgradeRequired", + "HTTPPreconditionRequired", + "HTTPTooManyRequests", + "HTTPRequestHeaderFieldsTooLarge", + "HTTPUnavailableForLegalReasons", + "HTTPServerError", + "HTTPInternalServerError", + "HTTPNotImplemented", + "HTTPBadGateway", + "HTTPServiceUnavailable", + "HTTPGatewayTimeout", + "HTTPVersionNotSupported", + "HTTPVariantAlsoNegotiates", + "HTTPInsufficientStorage", + "HTTPNotExtended", + "HTTPNetworkAuthenticationRequired", ) @@ -62,22 +71,46 @@ # HTTP Exceptions ############################################################ + class HTTPException(Response, Exception): # You should set in subclasses: # status = 200 - status_code = None + status_code = -1 empty_body = False - def __init__(self, *, headers=None, reason=None, - body=None, text=None, content_type=None): - Response.__init__(self, status=self.status_code, - headers=headers, reason=reason, - body=body, text=text, content_type=content_type) + __http_exception__ = True + + def __init__( + self, + *, + headers: Optional[LooseHeaders] = None, + reason: Optional[str] = None, + body: Any = None, + text: Optional[str] = None, + content_type: Optional[str] = None, + ) -> None: + if body is not None: + warnings.warn( + "body argument is deprecated for http web exceptions", + DeprecationWarning, + ) + Response.__init__( + self, + status=self.status_code, + headers=headers, + reason=reason, + body=body, + text=text, + content_type=content_type, + ) Exception.__init__(self, self.reason) if self.body is None and not self.empty_body: - self.text = "{}: {}".format(self.status, self.reason) + self.text = f"{self.status}: {self.reason}" + + def __bool__(self) -> bool: + return True class HTTPError(HTTPException): @@ -128,14 +161,26 @@ class HTTPPartialContent(HTTPSuccessful): class _HTTPMove(HTTPRedirection): - - def __init__(self, location, *, headers=None, reason=None, - body=None, text=None, content_type=None): + def __init__( + self, + location: StrOrURL, + *, + headers: Optional[LooseHeaders] = None, + reason: Optional[str] = None, + body: Any = None, + text: Optional[str] = None, + content_type: Optional[str] = None, + ) -> None: if not location: raise ValueError("HTTP redirects need a location to redirect to.") - super().__init__(headers=headers, reason=reason, - body=body, text=text, content_type=content_type) - self.headers['Location'] = str(location) + super().__init__( + headers=headers, + reason=reason, + body=body, + text=text, + content_type=content_type, + ) + self.headers["Location"] = str(URL(location)) self.location = location @@ -208,13 +253,27 @@ class HTTPNotFound(HTTPClientError): class HTTPMethodNotAllowed(HTTPClientError): status_code = 405 - def __init__(self, method, allowed_methods, *, headers=None, reason=None, - body=None, text=None, content_type=None): - allow = ','.join(sorted(allowed_methods)) - super().__init__(headers=headers, reason=reason, - body=body, text=text, content_type=content_type) - self.headers['Allow'] = allow - self.allowed_methods = allowed_methods + def __init__( + self, + method: str, + allowed_methods: Iterable[str], + *, + headers: Optional[LooseHeaders] = None, + reason: Optional[str] = None, + body: Any = None, + text: Optional[str] = None, + content_type: Optional[str] = None, + ) -> None: + allow = ",".join(sorted(allowed_methods)) + super().__init__( + headers=headers, + reason=reason, + body=body, + text=text, + content_type=content_type, + ) + self.headers["Allow"] = allow + self.allowed_methods = set(allowed_methods) # type: Set[str] self.method = method.upper() @@ -249,6 +308,14 @@ class HTTPPreconditionFailed(HTTPClientError): class HTTPRequestEntityTooLarge(HTTPClientError): status_code = 413 + def __init__(self, max_size: float, actual_size: float, **kwargs: Any) -> None: + kwargs.setdefault( + "text", + "Maximum request body size {} exceeded, " + "actual body size {}".format(max_size, actual_size), + ) + super().__init__(**kwargs) + class HTTPRequestURITooLong(HTTPClientError): status_code = 414 @@ -270,6 +337,14 @@ class HTTPMisdirectedRequest(HTTPClientError): status_code = 421 +class HTTPUnprocessableEntity(HTTPClientError): + status_code = 422 + + +class HTTPFailedDependency(HTTPClientError): + status_code = 424 + + class HTTPUpgradeRequired(HTTPClientError): status_code = 426 @@ -289,11 +364,24 @@ class HTTPRequestHeaderFieldsTooLarge(HTTPClientError): class HTTPUnavailableForLegalReasons(HTTPClientError): status_code = 451 - def __init__(self, link, *, headers=None, reason=None, - body=None, text=None, content_type=None): - super().__init__(headers=headers, reason=reason, - body=body, text=text, content_type=content_type) - self.headers['Link'] = '<%s>; rel="blocked-by"' % link + def __init__( + self, + link: str, + *, + headers: Optional[LooseHeaders] = None, + reason: Optional[str] = None, + body: Any = None, + text: Optional[str] = None, + content_type: Optional[str] = None, + ) -> None: + super().__init__( + headers=headers, + reason=reason, + body=body, + text=text, + content_type=content_type, + ) + self.headers["Link"] = '<%s>; rel="blocked-by"' % link self.link = link @@ -341,6 +429,10 @@ class HTTPVariantAlsoNegotiates(HTTPServerError): status_code = 506 +class HTTPInsufficientStorage(HTTPServerError): + status_code = 507 + + class HTTPNotExtended(HTTPServerError): status_code = 510 diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index 99af2b08051..0737c4f42d7 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -2,98 +2,54 @@ import mimetypes import os import pathlib +import sys +from typing import ( # noqa + IO, + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Optional, + Union, + cast, +) from . import hdrs -from .helpers import create_future -from .http_writer import PayloadWriter -from .log import server_logger -from .web_exceptions import (HTTPNotModified, HTTPOk, HTTPPartialContent, - HTTPRequestRangeNotSatisfiable) +from .abc import AbstractStreamWriter +from .typedefs import LooseHeaders +from .web_exceptions import ( + HTTPNotModified, + HTTPPartialContent, + HTTPPreconditionFailed, + HTTPRequestRangeNotSatisfiable, +) from .web_response import StreamResponse -__all__ = ('FileResponse',) +__all__ = ("FileResponse",) +if TYPE_CHECKING: # pragma: no cover + from .web_request import BaseRequest -NOSENDFILE = bool(os.environ.get("AIOHTTP_NOSENDFILE")) - - -class SendfilePayloadWriter(PayloadWriter): - - def set_transport(self, transport): - self._transport = transport - - if self._drain_waiter is not None: - waiter, self._drain_maiter = self._drain_maiter, None - if not waiter.done(): - waiter.set_result(None) - - def _write(self, chunk): - self.output_size += len(chunk) - self._buffer.append(chunk) - - def _sendfile_cb(self, fut, out_fd, in_fd, - offset, count, loop, registered): - if registered: - loop.remove_writer(out_fd) - if fut.cancelled(): - return - - try: - n = os.sendfile(out_fd, in_fd, offset, count) - if n == 0: # EOF reached - n = count - except (BlockingIOError, InterruptedError): - n = 0 - except Exception as exc: - fut.set_exception(exc) - return - - if n < count: - loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd, in_fd, - offset + n, count - n, loop, True) - else: - fut.set_result(None) - - @asyncio.coroutine - def sendfile(self, fobj, count): - if self._transport is None: - if self._drain_waiter is None: - self._drain_waiter = create_future(self.loop) - - yield from self._drain_waiter - out_socket = self._transport.get_extra_info("socket").dup() - out_socket.setblocking(False) - out_fd = out_socket.fileno() - in_fd = fobj.fileno() - offset = fobj.tell() +_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]] - loop = self.loop - try: - yield from loop.sock_sendall(out_socket, b''.join(self._buffer)) - fut = create_future(loop) - self._sendfile_cb(fut, out_fd, in_fd, offset, count, loop, False) - yield from fut - except: - server_logger.debug('Socket error') - self._transport.close() - finally: - out_socket.close() - self.output_size += count - self._transport = None - self._stream.release() - - @asyncio.coroutine - def write_eof(self, chunk=b''): - pass +NOSENDFILE = bool(os.environ.get("AIOHTTP_NOSENDFILE")) class FileResponse(StreamResponse): """A response object can be used to send files.""" - def __init__(self, path, chunk_size=256*1024, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + path: Union[str, pathlib.Path], + chunk_size: int = 256 * 1024, + status: int = 200, + reason: Optional[str] = None, + headers: Optional[LooseHeaders] = None, + ) -> None: + super().__init__(status=status, reason=reason, headers=headers) if isinstance(path, str): path = pathlib.Path(path) @@ -101,139 +57,187 @@ def __init__(self, path, chunk_size=256*1024, *args, **kwargs): self._path = path self._chunk_size = chunk_size - @asyncio.coroutine - def _sendfile_system(self, request, fobj, count): - # Write count bytes of fobj to resp using - # the os.sendfile system call. - # - # For details check - # https://github.com/KeepSafe/aiohttp/issues/1177 - # See https://github.com/KeepSafe/aiohttp/issues/958 for details - # - # request should be a aiohttp.web.Request instance. - # fobj should be an open file object. - # count should be an integer > 0. + async def _sendfile_fallback( + self, writer: AbstractStreamWriter, fobj: IO[Any], offset: int, count: int + ) -> AbstractStreamWriter: + # To keep memory usage low,fobj is transferred in chunks + # controlled by the constructor's chunk_size argument. - transport = request.transport - if transport.get_extra_info("sslcontext"): - writer = yield from self._sendfile_fallback(request, fobj, count) - else: - writer = request._protocol.writer.replace( - request._writer, SendfilePayloadWriter) - request._writer = writer - yield from super().prepare(request) - yield from writer.sendfile(fobj, count) + chunk_size = self._chunk_size + loop = asyncio.get_event_loop() + + await loop.run_in_executor(None, fobj.seek, offset) + + chunk = await loop.run_in_executor(None, fobj.read, chunk_size) + while chunk: + await writer.write(chunk) + count = count - chunk_size + if count <= 0: + break + chunk = await loop.run_in_executor(None, fobj.read, min(chunk_size, count)) + await writer.drain() return writer - @asyncio.coroutine - def _sendfile_fallback(self, request, fobj, count): - # Mimic the _sendfile_system() method, but without using the - # os.sendfile() system call. This should be used on systems - # that don't support the os.sendfile(). + async def _sendfile( + self, request: "BaseRequest", fobj: IO[Any], offset: int, count: int + ) -> AbstractStreamWriter: + writer = await super().prepare(request) + assert writer is not None - # To avoid blocking the event loop & to keep memory usage low, - # fobj is transferred in chunks controlled by the - # constructor's chunk_size argument. + if NOSENDFILE or sys.version_info < (3, 7) or self.compression: + return await self._sendfile_fallback(writer, fobj, offset, count) - writer = (yield from super().prepare(request)) + loop = request._loop + transport = request.transport + assert transport is not None - self.set_tcp_cork(True) try: - chunk_size = self._chunk_size - - chunk = fobj.read(chunk_size) - while True: - yield from writer.write(chunk) - count = count - chunk_size - if count <= 0: - break - chunk = fobj.read(min(chunk_size, count)) - finally: - self.set_tcp_nodelay(True) + await loop.sendfile(transport, fobj, offset, count) + except NotImplementedError: + return await self._sendfile_fallback(writer, fobj, offset, count) - yield from writer.drain() + await super().write_eof() return writer - if hasattr(os, "sendfile") and not NOSENDFILE: # pragma: no cover - _sendfile = _sendfile_system - else: # pragma: no cover - _sendfile = _sendfile_fallback - - @asyncio.coroutine - def prepare(self, request): + async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]: filepath = self._path gzip = False - if 'gzip' in request.headers.get(hdrs.ACCEPT_ENCODING, ''): - gzip_path = filepath.with_name(filepath.name + '.gz') + if "gzip" in request.headers.get(hdrs.ACCEPT_ENCODING, ""): + gzip_path = filepath.with_name(filepath.name + ".gz") if gzip_path.is_file(): filepath = gzip_path gzip = True - st = filepath.stat() + loop = asyncio.get_event_loop() + st = await loop.run_in_executor(None, filepath.stat) modsince = request.if_modified_since if modsince is not None and st.st_mtime <= modsince.timestamp(): self.set_status(HTTPNotModified.status_code) - return (yield from super().prepare(request)) - - ct, encoding = mimetypes.guess_type(str(filepath)) - if not ct: - ct = 'application/octet-stream' + self._length_check = False + # Delete any Content-Length headers provided by user. HTTP 304 + # should always have empty response body + return await super().prepare(request) + + unmodsince = request.if_unmodified_since + if unmodsince is not None and st.st_mtime > unmodsince.timestamp(): + self.set_status(HTTPPreconditionFailed.status_code) + return await super().prepare(request) + + if hdrs.CONTENT_TYPE not in self.headers: + ct, encoding = mimetypes.guess_type(str(filepath)) + if not ct: + ct = "application/octet-stream" + should_set_ct = True + else: + encoding = "gzip" if gzip else None + should_set_ct = False - status = HTTPOk.status_code + status = self._status file_size = st.st_size count = file_size - try: - rng = request.http_range - start = rng.start - end = rng.stop - except ValueError: - self.set_status(HTTPRequestRangeNotSatisfiable.status_code) - return (yield from super().prepare(request)) - - # If a range request has been made, convert start, end slice notation - # into file pointer offset and count - if start is not None or end is not None: - if start is None and end < 0: # return tail of file - start = file_size + end - count = -end - else: - count = (end or file_size) - start - - if start + count > file_size: - # rfc7233:If the last-byte-pos value is - # absent, or if the value is greater than or equal to - # the current length of the representation data, - # the byte range is interpreted as the remainder - # of the representation (i.e., the server replaces the - # value of last-byte-pos with a value that is one less than - # the current length of the selected representation). - count = file_size - start - - if start >= file_size: - count = 0 - - if count != file_size: - status = HTTPPartialContent.status_code - - self.set_status(status) - self.content_type = ct + start = None + + ifrange = request.if_range + if ifrange is None or st.st_mtime <= ifrange.timestamp(): + # If-Range header check: + # condition = cached date >= last modification date + # return 206 if True else 200. + # if False: + # Range header would not be processed, return 200 + # if True but Range header missing + # return 200 + try: + rng = request.http_range + start = rng.start + end = rng.stop + except ValueError: + # https://tools.ietf.org/html/rfc7233: + # A server generating a 416 (Range Not Satisfiable) response to + # a byte-range request SHOULD send a Content-Range header field + # with an unsatisfied-range value. + # The complete-length in a 416 response indicates the current + # length of the selected representation. + # + # Will do the same below. Many servers ignore this and do not + # send a Content-Range header with HTTP 416 + self.headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}" + self.set_status(HTTPRequestRangeNotSatisfiable.status_code) + return await super().prepare(request) + + # If a range request has been made, convert start, end slice + # notation into file pointer offset and count + if start is not None or end is not None: + if start < 0 and end is None: # return tail of file + start += file_size + if start < 0: + # if Range:bytes=-1000 in request header but file size + # is only 200, there would be trouble without this + start = 0 + count = file_size - start + else: + # rfc7233:If the last-byte-pos value is + # absent, or if the value is greater than or equal to + # the current length of the representation data, + # the byte range is interpreted as the remainder + # of the representation (i.e., the server replaces the + # value of last-byte-pos with a value that is one less than + # the current length of the selected representation). + count = ( + min(end if end is not None else file_size, file_size) - start + ) + + if start >= file_size: + # HTTP 416 should be returned in this case. + # + # According to https://tools.ietf.org/html/rfc7233: + # If a valid byte-range-set includes at least one + # byte-range-spec with a first-byte-pos that is less than + # the current length of the representation, or at least one + # suffix-byte-range-spec with a non-zero suffix-length, + # then the byte-range-set is satisfiable. Otherwise, the + # byte-range-set is unsatisfiable. + self.headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}" + self.set_status(HTTPRequestRangeNotSatisfiable.status_code) + return await super().prepare(request) + + status = HTTPPartialContent.status_code + # Even though you are sending the whole file, you should still + # return a HTTP 206 for a Range request. + self.set_status(status) + + if should_set_ct: + self.content_type = ct # type: ignore if encoding: self.headers[hdrs.CONTENT_ENCODING] = encoding if gzip: self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING - self.last_modified = st.st_mtime + self.last_modified = st.st_mtime # type: ignore self.content_length = count - if count: - with filepath.open('rb') as fobj: - if start: - fobj.seek(start) + self.headers[hdrs.ACCEPT_RANGES] = "bytes" + + real_start = cast(int, start) + + if status == HTTPPartialContent.status_code: + self.headers[hdrs.CONTENT_RANGE] = "bytes {}-{}/{}".format( + real_start, real_start + count - 1, file_size + ) - return (yield from self._sendfile(request, fobj, count)) + if request.method == hdrs.METH_HEAD or self.status in [204, 304]: + return await super().prepare(request) - return (yield from super().prepare(request)) + fobj = await loop.run_in_executor(None, filepath.open, "rb") + if start: # be aware that start could be None or int=0 here. + offset = start + else: + offset = 0 + + try: + return await self._sendfile(request, fobj, offset, count) + finally: + await loop.run_in_executor(None, fobj.close) diff --git a/aiohttp/web_log.py b/aiohttp/web_log.py new file mode 100644 index 00000000000..4cfa57929a9 --- /dev/null +++ b/aiohttp/web_log.py @@ -0,0 +1,208 @@ +import datetime +import functools +import logging +import os +import re +from collections import namedtuple +from typing import Any, Callable, Dict, Iterable, List, Tuple # noqa + +from .abc import AbstractAccessLogger +from .web_request import BaseRequest +from .web_response import StreamResponse + +KeyMethod = namedtuple("KeyMethod", "key method") + + +class AccessLogger(AbstractAccessLogger): + """Helper object to log access. + + Usage: + log = logging.getLogger("spam") + log_format = "%a %{User-Agent}i" + access_logger = AccessLogger(log, log_format) + access_logger.log(request, response, time) + + Format: + %% The percent sign + %a Remote IP-address (IP-address of proxy if using reverse proxy) + %t Time when the request was started to process + %P The process ID of the child that serviced the request + %r First line of request + %s Response status code + %b Size of response in bytes, including HTTP headers + %T Time taken to serve the request, in seconds + %Tf Time taken to serve the request, in seconds with floating fraction + in .06f format + %D Time taken to serve the request, in microseconds + %{FOO}i request.headers['FOO'] + %{FOO}o response.headers['FOO'] + %{FOO}e os.environ['FOO'] + + """ + + LOG_FORMAT_MAP = { + "a": "remote_address", + "t": "request_start_time", + "P": "process_id", + "r": "first_request_line", + "s": "response_status", + "b": "response_size", + "T": "request_time", + "Tf": "request_time_frac", + "D": "request_time_micro", + "i": "request_header", + "o": "response_header", + } + + LOG_FORMAT = '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"' + FORMAT_RE = re.compile(r"%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)") + CLEANUP_RE = re.compile(r"(%[^s])") + _FORMAT_CACHE = {} # type: Dict[str, Tuple[str, List[KeyMethod]]] + + def __init__(self, logger: logging.Logger, log_format: str = LOG_FORMAT) -> None: + """Initialise the logger. + + logger is a logger object to be used for logging. + log_format is a string with apache compatible log format description. + + """ + super().__init__(logger, log_format=log_format) + + _compiled_format = AccessLogger._FORMAT_CACHE.get(log_format) + if not _compiled_format: + _compiled_format = self.compile_format(log_format) + AccessLogger._FORMAT_CACHE[log_format] = _compiled_format + + self._log_format, self._methods = _compiled_format + + def compile_format(self, log_format: str) -> Tuple[str, List[KeyMethod]]: + """Translate log_format into form usable by modulo formatting + + All known atoms will be replaced with %s + Also methods for formatting of those atoms will be added to + _methods in appropriate order + + For example we have log_format = "%a %t" + This format will be translated to "%s %s" + Also contents of _methods will be + [self._format_a, self._format_t] + These method will be called and results will be passed + to translated string format. + + Each _format_* method receive 'args' which is list of arguments + given to self.log + + Exceptions are _format_e, _format_i and _format_o methods which + also receive key name (by functools.partial) + + """ + # list of (key, method) tuples, we don't use an OrderedDict as users + # can repeat the same key more than once + methods = list() + + for atom in self.FORMAT_RE.findall(log_format): + if atom[1] == "": + format_key1 = self.LOG_FORMAT_MAP[atom[0]] + m = getattr(AccessLogger, "_format_%s" % atom[0]) + key_method = KeyMethod(format_key1, m) + else: + format_key2 = (self.LOG_FORMAT_MAP[atom[2]], atom[1]) + m = getattr(AccessLogger, "_format_%s" % atom[2]) + key_method = KeyMethod(format_key2, functools.partial(m, atom[1])) + + methods.append(key_method) + + log_format = self.FORMAT_RE.sub(r"%s", log_format) + log_format = self.CLEANUP_RE.sub(r"%\1", log_format) + return log_format, methods + + @staticmethod + def _format_i( + key: str, request: BaseRequest, response: StreamResponse, time: float + ) -> str: + if request is None: + return "(no headers)" + + # suboptimal, make istr(key) once + return request.headers.get(key, "-") + + @staticmethod + def _format_o( + key: str, request: BaseRequest, response: StreamResponse, time: float + ) -> str: + # suboptimal, make istr(key) once + return response.headers.get(key, "-") + + @staticmethod + def _format_a(request: BaseRequest, response: StreamResponse, time: float) -> str: + if request is None: + return "-" + ip = request.remote + return ip if ip is not None else "-" + + @staticmethod + def _format_t(request: BaseRequest, response: StreamResponse, time: float) -> str: + now = datetime.datetime.utcnow() + start_time = now - datetime.timedelta(seconds=time) + return start_time.strftime("[%d/%b/%Y:%H:%M:%S +0000]") + + @staticmethod + def _format_P(request: BaseRequest, response: StreamResponse, time: float) -> str: + return "<%s>" % os.getpid() + + @staticmethod + def _format_r(request: BaseRequest, response: StreamResponse, time: float) -> str: + if request is None: + return "-" + return "{} {} HTTP/{}.{}".format( + request.method, + request.path_qs, + request.version.major, + request.version.minor, + ) + + @staticmethod + def _format_s(request: BaseRequest, response: StreamResponse, time: float) -> int: + return response.status + + @staticmethod + def _format_b(request: BaseRequest, response: StreamResponse, time: float) -> int: + return response.body_length + + @staticmethod + def _format_T(request: BaseRequest, response: StreamResponse, time: float) -> str: + return str(round(time)) + + @staticmethod + def _format_Tf(request: BaseRequest, response: StreamResponse, time: float) -> str: + return "%06f" % time + + @staticmethod + def _format_D(request: BaseRequest, response: StreamResponse, time: float) -> str: + return str(round(time * 1000000)) + + def _format_line( + self, request: BaseRequest, response: StreamResponse, time: float + ) -> Iterable[Tuple[str, Callable[[BaseRequest, StreamResponse, float], str]]]: + return [(key, method(request, response, time)) for key, method in self._methods] + + def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None: + try: + fmt_info = self._format_line(request, response, time) + + values = list() + extra = dict() + for key, value in fmt_info: + values.append(value) + + if key.__class__ is str: + extra[key] = value + else: + k1, k2 = key # type: ignore + dct = extra.get(k1, {}) # type: ignore + dct[k2] = value # type: ignore + extra[k1] = dct # type: ignore + + self.logger.info(self._log_format % tuple(values), extra=extra) + except Exception: + self.logger.exception("Error in logging") diff --git a/aiohttp/web_middlewares.py b/aiohttp/web_middlewares.py index 8676a154ab2..8a8967e8131 100644 --- a/aiohttp/web_middlewares.py +++ b/aiohttp/web_middlewares.py @@ -1,75 +1,121 @@ -import asyncio import re +from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, Type, TypeVar -from aiohttp.web_exceptions import HTTPMovedPermanently -from aiohttp.web_urldispatcher import SystemRoute +from .web_exceptions import HTTPPermanentRedirect, _HTTPMove +from .web_request import Request +from .web_response import StreamResponse +from .web_urldispatcher import SystemRoute __all__ = ( - 'normalize_path_middleware', + "middleware", + "normalize_path_middleware", ) +if TYPE_CHECKING: # pragma: no cover + from .web_app import Application -@asyncio.coroutine -def _check_request_resolves(request, path): +_Func = TypeVar("_Func") + + +async def _check_request_resolves(request: Request, path: str) -> Tuple[bool, Request]: alt_request = request.clone(rel_url=path) - match_info = yield from request.app.router.resolve(alt_request) - alt_request._match_info = match_info + match_info = await request.app.router.resolve(alt_request) + alt_request._match_info = match_info # type: ignore - if not isinstance(match_info.route, SystemRoute): + if match_info.http_exception is None: return True, alt_request return False, request +def middleware(f: _Func) -> _Func: + f.__middleware_version__ = 1 # type: ignore + return f + + +_Handler = Callable[[Request], Awaitable[StreamResponse]] +_Middleware = Callable[[Request, _Handler], Awaitable[StreamResponse]] + + def normalize_path_middleware( - *, append_slash=True, merge_slashes=True, - redirect_class=HTTPMovedPermanently): + *, + append_slash: bool = True, + remove_slash: bool = False, + merge_slashes: bool = True, + redirect_class: Type[_HTTPMove] = HTTPPermanentRedirect +) -> _Middleware: """ - Middleware that normalizes the path of a request. By normalizing - it means: + Middleware factory which produces a middleware that normalizes + the path of a request. By normalizing it means: - - Add a trailing slash to the path. + - Add or remove a trailing slash to the path. - Double slashes are replaced by one. The middleware returns as soon as it finds a path that resolves - correctly. The order if all enable is 1) merge_slashes, 2) append_slash - and 3) both merge_slashes and append_slash. If the path resolves with - at least one of those conditions, it will redirect to the new path. + correctly. The order if both merge and append/remove are enabled is + 1) merge slashes + 2) append/remove slash + 3) both merge slashes and append/remove slash. + If the path resolves with at least one of those conditions, it will + redirect to the new path. - If append_slash is True append slash when needed. If a resource is - defined with trailing slash and the request comes without it, it will - append it automatically. + Only one of `append_slash` and `remove_slash` can be enabled. If both + are `True` the factory will raise an assertion error + + If `append_slash` is `True` the middleware will append a slash when + needed. If a resource is defined with trailing slash and the request + comes without it, it will append it automatically. + + If `remove_slash` is `True`, `append_slash` must be `False`. When enabled + the middleware will remove trailing slashes and redirect if the resource + is defined If merge_slashes is True, merge multiple consecutive slashes in the path into one. """ - @asyncio.coroutine - def normalize_path_factory(app, handler): + correct_configuration = not (append_slash and remove_slash) + assert correct_configuration, "Cannot both remove and append slash" + + @middleware + async def impl(request: Request, handler: _Handler) -> StreamResponse: + if isinstance(request.match_info.route, SystemRoute): + paths_to_check = [] + if "?" in request.raw_path: + path, query = request.raw_path.split("?", 1) + query = "?" + query + else: + query = "" + path = request.raw_path - @asyncio.coroutine - def middleware(request): + if merge_slashes: + paths_to_check.append(re.sub("//+", "/", path)) + if append_slash and not request.path.endswith("/"): + paths_to_check.append(path + "/") + if remove_slash and request.path.endswith("/"): + paths_to_check.append(path[:-1]) + if merge_slashes and append_slash: + paths_to_check.append(re.sub("//+", "/", path + "/")) + if merge_slashes and remove_slash: + merged_slashes = re.sub("//+", "/", path) + paths_to_check.append(merged_slashes[:-1]) - if isinstance(request.match_info.route, SystemRoute): - paths_to_check = [] - path = request.raw_path - if merge_slashes: - paths_to_check.append(re.sub('//+', '/', path)) - if append_slash and not request.path.endswith('/'): - paths_to_check.append(path + '/') - if merge_slashes and append_slash: - paths_to_check.append( - re.sub('//+', '/', path + '/')) + for path in paths_to_check: + path = re.sub("^//+", "/", path) # SECURITY: GHSA-v6wp-4m6f-gcjg + resolves, request = await _check_request_resolves(request, path) + if resolves: + raise redirect_class(request.raw_path + query) + + return await handler(request) - for path in paths_to_check: - resolves, request = yield from _check_request_resolves( - request, path) - if resolves: - return redirect_class(request.path) + return impl - return (yield from handler(request)) - return middleware +def _fix_request_current_app(app: "Application") -> _Middleware: + @middleware + async def impl(request: Request, handler: _Handler) -> StreamResponse: + with request.match_info.set_current_app(app): + return await handler(request) - return normalize_path_factory + return impl diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 80596a80f15..8e02bc4aab7 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -1,43 +1,68 @@ import asyncio import asyncio.streams -import http.server -import socket import traceback import warnings from collections import deque from contextlib import suppress from html import escape as html_escape - -from . import helpers, http -from .helpers import CeilTimeout, create_future, ensure_future -from .http import (HttpProcessingError, HttpRequestParser, PayloadWriter, - StreamWriter) +from http import HTTPStatus +from logging import Logger +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, cast + +import yarl + +from .abc import AbstractAccessLogger, AbstractStreamWriter +from .base_protocol import BaseProtocol +from .helpers import CeilTimeout, current_task +from .http import ( + HttpProcessingError, + HttpRequestParser, + HttpVersion10, + RawRequestMessage, + StreamWriter, +) from .log import access_logger, server_logger -from .streams import EMPTY_PAYLOAD +from .streams import EMPTY_PAYLOAD, StreamReader +from .tcp_helpers import tcp_keepalive from .web_exceptions import HTTPException +from .web_log import AccessLogger from .web_request import BaseRequest -from .web_response import Response +from .web_response import Response, StreamResponse -__all__ = ('RequestHandler', 'RequestPayloadError') +__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError") -ERROR = http.RawRequestMessage( - 'UNKNOWN', '/', http.HttpVersion10, {}, - {}, True, False, False, False, http.URL('/')) +if TYPE_CHECKING: # pragma: no cover + from .web_server import Server + + +_RequestFactory = Callable[ + [ + RawRequestMessage, + StreamReader, + "RequestHandler", + AbstractStreamWriter, + "asyncio.Task[None]", + ], + BaseRequest, +] + +_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]] -if hasattr(socket, 'SO_KEEPALIVE'): - def tcp_keepalive(server, transport): - sock = transport.get_extra_info('socket') - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) -else: - def tcp_keepalive(server, transport): # pragma: no cover - pass + +ERROR = RawRequestMessage( + "UNKNOWN", "/", HttpVersion10, {}, {}, True, False, False, False, yarl.URL("/") +) class RequestPayloadError(Exception): """Payload parsing error.""" -class RequestHandler(asyncio.streams.FlowControlMixin, asyncio.Protocol): +class PayloadAccessError(Exception): + """Payload was accessed after response was sent.""" + + +class RequestHandler(BaseProtocol): """HTTP protocol implementation. RequestHandler handles incoming HTTP request. It reads request line, @@ -48,8 +73,6 @@ class RequestHandler(asyncio.streams.FlowControlMixin, asyncio.Protocol): status line, bad headers or incomplete payload. If any error occurs, connection gets closed. - :param time_service: Low resolution time service - :param keepalive_timeout: number of seconds before closing keep-alive connection :type keepalive_timeout: int or None @@ -61,6 +84,9 @@ class RequestHandler(asyncio.streams.FlowControlMixin, asyncio.Protocol): :param logger: custom logger object :type logger: aiohttp.log.server_logger + :param access_log_class: custom class for access_logger + :type access_log_class: aiohttp.abc.AbstractAccessLogger + :param access_log: custom logging object :type access_log: aiohttp.log.server_logger @@ -75,102 +101,120 @@ class RequestHandler(asyncio.streams.FlowControlMixin, asyncio.Protocol): :param int max_headers: Optional maximum header size """ - _request_count = 0 - _keepalive = False # keep transport open - - def __init__(self, manager, *, loop=None, - keepalive_timeout=75, # NGINX default value is 75 secs - tcp_keepalive=True, - slow_request_timeout=None, - logger=server_logger, - access_log=access_logger, - access_log_format=helpers.AccessLogger.LOG_FORMAT, - debug=False, - max_line_size=8190, - max_headers=32768, - max_field_size=8190, - lingering_time=10.0, - max_concurrent_handlers=2, - **kwargs): - - # process deprecated params - logger = kwargs.get('logger', logger) - - if slow_request_timeout is not None: - warnings.warn( - 'slow_request_timeout is deprecated', DeprecationWarning) - - super().__init__(loop=loop) - - self._loop = loop if loop is not None else asyncio.get_event_loop() - - self._manager = manager - self._time_service = manager.time_service - self._request_handler = manager.request_handler - self._request_factory = manager.request_factory + + KEEPALIVE_RESCHEDULE_DELAY = 1 + + __slots__ = ( + "_request_count", + "_keepalive", + "_manager", + "_request_handler", + "_request_factory", + "_tcp_keepalive", + "_keepalive_time", + "_keepalive_handle", + "_keepalive_timeout", + "_lingering_time", + "_messages", + "_message_tail", + "_waiter", + "_error_handler", + "_task_handler", + "_upgrade", + "_payload_parser", + "_request_parser", + "_reading_paused", + "logger", + "debug", + "access_log", + "access_logger", + "_close", + "_force_close", + "_current_request", + ) + + def __init__( + self, + manager: "Server", + *, + loop: asyncio.AbstractEventLoop, + keepalive_timeout: float = 75.0, # NGINX default is 75 secs + tcp_keepalive: bool = True, + logger: Logger = server_logger, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log: Logger = access_logger, + access_log_format: str = AccessLogger.LOG_FORMAT, + debug: bool = False, + max_line_size: int = 8190, + max_headers: int = 32768, + max_field_size: int = 8190, + lingering_time: float = 10.0, + read_bufsize: int = 2 ** 16, + ): + + super().__init__(loop) + + self._request_count = 0 + self._keepalive = False + self._current_request = None # type: Optional[BaseRequest] + self._manager = manager # type: Optional[Server] + self._request_handler = ( + manager.request_handler + ) # type: Optional[_RequestHandler] + self._request_factory = ( + manager.request_factory + ) # type: Optional[_RequestFactory] self._tcp_keepalive = tcp_keepalive - self._keepalive_time = None - self._keepalive_handle = None + # placeholder to be replaced on keepalive timeout setup + self._keepalive_time = 0.0 + self._keepalive_handle = None # type: Optional[asyncio.Handle] self._keepalive_timeout = keepalive_timeout self._lingering_time = float(lingering_time) - self._messages = deque() - self._message_tail = b'' + self._messages = deque() # type: Any # Python 3.5 has no typing.Deque + self._message_tail = b"" - self._waiters = deque() - self._error_handler = None - self._request_handlers = [] - self._max_concurrent_handlers = max_concurrent_handlers + self._waiter = None # type: Optional[asyncio.Future[None]] + self._error_handler = None # type: Optional[asyncio.Task[None]] + self._task_handler = None # type: Optional[asyncio.Task[None]] self._upgrade = False - self._payload_parser = None + self._payload_parser = None # type: Any self._request_parser = HttpRequestParser( - self, loop, + self, + loop, + read_bufsize, max_line_size=max_line_size, max_field_size=max_field_size, max_headers=max_headers, - payload_exception=RequestPayloadError) - - self.transport = None - self._reading_paused = False + payload_exception=RequestPayloadError, + ) # type: Optional[HttpRequestParser] self.logger = logger self.debug = debug self.access_log = access_log if access_log: - self.access_logger = helpers.AccessLogger( - access_log, access_log_format) + self.access_logger = access_log_class( + access_log, access_log_format + ) # type: Optional[AbstractAccessLogger] else: self.access_logger = None self._close = False self._force_close = False - def __repr__(self): - self._request = None - if self._request is None: - meth = 'none' - path = 'none' - else: - meth = 'none' - path = 'none' - # meth = self._request.method - # path = self._request.rel_url.raw_path - return "<{} {}:{} {}>".format( - self.__class__.__name__, meth, path, - 'connected' if self.transport is not None else 'disconnected') - - @property - def time_service(self): - return self._time_service + def __repr__(self) -> str: + return "<{} {}>".format( + self.__class__.__name__, + "connected" if self.transport is not None else "disconnected", + ) @property - def keepalive_timeout(self): + def keepalive_timeout(self) -> float: return self._keepalive_timeout - @asyncio.coroutine - def shutdown(self, timeout=15.0): + async def shutdown(self, timeout: Optional[float] = 15.0) -> None: """Worker process is about to exit, we need cleanup everything and stop accepting requests. It is especially important for keep-alive connections.""" @@ -179,53 +223,43 @@ def shutdown(self, timeout=15.0): if self._keepalive_handle is not None: self._keepalive_handle.cancel() - # cancel waiters - for waiter in self._waiters: - if not waiter.done(): - waiter.cancel() + if self._waiter: + self._waiter.cancel() # wait for handlers with suppress(asyncio.CancelledError, asyncio.TimeoutError): with CeilTimeout(timeout, loop=self._loop): - if self._error_handler and not self._error_handler.done(): - yield from self._error_handler - - while True: - h = None - for handler in self._request_handlers: - if not handler.done(): - h = handler - break - if h: - yield from h - else: - break + if self._error_handler is not None and not self._error_handler.done(): + await self._error_handler + + if self._current_request is not None: + self._current_request._cancel(asyncio.CancelledError()) - # force-close non-idle handlers - for handler in self._request_handlers: - if not handler.done(): - handler.cancel() + if self._task_handler is not None and not self._task_handler.done(): + await self._task_handler + + # force-close non-idle handler + if self._task_handler is not None: + self._task_handler.cancel() if self.transport is not None: self.transport.close() self.transport = None - if self._request_handlers: - self._request_handlers.clear() - - def connection_made(self, transport): + def connection_made(self, transport: asyncio.BaseTransport) -> None: super().connection_made(transport) - self.transport = transport - self.writer = StreamWriter(self, transport, self._loop) - + real_transport = cast(asyncio.Transport, transport) if self._tcp_keepalive: - tcp_keepalive(self, transport) + tcp_keepalive(real_transport) - self.writer.set_tcp_nodelay(True) - self._manager.connection_made(self, transport) + self._task_handler = self._loop.create_task(self.start()) + assert self._manager is not None + self._manager.connection_made(self, real_transport) - def connection_lost(self, exc): + def connection_lost(self, exc: Optional[BaseException]) -> None: + if self._manager is None: + return self._manager.connection_lost(self, exc) super().connection_lost(exc) @@ -235,79 +269,78 @@ def connection_lost(self, exc): self._request_factory = None self._request_handler = None self._request_parser = None - self.transport = self.writer = None - - if self._payload_parser is not None: - self._payload_parser.feed_eof() - self._payload_parser = None if self._keepalive_handle is not None: self._keepalive_handle.cancel() - for handler in self._request_handlers: - if not handler.done(): - handler.cancel() + if self._current_request is not None: + if exc is None: + exc = ConnectionResetError("Connection lost") + self._current_request._cancel(exc) if self._error_handler is not None: - if not self._error_handler.done(): - self._error_handler.cancel() + self._error_handler.cancel() + if self._task_handler is not None: + self._task_handler.cancel() + if self._waiter is not None: + self._waiter.cancel() - self._request_handlers = () + self._task_handler = None + + if self._payload_parser is not None: + self._payload_parser.feed_eof() + self._payload_parser = None - def set_parser(self, parser): + def set_parser(self, parser: Any) -> None: + # Actual type is WebReader assert self._payload_parser is None self._payload_parser = parser if self._message_tail: self._payload_parser.feed_data(self._message_tail) - self._message_tail = b'' + self._message_tail = b"" - def eof_received(self): + def eof_received(self) -> None: pass - def data_received(self, data): + def data_received(self, data: bytes) -> None: if self._force_close or self._close: return - # parse http messages if self._payload_parser is None and not self._upgrade: + assert self._request_parser is not None try: messages, upgraded, tail = self._request_parser.feed_data(data) except HttpProcessingError as exc: # something happened during parsing - self.close() - self._error_handler = ensure_future( + self._error_handler = self._loop.create_task( self.handle_parse_error( - PayloadWriter(self.writer, self._loop), - 400, exc, exc.message), - loop=self._loop) + StreamWriter(self, self._loop), 400, exc, exc.message + ) + ) + self.close() except Exception as exc: # 500: internal error + self._error_handler = self._loop.create_task( + self.handle_parse_error(StreamWriter(self, self._loop), 500, exc) + ) self.close() - self._error_handler = ensure_future( - self.handle_parse_error( - PayloadWriter(self.writer, self._loop), - 500, exc), loop=self._loop) else: - for (msg, payload) in messages: - self._request_count += 1 - - if self._waiters: - waiter = self._waiters.popleft() - waiter.set_result((msg, payload)) - elif self._max_concurrent_handlers: - self._max_concurrent_handlers -= 1 - data = [] - handler = ensure_future( - self.start(msg, payload, data), loop=self._loop) - data.append(handler) - self._request_handlers.append(handler) - else: + if messages: + # sometimes the parser returns no messages + for (msg, payload) in messages: + self._request_count += 1 self._messages.append((msg, payload)) - self._upgraded = upgraded - if upgraded: + waiter = self._waiter + if waiter is not None: + if not waiter.done(): + # don't set result twice + waiter.set_result(None) + + self._upgrade = upgraded + if upgraded and tail: self._message_tail = tail # no parser, just store @@ -320,78 +353,96 @@ def data_received(self, data): if eof: self.close() - def keep_alive(self, val): + def keep_alive(self, val: bool) -> None: """Set keep-alive connection mode. :param bool val: new state. """ self._keepalive = val + if self._keepalive_handle: + self._keepalive_handle.cancel() + self._keepalive_handle = None - def close(self): + def close(self) -> None: """Stop accepting new pipelinig messages and close connection when handlers done processing messages""" self._close = True - for waiter in self._waiters: - if not waiter.done(): - waiter.cancel() + if self._waiter: + self._waiter.cancel() - def force_close(self): + def force_close(self) -> None: """Force close connection""" self._force_close = True - for waiter in self._waiters: - if not waiter.done(): - waiter.cancel() + if self._waiter: + self._waiter.cancel() if self.transport is not None: self.transport.close() self.transport = None - def log_access(self, message, environ, response, time): - if self.access_logger: - self.access_logger.log(message, environ, response, - self.transport, time) + def log_access( + self, request: BaseRequest, response: StreamResponse, time: float + ) -> None: + if self.access_logger is not None: + self.access_logger.log(request, response, self._loop.time() - time) - def log_debug(self, *args, **kw): + def log_debug(self, *args: Any, **kw: Any) -> None: if self.debug: self.logger.debug(*args, **kw) - def log_exception(self, *args, **kw): + def log_exception(self, *args: Any, **kw: Any) -> None: self.logger.exception(*args, **kw) - def _process_keepalive(self): - if self._force_close: + def _process_keepalive(self) -> None: + if self._force_close or not self._keepalive: return next = self._keepalive_time + self._keepalive_timeout - # all handlers in idle state - if len(self._request_handlers) == len(self._waiters): - now = self._time_service.loop_time - if now + 1.0 > next: + # handler in idle state + if self._waiter: + if self._loop.time() > next: self.force_close() return - self._keepalive_handle = self._loop.call_at( - next, self._process_keepalive) - - def pause_reading(self): - if not self._reading_paused: + # not all request handlers are done, + # reschedule itself to next second + self._keepalive_handle = self._loop.call_later( + self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive + ) + + async def _handle_request( + self, + request: BaseRequest, + start_time: float, + ) -> Tuple[StreamResponse, bool]: + assert self._request_handler is not None + try: try: - self.transport.pause_reading() - except (AttributeError, NotImplementedError, RuntimeError): - pass - self._reading_paused = True + self._current_request = request + resp = await self._request_handler(request) + finally: + self._current_request = None + except HTTPException as exc: + resp = Response( + status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers + ) + reset = await self.finish_response(request, resp, start_time) + except asyncio.CancelledError: + raise + except asyncio.TimeoutError as exc: + self.log_debug("Request handler timed out.", exc_info=exc) + resp = self.handle_error(request, 504) + reset = await self.finish_response(request, resp, start_time) + except Exception as exc: + resp = self.handle_error(request, 500, exc) + reset = await self.finish_response(request, resp, start_time) + else: + reset = await self.finish_response(request, resp, start_time) - def resume_reading(self): - if self._reading_paused: - try: - self.transport.resume_reading() - except (AttributeError, NotImplementedError, RuntimeError): - pass - self._reading_paused = False + return resp, reset - @asyncio.coroutine - def start(self, message, payload, handler): - """Start processing of incoming requests. + async def start(self) -> None: + """Process incoming request. It reads request line, request headers and request payload, then calls handle_request() method. Subclass has to override @@ -400,136 +451,191 @@ def start(self, message, payload, handler): keep_alive(True) specified. """ loop = self._loop - handler = handler[0] + handler = self._task_handler + assert handler is not None manager = self._manager + assert manager is not None keepalive_timeout = self._keepalive_timeout + resp = None + assert self._request_factory is not None + assert self._request_handler is not None while not self._force_close: - if self.access_log: - now = loop.time() - - manager.requests_count += 1 - writer = PayloadWriter(self.writer, loop) - request = self._request_factory( - message, payload, self, writer, handler) - try: + if not self._messages: try: - resp = yield from self._request_handler(request) - except HTTPException as exc: - resp = exc + # wait for next request + self._waiter = loop.create_future() + await self._waiter except asyncio.CancelledError: - self.log_debug('Ignored premature client disconnection') break - except asyncio.TimeoutError: - self.log_debug('Request handler timed out.') - resp = self.handle_error(request, 504) - except Exception as exc: - resp = self.handle_error(request, 500, exc) + finally: + self._waiter = None - yield from resp.prepare(request) - yield from resp.write_eof() + message, payload = self._messages.popleft() - # notify server about keep-alive - self._keepalive = resp.keep_alive + start = loop.time() - # Restore default state. - # Should be no-op if server code didn't touch these attributes. - writer.set_tcp_cork(False) - writer.set_tcp_nodelay(True) + manager.requests_count += 1 + writer = StreamWriter(self, loop) + request = self._request_factory(message, payload, self, writer, handler) + try: + # a new task is used for copy context vars (#3406) + task = self._loop.create_task(self._handle_request(request, start)) + try: + resp, reset = await task + except (asyncio.CancelledError, ConnectionError): + self.log_debug("Ignored premature client disconnection") + break + # Deprecation warning (See #2415) + if getattr(resp, "__http_exception__", False): + warnings.warn( + "returning HTTPException object is deprecated " + "(#2415) and will be removed, " + "please raise the exception instead", + DeprecationWarning, + ) + + # Drop the processed task from asyncio.Task.all_tasks() early + del task + if reset: + self.log_debug("Ignored premature client disconnection 2") + break - # log access - if self.access_log: - self.log_access(message, None, resp, loop.time() - now) + # notify server about keep-alive + self._keepalive = bool(resp.keep_alive) # check payload if not payload.is_eof(): lingering_time = self._lingering_time if not self._force_close and lingering_time: self.log_debug( - 'Start lingering close timer for %s sec.', - lingering_time) + "Start lingering close timer for %s sec.", lingering_time + ) now = loop.time() end_t = now + lingering_time - with suppress( - asyncio.TimeoutError, asyncio.CancelledError): - while (not payload.is_eof() and now < end_t): - timeout = min(end_t - now, lingering_time) - with CeilTimeout(timeout, loop=loop): + with suppress(asyncio.TimeoutError, asyncio.CancelledError): + while not payload.is_eof() and now < end_t: + with CeilTimeout(end_t - now, loop=loop): # read and ignore - yield from payload.readany() + await payload.readany() now = loop.time() # if payload still uncompleted if not payload.is_eof() and not self._force_close: - self.log_debug('Uncompleted request.') + self.log_debug("Uncompleted request.") self.close() + payload.set_exception(PayloadAccessError()) + + except asyncio.CancelledError: + self.log_debug("Ignored premature client disconnection ") + break + except RuntimeError as exc: + if self.debug: + self.log_exception("Unhandled runtime exception", exc_info=exc) + self.force_close() except Exception as exc: - self.log_exception('Unhandled exception', exc_info=exc) + self.log_exception("Unhandled exception", exc_info=exc) self.force_close() finally: - if self.transport is None: - self.log_debug('Ignored premature client disconnection.') + if self.transport is None and resp is not None: + self.log_debug("Ignored premature client disconnection.") elif not self._force_close: - if self._messages: - message, payload = self._messages.popleft() + if self._keepalive and not self._close: + # start keep-alive timer + if keepalive_timeout is not None: + now = self._loop.time() + self._keepalive_time = now + if self._keepalive_handle is None: + self._keepalive_handle = loop.call_at( + now + keepalive_timeout, self._process_keepalive + ) else: - if self._keepalive and not self._close: - # start keep-alive timer - if keepalive_timeout is not None: - now = self._time_service.loop_time - self._keepalive_time = now - if self._keepalive_handle is None: - self._keepalive_handle = loop.call_at( - now + keepalive_timeout, - self._process_keepalive) - - # wait for next request - waiter = create_future(loop) - self._waiters.append(waiter) - try: - message, payload = yield from waiter - except asyncio.CancelledError: - # shutdown process - break - else: - break + break # remove handler, close transport if no handlers left if not self._force_close: - self._request_handlers.remove(handler) - if not self._request_handlers: - if self.transport is not None: - self.transport.close() + self._task_handler = None + if self.transport is not None and self._error_handler is None: + self.transport.close() - def handle_error(self, request, status=500, exc=None, message=None): + async def finish_response( + self, request: BaseRequest, resp: StreamResponse, start_time: float + ) -> bool: + """ + Prepare the response and write_eof, then log access. This has to + be called within the context of any exception so the access logger + can get exception information. Returns True if the client disconnects + prematurely. + """ + if self._request_parser is not None: + self._request_parser.set_upgraded(False) + self._upgrade = False + if self._message_tail: + self._request_parser.feed_data(self._message_tail) + self._message_tail = b"" + try: + prepare_meth = resp.prepare + except AttributeError: + if resp is None: + raise RuntimeError("Missing return " "statement on request handler") + else: + raise RuntimeError( + "Web-handler should return " + "a response instance, " + "got {!r}".format(resp) + ) + try: + await prepare_meth(request) + await resp.write_eof() + except ConnectionError: + self.log_access(request, resp, start_time) + return True + else: + self.log_access(request, resp, start_time) + return False + + def handle_error( + self, + request: BaseRequest, + status: int = 500, + exc: Optional[BaseException] = None, + message: Optional[str] = None, + ) -> StreamResponse: """Handle errors. Returns HTTP response with specific status code. Logs additional information. It always closes current connection.""" self.log_exception("Error handling request", exc_info=exc) - if status == 500: - msg = "

500 Internal Server Error

" + ct = "text/plain" + if status == HTTPStatus.INTERNAL_SERVER_ERROR: + title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR) + msg = HTTPStatus.INTERNAL_SERVER_ERROR.description + tb = None if self.debug: - try: + with suppress(Exception): tb = traceback.format_exc() + + if "text/html" in request.headers.get("Accept", ""): + if tb: tb = html_escape(tb) - msg += '

Traceback:

\n
'
-                    msg += tb
-                    msg += '
' - except: # pragma: no cover - pass + msg = f"

Traceback:

\n
{tb}
" + message = ( + "" + "{title}" + " \n

{title}

" + "\n{msg}\n\n" + ).format(title=title, msg=msg) + ct = "text/html" else: - msg += "Server got itself in trouble" - msg = ("500 Internal Server Error" - " " + msg + "") - else: - msg = message + if tb: + msg = tb + message = title + "\n\n" + msg - resp = Response(status=status, text=msg, content_type='text/html') + resp = Response(status=status, text=message, content_type=ct) resp.force_close() # some data already got sent, connection is broken @@ -538,17 +644,24 @@ def handle_error(self, request, status=500, exc=None, message=None): return resp - @asyncio.coroutine - def handle_parse_error(self, writer, status, exc=None, message=None): + async def handle_parse_error( + self, + writer: AbstractStreamWriter, + status: int, + exc: Optional[BaseException] = None, + message: Optional[str] = None, + ) -> None: + task = current_task() + assert task is not None request = BaseRequest( - ERROR, EMPTY_PAYLOAD, - self, writer, self._time_service, None) + ERROR, EMPTY_PAYLOAD, self, writer, task, self._loop # type: ignore + ) resp = self.handle_error(request, status, exc, message) - yield from resp.prepare(request) - yield from resp.write_eof() + await resp.prepare(request) + await resp.write_eof() - # Restore default state. - # Should be no-op if server code didn't touch these attributes. - self.writer.set_tcp_cork(False) - self.writer.set_tcp_nodelay(True) + if self.transport is not None: + self.transport.close() + + self._error_handler = None diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 9f4ce494976..f11e7be44be 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -1,61 +1,190 @@ import asyncio -import collections import datetime -import json +import io import re +import socket +import string import tempfile +import types import warnings from email.utils import parsedate +from http.cookies import SimpleCookie from types import MappingProxyType +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterator, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, + cast, +) from urllib.parse import parse_qsl -from multidict import CIMultiDict, MultiDict, MultiDictProxy +import attr +from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy from yarl import URL -from . import hdrs, multipart -from .helpers import HeadersMixin, SimpleCookie, reify, sentinel +from . import hdrs +from .abc import AbstractStreamWriter +from .helpers import DEBUG, ChainMapProxy, HeadersMixin, reify, sentinel +from .http_parser import RawRequestMessage +from .http_writer import HttpVersion +from .multipart import BodyPartReader, MultipartReader +from .streams import EmptyStreamReader, StreamReader +from .typedefs import ( + DEFAULT_JSON_DECODER, + JSONDecoder, + LooseHeaders, + RawHeaders, + StrOrURL, +) from .web_exceptions import HTTPRequestEntityTooLarge +from .web_response import StreamResponse -__all__ = ('BaseRequest', 'FileField', 'Request') +__all__ = ("BaseRequest", "FileField", "Request") -FileField = collections.namedtuple( - 'Field', 'name filename file content_type headers') +if TYPE_CHECKING: # pragma: no cover + from .web_app import Application + from .web_protocol import RequestHandler + from .web_urldispatcher import UrlMappingMatchInfo + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class FileField: + name: str + filename: str + file: io.BufferedReader + content_type: str + headers: "CIMultiDictProxy[str]" + + +_TCHAR = string.digits + string.ascii_letters + r"!#$%&'*+.^_`|~-" +# '-' at the end to prevent interpretation as range in a char class + +_TOKEN = fr"[{_TCHAR}]+" + +_QDTEXT = r"[{}]".format( + r"".join(chr(c) for c in (0x09, 0x20, 0x21) + tuple(range(0x23, 0x7F))) +) +# qdtext includes 0x5C to escape 0x5D ('\]') +# qdtext excludes obs-text (because obsoleted, and encoding not specified) + +_QUOTED_PAIR = r"\\[\t !-~]" + +_QUOTED_STRING = r'"(?:{quoted_pair}|{qdtext})*"'.format( + qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR +) + +_FORWARDED_PAIR = r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format( + token=_TOKEN, quoted_string=_QUOTED_STRING +) + +_QUOTED_PAIR_REPLACE_RE = re.compile(r"\\([\t !-~])") +# same pattern as _QUOTED_PAIR but contains a capture group + +_FORWARDED_PAIR_RE = re.compile(_FORWARDED_PAIR) ############################################################ # HTTP Request ############################################################ -class BaseRequest(collections.MutableMapping, HeadersMixin): - - POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT, - hdrs.METH_TRACE, hdrs.METH_DELETE} - - def __init__(self, message, payload, protocol, writer, time_service, task, - *, secure_proxy_ssl_header=None, client_max_size=1024**2): +class BaseRequest(MutableMapping[str, Any], HeadersMixin): + + POST_METHODS = { + hdrs.METH_PATCH, + hdrs.METH_POST, + hdrs.METH_PUT, + hdrs.METH_TRACE, + hdrs.METH_DELETE, + } + + ATTRS = HeadersMixin.ATTRS | frozenset( + [ + "_message", + "_protocol", + "_payload_writer", + "_payload", + "_headers", + "_method", + "_version", + "_rel_url", + "_post", + "_read_bytes", + "_state", + "_cache", + "_task", + "_client_max_size", + "_loop", + "_transport_sslcontext", + "_transport_peername", + ] + ) + + def __init__( + self, + message: RawRequestMessage, + payload: StreamReader, + protocol: "RequestHandler", + payload_writer: AbstractStreamWriter, + task: "asyncio.Task[None]", + loop: asyncio.AbstractEventLoop, + *, + client_max_size: int = 1024 ** 2, + state: Optional[Dict[str, Any]] = None, + scheme: Optional[str] = None, + host: Optional[str] = None, + remote: Optional[str] = None, + ) -> None: + if state is None: + state = {} self._message = message self._protocol = protocol - self._transport = protocol.transport - self._writer = writer + self._payload_writer = payload_writer self._payload = payload self._headers = message.headers self._method = message.method self._version = message.version self._rel_url = message.url - self._post = None - self._read_bytes = None + self._post = ( + None + ) # type: Optional[MultiDictProxy[Union[str, bytes, FileField]]] + self._read_bytes = None # type: Optional[bytes] - self._secure_proxy_ssl_header = secure_proxy_ssl_header - self._time_service = time_service - self._state = {} - self._cache = {} + self._state = state + self._cache = {} # type: Dict[str, Any] self._task = task self._client_max_size = client_max_size - - def clone(self, *, method=sentinel, rel_url=sentinel, - headers=sentinel): + self._loop = loop + + transport = self._protocol.transport + assert transport is not None + self._transport_sslcontext = transport.get_extra_info("sslcontext") + self._transport_peername = transport.get_extra_info("peername") + + if scheme is not None: + self._cache["scheme"] = scheme + if host is not None: + self._cache["host"] = host + if remote is not None: + self._cache["remote"] = remote + + def clone( + self, + *, + method: str = sentinel, + rel_url: StrOrURL = sentinel, + headers: LooseHeaders = sentinel, + scheme: str = sentinel, + host: str = sentinel, + remote: str = sentinel, + ) -> "BaseRequest": """Clone itself with replacement some attributes. Creates and returns a new instance of Request object. If no parameters @@ -65,104 +194,189 @@ def clone(self, *, method=sentinel, rel_url=sentinel, """ if self._read_bytes: - raise RuntimeError("Cannot clone request " - "after reading it's content") + raise RuntimeError("Cannot clone request " "after reading its content") - dct = {} + dct = {} # type: Dict[str, Any] if method is not sentinel: - dct['method'] = method + dct["method"] = method if rel_url is not sentinel: - rel_url = URL(rel_url) - dct['url'] = rel_url - dct['path'] = str(rel_url) + new_url = URL(rel_url) + dct["url"] = new_url + dct["path"] = str(new_url) if headers is not sentinel: - dct['headers'] = CIMultiDict(headers) - dct['raw_headers'] = tuple((k.encode('utf-8'), v.encode('utf-8')) - for k, v in headers.items()) + # a copy semantic + dct["headers"] = CIMultiDictProxy(CIMultiDict(headers)) + dct["raw_headers"] = tuple( + (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() + ) message = self._message._replace(**dct) + kwargs = {} + if scheme is not sentinel: + kwargs["scheme"] = scheme + if host is not sentinel: + kwargs["host"] = host + if remote is not sentinel: + kwargs["remote"] = remote + return self.__class__( message, self._payload, self._protocol, - self._writer, - self._time_service, + self._payload_writer, self._task, - secure_proxy_ssl_header=self._secure_proxy_ssl_header) + self._loop, + client_max_size=self._client_max_size, + state=self._state.copy(), + **kwargs, + ) @property - def task(self): + def task(self) -> "asyncio.Task[None]": return self._task @property - def protocol(self): + def protocol(self) -> "RequestHandler": return self._protocol @property - def transport(self): + def transport(self) -> Optional[asyncio.Transport]: + if self._protocol is None: + return None return self._protocol.transport @property - def writer(self): - return self._writer + def writer(self) -> AbstractStreamWriter: + return self._payload_writer - @property - def message(self): + @reify + def message(self) -> RawRequestMessage: + warnings.warn("Request.message is deprecated", DeprecationWarning, stacklevel=3) return self._message - @property - def rel_url(self): + @reify + def rel_url(self) -> URL: return self._rel_url + @reify + def loop(self) -> asyncio.AbstractEventLoop: + warnings.warn( + "request.loop property is deprecated", DeprecationWarning, stacklevel=2 + ) + return self._loop + # MutableMapping API - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self._state[key] - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: self._state[key] = value - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: del self._state[key] - def __len__(self): + def __len__(self) -> int: return len(self._state) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._state) ######## - @property - def scheme(self): + @reify + def secure(self) -> bool: + """A bool indicating if the request is handled with SSL.""" + return self.scheme == "https" + + @reify + def forwarded(self) -> Tuple[Mapping[str, str], ...]: + """A tuple containing all parsed Forwarded header(s). + + Makes an effort to parse Forwarded headers as specified by RFC 7239: + + - It adds one (immutable) dictionary per Forwarded 'field-value', ie + per proxy. The element corresponds to the data in the Forwarded + field-value added by the first proxy encountered by the client. Each + subsequent item corresponds to those added by later proxies. + - It checks that every value has valid syntax in general as specified + in section 4: either a 'token' or a 'quoted-string'. + - It un-escapes found escape sequences. + - It does NOT validate 'by' and 'for' contents as specified in section + 6. + - It does NOT validate 'host' contents (Host ABNF). + - It does NOT validate 'proto' contents for valid URI scheme names. + + Returns a tuple containing one or more immutable dicts + """ + elems = [] + for field_value in self._message.headers.getall(hdrs.FORWARDED, ()): + length = len(field_value) + pos = 0 + need_separator = False + elem = {} # type: Dict[str, str] + elems.append(types.MappingProxyType(elem)) + while 0 <= pos < length: + match = _FORWARDED_PAIR_RE.match(field_value, pos) + if match is not None: # got a valid forwarded-pair + if need_separator: + # bad syntax here, skip to next comma + pos = field_value.find(",", pos) + else: + name, value, port = match.groups() + if value[0] == '"': + # quoted string: remove quotes and unescape + value = _QUOTED_PAIR_REPLACE_RE.sub(r"\1", value[1:-1]) + if port: + value += port + elem[name.lower()] = value + pos += len(match.group(0)) + need_separator = True + elif field_value[pos] == ",": # next forwarded-element + need_separator = False + elem = {} + elems.append(types.MappingProxyType(elem)) + pos += 1 + elif field_value[pos] == ";": # next forwarded-pair + need_separator = False + pos += 1 + elif field_value[pos] in " \t": + # Allow whitespace even between forwarded-pairs, though + # RFC 7239 doesn't. This simplifies code and is in line + # with Postel's law. + pos += 1 + else: + # bad syntax here, skip to next comma + pos = field_value.find(",", pos) + return tuple(elems) + + @reify + def scheme(self) -> str: """A string representing the scheme of the request. + Hostname is resolved in this order: + + - overridden value by .clone(scheme=new_scheme) call. + - type of connection to peer: HTTPS if socket is SSL, HTTP otherwise. + 'http' or 'https'. """ - return self.url.scheme + if self._transport_sslcontext: + return "https" + else: + return "http" @reify - def _scheme(self): - if self._transport.get_extra_info('sslcontext'): - return 'https' - secure_proxy_ssl_header = self._secure_proxy_ssl_header - if secure_proxy_ssl_header is not None: - header, value = secure_proxy_ssl_header - if self.headers.get(header) == value: - return 'https' - return 'http' - - @property - def method(self): + def method(self) -> str: """Read only property for getting HTTP method. The value is upper-cased str like 'GET', 'POST', 'PUT' etc. """ return self._method - @property - def version(self): + @reify + def version(self) -> HttpVersion: """Read only property for getting HTTP version of request. Returns aiohttp.protocol.HttpVersion instance. @@ -170,21 +384,42 @@ def version(self): return self._version @reify - def host(self): - """Read only property for getting *HOST* header of request. + def host(self) -> str: + """Hostname of the request. - Returns str or None if HTTP request has no HOST header. + Hostname is resolved in this order: + + - overridden value by .clone(host=new_host) call. + - HOST HTTP header + - socket.getfqdn() value """ - return self._message.headers.get(hdrs.HOST) + host = self._message.headers.get(hdrs.HOST) + if host is not None: + return host + else: + return socket.getfqdn() @reify - def url(self): - return URL('{}://{}{}'.format(self._scheme, - self._message.headers.get(hdrs.HOST), - str(self._rel_url))) + def remote(self) -> Optional[str]: + """Remote IP of client initiated HTTP request. - @property - def path(self): + The IP is resolved in this order: + + - overridden value by .clone(remote=new_remote) call. + - peername of opened socket + """ + if isinstance(self._transport_peername, (list, tuple)): + return self._transport_peername[0] + else: + return self._transport_peername + + @reify + def url(self) -> URL: + url = URL.build(scheme=self.scheme, host=self.host) + return url.join(self._rel_url) + + @reify + def path(self) -> str: """The URL including *PATH INFO* without the host or scheme. E.g., ``/app/blog`` @@ -192,142 +427,161 @@ def path(self): return self._rel_url.path @reify - def path_qs(self): + def path_qs(self) -> str: """The URL including PATH_INFO and the query string. E.g, /app/blog?id=10 """ return str(self._rel_url) - @property - def raw_path(self): - """ The URL including raw *PATH INFO* without the host or scheme. + @reify + def raw_path(self) -> str: + """The URL including raw *PATH INFO* without the host or scheme. Warning, the path is unquoted and may contains non valid URL characters E.g., ``/my%2Fpath%7Cwith%21some%25strange%24characters`` """ return self._message.path - @property - def query(self): - """A multidict with all the variables in the query string.""" - return self._rel_url.query - - @property - def GET(self): + @reify + def query(self) -> "MultiDictProxy[str]": """A multidict with all the variables in the query string.""" - warnings.warn("GET property is deprecated, use .query instead", - DeprecationWarning) return self._rel_url.query - @property - def query_string(self): + @reify + def query_string(self) -> str: """The query string in the URL. E.g., id=10 """ return self._rel_url.query_string - @property - def headers(self): + @reify + def headers(self) -> "CIMultiDictProxy[str]": """A case-insensitive multidict proxy with all headers.""" return self._headers - @property - def raw_headers(self): - """A sequence of pars for all headers.""" + @reify + def raw_headers(self) -> RawHeaders: + """A sequence of pairs for all headers.""" return self._message.raw_headers + @staticmethod + def _http_date(_date_str: Optional[str]) -> Optional[datetime.datetime]: + """Process a date string, return a datetime object""" + if _date_str is not None: + timetuple = parsedate(_date_str) + if timetuple is not None: + return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc) + return None + @reify - def if_modified_since(self, _IF_MODIFIED_SINCE=hdrs.IF_MODIFIED_SINCE): + def if_modified_since(self) -> Optional[datetime.datetime]: """The value of If-Modified-Since HTTP header, or None. This header is represented as a `datetime` object. """ - httpdate = self.headers.get(_IF_MODIFIED_SINCE) - if httpdate is not None: - timetuple = parsedate(httpdate) - if timetuple is not None: - return datetime.datetime(*timetuple[:6], - tzinfo=datetime.timezone.utc) - return None + return self._http_date(self.headers.get(hdrs.IF_MODIFIED_SINCE)) - @property - def keep_alive(self): + @reify + def if_unmodified_since(self) -> Optional[datetime.datetime]: + """The value of If-Unmodified-Since HTTP header, or None. + + This header is represented as a `datetime` object. + """ + return self._http_date(self.headers.get(hdrs.IF_UNMODIFIED_SINCE)) + + @reify + def if_range(self) -> Optional[datetime.datetime]: + """The value of If-Range HTTP header, or None. + + This header is represented as a `datetime` object. + """ + return self._http_date(self.headers.get(hdrs.IF_RANGE)) + + @reify + def keep_alive(self) -> bool: """Is keepalive enabled by client?""" return not self._message.should_close - @property - def time_service(self): - """Time service""" - return self._time_service - @reify - def cookies(self): + def cookies(self) -> Mapping[str, str]: """Return request cookies. A read-only dictionary-like object. """ - raw = self.headers.get(hdrs.COOKIE, '') - parsed = SimpleCookie(raw) - return MappingProxyType( - {key: val.value for key, val in parsed.items()}) + raw = self.headers.get(hdrs.COOKIE, "") + parsed = SimpleCookie(raw) # type: SimpleCookie[str] + return MappingProxyType({key: val.value for key, val in parsed.items()}) - @property - def http_range(self, *, _RANGE=hdrs.RANGE): + @reify + def http_range(self) -> slice: """The content of Range HTTP header. Return a slice instance. """ - rng = self._headers.get(_RANGE) + rng = self._headers.get(hdrs.RANGE) start, end = None, None if rng is not None: try: - pattern = r'^bytes=(\d*)-(\d*)$' + pattern = r"^bytes=(\d*)-(\d*)$" start, end = re.findall(pattern, rng)[0] except IndexError: # pattern was not found in header - raise ValueError("range not in acceptible format") + raise ValueError("range not in acceptable format") end = int(end) if end else None start = int(start) if start else None if start is None and end is not None: # end with no start is to return tail of content - end = -end + start = -end + end = None if start is not None and end is not None: # end is inclusive in range header, exclusive for slice end += 1 if start >= end: - raise ValueError('start cannot be after end') + raise ValueError("start cannot be after end") if start is end is None: # No valid range supplied - raise ValueError('No start or end of range specified') + raise ValueError("No start or end of range specified") + return slice(start, end, 1) - @property - def content(self): + @reify + def content(self) -> StreamReader: """Return raw payload stream.""" return self._payload @property - def has_body(self): - """Return True if request has HTTP BODY, False otherwise.""" + def has_body(self) -> bool: + """Return True if request's HTTP BODY can be read, False otherwise.""" + warnings.warn( + "Deprecated, use .can_read_body #2005", DeprecationWarning, stacklevel=2 + ) + return not self._payload.at_eof() + + @property + def can_read_body(self) -> bool: + """Return True if request's HTTP BODY can be read, False otherwise.""" return not self._payload.at_eof() - @asyncio.coroutine - def release(self): + @reify + def body_exists(self) -> bool: + """Return True if request has HTTP BODY, False otherwise.""" + return type(self._payload) is not EmptyStreamReader + + async def release(self) -> None: """Release request. Eat unread part of HTTP BODY if present. """ while not self._payload.at_eof(): - yield from self._payload.readany() + await self._payload.readany() - @asyncio.coroutine - def read(self): + async def read(self) -> bytes: """Read request body if present. Returns bytes object with full request content. @@ -335,36 +589,35 @@ def read(self): if self._read_bytes is None: body = bytearray() while True: - chunk = yield from self._payload.readany() + chunk = await self._payload.readany() body.extend(chunk) - if self._client_max_size \ - and len(body) >= self._client_max_size: - raise HTTPRequestEntityTooLarge + if self._client_max_size: + body_size = len(body) + if body_size >= self._client_max_size: + raise HTTPRequestEntityTooLarge( + max_size=self._client_max_size, actual_size=body_size + ) if not chunk: break self._read_bytes = bytes(body) return self._read_bytes - @asyncio.coroutine - def text(self): + async def text(self) -> str: """Return BODY as text using encoding from .charset.""" - bytes_body = yield from self.read() - encoding = self.charset or 'utf-8' + bytes_body = await self.read() + encoding = self.charset or "utf-8" return bytes_body.decode(encoding) - @asyncio.coroutine - def json(self, *, loads=json.loads): + async def json(self, *, loads: JSONDecoder = DEFAULT_JSON_DECODER) -> Any: """Return BODY as JSON.""" - body = yield from self.text() + body = await self.text() return loads(body) - @asyncio.coroutine - def multipart(self, *, reader=multipart.MultipartReader): + async def multipart(self) -> MultipartReader: """Return async iterator to process BODY as multipart.""" - return reader(self._headers, self._payload) + return MultipartReader(self._headers, self._payload) - @asyncio.coroutine - def post(self): + async def post(self) -> "MultiDictProxy[Union[str, bytes, FileField]]": """Return POST parameters.""" if self._post is not None: return self._post @@ -373,101 +626,199 @@ def post(self): return self._post content_type = self.content_type - if (content_type not in ('', - 'application/x-www-form-urlencoded', - 'multipart/form-data')): + if content_type not in ( + "", + "application/x-www-form-urlencoded", + "multipart/form-data", + ): self._post = MultiDictProxy(MultiDict()) return self._post - out = MultiDict() + out = MultiDict() # type: MultiDict[Union[str, bytes, FileField]] - if content_type == 'multipart/form-data': - multipart = yield from self.multipart() + if content_type == "multipart/form-data": + multipart = await self.multipart() + max_size = self._client_max_size - field = yield from multipart.next() + field = await multipart.next() while field is not None: size = 0 - max_size = self._client_max_size - content_type = field.headers.get(hdrs.CONTENT_TYPE) - - if field.filename: - # store file in temp file - tmp = tempfile.TemporaryFile() - chunk = yield from field.read_chunk(size=2**16) - while chunk: - chunk = field.decode(chunk) - tmp.write(chunk) - size += len(chunk) - if max_size > 0 and size > max_size: - raise ValueError( - 'Maximum request body size exceeded') - chunk = yield from field.read_chunk(size=2**16) - tmp.seek(0) - - ff = FileField(field.name, field.filename, - tmp, content_type, field.headers) - out.add(field.name, ff) + field_ct = field.headers.get(hdrs.CONTENT_TYPE) + + if isinstance(field, BodyPartReader): + assert field.name is not None + + # Note that according to RFC 7578, the Content-Type header + # is optional, even for files, so we can't assume it's + # present. + # https://tools.ietf.org/html/rfc7578#section-4.4 + if field.filename: + # store file in temp file + tmp = tempfile.TemporaryFile() + chunk = await field.read_chunk(size=2 ** 16) + while chunk: + chunk = field.decode(chunk) + tmp.write(chunk) + size += len(chunk) + if 0 < max_size < size: + raise HTTPRequestEntityTooLarge( + max_size=max_size, actual_size=size + ) + chunk = await field.read_chunk(size=2 ** 16) + tmp.seek(0) + + if field_ct is None: + field_ct = "application/octet-stream" + + ff = FileField( + field.name, + field.filename, + cast(io.BufferedReader, tmp), + field_ct, + field.headers, + ) + out.add(field.name, ff) + else: + # deal with ordinary data + value = await field.read(decode=True) + if field_ct is None or field_ct.startswith("text/"): + charset = field.get_charset(default="utf-8") + out.add(field.name, value.decode(charset)) + else: + out.add(field.name, value) + size += len(value) + if 0 < max_size < size: + raise HTTPRequestEntityTooLarge( + max_size=max_size, actual_size=size + ) else: - value = yield from field.read(decode=True) - if content_type is None or \ - content_type.startswith('text/'): - charset = field.get_charset(default='utf-8') - value = value.decode(charset) - out.add(field.name, value) - size += len(value) - if max_size > 0 and size > max_size: - raise ValueError( - 'Maximum request body size exceeded') - - field = yield from multipart.next() + raise ValueError( + "To decode nested multipart you need " "to use custom reader", + ) + + field = await multipart.next() else: - data = yield from self.read() + data = await self.read() if data: - charset = self.charset or 'utf-8' + charset = self.charset or "utf-8" out.extend( parse_qsl( data.rstrip().decode(charset), keep_blank_values=True, - encoding=charset)) + encoding=charset, + ) + ) self._post = MultiDictProxy(out) return self._post - def __repr__(self): - ascii_encodable_path = self.path.encode('ascii', 'backslashreplace') \ - .decode('ascii') - return "<{} {} {} >".format(self.__class__.__name__, - self._method, ascii_encodable_path) + def get_extra_info(self, name: str, default: Any = None) -> Any: + """Extra info from protocol transport""" + protocol = self._protocol + if protocol is None: + return default + + transport = protocol.transport + if transport is None: + return default + + return transport.get_extra_info(name, default) + + def __repr__(self) -> str: + ascii_encodable_path = self.path.encode("ascii", "backslashreplace").decode( + "ascii" + ) + return "<{} {} {} >".format( + self.__class__.__name__, self._method, ascii_encodable_path + ) - @asyncio.coroutine - def _prepare_hook(self, response): + def __eq__(self, other: object) -> bool: + return id(self) == id(other) + + def __bool__(self) -> bool: + return True + + async def _prepare_hook(self, response: StreamResponse) -> None: return - yield # pragma: no cover + + def _cancel(self, exc: BaseException) -> None: + self._payload.set_exception(exc) class Request(BaseRequest): - def __init__(self, *args, **kwargs): + ATTRS = BaseRequest.ATTRS | frozenset(["_match_info"]) + + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) # matchdict, route_name, handler # or information about traversal lookup - self._match_info = None # initialized after route resolving - @property - def match_info(self): - """Result of route resolving.""" - return self._match_info + # initialized after route resolving + self._match_info = None # type: Optional[UrlMappingMatchInfo] + + if DEBUG: + + def __setattr__(self, name: str, val: Any) -> None: + if name not in self.ATTRS: + warnings.warn( + "Setting custom {}.{} attribute " + "is discouraged".format(self.__class__.__name__, name), + DeprecationWarning, + stacklevel=2, + ) + super().__setattr__(name, val) + + def clone( + self, + *, + method: str = sentinel, + rel_url: StrOrURL = sentinel, + headers: LooseHeaders = sentinel, + scheme: str = sentinel, + host: str = sentinel, + remote: str = sentinel, + ) -> "Request": + ret = super().clone( + method=method, + rel_url=rel_url, + headers=headers, + scheme=scheme, + host=host, + remote=remote, + ) + new_ret = cast(Request, ret) + new_ret._match_info = self._match_info + return new_ret @reify - def app(self): + def match_info(self) -> "UrlMappingMatchInfo": + """Result of route resolving.""" + match_info = self._match_info + assert match_info is not None + return match_info + + @property + def app(self) -> "Application": """Application instance.""" - return self._match_info.apps[-1] + match_info = self._match_info + assert match_info is not None + return match_info.current_app - @asyncio.coroutine - def _prepare_hook(self, response): + @property + def config_dict(self) -> ChainMapProxy: + match_info = self._match_info + assert match_info is not None + lst = match_info.apps + app = self.app + idx = lst.index(app) + sublist = list(reversed(lst[: idx + 1])) + return ChainMapProxy(sublist) + + async def _prepare_hook(self, response: StreamResponse) -> None: match_info = self._match_info if match_info is None: return - for app in match_info.apps: - yield from app.on_response_prepare.send(self, response) + for app in match_info._apps: + await app.on_response_prepare.send(self, response) diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 743b11a51b3..f34b00e2d95 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -1,19 +1,52 @@ import asyncio +import collections.abc import datetime import enum import json import math import time import warnings +import zlib +from concurrent.futures import Executor from email.utils import parsedate - -from multidict import CIMultiDict, CIMultiDictProxy +from http.cookies import Morsel, SimpleCookie +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterator, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, + cast, +) + +from multidict import CIMultiDict, istr from . import hdrs, payload -from .helpers import HeadersMixin, SimpleCookie, sentinel +from .abc import AbstractStreamWriter +from .helpers import PY_38, HeadersMixin, rfc822_formatted_time, sentinel from .http import RESPONSES, SERVER_SOFTWARE, HttpVersion10, HttpVersion11 +from .payload import Payload +from .typedefs import JSONEncoder, LooseHeaders + +__all__ = ("ContentCoding", "StreamResponse", "Response", "json_response") + + +if TYPE_CHECKING: # pragma: no cover + from .web_request import BaseRequest -__all__ = ('ContentCoding', 'StreamResponse', 'Response', 'json_response') + BaseClass = MutableMapping[str, Any] +else: + BaseClass = collections.abc.MutableMapping + + +if not PY_38: + # allow samesite to be used in python < 3.8 + # already permitted in python 3.8, see https://bugs.python.org/issue29613 + Morsel._reserved["samesite"] = "SameSite" # type: ignore class ContentCoding(enum.Enum): @@ -21,9 +54,9 @@ class ContentCoding(enum.Enum): # # Additional registered codings are listed at: # https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding - deflate = 'deflate' - gzip = 'gzip' - identity = 'identity' + deflate = "deflate" + gzip = "gzip" + identity = "identity" ############################################################ @@ -31,112 +64,146 @@ class ContentCoding(enum.Enum): ############################################################ -class StreamResponse(HeadersMixin): +class StreamResponse(BaseClass, HeadersMixin): _length_check = True - def __init__(self, *, status=200, reason=None, headers=None): + def __init__( + self, + *, + status: int = 200, + reason: Optional[str] = None, + headers: Optional[LooseHeaders] = None, + ) -> None: self._body = None - self._keep_alive = None + self._keep_alive = None # type: Optional[bool] self._chunked = False self._compression = False - self._compression_force = False - self._cookies = SimpleCookie() + self._compression_force = None # type: Optional[ContentCoding] + self._cookies = SimpleCookie() # type: SimpleCookie[str] - self._req = None - self._payload_writer = None + self._req = None # type: Optional[BaseRequest] + self._payload_writer = None # type: Optional[AbstractStreamWriter] self._eof_sent = False self._body_length = 0 + self._state = {} # type: Dict[str, Any] if headers is not None: - self._headers = CIMultiDict(headers) + self._headers = CIMultiDict(headers) # type: CIMultiDict[str] else: self._headers = CIMultiDict() self.set_status(status, reason) @property - def prepared(self): + def prepared(self) -> bool: return self._payload_writer is not None @property - def task(self): - return getattr(self._req, 'task', None) + def task(self) -> "asyncio.Task[None]": + return getattr(self._req, "task", None) @property - def status(self): + def status(self) -> int: return self._status @property - def chunked(self): + def chunked(self) -> bool: return self._chunked @property - def compression(self): + def compression(self) -> bool: return self._compression @property - def reason(self): + def reason(self) -> str: return self._reason - def set_status(self, status, reason=None, _RESPONSES=RESPONSES): - assert not self.prepared, \ - 'Cannot change the response status code after ' \ - 'the headers have been sent' + def set_status( + self, + status: int, + reason: Optional[str] = None, + _RESPONSES: Mapping[int, Tuple[str, str]] = RESPONSES, + ) -> None: + assert not self.prepared, ( + "Cannot change the response status code after " "the headers have been sent" + ) self._status = int(status) if reason is None: try: reason = _RESPONSES[self._status][0] - except: - reason = '' + except Exception: + reason = "" self._reason = reason @property - def keep_alive(self): + def keep_alive(self) -> Optional[bool]: return self._keep_alive - def force_close(self): + def force_close(self) -> None: self._keep_alive = False @property - def body_length(self): + def body_length(self) -> int: return self._body_length @property - def output_length(self): - warnings.warn('output_length is deprecated', DeprecationWarning) + def output_length(self) -> int: + warnings.warn("output_length is deprecated", DeprecationWarning) + assert self._payload_writer return self._payload_writer.buffer_size - def enable_chunked_encoding(self, chunk_size=None): + def enable_chunked_encoding(self, chunk_size: Optional[int] = None) -> None: """Enables automatic chunked transfer encoding.""" self._chunked = True + + if hdrs.CONTENT_LENGTH in self._headers: + raise RuntimeError( + "You can't enable chunked encoding when " "a content length is set" + ) if chunk_size is not None: - warnings.warn('Chunk size is deprecated #1615', DeprecationWarning) + warnings.warn("Chunk size is deprecated #1615", DeprecationWarning) - def enable_compression(self, force=None): + def enable_compression( + self, force: Optional[Union[bool, ContentCoding]] = None + ) -> None: """Enables response compression encoding.""" # Backwards compatibility for when force was a bool <0.17. if type(force) == bool: force = ContentCoding.deflate if force else ContentCoding.identity + warnings.warn( + "Using boolean for force is deprecated #3318", DeprecationWarning + ) elif force is not None: - assert isinstance(force, ContentCoding), ("force should one of " - "None, bool or " - "ContentEncoding") + assert isinstance(force, ContentCoding), ( + "force should one of " "None, bool or " "ContentEncoding" + ) self._compression = True self._compression_force = force @property - def headers(self): + def headers(self) -> "CIMultiDict[str]": return self._headers @property - def cookies(self): + def cookies(self) -> "SimpleCookie[str]": return self._cookies - def set_cookie(self, name, value, *, expires=None, - domain=None, max_age=None, path='/', - secure=None, httponly=None, version=None): + def set_cookie( + self, + name: str, + value: str, + *, + expires: Optional[str] = None, + domain: Optional[str] = None, + max_age: Optional[Union[int, str]] = None, + path: str = "/", + secure: Optional[bool] = None, + httponly: Optional[bool] = None, + version: Optional[str] = None, + samesite: Optional[str] = None, + ) -> None: """Set or update response cookie. Sets new cookie or updates existent with new value. @@ -144,7 +211,7 @@ def set_cookie(self, name, value, *, expires=None, """ old = self._cookies.get(name) - if old is not None and old.coded_value == '': + if old is not None and old.coded_value == "": # deleted cookie self._cookies.pop(name, None) @@ -152,347 +219,399 @@ def set_cookie(self, name, value, *, expires=None, c = self._cookies[name] if expires is not None: - c['expires'] = expires - elif c.get('expires') == 'Thu, 01 Jan 1970 00:00:00 GMT': - del c['expires'] + c["expires"] = expires + elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT": + del c["expires"] if domain is not None: - c['domain'] = domain + c["domain"] = domain if max_age is not None: - c['max-age'] = max_age - elif 'max-age' in c: - del c['max-age'] + c["max-age"] = str(max_age) + elif "max-age" in c: + del c["max-age"] - c['path'] = path + c["path"] = path if secure is not None: - c['secure'] = secure + c["secure"] = secure if httponly is not None: - c['httponly'] = httponly + c["httponly"] = httponly if version is not None: - c['version'] = version + c["version"] = version + if samesite is not None: + c["samesite"] = samesite - def del_cookie(self, name, *, domain=None, path='/'): + def del_cookie( + self, name: str, *, domain: Optional[str] = None, path: str = "/" + ) -> None: """Delete cookie. Creates new empty expired cookie. """ # TODO: do we need domain/path here? self._cookies.pop(name, None) - self.set_cookie(name, '', max_age=0, - expires="Thu, 01 Jan 1970 00:00:00 GMT", - domain=domain, path=path) + self.set_cookie( + name, + "", + max_age=0, + expires="Thu, 01 Jan 1970 00:00:00 GMT", + domain=domain, + path=path, + ) @property - def content_length(self): + def content_length(self) -> Optional[int]: # Just a placeholder for adding setter return super().content_length @content_length.setter - def content_length(self, value): + def content_length(self, value: Optional[int]) -> None: if value is not None: value = int(value) - # TODO: raise error if chunked enabled + if self._chunked: + raise RuntimeError( + "You can't set content length when " "chunked encoding is enable" + ) self._headers[hdrs.CONTENT_LENGTH] = str(value) else: self._headers.pop(hdrs.CONTENT_LENGTH, None) @property - def content_type(self): + def content_type(self) -> str: # Just a placeholder for adding setter return super().content_type @content_type.setter - def content_type(self, value): + def content_type(self, value: str) -> None: self.content_type # read header values if needed self._content_type = str(value) self._generate_content_type_header() @property - def charset(self): + def charset(self) -> Optional[str]: # Just a placeholder for adding setter return super().charset @charset.setter - def charset(self, value): + def charset(self, value: Optional[str]) -> None: ctype = self.content_type # read header values if needed - if ctype == 'application/octet-stream': - raise RuntimeError("Setting charset for application/octet-stream " - "doesn't make sense, setup content_type first") + if ctype == "application/octet-stream": + raise RuntimeError( + "Setting charset for application/octet-stream " + "doesn't make sense, setup content_type first" + ) + assert self._content_dict is not None if value is None: - self._content_dict.pop('charset', None) + self._content_dict.pop("charset", None) else: - self._content_dict['charset'] = str(value).lower() + self._content_dict["charset"] = str(value).lower() self._generate_content_type_header() @property - def last_modified(self, _LAST_MODIFIED=hdrs.LAST_MODIFIED): + def last_modified(self) -> Optional[datetime.datetime]: """The value of Last-Modified HTTP header, or None. This header is represented as a `datetime` object. """ - httpdate = self.headers.get(_LAST_MODIFIED) + httpdate = self._headers.get(hdrs.LAST_MODIFIED) if httpdate is not None: timetuple = parsedate(httpdate) if timetuple is not None: - return datetime.datetime(*timetuple[:6], - tzinfo=datetime.timezone.utc) + return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc) return None @last_modified.setter - def last_modified(self, value): + def last_modified( + self, value: Optional[Union[int, float, datetime.datetime, str]] + ) -> None: if value is None: - self.headers.pop(hdrs.LAST_MODIFIED, None) + self._headers.pop(hdrs.LAST_MODIFIED, None) elif isinstance(value, (int, float)): - self.headers[hdrs.LAST_MODIFIED] = time.strftime( - "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value))) + self._headers[hdrs.LAST_MODIFIED] = time.strftime( + "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)) + ) elif isinstance(value, datetime.datetime): - self.headers[hdrs.LAST_MODIFIED] = time.strftime( - "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple()) + self._headers[hdrs.LAST_MODIFIED] = time.strftime( + "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple() + ) elif isinstance(value, str): - self.headers[hdrs.LAST_MODIFIED] = value - - @property - def tcp_nodelay(self): - payload_writer = self._payload_writer - assert payload_writer is not None, \ - "Cannot get tcp_nodelay for not prepared response" - return payload_writer.tcp_nodelay - - def set_tcp_nodelay(self, value): - payload_writer = self._payload_writer - assert payload_writer is not None, \ - "Cannot set tcp_nodelay for not prepared response" - payload_writer.set_tcp_nodelay(value) - - @property - def tcp_cork(self): - payload_writer = self._payload_writer - assert payload_writer is not None, \ - "Cannot get tcp_cork for not prepared response" - return payload_writer.tcp_cork - - def set_tcp_cork(self, value): - payload_writer = self._payload_writer - assert payload_writer is not None, \ - "Cannot set tcp_cork for not prepared response" - - payload_writer.set_tcp_cork(value) - - def _generate_content_type_header(self, CONTENT_TYPE=hdrs.CONTENT_TYPE): - params = '; '.join("%s=%s" % i for i in self._content_dict.items()) + self._headers[hdrs.LAST_MODIFIED] = value + + def _generate_content_type_header( + self, CONTENT_TYPE: istr = hdrs.CONTENT_TYPE + ) -> None: + assert self._content_dict is not None + assert self._content_type is not None + params = "; ".join(f"{k}={v}" for k, v in self._content_dict.items()) if params: - ctype = self._content_type + '; ' + params + ctype = self._content_type + "; " + params else: ctype = self._content_type - self.headers[CONTENT_TYPE] = ctype + self._headers[CONTENT_TYPE] = ctype - def _do_start_compression(self, coding): + async def _do_start_compression(self, coding: ContentCoding) -> None: if coding != ContentCoding.identity: - self.headers[hdrs.CONTENT_ENCODING] = coding.value + assert self._payload_writer is not None + self._headers[hdrs.CONTENT_ENCODING] = coding.value self._payload_writer.enable_compression(coding.value) - self._chunked = True + # Compressed payload may have different content length, + # remove the header + self._headers.popall(hdrs.CONTENT_LENGTH, None) - def _start_compression(self, request): + async def _start_compression(self, request: "BaseRequest") -> None: if self._compression_force: - self._do_start_compression(self._compression_force) + await self._do_start_compression(self._compression_force) else: - accept_encoding = request.headers.get( - hdrs.ACCEPT_ENCODING, '').lower() + accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() for coding in ContentCoding: if coding.value in accept_encoding: - self._do_start_compression(coding) + await self._do_start_compression(coding) return - @asyncio.coroutine - def prepare(self, request): + async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]: if self._eof_sent: - return + return None if self._payload_writer is not None: return self._payload_writer - yield from request._prepare_hook(self) - return self._start(request) - - def _start(self, request, - HttpVersion10=HttpVersion10, - HttpVersion11=HttpVersion11, - CONNECTION=hdrs.CONNECTION, - DATE=hdrs.DATE, - SERVER=hdrs.SERVER, - CONTENT_TYPE=hdrs.CONTENT_TYPE, - CONTENT_LENGTH=hdrs.CONTENT_LENGTH, - SET_COOKIE=hdrs.SET_COOKIE, - SERVER_SOFTWARE=SERVER_SOFTWARE, - TRANSFER_ENCODING=hdrs.TRANSFER_ENCODING): + return await self._start(request) + + async def _start(self, request: "BaseRequest") -> AbstractStreamWriter: self._req = request + writer = self._payload_writer = request._payload_writer + + await self._prepare_headers() + await request._prepare_hook(self) + await self._write_headers() + + return writer + async def _prepare_headers(self) -> None: + request = self._req + assert request is not None + writer = self._payload_writer + assert writer is not None keep_alive = self._keep_alive if keep_alive is None: keep_alive = request.keep_alive self._keep_alive = keep_alive version = request.version - writer = self._payload_writer = request._writer headers = self._headers for cookie in self._cookies.values(): - value = cookie.output(header='')[1:] - headers.add(SET_COOKIE, value) + value = cookie.output(header="")[1:] + headers.add(hdrs.SET_COOKIE, value) if self._compression: - self._start_compression(request) + await self._start_compression(request) if self._chunked: if version != HttpVersion11: raise RuntimeError( "Using chunked encoding is forbidden " - "for HTTP/{0.major}.{0.minor}".format(request.version)) + "for HTTP/{0.major}.{0.minor}".format(request.version) + ) writer.enable_chunking() - headers[TRANSFER_ENCODING] = 'chunked' - if CONTENT_LENGTH in headers: - del headers[CONTENT_LENGTH] + headers[hdrs.TRANSFER_ENCODING] = "chunked" + if hdrs.CONTENT_LENGTH in headers: + del headers[hdrs.CONTENT_LENGTH] elif self._length_check: writer.length = self.content_length - if writer.length is None and version >= HttpVersion11: - writer.enable_chunking() - headers[TRANSFER_ENCODING] = 'chunked' - if CONTENT_LENGTH in headers: - del headers[CONTENT_LENGTH] - - headers.setdefault(CONTENT_TYPE, 'application/octet-stream') - headers.setdefault(DATE, request.time_service.strtime()) - headers.setdefault(SERVER, SERVER_SOFTWARE) + if writer.length is None: + if version >= HttpVersion11: + writer.enable_chunking() + headers[hdrs.TRANSFER_ENCODING] = "chunked" + if hdrs.CONTENT_LENGTH in headers: + del headers[hdrs.CONTENT_LENGTH] + else: + keep_alive = False + # HTTP 1.1: https://tools.ietf.org/html/rfc7230#section-3.3.2 + # HTTP 1.0: https://tools.ietf.org/html/rfc1945#section-10.4 + elif version >= HttpVersion11 and self.status in (100, 101, 102, 103, 204): + del headers[hdrs.CONTENT_LENGTH] + + headers.setdefault(hdrs.CONTENT_TYPE, "application/octet-stream") + headers.setdefault(hdrs.DATE, rfc822_formatted_time()) + headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE) # connection header - if CONNECTION not in headers: + if hdrs.CONNECTION not in headers: if keep_alive: if version == HttpVersion10: - headers[CONNECTION] = 'keep-alive' + headers[hdrs.CONNECTION] = "keep-alive" else: if version == HttpVersion11: - headers[CONNECTION] = 'close' + headers[hdrs.CONNECTION] = "close" + async def _write_headers(self) -> None: + request = self._req + assert request is not None + writer = self._payload_writer + assert writer is not None # status line - status_line = 'HTTP/{}.{} {} {}\r\n'.format( - version[0], version[1], self._status, self._reason) - writer.write_headers(status_line, headers) - - return writer + version = request.version + status_line = "HTTP/{}.{} {} {}".format( + version[0], version[1], self._status, self._reason + ) + await writer.write_headers(status_line, self._headers) - def write(self, data): - assert isinstance(data, (bytes, bytearray, memoryview)), \ - "data argument must be byte-ish (%r)" % type(data) + async def write(self, data: bytes) -> None: + assert isinstance( + data, (bytes, bytearray, memoryview) + ), "data argument must be byte-ish (%r)" % type(data) if self._eof_sent: raise RuntimeError("Cannot call write() after write_eof()") if self._payload_writer is None: raise RuntimeError("Cannot call write() before prepare()") - return self._payload_writer.write(data) + await self._payload_writer.write(data) - @asyncio.coroutine - def drain(self): + async def drain(self) -> None: assert not self._eof_sent, "EOF has already been sent" - assert self._payload_writer is not None, \ - "Response has not been started" - yield from self._payload_writer.drain() - - @asyncio.coroutine - def write_eof(self, data=b''): - assert isinstance(data, (bytes, bytearray, memoryview)), \ - "data argument must be byte-ish (%r)" % type(data) + assert self._payload_writer is not None, "Response has not been started" + warnings.warn( + "drain method is deprecated, use await resp.write()", + DeprecationWarning, + stacklevel=2, + ) + await self._payload_writer.drain() + + async def write_eof(self, data: bytes = b"") -> None: + assert isinstance( + data, (bytes, bytearray, memoryview) + ), "data argument must be byte-ish (%r)" % type(data) if self._eof_sent: return - assert self._payload_writer is not None, \ - "Response has not been started" + assert self._payload_writer is not None, "Response has not been started" - yield from self._payload_writer.write_eof(data) + await self._payload_writer.write_eof(data) self._eof_sent = True self._req = None self._body_length = self._payload_writer.output_size self._payload_writer = None - def __repr__(self): + def __repr__(self) -> str: if self._eof_sent: info = "eof" elif self.prepared: - info = "{} {} ".format(self._req.method, self._req.path) + assert self._req is not None + info = f"{self._req.method} {self._req.path} " else: info = "not prepared" - return "<{} {} {}>".format(self.__class__.__name__, - self.reason, info) + return f"<{self.__class__.__name__} {self.reason} {info}>" + def __getitem__(self, key: str) -> Any: + return self._state[key] -class Response(StreamResponse): + def __setitem__(self, key: str, value: Any) -> None: + self._state[key] = value - def __init__(self, *, body=None, status=200, - reason=None, text=None, headers=None, content_type=None, - charset=None): + def __delitem__(self, key: str) -> None: + del self._state[key] + + def __len__(self) -> int: + return len(self._state) + + def __iter__(self) -> Iterator[str]: + return iter(self._state) + + def __hash__(self) -> int: + return hash(id(self)) + + def __eq__(self, other: object) -> bool: + return self is other + + +class Response(StreamResponse): + def __init__( + self, + *, + body: Any = None, + status: int = 200, + reason: Optional[str] = None, + text: Optional[str] = None, + headers: Optional[LooseHeaders] = None, + content_type: Optional[str] = None, + charset: Optional[str] = None, + zlib_executor_size: Optional[int] = None, + zlib_executor: Optional[Executor] = None, + ) -> None: if body is not None and text is not None: raise ValueError("body and text are not allowed together") if headers is None: - headers = CIMultiDict() - elif not isinstance(headers, (CIMultiDict, CIMultiDictProxy)): - headers = CIMultiDict(headers) + real_headers = CIMultiDict() # type: CIMultiDict[str] + elif not isinstance(headers, CIMultiDict): + real_headers = CIMultiDict(headers) + else: + real_headers = headers # = cast('CIMultiDict[str]', headers) - if content_type is not None and ";" in content_type: - raise ValueError("charset must not be in content_type " - "argument") + if content_type is not None and "charset" in content_type: + raise ValueError("charset must not be in content_type " "argument") if text is not None: - if hdrs.CONTENT_TYPE in headers: + if hdrs.CONTENT_TYPE in real_headers: if content_type or charset: - raise ValueError("passing both Content-Type header and " - "content_type or charset params " - "is forbidden") + raise ValueError( + "passing both Content-Type header and " + "content_type or charset params " + "is forbidden" + ) else: # fast path for filling headers if not isinstance(text, str): - raise TypeError("text argument must be str (%r)" % - type(text)) + raise TypeError("text argument must be str (%r)" % type(text)) if content_type is None: - content_type = 'text/plain' + content_type = "text/plain" if charset is None: - charset = 'utf-8' - headers[hdrs.CONTENT_TYPE] = ( - content_type + '; charset=' + charset) + charset = "utf-8" + real_headers[hdrs.CONTENT_TYPE] = content_type + "; charset=" + charset body = text.encode(charset) text = None else: - if hdrs.CONTENT_TYPE in headers: + if hdrs.CONTENT_TYPE in real_headers: if content_type is not None or charset is not None: - raise ValueError("passing both Content-Type header and " - "content_type or charset params " - "is forbidden") + raise ValueError( + "passing both Content-Type header and " + "content_type or charset params " + "is forbidden" + ) else: if content_type is not None: if charset is not None: - content_type += '; charset=' + charset - headers[hdrs.CONTENT_TYPE] = content_type + content_type += "; charset=" + charset + real_headers[hdrs.CONTENT_TYPE] = content_type - super().__init__(status=status, reason=reason, headers=headers) + super().__init__(status=status, reason=reason, headers=real_headers) if text is not None: self.text = text else: self.body = body + self._compressed_body = None # type: Optional[bytes] + self._zlib_executor_size = zlib_executor_size + self._zlib_executor = zlib_executor + @property - def body(self): + def body(self) -> Optional[Union[bytes, Payload]]: return self._body @body.setter - def body(self, body, - CONTENT_TYPE=hdrs.CONTENT_TYPE, - CONTENT_LENGTH=hdrs.CONTENT_LENGTH): + def body( + self, + body: bytes, + CONTENT_TYPE: istr = hdrs.CONTENT_TYPE, + CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH, + ) -> None: if body is None: - self._body = None - self._body_payload = False + self._body = None # type: Optional[bytes] + self._body_payload = False # type: bool elif isinstance(body, (bytes, bytearray)): self._body = body self._body_payload = False @@ -500,18 +619,16 @@ def body(self, body, try: self._body = body = payload.PAYLOAD_REGISTRY.get(body) except payload.LookupError: - raise ValueError('Unsupported body type %r' % type(body)) + raise ValueError("Unsupported body type %r" % type(body)) self._body_payload = True headers = self._headers - # enable chunked encoding if needed + # set content-length header if needed if not self._chunked and CONTENT_LENGTH not in headers: size = body.size - if size is None: - self._chunked = True - elif CONTENT_LENGTH not in headers: + if size is not None: headers[CONTENT_LENGTH] = str(size) # set content-type @@ -524,76 +641,141 @@ def body(self, body, if key not in headers: headers[key] = value + self._compressed_body = None + @property - def text(self): + def text(self) -> Optional[str]: if self._body is None: return None - return self._body.decode(self.charset or 'utf-8') + return self._body.decode(self.charset or "utf-8") @text.setter - def text(self, text): - assert text is None or isinstance(text, str), \ - "text argument must be str (%r)" % type(text) + def text(self, text: str) -> None: + assert text is None or isinstance( + text, str + ), "text argument must be str (%r)" % type(text) - if self.content_type == 'application/octet-stream': - self.content_type = 'text/plain' + if self.content_type == "application/octet-stream": + self.content_type = "text/plain" if self.charset is None: - self.charset = 'utf-8' + self.charset = "utf-8" self._body = text.encode(self.charset) self._body_payload = False + self._compressed_body = None @property - def content_length(self): + def content_length(self) -> Optional[int]: if self._chunked: return None - if hdrs.CONTENT_LENGTH in self.headers: + if hdrs.CONTENT_LENGTH in self._headers: return super().content_length - if self._body is not None: + if self._compressed_body is not None: + # Return length of the compressed body + return len(self._compressed_body) + elif self._body_payload: + # A payload without content length, or a compressed payload + return None + elif self._body is not None: return len(self._body) else: return 0 @content_length.setter - def content_length(self, value): - super().content_length = value + def content_length(self, value: Optional[int]) -> None: + raise RuntimeError("Content length is set automatically") - @asyncio.coroutine - def write_eof(self): - body = self._body + async def write_eof(self, data: bytes = b"") -> None: + if self._eof_sent: + return + if self._compressed_body is None: + body = self._body # type: Optional[Union[bytes, Payload]] + else: + body = self._compressed_body + assert not data, f"data arg is not supported, got {data!r}" + assert self._req is not None + assert self._payload_writer is not None if body is not None: - if (self._req._method == hdrs.METH_HEAD or - self._status in [204, 304]): - yield from super().write_eof() + if self._req._method == hdrs.METH_HEAD or self._status in [204, 304]: + await super().write_eof() elif self._body_payload: - yield from body.write(self._payload_writer) - yield from super().write_eof() + payload = cast(Payload, body) + await payload.write(self._payload_writer) + await super().write_eof() else: - yield from super().write_eof(body) + await super().write_eof(cast(bytes, body)) else: - yield from super().write_eof() + await super().write_eof() - def _start(self, request): + async def _start(self, request: "BaseRequest") -> AbstractStreamWriter: if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers: - if self._body is not None: - self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body)) - else: - self._headers[hdrs.CONTENT_LENGTH] = '0' + if not self._body_payload: + if self._body is not None: + self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body)) + else: + self._headers[hdrs.CONTENT_LENGTH] = "0" - return super()._start(request) + return await super()._start(request) + def _compress_body(self, zlib_mode: int) -> None: + assert zlib_mode > 0 + compressobj = zlib.compressobj(wbits=zlib_mode) + body_in = self._body + assert body_in is not None + self._compressed_body = compressobj.compress(body_in) + compressobj.flush() -def json_response(data=sentinel, *, text=None, body=None, status=200, - reason=None, headers=None, content_type='application/json', - dumps=json.dumps): + async def _do_start_compression(self, coding: ContentCoding) -> None: + if self._body_payload or self._chunked: + return await super()._do_start_compression(coding) + + if coding != ContentCoding.identity: + # Instead of using _payload_writer.enable_compression, + # compress the whole body + zlib_mode = ( + 16 + zlib.MAX_WBITS if coding == ContentCoding.gzip else zlib.MAX_WBITS + ) + body_in = self._body + assert body_in is not None + if ( + self._zlib_executor_size is not None + and len(body_in) > self._zlib_executor_size + ): + await asyncio.get_event_loop().run_in_executor( + self._zlib_executor, self._compress_body, zlib_mode + ) + else: + self._compress_body(zlib_mode) + + body_out = self._compressed_body + assert body_out is not None + + self._headers[hdrs.CONTENT_ENCODING] = coding.value + self._headers[hdrs.CONTENT_LENGTH] = str(len(body_out)) + + +def json_response( + data: Any = sentinel, + *, + text: Optional[str] = None, + body: Optional[bytes] = None, + status: int = 200, + reason: Optional[str] = None, + headers: Optional[LooseHeaders] = None, + content_type: str = "application/json", + dumps: JSONEncoder = json.dumps, +) -> Response: if data is not sentinel: if text or body: - raise ValueError( - "only one of data, text, or body should be specified" - ) + raise ValueError("only one of data, text, or body should be specified") else: text = dumps(data) - return Response(text=text, body=body, status=status, reason=reason, - headers=headers, content_type=content_type) + return Response( + text=text, + body=body, + status=status, + reason=reason, + headers=headers, + content_type=content_type, + ) diff --git a/aiohttp/web_routedef.py b/aiohttp/web_routedef.py new file mode 100644 index 00000000000..188525103de --- /dev/null +++ b/aiohttp/web_routedef.py @@ -0,0 +1,215 @@ +import abc +import os # noqa +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Type, + Union, + overload, +) + +import attr + +from . import hdrs +from .abc import AbstractView +from .typedefs import PathLike + +if TYPE_CHECKING: # pragma: no cover + from .web_request import Request + from .web_response import StreamResponse + from .web_urldispatcher import AbstractRoute, UrlDispatcher +else: + Request = StreamResponse = UrlDispatcher = AbstractRoute = None + + +__all__ = ( + "AbstractRouteDef", + "RouteDef", + "StaticDef", + "RouteTableDef", + "head", + "options", + "get", + "post", + "patch", + "put", + "delete", + "route", + "view", + "static", +) + + +class AbstractRouteDef(abc.ABC): + @abc.abstractmethod + def register(self, router: UrlDispatcher) -> List[AbstractRoute]: + pass # pragma: no cover + + +_SimpleHandler = Callable[[Request], Awaitable[StreamResponse]] +_HandlerType = Union[Type[AbstractView], _SimpleHandler] + + +@attr.s(auto_attribs=True, frozen=True, repr=False, slots=True) +class RouteDef(AbstractRouteDef): + method: str + path: str + handler: _HandlerType + kwargs: Dict[str, Any] + + def __repr__(self) -> str: + info = [] + for name, value in sorted(self.kwargs.items()): + info.append(f", {name}={value!r}") + return " {handler.__name__!r}" "{info}>".format( + method=self.method, path=self.path, handler=self.handler, info="".join(info) + ) + + def register(self, router: UrlDispatcher) -> List[AbstractRoute]: + if self.method in hdrs.METH_ALL: + reg = getattr(router, "add_" + self.method.lower()) + return [reg(self.path, self.handler, **self.kwargs)] + else: + return [ + router.add_route(self.method, self.path, self.handler, **self.kwargs) + ] + + +@attr.s(auto_attribs=True, frozen=True, repr=False, slots=True) +class StaticDef(AbstractRouteDef): + prefix: str + path: PathLike + kwargs: Dict[str, Any] + + def __repr__(self) -> str: + info = [] + for name, value in sorted(self.kwargs.items()): + info.append(f", {name}={value!r}") + return " {path}" "{info}>".format( + prefix=self.prefix, path=self.path, info="".join(info) + ) + + def register(self, router: UrlDispatcher) -> List[AbstractRoute]: + resource = router.add_static(self.prefix, self.path, **self.kwargs) + routes = resource.get_info().get("routes", {}) + return list(routes.values()) + + +def route(method: str, path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return RouteDef(method, path, handler, kwargs) + + +def head(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_HEAD, path, handler, **kwargs) + + +def options(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_OPTIONS, path, handler, **kwargs) + + +def get( + path: str, + handler: _HandlerType, + *, + name: Optional[str] = None, + allow_head: bool = True, + **kwargs: Any, +) -> RouteDef: + return route( + hdrs.METH_GET, path, handler, name=name, allow_head=allow_head, **kwargs + ) + + +def post(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_POST, path, handler, **kwargs) + + +def put(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_PUT, path, handler, **kwargs) + + +def patch(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_PATCH, path, handler, **kwargs) + + +def delete(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_DELETE, path, handler, **kwargs) + + +def view(path: str, handler: Type[AbstractView], **kwargs: Any) -> RouteDef: + return route(hdrs.METH_ANY, path, handler, **kwargs) + + +def static(prefix: str, path: PathLike, **kwargs: Any) -> StaticDef: + return StaticDef(prefix, path, kwargs) + + +_Deco = Callable[[_HandlerType], _HandlerType] + + +class RouteTableDef(Sequence[AbstractRouteDef]): + """Route definition table""" + + def __init__(self) -> None: + self._items = [] # type: List[AbstractRouteDef] + + def __repr__(self) -> str: + return "".format(len(self._items)) + + @overload + def __getitem__(self, index: int) -> AbstractRouteDef: + ... + + @overload + def __getitem__(self, index: slice) -> List[AbstractRouteDef]: + ... + + def __getitem__(self, index): # type: ignore + return self._items[index] + + def __iter__(self) -> Iterator[AbstractRouteDef]: + return iter(self._items) + + def __len__(self) -> int: + return len(self._items) + + def __contains__(self, item: object) -> bool: + return item in self._items + + def route(self, method: str, path: str, **kwargs: Any) -> _Deco: + def inner(handler: _HandlerType) -> _HandlerType: + self._items.append(RouteDef(method, path, handler, kwargs)) + return handler + + return inner + + def head(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_HEAD, path, **kwargs) + + def get(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_GET, path, **kwargs) + + def post(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_POST, path, **kwargs) + + def put(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_PUT, path, **kwargs) + + def patch(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_PATCH, path, **kwargs) + + def delete(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_DELETE, path, **kwargs) + + def view(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_ANY, path, **kwargs) + + def static(self, prefix: str, path: PathLike, **kwargs: Any) -> None: + self._items.append(StaticDef(prefix, path, kwargs)) diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py new file mode 100644 index 00000000000..25ac28a7a89 --- /dev/null +++ b/aiohttp/web_runner.py @@ -0,0 +1,381 @@ +import asyncio +import signal +import socket +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Set + +from yarl import URL + +from .web_app import Application +from .web_server import Server + +try: + from ssl import SSLContext +except ImportError: + SSLContext = object # type: ignore + + +__all__ = ( + "BaseSite", + "TCPSite", + "UnixSite", + "NamedPipeSite", + "SockSite", + "BaseRunner", + "AppRunner", + "ServerRunner", + "GracefulExit", +) + + +class GracefulExit(SystemExit): + code = 1 + + +def _raise_graceful_exit() -> None: + raise GracefulExit() + + +class BaseSite(ABC): + __slots__ = ("_runner", "_shutdown_timeout", "_ssl_context", "_backlog", "_server") + + def __init__( + self, + runner: "BaseRunner", + *, + shutdown_timeout: float = 60.0, + ssl_context: Optional[SSLContext] = None, + backlog: int = 128, + ) -> None: + if runner.server is None: + raise RuntimeError("Call runner.setup() before making a site") + self._runner = runner + self._shutdown_timeout = shutdown_timeout + self._ssl_context = ssl_context + self._backlog = backlog + self._server = None # type: Optional[asyncio.AbstractServer] + + @property + @abstractmethod + def name(self) -> str: + pass # pragma: no cover + + @abstractmethod + async def start(self) -> None: + self._runner._reg_site(self) + + async def stop(self) -> None: + self._runner._check_site(self) + if self._server is None: + self._runner._unreg_site(self) + return # not started yet + self._server.close() + # named pipes do not have wait_closed property + if hasattr(self._server, "wait_closed"): + await self._server.wait_closed() + await self._runner.shutdown() + assert self._runner.server + await self._runner.server.shutdown(self._shutdown_timeout) + self._runner._unreg_site(self) + + +class TCPSite(BaseSite): + __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port") + + def __init__( + self, + runner: "BaseRunner", + host: Optional[str] = None, + port: Optional[int] = None, + *, + shutdown_timeout: float = 60.0, + ssl_context: Optional[SSLContext] = None, + backlog: int = 128, + reuse_address: Optional[bool] = None, + reuse_port: Optional[bool] = None, + ) -> None: + super().__init__( + runner, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + ) + self._host = host + if port is None: + port = 8443 if self._ssl_context else 8080 + self._port = port + self._reuse_address = reuse_address + self._reuse_port = reuse_port + + @property + def name(self) -> str: + scheme = "https" if self._ssl_context else "http" + host = "0.0.0.0" if self._host is None else self._host + return str(URL.build(scheme=scheme, host=host, port=self._port)) + + async def start(self) -> None: + await super().start() + loop = asyncio.get_event_loop() + server = self._runner.server + assert server is not None + self._server = await loop.create_server( + server, + self._host, + self._port, + ssl=self._ssl_context, + backlog=self._backlog, + reuse_address=self._reuse_address, + reuse_port=self._reuse_port, + ) + + +class UnixSite(BaseSite): + __slots__ = ("_path",) + + def __init__( + self, + runner: "BaseRunner", + path: str, + *, + shutdown_timeout: float = 60.0, + ssl_context: Optional[SSLContext] = None, + backlog: int = 128, + ) -> None: + super().__init__( + runner, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + ) + self._path = path + + @property + def name(self) -> str: + scheme = "https" if self._ssl_context else "http" + return f"{scheme}://unix:{self._path}:" + + async def start(self) -> None: + await super().start() + loop = asyncio.get_event_loop() + server = self._runner.server + assert server is not None + self._server = await loop.create_unix_server( + server, self._path, ssl=self._ssl_context, backlog=self._backlog + ) + + +class NamedPipeSite(BaseSite): + __slots__ = ("_path",) + + def __init__( + self, runner: "BaseRunner", path: str, *, shutdown_timeout: float = 60.0 + ) -> None: + loop = asyncio.get_event_loop() + if not isinstance(loop, asyncio.ProactorEventLoop): # type: ignore + raise RuntimeError( + "Named Pipes only available in proactor" "loop under windows" + ) + super().__init__(runner, shutdown_timeout=shutdown_timeout) + self._path = path + + @property + def name(self) -> str: + return self._path + + async def start(self) -> None: + await super().start() + loop = asyncio.get_event_loop() + server = self._runner.server + assert server is not None + _server = await loop.start_serving_pipe(server, self._path) # type: ignore + self._server = _server[0] + + +class SockSite(BaseSite): + __slots__ = ("_sock", "_name") + + def __init__( + self, + runner: "BaseRunner", + sock: socket.socket, + *, + shutdown_timeout: float = 60.0, + ssl_context: Optional[SSLContext] = None, + backlog: int = 128, + ) -> None: + super().__init__( + runner, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + ) + self._sock = sock + scheme = "https" if self._ssl_context else "http" + if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: + name = f"{scheme}://unix:{sock.getsockname()}:" + else: + host, port = sock.getsockname()[:2] + name = str(URL.build(scheme=scheme, host=host, port=port)) + self._name = name + + @property + def name(self) -> str: + return self._name + + async def start(self) -> None: + await super().start() + loop = asyncio.get_event_loop() + server = self._runner.server + assert server is not None + self._server = await loop.create_server( + server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog + ) + + +class BaseRunner(ABC): + __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites") + + def __init__(self, *, handle_signals: bool = False, **kwargs: Any) -> None: + self._handle_signals = handle_signals + self._kwargs = kwargs + self._server = None # type: Optional[Server] + self._sites = [] # type: List[BaseSite] + + @property + def server(self) -> Optional[Server]: + return self._server + + @property + def addresses(self) -> List[Any]: + ret = [] # type: List[Any] + for site in self._sites: + server = site._server + if server is not None: + sockets = server.sockets + if sockets is not None: + for sock in sockets: + ret.append(sock.getsockname()) + return ret + + @property + def sites(self) -> Set[BaseSite]: + return set(self._sites) + + async def setup(self) -> None: + loop = asyncio.get_event_loop() + + if self._handle_signals: + try: + loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit) + loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit) + except NotImplementedError: # pragma: no cover + # add_signal_handler is not implemented on Windows + pass + + self._server = await self._make_server() + + @abstractmethod + async def shutdown(self) -> None: + pass # pragma: no cover + + async def cleanup(self) -> None: + loop = asyncio.get_event_loop() + + if self._server is None: + # no started yet, do nothing + return + + # The loop over sites is intentional, an exception on gather() + # leaves self._sites in unpredictable state. + # The loop guaranties that a site is either deleted on success or + # still present on failure + for site in list(self._sites): + await site.stop() + await self._cleanup_server() + self._server = None + if self._handle_signals: + try: + loop.remove_signal_handler(signal.SIGINT) + loop.remove_signal_handler(signal.SIGTERM) + except NotImplementedError: # pragma: no cover + # remove_signal_handler is not implemented on Windows + pass + + @abstractmethod + async def _make_server(self) -> Server: + pass # pragma: no cover + + @abstractmethod + async def _cleanup_server(self) -> None: + pass # pragma: no cover + + def _reg_site(self, site: BaseSite) -> None: + if site in self._sites: + raise RuntimeError(f"Site {site} is already registered in runner {self}") + self._sites.append(site) + + def _check_site(self, site: BaseSite) -> None: + if site not in self._sites: + raise RuntimeError(f"Site {site} is not registered in runner {self}") + + def _unreg_site(self, site: BaseSite) -> None: + if site not in self._sites: + raise RuntimeError(f"Site {site} is not registered in runner {self}") + self._sites.remove(site) + + +class ServerRunner(BaseRunner): + """Low-level web server runner""" + + __slots__ = ("_web_server",) + + def __init__( + self, web_server: Server, *, handle_signals: bool = False, **kwargs: Any + ) -> None: + super().__init__(handle_signals=handle_signals, **kwargs) + self._web_server = web_server + + async def shutdown(self) -> None: + pass + + async def _make_server(self) -> Server: + return self._web_server + + async def _cleanup_server(self) -> None: + pass + + +class AppRunner(BaseRunner): + """Web Application runner""" + + __slots__ = ("_app",) + + def __init__( + self, app: Application, *, handle_signals: bool = False, **kwargs: Any + ) -> None: + super().__init__(handle_signals=handle_signals, **kwargs) + if not isinstance(app, Application): + raise TypeError( + "The first argument should be web.Application " + "instance, got {!r}".format(app) + ) + self._app = app + + @property + def app(self) -> Application: + return self._app + + async def shutdown(self) -> None: + await self._app.shutdown() + + async def _make_server(self) -> Server: + loop = asyncio.get_event_loop() + self._app._set_loop(loop) + self._app.on_startup.freeze() + await self._app.startup() + self._app.freeze() + + return self._app._make_handler(loop=loop, **self._kwargs) + + async def _cleanup_server(self) -> None: + await self._app.cleanup() diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 8e240e2e0c4..5657ed9c800 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -1,50 +1,62 @@ """Low level HTTP server.""" import asyncio +from typing import Any, Awaitable, Callable, Dict, List, Optional # noqa -from .helpers import TimeService -from .web_protocol import RequestHandler +from .abc import AbstractStreamWriter +from .helpers import get_running_loop +from .http_parser import RawRequestMessage +from .streams import StreamReader +from .web_protocol import RequestHandler, _RequestFactory, _RequestHandler from .web_request import BaseRequest -__all__ = ('Server',) +__all__ = ("Server",) class Server: - - def __init__(self, handler, *, request_factory=None, loop=None, **kwargs): - if loop is None: - loop = asyncio.get_event_loop() - self._loop = loop - self._connections = {} + def __init__( + self, + handler: _RequestHandler, + *, + request_factory: Optional[_RequestFactory] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + **kwargs: Any + ) -> None: + self._loop = get_running_loop(loop) + self._connections = {} # type: Dict[RequestHandler, asyncio.Transport] self._kwargs = kwargs - self.time_service = TimeService(self._loop) self.requests_count = 0 self.request_handler = handler self.request_factory = request_factory or self._make_request @property - def connections(self): + def connections(self) -> List[RequestHandler]: return list(self._connections.keys()) - def connection_made(self, handler, transport): + def connection_made( + self, handler: RequestHandler, transport: asyncio.Transport + ) -> None: self._connections[handler] = transport - def connection_lost(self, handler, exc=None): + def connection_lost( + self, handler: RequestHandler, exc: Optional[BaseException] = None + ) -> None: if handler in self._connections: del self._connections[handler] - def _make_request(self, message, payload, protocol, writer, task): - return BaseRequest( - message, payload, protocol, writer, - protocol.time_service, task) - - @asyncio.coroutine - def shutdown(self, timeout=None): + def _make_request( + self, + message: RawRequestMessage, + payload: StreamReader, + protocol: RequestHandler, + writer: AbstractStreamWriter, + task: "asyncio.Task[None]", + ) -> BaseRequest: + return BaseRequest(message, payload, protocol, writer, task, self._loop) + + async def shutdown(self, timeout: Optional[float] = None) -> None: coros = [conn.shutdown(timeout) for conn in self._connections] - yield from asyncio.gather(*coros, loop=self._loop) + await asyncio.gather(*coros) self._connections.clear() - self.time_service.close() - - finish_connections = shutdown - def __call__(self): + def __call__(self) -> RequestHandler: return RequestHandler(self, loop=self._loop, **self._kwargs) diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index dddece7ee79..2afd72f13db 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -1,69 +1,139 @@ import abc import asyncio -import collections +import base64 +import hashlib import inspect import keyword import os import re import warnings -from collections.abc import Container, Iterable, Sized +from contextlib import contextmanager +from functools import wraps from pathlib import Path from types import MappingProxyType - -# do not use yarl.quote directly, -# use `URL(path).raw_path` instead of `quote(path)` -# Escaping of the URLs need to be consitent with the escaping done by yarl -from yarl import URL, unquote - -from . import hdrs, helpers +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Container, + Dict, + Generator, + Iterable, + Iterator, + List, + Mapping, + Optional, + Pattern, + Set, + Sized, + Tuple, + Type, + Union, + cast, +) + +from typing_extensions import TypedDict +from yarl import URL, __version__ as yarl_version # type: ignore + +from . import hdrs from .abc import AbstractMatchInfo, AbstractRouter, AbstractView +from .helpers import DEBUG from .http import HttpVersion11 -from .web_exceptions import (HTTPExpectationFailed, HTTPForbidden, - HTTPMethodNotAllowed, HTTPNotFound) +from .typedefs import PathLike +from .web_exceptions import ( + HTTPException, + HTTPExpectationFailed, + HTTPForbidden, + HTTPMethodNotAllowed, + HTTPNotFound, +) from .web_fileresponse import FileResponse +from .web_request import Request from .web_response import Response, StreamResponse +from .web_routedef import AbstractRouteDef + +__all__ = ( + "UrlDispatcher", + "UrlMappingMatchInfo", + "AbstractResource", + "Resource", + "PlainResource", + "DynamicResource", + "AbstractRoute", + "ResourceRoute", + "StaticResource", + "View", +) + + +if TYPE_CHECKING: # pragma: no cover + from .web_app import Application -__all__ = ('UrlDispatcher', 'UrlMappingMatchInfo', - 'AbstractResource', 'Resource', 'PlainResource', 'DynamicResource', - 'AbstractRoute', 'ResourceRoute', - 'StaticResource', 'View') + BaseDict = Dict[str, str] +else: + BaseDict = dict + +YARL_VERSION = tuple(map(int, yarl_version.split(".")[:2])) HTTP_METHOD_RE = re.compile(r"^[0-9A-Za-z!#\$%&'\*\+\-\.\^_`\|~]+$") +ROUTE_RE = re.compile(r"(\{[_a-zA-Z][^{}]*(?:\{[^{}]*\}[^{}]*)*\})") +PATH_SEP = re.escape("/") + + +_WebHandler = Callable[[Request], Awaitable[StreamResponse]] +_ExpectHandler = Callable[[Request], Awaitable[None]] +_Resolve = Tuple[Optional[AbstractMatchInfo], Set[str]] + + +class _InfoDict(TypedDict, total=False): + path: str + formatter: str + pattern: Pattern[str] -class AbstractResource(Sized, Iterable): + directory: Path + prefix: str + routes: Mapping[str, "AbstractRoute"] - def __init__(self, *, name=None): + app: "Application" + + domain: str + + rule: "AbstractRuleMatching" + + http_exception: HTTPException + + +class AbstractResource(Sized, Iterable["AbstractRoute"]): + def __init__(self, *, name: Optional[str] = None) -> None: self._name = name @property - def name(self): + def name(self) -> Optional[str]: return self._name - @abc.abstractmethod # pragma: no branch - def url(self, **kwargs): - """Construct url for resource with additional params. + @property + @abc.abstractmethod + def canonical(self) -> str: + """Exposes the resource's canonical path. - Deprecated, use url_for() instead. + For example '/foo/bar/{name}' """ - warnings.warn(".url(...) is deprecated, use .url_for instead", - DeprecationWarning, - stacklevel=3) @abc.abstractmethod # pragma: no branch - def url_for(self, **kwargs): + def url_for(self, **kwargs: str) -> URL: """Construct url for resource with additional params.""" - @asyncio.coroutine @abc.abstractmethod # pragma: no branch - def resolve(self, request): + async def resolve(self, request: Request) -> _Resolve: """Resolve resource Return (UrlMappingMatchInfo, allowed_methods) pair.""" @abc.abstractmethod - def add_prefix(self, prefix): + def add_prefix(self, prefix: str) -> None: """Add a prefix to processed URLs. Required for subapplications support. @@ -71,45 +141,60 @@ def add_prefix(self, prefix): """ @abc.abstractmethod - def get_info(self): + def get_info(self) -> _InfoDict: """Return a dict with additional info useful for introspection""" - def freeze(self): + def freeze(self) -> None: pass + @abc.abstractmethod + def raw_match(self, path: str) -> bool: + """Perform a raw match against path""" + class AbstractRoute(abc.ABC): - - def __init__(self, method, handler, *, - expect_handler=None, - resource=None): + def __init__( + self, + method: str, + handler: Union[_WebHandler, Type[AbstractView]], + *, + expect_handler: Optional[_ExpectHandler] = None, + resource: Optional[AbstractResource] = None, + ) -> None: if expect_handler is None: - expect_handler = _defaultExpectHandler + expect_handler = _default_expect_handler - assert asyncio.iscoroutinefunction(expect_handler), \ - 'Coroutine is expected, got {!r}'.format(expect_handler) + assert asyncio.iscoroutinefunction( + expect_handler + ), f"Coroutine is expected, got {expect_handler!r}" method = method.upper() if not HTTP_METHOD_RE.match(method): - raise ValueError("{} is not allowed HTTP method".format(method)) + raise ValueError(f"{method} is not allowed HTTP method") assert callable(handler), handler if asyncio.iscoroutinefunction(handler): pass elif inspect.isgeneratorfunction(handler): - warnings.warn("Bare generators are deprecated, " - "use @coroutine wrapper", DeprecationWarning) - elif (isinstance(handler, type) and - issubclass(handler, AbstractView)): + warnings.warn( + "Bare generators are deprecated, " "use @coroutine wrapper", + DeprecationWarning, + ) + elif isinstance(handler, type) and issubclass(handler, AbstractView): pass else: - @asyncio.coroutine - def handler_wrapper(*args, **kwargs): - result = old_handler(*args, **kwargs) + warnings.warn( + "Bare functions are deprecated, " "use async ones", DeprecationWarning + ) + + @wraps(handler) + async def handler_wrapper(request: Request) -> StreamResponse: + result = old_handler(request) if asyncio.iscoroutine(result): - result = yield from result - return result + return await result + return result # type: ignore + old_handler = handler handler = handler_wrapper @@ -119,325 +204,433 @@ def handler_wrapper(*args, **kwargs): self._resource = resource @property - def method(self): + def method(self) -> str: return self._method @property - def handler(self): + def handler(self) -> _WebHandler: return self._handler @property @abc.abstractmethod - def name(self): + def name(self) -> Optional[str]: """Optional route's name, always equals to resource's name.""" @property - def resource(self): + def resource(self) -> Optional[AbstractResource]: return self._resource @abc.abstractmethod - def get_info(self): + def get_info(self) -> _InfoDict: """Return a dict with additional info useful for introspection""" @abc.abstractmethod # pragma: no branch - def url_for(self, *args, **kwargs): + def url_for(self, *args: str, **kwargs: str) -> URL: """Construct url for route with additional params.""" - @abc.abstractmethod # pragma: no branch - def url(self, **kwargs): - """Construct url for resource with additional params. - - Deprecated, use url_for() instead. - - """ - warnings.warn(".url(...) is deprecated, use .url_for instead", - DeprecationWarning, - stacklevel=3) + async def handle_expect_header(self, request: Request) -> None: + await self._expect_handler(request) - @asyncio.coroutine - def handle_expect_header(self, request): - return (yield from self._expect_handler(request)) - -class UrlMappingMatchInfo(dict, AbstractMatchInfo): - - def __init__(self, match_dict, route): +class UrlMappingMatchInfo(BaseDict, AbstractMatchInfo): + def __init__(self, match_dict: Dict[str, str], route: AbstractRoute): super().__init__(match_dict) self._route = route - self._apps = () + self._apps = [] # type: List[Application] + self._current_app = None # type: Optional[Application] self._frozen = False @property - def handler(self): + def handler(self) -> _WebHandler: return self._route.handler @property - def route(self): + def route(self) -> AbstractRoute: return self._route @property - def expect_handler(self): + def expect_handler(self) -> _ExpectHandler: return self._route.handle_expect_header @property - def http_exception(self): + def http_exception(self) -> Optional[HTTPException]: return None - def get_info(self): + def get_info(self) -> _InfoDict: # type: ignore return self._route.get_info() @property - def apps(self): - return self._apps + def apps(self) -> Tuple["Application", ...]: + return tuple(self._apps) - def add_app(self, app): + def add_app(self, app: "Application") -> None: if self._frozen: raise RuntimeError("Cannot change apps stack after .freeze() call") - self._apps = (app,) + self._apps + if self._current_app is None: + self._current_app = app + self._apps.insert(0, app) + + @property + def current_app(self) -> "Application": + app = self._current_app + assert app is not None + return app + + @contextmanager + def set_current_app(self, app: "Application") -> Generator[None, None, None]: + if DEBUG: # pragma: no cover + if app not in self._apps: + raise RuntimeError( + "Expected one of the following apps {!r}, got {!r}".format( + self._apps, app + ) + ) + prev = self._current_app + self._current_app = app + try: + yield + finally: + self._current_app = prev - def freeze(self): + def freeze(self) -> None: self._frozen = True - def __repr__(self): - return "".format(super().__repr__(), self._route) + def __repr__(self) -> str: + return f"" class MatchInfoError(UrlMappingMatchInfo): - - def __init__(self, http_exception): + def __init__(self, http_exception: HTTPException) -> None: self._exception = http_exception super().__init__({}, SystemRoute(self._exception)) @property - def http_exception(self): + def http_exception(self) -> HTTPException: return self._exception - def __repr__(self): - return "".format(self._exception.status, - self._exception.reason) + def __repr__(self) -> str: + return "".format( + self._exception.status, self._exception.reason + ) -@asyncio.coroutine -def _defaultExpectHandler(request): +async def _default_expect_handler(request: Request) -> None: """Default handler for Expect header. Just send "100 Continue" to client. raise HTTPExpectationFailed if value of header is not "100-continue" """ - expect = request.headers.get(hdrs.EXPECT) + expect = request.headers.get(hdrs.EXPECT, "") if request.version == HttpVersion11: if expect.lower() == "100-continue": - request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n", drain=False) - yield from request.writer.drain() + await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") else: raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect) class Resource(AbstractResource): - - def __init__(self, *, name=None): + def __init__(self, *, name: Optional[str] = None) -> None: super().__init__(name=name) - self._routes = [] - - def add_route(self, method, handler, *, - expect_handler=None): - - for route in self._routes: - if route.method == method or route.method == hdrs.METH_ANY: - raise RuntimeError("Added route will never be executed, " - "method {route.method} is " - "already registered".format(route=route)) - - route = ResourceRoute(method, handler, self, - expect_handler=expect_handler) - self.register_route(route) - return route - - def register_route(self, route): - assert isinstance(route, ResourceRoute), \ - 'Instance of Route class is required, got {!r}'.format(route) + self._routes = [] # type: List[ResourceRoute] + + def add_route( + self, + method: str, + handler: Union[Type[AbstractView], _WebHandler], + *, + expect_handler: Optional[_ExpectHandler] = None, + ) -> "ResourceRoute": + + for route_obj in self._routes: + if route_obj.method == method or route_obj.method == hdrs.METH_ANY: + raise RuntimeError( + "Added route will never be executed, " + "method {route.method} is already " + "registered".format(route=route_obj) + ) + + route_obj = ResourceRoute(method, handler, self, expect_handler=expect_handler) + self.register_route(route_obj) + return route_obj + + def register_route(self, route: "ResourceRoute") -> None: + assert isinstance( + route, ResourceRoute + ), f"Instance of Route class is required, got {route!r}" self._routes.append(route) - @asyncio.coroutine - def resolve(self, request): - allowed_methods = set() + async def resolve(self, request: Request) -> _Resolve: + allowed_methods = set() # type: Set[str] match_dict = self._match(request.rel_url.raw_path) if match_dict is None: return None, allowed_methods - for route in self._routes: - route_method = route.method + for route_obj in self._routes: + route_method = route_obj.method allowed_methods.add(route_method) - if (route_method == request._method or - route_method == hdrs.METH_ANY): - return UrlMappingMatchInfo(match_dict, route), allowed_methods + if route_method == request.method or route_method == hdrs.METH_ANY: + return (UrlMappingMatchInfo(match_dict, route_obj), allowed_methods) else: return None, allowed_methods - yield # pragma: no cover + @abc.abstractmethod + def _match(self, path: str) -> Optional[Dict[str, str]]: + pass # pragma: no cover - def __len__(self): + def __len__(self) -> int: return len(self._routes) - def __iter__(self): + def __iter__(self) -> Iterator[AbstractRoute]: return iter(self._routes) + # TODO: implement all abstract methods -class PlainResource(Resource): - def __init__(self, path, *, name=None): +class PlainResource(Resource): + def __init__(self, path: str, *, name: Optional[str] = None) -> None: super().__init__(name=name) - assert not path or path.startswith('/') + assert not path or path.startswith("/") self._path = path - def freeze(self): + @property + def canonical(self) -> str: + return self._path + + def freeze(self) -> None: if not self._path: - self._path = '/' + self._path = "/" - def add_prefix(self, prefix): - assert prefix.startswith('/') - assert not prefix.endswith('/') + def add_prefix(self, prefix: str) -> None: + assert prefix.startswith("/") + assert not prefix.endswith("/") assert len(prefix) > 1 self._path = prefix + self._path - def _match(self, path): + def _match(self, path: str) -> Optional[Dict[str, str]]: # string comparison is about 10 times faster than regexp matching if self._path == path: return {} else: return None - def get_info(self): - return {'path': self._path} + def raw_match(self, path: str) -> bool: + return self._path == path - def url(self, *, query=None): - super().url() - return str(self.url_for().with_query(query)) + def get_info(self) -> _InfoDict: + return {"path": self._path} - def url_for(self): - return URL(self._path) + def url_for(self) -> URL: # type: ignore + return URL.build(path=self._path, encoded=True) - def __repr__(self): + def __repr__(self) -> str: name = "'" + self.name + "' " if self.name is not None else "" - return "" class DynamicResource(Resource): - def __init__(self, pattern, formatter, *, name=None): + DYN = re.compile(r"\{(?P[_a-zA-Z][_a-zA-Z0-9]*)\}") + DYN_WITH_RE = re.compile(r"\{(?P[_a-zA-Z][_a-zA-Z0-9]*):(?P.+)\}") + GOOD = r"[^{}/]+" + + def __init__(self, path: str, *, name: Optional[str] = None) -> None: super().__init__(name=name) - assert pattern.pattern.startswith('\\/') - assert formatter.startswith('/') - self._pattern = pattern + pattern = "" + formatter = "" + for part in ROUTE_RE.split(path): + match = self.DYN.fullmatch(part) + if match: + pattern += "(?P<{}>{})".format(match.group("var"), self.GOOD) + formatter += "{" + match.group("var") + "}" + continue + + match = self.DYN_WITH_RE.fullmatch(part) + if match: + pattern += "(?P<{var}>{re})".format(**match.groupdict()) + formatter += "{" + match.group("var") + "}" + continue + + if "{" in part or "}" in part: + raise ValueError(f"Invalid path '{path}'['{part}']") + + part = _requote_path(part) + formatter += part + pattern += re.escape(part) + + try: + compiled = re.compile(pattern) + except re.error as exc: + raise ValueError(f"Bad pattern '{pattern}': {exc}") from None + assert compiled.pattern.startswith(PATH_SEP) + assert formatter.startswith("/") + self._pattern = compiled self._formatter = formatter - def add_prefix(self, prefix): - assert prefix.startswith('/') - assert not prefix.endswith('/') + @property + def canonical(self) -> str: + return self._formatter + + def add_prefix(self, prefix: str) -> None: + assert prefix.startswith("/") + assert not prefix.endswith("/") assert len(prefix) > 1 - self._pattern = re.compile(re.escape(prefix)+self._pattern.pattern) + self._pattern = re.compile(re.escape(prefix) + self._pattern.pattern) self._formatter = prefix + self._formatter - def _match(self, path): + def _match(self, path: str) -> Optional[Dict[str, str]]: match = self._pattern.fullmatch(path) if match is None: return None else: - return {key: unquote(value) for key, value in - match.groupdict().items()} + return { + key: _unquote_path(value) for key, value in match.groupdict().items() + } - def get_info(self): - return {'formatter': self._formatter, - 'pattern': self._pattern} + def raw_match(self, path: str) -> bool: + return self._formatter == path - def url_for(self, **parts): - url = self._formatter.format_map(parts) - return URL(url) + def get_info(self) -> _InfoDict: + return {"formatter": self._formatter, "pattern": self._pattern} - def url(self, *, parts, query=None): - super().url(**parts) - return str(self.url_for(**parts).with_query(query)) + def url_for(self, **parts: str) -> URL: + url = self._formatter.format_map({k: _quote_path(v) for k, v in parts.items()}) + return URL.build(path=url, encoded=True) - def __repr__(self): + def __repr__(self) -> str: name = "'" + self.name + "' " if self.name is not None else "" - return ("".format( + name=name, formatter=self._formatter + ) class PrefixResource(AbstractResource): - - def __init__(self, prefix, *, name=None): - assert not prefix or prefix.startswith('/'), prefix - assert prefix in ('', '/') or not prefix.endswith('/'), prefix + def __init__(self, prefix: str, *, name: Optional[str] = None) -> None: + assert not prefix or prefix.startswith("/"), prefix + assert prefix in ("", "/") or not prefix.endswith("/"), prefix super().__init__(name=name) - self._prefix = URL(prefix).raw_path + self._prefix = _requote_path(prefix) - def add_prefix(self, prefix): - assert prefix.startswith('/') - assert not prefix.endswith('/') + @property + def canonical(self) -> str: + return self._prefix + + def add_prefix(self, prefix: str) -> None: + assert prefix.startswith("/") + assert not prefix.endswith("/") assert len(prefix) > 1 self._prefix = prefix + self._prefix + def raw_match(self, prefix: str) -> bool: + return False + + # TODO: impl missing abstract methods -class StaticResource(PrefixResource): - def __init__(self, prefix, directory, *, name=None, - expect_handler=None, chunk_size=256*1024, - response_factory=StreamResponse, - show_index=False, follow_symlinks=False): +class StaticResource(PrefixResource): + VERSION_KEY = "v" + + def __init__( + self, + prefix: str, + directory: PathLike, + *, + name: Optional[str] = None, + expect_handler: Optional[_ExpectHandler] = None, + chunk_size: int = 256 * 1024, + show_index: bool = False, + follow_symlinks: bool = False, + append_version: bool = False, + ) -> None: super().__init__(prefix, name=name) try: directory = Path(directory) - if str(directory).startswith('~'): + if str(directory).startswith("~"): directory = Path(os.path.expanduser(str(directory))) directory = directory.resolve() if not directory.is_dir(): - raise ValueError('Not a directory') + raise ValueError("Not a directory") except (FileNotFoundError, ValueError) as error: - raise ValueError( - "No directory exists at '{}'".format(directory)) from error + raise ValueError(f"No directory exists at '{directory}'") from error self._directory = directory self._show_index = show_index self._chunk_size = chunk_size self._follow_symlinks = follow_symlinks self._expect_handler = expect_handler - - self._routes = {'GET': ResourceRoute('GET', self._handle, self, - expect_handler=expect_handler), - - 'HEAD': ResourceRoute('HEAD', self._handle, self, - expect_handler=expect_handler)} - - def url(self, *, filename, query=None): - return str(self.url_for(filename=filename).with_query(query)) - - def url_for(self, *, filename): + self._append_version = append_version + + self._routes = { + "GET": ResourceRoute( + "GET", self._handle, self, expect_handler=expect_handler + ), + "HEAD": ResourceRoute( + "HEAD", self._handle, self, expect_handler=expect_handler + ), + } + + def url_for( # type: ignore + self, + *, + filename: Union[str, Path], + append_version: Optional[bool] = None, + ) -> URL: + if append_version is None: + append_version = self._append_version if isinstance(filename, Path): filename = str(filename) - while filename.startswith('/'): - filename = filename[1:] - filename = '/' + filename - url = self._prefix + URL(filename).raw_path - return URL(url) - - def get_info(self): - return {'directory': self._directory, - 'prefix': self._prefix} - - def set_options_route(self, handler): - if 'OPTIONS' in self._routes: - raise RuntimeError('OPTIONS route was set already') - self._routes['OPTIONS'] = ResourceRoute( - 'OPTIONS', handler, self, - expect_handler=self._expect_handler) - - @asyncio.coroutine - def resolve(self, request): + filename = filename.lstrip("/") + + url = URL.build(path=self._prefix, encoded=True) + # filename is not encoded + if YARL_VERSION < (1, 6): + url = url / filename.replace("%", "%25") + else: + url = url / filename + + if append_version: + try: + filepath = self._directory.joinpath(filename).resolve() + if not self._follow_symlinks: + filepath.relative_to(self._directory) + except (ValueError, FileNotFoundError): + # ValueError for case when path point to symlink + # with follow_symlinks is False + return url # relatively safe + if filepath.is_file(): + # TODO cache file content + # with file watcher for cache invalidation + with filepath.open("rb") as f: + file_bytes = f.read() + h = self._get_file_hash(file_bytes) + url = url.with_query({self.VERSION_KEY: h}) + return url + return url + + @staticmethod + def _get_file_hash(byte_array: bytes) -> str: + m = hashlib.sha256() # todo sha256 can be configurable param + m.update(byte_array) + b64 = base64.urlsafe_b64encode(m.digest()) + return b64.decode("ascii") + + def get_info(self) -> _InfoDict: + return { + "directory": self._directory, + "prefix": self._prefix, + "routes": self._routes, + } + + def set_options_route(self, handler: _WebHandler) -> None: + if "OPTIONS" in self._routes: + raise RuntimeError("OPTIONS route was set already") + self._routes["OPTIONS"] = ResourceRoute( + "OPTIONS", handler, self, expect_handler=self._expect_handler + ) + + async def resolve(self, request: Request) -> _Resolve: path = request.rel_url.raw_path - method = request._method + method = request.method allowed_methods = set(self._routes) if not path.startswith(self._prefix): return None, set() @@ -445,451 +638,596 @@ def resolve(self, request): if method not in allowed_methods: return None, allowed_methods - match_dict = {'filename': unquote(path[len(self._prefix)+1:])} - return (UrlMappingMatchInfo(match_dict, self._routes[method]), - allowed_methods) - yield # pragma: no cover + match_dict = {"filename": _unquote_path(path[len(self._prefix) + 1 :])} + return (UrlMappingMatchInfo(match_dict, self._routes[method]), allowed_methods) - def __len__(self): + def __len__(self) -> int: return len(self._routes) - def __iter__(self): + def __iter__(self) -> Iterator[AbstractRoute]: return iter(self._routes.values()) - @asyncio.coroutine - def _handle(self, request): - filename = unquote(request.match_info['filename']) + async def _handle(self, request: Request) -> StreamResponse: + rel_url = request.match_info["filename"] try: + filename = Path(rel_url) + if filename.anchor: + # rel_url is an absolute name like + # /static/\\machine_name\c$ or /static/D:\path + # where the static dir is totally different + raise HTTPForbidden() filepath = self._directory.joinpath(filename).resolve() if not self._follow_symlinks: filepath.relative_to(self._directory) except (ValueError, FileNotFoundError) as error: # relatively safe raise HTTPNotFound() from error + except HTTPForbidden: + raise except Exception as error: # perm error or other kind! request.app.logger.exception(error) raise HTTPNotFound() from error - # on opening a dir, load it's contents if allowed + # on opening a dir, load its contents if allowed if filepath.is_dir(): if self._show_index: try: - ret = Response(text=self._directory_as_html(filepath), - content_type="text/html") + return Response( + text=self._directory_as_html(filepath), content_type="text/html" + ) except PermissionError: raise HTTPForbidden() else: raise HTTPForbidden() elif filepath.is_file(): - ret = FileResponse(filepath, chunk_size=self._chunk_size) + return FileResponse(filepath, chunk_size=self._chunk_size) else: raise HTTPNotFound - return ret + def _directory_as_html(self, filepath: Path) -> str: + # returns directory's index as html - def _directory_as_html(self, filepath): - "returns directory's index as html" # sanity check assert filepath.is_dir() - posix_dir_len = len(self._directory.as_posix()) - - # remove the beginning of posix path, so it would be relative - # to our added static path - relative_path_to_dir = filepath.as_posix()[posix_dir_len:] - index_of = "Index of /{}".format(relative_path_to_dir) - head = "\n{}\n ".format(index_of) - h1 = "

{}

".format(index_of) + relative_path_to_dir = filepath.relative_to(self._directory).as_posix() + index_of = f"Index of /{relative_path_to_dir}" + h1 = f"

{index_of}

" index_list = [] dir_index = filepath.iterdir() for _file in sorted(dir_index): # show file url as relative to static path - file_url = _file.as_posix()[posix_dir_len:] + rel_path = _file.relative_to(self._directory).as_posix() + file_url = self._prefix + "/" + rel_path # if file is a directory, add '/' to the end of the name if _file.is_dir(): - file_name = "{}/".format(_file.name) + file_name = f"{_file.name}/" else: file_name = _file.name index_list.append( - '
  • {name}
  • '.format(url=file_url, - name=file_name) + '
  • {name}
  • '.format( + url=file_url, name=file_name + ) ) - ul = "
      \n{}\n
    ".format('\n'.join(index_list)) - body = "\n{}\n{}\n".format(h1, ul) + ul = "
      \n{}\n
    ".format("\n".join(index_list)) + body = f"\n{h1}\n{ul}\n" - html = "\n{}\n{}\n".format(head, body) + head_str = f"\n{index_of}\n " + html = f"\n{head_str}\n{body}\n" return html - def __repr__(self): + def __repr__(self) -> str: name = "'" + self.name + "'" if self.name is not None else "" - return " {directory!r}".format( - name=name, path=self._prefix, directory=self._directory) + return " {directory!r}>".format( + name=name, path=self._prefix, directory=self._directory + ) class PrefixedSubAppResource(PrefixResource): - - def __init__(self, prefix, app): + def __init__(self, prefix: str, app: "Application") -> None: super().__init__(prefix) self._app = app for resource in app.router.resources(): resource.add_prefix(prefix) - def add_prefix(self, prefix): + def add_prefix(self, prefix: str) -> None: super().add_prefix(prefix) for resource in self._app.router.resources(): resource.add_prefix(prefix) - def url_for(self, *args, **kwargs): - raise RuntimeError(".url_for() is not supported " - "by sub-application root") - - def url(self, **kwargs): - """Construct url for route with additional params.""" - raise RuntimeError(".url() is not supported " - "by sub-application root") + def url_for(self, *args: str, **kwargs: str) -> URL: + raise RuntimeError(".url_for() is not supported " "by sub-application root") - def get_info(self): - return {'app': self._app, - 'prefix': self._prefix} + def get_info(self) -> _InfoDict: + return {"app": self._app, "prefix": self._prefix} - @asyncio.coroutine - def resolve(self, request): - if not request.url.raw_path.startswith(self._prefix): + async def resolve(self, request: Request) -> _Resolve: + if ( + not request.url.raw_path.startswith(self._prefix + "/") + and request.url.raw_path != self._prefix + ): return None, set() - match_info = yield from self._app.router.resolve(request) + match_info = await self._app.router.resolve(request) match_info.add_app(self._app) if isinstance(match_info.http_exception, HTTPMethodNotAllowed): methods = match_info.http_exception.allowed_methods else: methods = set() - return (match_info, methods) + return match_info, methods - def __len__(self): + def __len__(self) -> int: return len(self._app.router.routes()) - def __iter__(self): + def __iter__(self) -> Iterator[AbstractRoute]: return iter(self._app.router.routes()) - def __repr__(self): + def __repr__(self) -> str: return " {app!r}>".format( - prefix=self._prefix, app=self._app) + prefix=self._prefix, app=self._app + ) + + +class AbstractRuleMatching(abc.ABC): + @abc.abstractmethod # pragma: no branch + async def match(self, request: Request) -> bool: + """Return bool if the request satisfies the criteria""" + + @abc.abstractmethod # pragma: no branch + def get_info(self) -> _InfoDict: + """Return a dict with additional info useful for introspection""" + + @property + @abc.abstractmethod # pragma: no branch + def canonical(self) -> str: + """Return a str""" + + +class Domain(AbstractRuleMatching): + re_part = re.compile(r"(?!-)[a-z\d-]{1,63}(? None: + super().__init__() + self._domain = self.validation(domain) + + @property + def canonical(self) -> str: + return self._domain + + def validation(self, domain: str) -> str: + if not isinstance(domain, str): + raise TypeError("Domain must be str") + domain = domain.rstrip(".").lower() + if not domain: + raise ValueError("Domain cannot be empty") + elif "://" in domain: + raise ValueError("Scheme not supported") + url = URL("http://" + domain) + assert url.raw_host is not None + if not all(self.re_part.fullmatch(x) for x in url.raw_host.split(".")): + raise ValueError("Domain not valid") + if url.port == 80: + return url.raw_host + return f"{url.raw_host}:{url.port}" + + async def match(self, request: Request) -> bool: + host = request.headers.get(hdrs.HOST) + if not host: + return False + return self.match_domain(host) + + def match_domain(self, host: str) -> bool: + return host.lower() == self._domain + + def get_info(self) -> _InfoDict: + return {"domain": self._domain} + + +class MaskDomain(Domain): + re_part = re.compile(r"(?!-)[a-z\d\*-]{1,63}(? None: + super().__init__(domain) + mask = self._domain.replace(".", r"\.").replace("*", ".*") + self._mask = re.compile(mask) + + @property + def canonical(self) -> str: + return self._mask.pattern + + def match_domain(self, host: str) -> bool: + return self._mask.fullmatch(host) is not None + + +class MatchedSubAppResource(PrefixedSubAppResource): + def __init__(self, rule: AbstractRuleMatching, app: "Application") -> None: + AbstractResource.__init__(self) + self._prefix = "" + self._app = app + self._rule = rule + + @property + def canonical(self) -> str: + return self._rule.canonical + + def get_info(self) -> _InfoDict: + return {"app": self._app, "rule": self._rule} + + async def resolve(self, request: Request) -> _Resolve: + if not await self._rule.match(request): + return None, set() + match_info = await self._app.router.resolve(request) + match_info.add_app(self._app) + if isinstance(match_info.http_exception, HTTPMethodNotAllowed): + methods = match_info.http_exception.allowed_methods + else: + methods = set() + return match_info, methods + + def __repr__(self) -> str: + return " {app!r}>" "".format(app=self._app) class ResourceRoute(AbstractRoute): """A route with resource""" - def __init__(self, method, handler, resource, *, - expect_handler=None): - super().__init__(method, handler, expect_handler=expect_handler, - resource=resource) - - def __repr__(self): + def __init__( + self, + method: str, + handler: Union[_WebHandler, Type[AbstractView]], + resource: AbstractResource, + *, + expect_handler: Optional[_ExpectHandler] = None, + ) -> None: + super().__init__( + method, handler, expect_handler=expect_handler, resource=resource + ) + + def __repr__(self) -> str: return " {handler!r}".format( - method=self.method, resource=self._resource, - handler=self.handler) + method=self.method, resource=self._resource, handler=self.handler + ) @property - def name(self): + def name(self) -> Optional[str]: + if self._resource is None: + return None return self._resource.name - def url_for(self, *args, **kwargs): + def url_for(self, *args: str, **kwargs: str) -> URL: """Construct url for route with additional params.""" + assert self._resource is not None return self._resource.url_for(*args, **kwargs) - def url(self, **kwargs): - """Construct url for route with additional params.""" - super().url(**kwargs) - return self._resource.url(**kwargs) - - def get_info(self): + def get_info(self) -> _InfoDict: + assert self._resource is not None return self._resource.get_info() class SystemRoute(AbstractRoute): - - def __init__(self, http_exception): - super().__init__(hdrs.METH_ANY, self._handler) + def __init__(self, http_exception: HTTPException) -> None: + super().__init__(hdrs.METH_ANY, self._handle) self._http_exception = http_exception - def url_for(self, *args, **kwargs): + def url_for(self, *args: str, **kwargs: str) -> URL: raise RuntimeError(".url_for() is not allowed for SystemRoute") - def url(self, *args, **kwargs): - raise RuntimeError(".url() is not allowed for SystemRoute") - @property - def name(self): + def name(self) -> Optional[str]: return None - def get_info(self): - return {'http_exception': self._http_exception} + def get_info(self) -> _InfoDict: + return {"http_exception": self._http_exception} - @asyncio.coroutine - def _handler(self, request): + async def _handle(self, request: Request) -> StreamResponse: raise self._http_exception @property - def status(self): + def status(self) -> int: return self._http_exception.status @property - def reason(self): + def reason(self) -> str: return self._http_exception.reason - def __repr__(self): + def __repr__(self) -> str: return "".format(self=self) class View(AbstractView): - - @asyncio.coroutine - def __iter__(self): - if self.request._method not in hdrs.METH_ALL: + async def _iter(self) -> StreamResponse: + if self.request.method not in hdrs.METH_ALL: self._raise_allowed_methods() - method = getattr(self, self.request._method.lower(), None) + method = getattr(self, self.request.method.lower(), None) if method is None: self._raise_allowed_methods() - resp = yield from method() + resp = await method() return resp - if helpers.PY_35: - def __await__(self): - return (yield from self.__iter__()) + def __await__(self) -> Generator[Any, None, StreamResponse]: + return self._iter().__await__() - def _raise_allowed_methods(self): - allowed_methods = { - m for m in hdrs.METH_ALL if hasattr(self, m.lower())} + def _raise_allowed_methods(self) -> None: + allowed_methods = {m for m in hdrs.METH_ALL if hasattr(self, m.lower())} raise HTTPMethodNotAllowed(self.request.method, allowed_methods) -class ResourcesView(Sized, Iterable, Container): - - def __init__(self, resources): +class ResourcesView(Sized, Iterable[AbstractResource], Container[AbstractResource]): + def __init__(self, resources: List[AbstractResource]) -> None: self._resources = resources - def __len__(self): + def __len__(self) -> int: return len(self._resources) - def __iter__(self): + def __iter__(self) -> Iterator[AbstractResource]: yield from self._resources - def __contains__(self, resource): + def __contains__(self, resource: object) -> bool: return resource in self._resources -class RoutesView(Sized, Iterable, Container): - - def __init__(self, resources): - self._routes = [] +class RoutesView(Sized, Iterable[AbstractRoute], Container[AbstractRoute]): + def __init__(self, resources: List[AbstractResource]): + self._routes = [] # type: List[AbstractRoute] for resource in resources: for route in resource: self._routes.append(route) - def __len__(self): + def __len__(self) -> int: return len(self._routes) - def __iter__(self): + def __iter__(self) -> Iterator[AbstractRoute]: yield from self._routes - def __contains__(self, route): + def __contains__(self, route: object) -> bool: return route in self._routes -class UrlDispatcher(AbstractRouter, collections.abc.Mapping): +class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]): - DYN = re.compile(r'\{(?P[_a-zA-Z][_a-zA-Z0-9]*)\}') - DYN_WITH_RE = re.compile( - r'\{(?P[_a-zA-Z][_a-zA-Z0-9]*):(?P.+)\}') - GOOD = r'[^{}/]+' - ROUTE_RE = re.compile(r'(\{[_a-zA-Z][^{}]*(?:\{[^{}]*\}[^{}]*)*\})') - NAME_SPLIT_RE = re.compile(r'[.:-]') + NAME_SPLIT_RE = re.compile(r"[.:-]") - def __init__(self): + def __init__(self) -> None: super().__init__() - self._resources = [] - self._named_resources = {} + self._resources = [] # type: List[AbstractResource] + self._named_resources = {} # type: Dict[str, AbstractResource] - @asyncio.coroutine - def resolve(self, request): - method = request._method - allowed_methods = set() + async def resolve(self, request: Request) -> AbstractMatchInfo: + method = request.method + allowed_methods = set() # type: Set[str] for resource in self._resources: - match_dict, allowed = yield from resource.resolve(request) + match_dict, allowed = await resource.resolve(request) if match_dict is not None: return match_dict else: allowed_methods |= allowed else: if allowed_methods: - return MatchInfoError(HTTPMethodNotAllowed(method, - allowed_methods)) + return MatchInfoError(HTTPMethodNotAllowed(method, allowed_methods)) else: return MatchInfoError(HTTPNotFound()) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._named_resources) - def __len__(self): + def __len__(self) -> int: return len(self._named_resources) - def __contains__(self, name): - return name in self._named_resources + def __contains__(self, resource: object) -> bool: + return resource in self._named_resources - def __getitem__(self, name): + def __getitem__(self, name: str) -> AbstractResource: return self._named_resources[name] - def resources(self): + def resources(self) -> ResourcesView: return ResourcesView(self._resources) - def routes(self): + def routes(self) -> RoutesView: return RoutesView(self._resources) - def named_resources(self): + def named_resources(self) -> Mapping[str, AbstractResource]: return MappingProxyType(self._named_resources) - def register_resource(self, resource): - assert isinstance(resource, AbstractResource), \ - 'Instance of AbstractResource class is required, got {!r}'.format( - resource) + def register_resource(self, resource: AbstractResource) -> None: + assert isinstance( + resource, AbstractResource + ), f"Instance of AbstractResource class is required, got {resource!r}" if self.frozen: - raise RuntimeError( - "Cannot register a resource into frozen router.") + raise RuntimeError("Cannot register a resource into frozen router.") name = resource.name if name is not None: parts = self.NAME_SPLIT_RE.split(name) for part in parts: - if not part.isidentifier() or keyword.iskeyword(part): - raise ValueError('Incorrect route name {!r}, ' - 'the name should be a sequence of ' - 'python identifiers separated ' - 'by dash, dot or column'.format(name)) + if keyword.iskeyword(part): + raise ValueError( + f"Incorrect route name {name!r}, " + "python keywords cannot be used " + "for route name" + ) + if not part.isidentifier(): + raise ValueError( + "Incorrect route name {!r}, " + "the name should be a sequence of " + "python identifiers separated " + "by dash, dot or column".format(name) + ) if name in self._named_resources: - raise ValueError('Duplicate {!r}, ' - 'already handled by {!r}' - .format(name, self._named_resources[name])) + raise ValueError( + "Duplicate {!r}, " + "already handled by {!r}".format(name, self._named_resources[name]) + ) self._named_resources[name] = resource self._resources.append(resource) - def add_resource(self, path, *, name=None): - if path and not path.startswith('/'): + def add_resource(self, path: str, *, name: Optional[str] = None) -> Resource: + if path and not path.startswith("/"): raise ValueError("path should be started with / or be empty") - if not ('{' in path or '}' in path or self.ROUTE_RE.search(path)): - url = URL(path) - resource = PlainResource(url.raw_path, name=name) + # Reuse last added resource if path and name are the same + if self._resources: + resource = self._resources[-1] + if resource.name == name and resource.raw_match(path): + return cast(Resource, resource) + if not ("{" in path or "}" in path or ROUTE_RE.search(path)): + resource = PlainResource(_requote_path(path), name=name) self.register_resource(resource) return resource - - pattern = '' - formatter = '' - for part in self.ROUTE_RE.split(path): - match = self.DYN.fullmatch(part) - if match: - pattern += '(?P<{}>{})'.format(match.group('var'), self.GOOD) - formatter += '{' + match.group('var') + '}' - continue - - match = self.DYN_WITH_RE.fullmatch(part) - if match: - pattern += '(?P<{var}>{re})'.format(**match.groupdict()) - formatter += '{' + match.group('var') + '}' - continue - - if '{' in part or '}' in part: - raise ValueError("Invalid path '{}'['{}']".format(path, part)) - - path = URL(part).raw_path - formatter += path - pattern += re.escape(path) - - try: - compiled = re.compile(pattern) - except re.error as exc: - raise ValueError( - "Bad pattern '{}': {}".format(pattern, exc)) from None - resource = DynamicResource(compiled, formatter, name=name) + resource = DynamicResource(path, name=name) self.register_resource(resource) return resource - def add_route(self, method, path, handler, - *, name=None, expect_handler=None): + def add_route( + self, + method: str, + path: str, + handler: Union[_WebHandler, Type[AbstractView]], + *, + name: Optional[str] = None, + expect_handler: Optional[_ExpectHandler] = None, + ) -> AbstractRoute: resource = self.add_resource(path, name=name) - return resource.add_route(method, handler, - expect_handler=expect_handler) - - def add_static(self, prefix, path, *, name=None, expect_handler=None, - chunk_size=256*1024, response_factory=StreamResponse, - show_index=False, follow_symlinks=False): + return resource.add_route(method, handler, expect_handler=expect_handler) + + def add_static( + self, + prefix: str, + path: PathLike, + *, + name: Optional[str] = None, + expect_handler: Optional[_ExpectHandler] = None, + chunk_size: int = 256 * 1024, + show_index: bool = False, + follow_symlinks: bool = False, + append_version: bool = False, + ) -> AbstractResource: """Add static files view. prefix - url prefix path - folder with files """ - # TODO: implement via PrefixedResource, not ResourceAdapter - assert prefix.startswith('/') - if prefix.endswith('/'): + assert prefix.startswith("/") + if prefix.endswith("/"): prefix = prefix[:-1] - resource = StaticResource(prefix, path, - name=name, - expect_handler=expect_handler, - chunk_size=chunk_size, - response_factory=response_factory, - show_index=show_index, - follow_symlinks=follow_symlinks) + resource = StaticResource( + prefix, + path, + name=name, + expect_handler=expect_handler, + chunk_size=chunk_size, + show_index=show_index, + follow_symlinks=follow_symlinks, + append_version=append_version, + ) self.register_resource(resource) return resource - def add_head(self, *args, **kwargs): + def add_head(self, path: str, handler: _WebHandler, **kwargs: Any) -> AbstractRoute: """ Shortcut for add_route with method HEAD """ - return self.add_route(hdrs.METH_HEAD, *args, **kwargs) + return self.add_route(hdrs.METH_HEAD, path, handler, **kwargs) - def add_get(self, *args, name=None, allow_head=True, **kwargs): + def add_options( + self, path: str, handler: _WebHandler, **kwargs: Any + ) -> AbstractRoute: + """ + Shortcut for add_route with method OPTIONS + """ + return self.add_route(hdrs.METH_OPTIONS, path, handler, **kwargs) + + def add_get( + self, + path: str, + handler: _WebHandler, + *, + name: Optional[str] = None, + allow_head: bool = True, + **kwargs: Any, + ) -> AbstractRoute: """ Shortcut for add_route with method GET, if allow_head is true another route is added allowing head requests to the same endpoint """ + resource = self.add_resource(path, name=name) if allow_head: - # it name is not None append -head to avoid it conflicting with - # the GET route below - head_name = name and '{}-head'.format(name) - self.add_route(hdrs.METH_HEAD, *args, name=head_name, **kwargs) - return self.add_route(hdrs.METH_GET, *args, name=name, **kwargs) + resource.add_route(hdrs.METH_HEAD, handler, **kwargs) + return resource.add_route(hdrs.METH_GET, handler, **kwargs) - def add_post(self, *args, **kwargs): + def add_post(self, path: str, handler: _WebHandler, **kwargs: Any) -> AbstractRoute: """ Shortcut for add_route with method POST """ - return self.add_route(hdrs.METH_POST, *args, **kwargs) + return self.add_route(hdrs.METH_POST, path, handler, **kwargs) - def add_put(self, *args, **kwargs): + def add_put(self, path: str, handler: _WebHandler, **kwargs: Any) -> AbstractRoute: """ Shortcut for add_route with method PUT """ - return self.add_route(hdrs.METH_PUT, *args, **kwargs) + return self.add_route(hdrs.METH_PUT, path, handler, **kwargs) - def add_patch(self, *args, **kwargs): + def add_patch( + self, path: str, handler: _WebHandler, **kwargs: Any + ) -> AbstractRoute: """ Shortcut for add_route with method PATCH """ - return self.add_route(hdrs.METH_PATCH, *args, **kwargs) + return self.add_route(hdrs.METH_PATCH, path, handler, **kwargs) - def add_delete(self, *args, **kwargs): + def add_delete( + self, path: str, handler: _WebHandler, **kwargs: Any + ) -> AbstractRoute: """ Shortcut for add_route with method DELETE """ - return self.add_route(hdrs.METH_DELETE, *args, **kwargs) + return self.add_route(hdrs.METH_DELETE, path, handler, **kwargs) - def freeze(self): + def add_view( + self, path: str, handler: Type[AbstractView], **kwargs: Any + ) -> AbstractRoute: + """ + Shortcut for add_route with ANY methods for a class-based view + """ + return self.add_route(hdrs.METH_ANY, path, handler, **kwargs) + + def freeze(self) -> None: super().freeze() for resource in self._resources: resource.freeze() + + def add_routes(self, routes: Iterable[AbstractRouteDef]) -> List[AbstractRoute]: + """Append routes to route table. + + Parameter should be a sequence of RouteDef objects. + + Returns a list of registered AbstractRoute instances. + """ + registered_routes = [] + for route_def in routes: + registered_routes.extend(route_def.register(self)) + return registered_routes + + +def _quote_path(value: str) -> str: + if YARL_VERSION < (1, 6): + value = value.replace("%", "%25") + return URL.build(path=value, encoded=False).raw_path + + +def _unquote_path(value: str) -> str: + return URL.build(path=value, encoded=True).path + + +def _requote_path(value: str) -> str: + # Quote non-ascii characters and other characters which must be quoted, + # but preserve existing %-sequences. + result = _quote_path(value) + if "%" in value: + result = result.replace("%25", "%") + return result diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index fde267898cf..da7ce6df1c5 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -1,49 +1,82 @@ import asyncio +import base64 +import binascii +import hashlib import json -from collections import namedtuple +from typing import Any, Iterable, Optional, Tuple + +import async_timeout +import attr +from multidict import CIMultiDict from . import hdrs -from .helpers import PY_35, PY_352, Timeout, call_later, create_future -from .http import (WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, HttpProcessingError, - WebSocketError, WebSocketReader, - WSMessage, WSMsgType, do_handshake) -from .streams import FlowControlDataQueue -from .web_exceptions import (HTTPBadRequest, HTTPInternalServerError, - HTTPMethodNotAllowed) +from .abc import AbstractStreamWriter +from .helpers import call_later, set_result +from .http import ( + WS_CLOSED_MESSAGE, + WS_CLOSING_MESSAGE, + WS_KEY, + WebSocketError, + WebSocketReader, + WebSocketWriter, + WSMessage, + WSMsgType as WSMsgType, + ws_ext_gen, + ws_ext_parse, +) +from .log import ws_logger +from .streams import EofStream, FlowControlDataQueue +from .typedefs import JSONDecoder, JSONEncoder +from .web_exceptions import HTTPBadRequest, HTTPException +from .web_request import BaseRequest from .web_response import StreamResponse -__all__ = ('WebSocketResponse', 'WebSocketReady', 'MsgType', 'WSMsgType',) +__all__ = ( + "WebSocketResponse", + "WebSocketReady", + "WSMsgType", +) THRESHOLD_CONNLOST_ACCESS = 5 -# deprecated since 1.0 -MsgType = WSMsgType - +@attr.s(auto_attribs=True, frozen=True, slots=True) +class WebSocketReady: + ok: bool + protocol: Optional[str] -class WebSocketReady(namedtuple('WebSocketReady', 'ok protocol')): - def __bool__(self): + def __bool__(self) -> bool: return self.ok class WebSocketResponse(StreamResponse): - def __init__(self, *, - timeout=10.0, receive_timeout=None, - autoclose=True, autoping=True, heartbeat=None, - protocols=()): + _length_check = False + + def __init__( + self, + *, + timeout: float = 10.0, + receive_timeout: Optional[float] = None, + autoclose: bool = True, + autoping: bool = True, + heartbeat: Optional[float] = None, + protocols: Iterable[str] = (), + compress: bool = True, + max_msg_size: int = 4 * 1024 * 1024, + ) -> None: super().__init__(status=101) self._protocols = protocols - self._ws_protocol = None - self._writer = None - self._reader = None + self._ws_protocol = None # type: Optional[str] + self._writer = None # type: Optional[WebSocketWriter] + self._reader = None # type: Optional[FlowControlDataQueue[WSMessage]] self._closed = False self._closing = False self._conn_lost = 0 - self._close_code = None - self._loop = None - self._waiting = None - self._exception = None + self._close_code = None # type: Optional[int] + self._loop = None # type: Optional[asyncio.AbstractEventLoop] + self._waiting = None # type: Optional[asyncio.Future[bool]] + self._exception = None # type: Optional[BaseException] self._timeout = timeout self._receive_timeout = receive_timeout self._autoclose = autoclose @@ -51,10 +84,12 @@ def __init__(self, *, self._heartbeat = heartbeat self._heartbeat_cb = None if heartbeat is not None: - self._pong_heartbeat = heartbeat/2.0 + self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb = None + self._compress = compress + self._max_msg_size = max_msg_size - def _cancel_heartbeat(self): + def _cancel_heartbeat(self) -> None: if self._pong_response_cb is not None: self._pong_response_cb.cancel() self._pong_response_cb = None @@ -63,157 +98,253 @@ def _cancel_heartbeat(self): self._heartbeat_cb.cancel() self._heartbeat_cb = None - def _reset_heartbeat(self): + def _reset_heartbeat(self) -> None: self._cancel_heartbeat() if self._heartbeat is not None: self._heartbeat_cb = call_later( - self._send_heartbeat, self._heartbeat, self._loop) + self._send_heartbeat, self._heartbeat, self._loop + ) - def _send_heartbeat(self): + def _send_heartbeat(self) -> None: if self._heartbeat is not None and not self._closed: - self.ping() + # fire-and-forget a task is not perfect but maybe ok for + # sending ping. Otherwise we need a long-living heartbeat + # task in the class. + self._loop.create_task(self._writer.ping()) # type: ignore if self._pong_response_cb is not None: self._pong_response_cb.cancel() self._pong_response_cb = call_later( - self._pong_not_received, self._pong_heartbeat, self._loop) + self._pong_not_received, self._pong_heartbeat, self._loop + ) - def _pong_not_received(self): + def _pong_not_received(self) -> None: if self._req is not None and self._req.transport is not None: self._closed = True self._close_code = 1006 self._exception = asyncio.TimeoutError() self._req.transport.close() - @asyncio.coroutine - def prepare(self, request): + async def prepare(self, request: BaseRequest) -> AbstractStreamWriter: # make pre-check to don't hide it by do_handshake() exceptions if self._payload_writer is not None: return self._payload_writer protocol, writer = self._pre_start(request) - payload_writer = yield from super().prepare(request) + payload_writer = await super().prepare(request) + assert payload_writer is not None self._post_start(request, protocol, writer) - yield from payload_writer.drain() + await payload_writer.drain() return payload_writer - def _pre_start(self, request): - self._loop = request.app.loop - + def _handshake( + self, request: BaseRequest + ) -> Tuple["CIMultiDict[str]", str, bool, bool]: + headers = request.headers + if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip(): + raise HTTPBadRequest( + text=( + "No WebSocket UPGRADE hdr: {}\n Can " + '"Upgrade" only to "WebSocket".' + ).format(headers.get(hdrs.UPGRADE)) + ) + + if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower(): + raise HTTPBadRequest( + text="No CONNECTION upgrade hdr: {}".format( + headers.get(hdrs.CONNECTION) + ) + ) + + # find common sub-protocol between client and server + protocol = None + if hdrs.SEC_WEBSOCKET_PROTOCOL in headers: + req_protocols = [ + str(proto.strip()) + for proto in headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",") + ] + + for proto in req_protocols: + if proto in self._protocols: + protocol = proto + break + else: + # No overlap found: Return no protocol as per spec + ws_logger.warning( + "Client protocols %r don’t overlap server-known ones %r", + req_protocols, + self._protocols, + ) + + # check supported version + version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, "") + if version not in ("13", "8", "7"): + raise HTTPBadRequest(text=f"Unsupported version: {version}") + + # check client handshake for validity + key = headers.get(hdrs.SEC_WEBSOCKET_KEY) try: - status, headers, _, writer, protocol = do_handshake( - request.method, request.headers, request._protocol.writer, - self._protocols) - except HttpProcessingError as err: - if err.code == 405: - raise HTTPMethodNotAllowed( - request.method, [hdrs.METH_GET], body=b'') - elif err.code == 400: - raise HTTPBadRequest(text=err.message, headers=err.headers) - else: # pragma: no cover - raise HTTPInternalServerError() from err - - self._reset_heartbeat() - - if self.status != status: - self.set_status(status) - for k, v in headers: - self.headers[k] = v + if not key or len(base64.b64decode(key)) != 16: + raise HTTPBadRequest(text=f"Handshake error: {key!r}") + except binascii.Error: + raise HTTPBadRequest(text=f"Handshake error: {key!r}") from None + + accept_val = base64.b64encode( + hashlib.sha1(key.encode() + WS_KEY).digest() + ).decode() + response_headers = CIMultiDict( # type: ignore + { + hdrs.UPGRADE: "websocket", # type: ignore + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: accept_val, + } + ) + + notakeover = False + compress = 0 + if self._compress: + extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS) + # Server side always get return with no exception. + # If something happened, just drop compress extension + compress, notakeover = ws_ext_parse(extensions, isserver=True) + if compress: + enabledext = ws_ext_gen( + compress=compress, isserver=True, server_notakeover=notakeover + ) + response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext + + if protocol: + response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol + return (response_headers, protocol, compress, notakeover) # type: ignore + + def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]: + self._loop = request._loop + + headers, protocol, compress, notakeover = self._handshake(request) + + self.set_status(101) + self.headers.update(headers) self.force_close() + self._compress = compress + transport = request._protocol.transport + assert transport is not None + writer = WebSocketWriter( + request._protocol, transport, compress=compress, notakeover=notakeover + ) + return protocol, writer - def _post_start(self, request, protocol, writer): + def _post_start( + self, request: BaseRequest, protocol: str, writer: WebSocketWriter + ) -> None: self._ws_protocol = protocol self._writer = writer - self._reader = FlowControlDataQueue( - request._protocol, limit=2 ** 16, loop=self._loop) - request.protocol.set_parser(WebSocketReader(self._reader)) - def can_prepare(self, request): + self._reset_heartbeat() + + loop = self._loop + assert loop is not None + self._reader = FlowControlDataQueue(request._protocol, 2 ** 16, loop=loop) + request.protocol.set_parser( + WebSocketReader(self._reader, self._max_msg_size, compress=self._compress) + ) + # disable HTTP keepalive for WebSocket + request.protocol.keep_alive(False) + + def can_prepare(self, request: BaseRequest) -> WebSocketReady: if self._writer is not None: - raise RuntimeError('Already started') + raise RuntimeError("Already started") try: - _, _, _, _, protocol = do_handshake( - request.method, request.headers, request._protocol.writer, - self._protocols) - except HttpProcessingError: + _, protocol, _, _ = self._handshake(request) + except HTTPException: return WebSocketReady(False, None) else: return WebSocketReady(True, protocol) @property - def closed(self): + def closed(self) -> bool: return self._closed @property - def close_code(self): + def close_code(self) -> Optional[int]: return self._close_code @property - def ws_protocol(self): + def ws_protocol(self) -> Optional[str]: return self._ws_protocol - def exception(self): + @property + def compress(self) -> bool: + return self._compress + + def exception(self) -> Optional[BaseException]: return self._exception - def ping(self, message='b'): + async def ping(self, message: bytes = b"") -> None: if self._writer is None: - raise RuntimeError('Call .prepare() first') - self._writer.ping(message) + raise RuntimeError("Call .prepare() first") + await self._writer.ping(message) - def pong(self, message='b'): + async def pong(self, message: bytes = b"") -> None: # unsolicited pong if self._writer is None: - raise RuntimeError('Call .prepare() first') - self._writer.pong(message) + raise RuntimeError("Call .prepare() first") + await self._writer.pong(message) - def send_str(self, data): + async def send_str(self, data: str, compress: Optional[bool] = None) -> None: if self._writer is None: - raise RuntimeError('Call .prepare() first') + raise RuntimeError("Call .prepare() first") if not isinstance(data, str): - raise TypeError('data argument must be str (%r)' % type(data)) - return self._writer.send(data, binary=False) + raise TypeError("data argument must be str (%r)" % type(data)) + await self._writer.send(data, binary=False, compress=compress) - def send_bytes(self, data): + async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None: if self._writer is None: - raise RuntimeError('Call .prepare() first') + raise RuntimeError("Call .prepare() first") if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)' % - type(data)) - return self._writer.send(data, binary=True) - - def send_json(self, data, *, dumps=json.dumps): - return self.send_str(dumps(data)) - - @asyncio.coroutine - def write_eof(self): + raise TypeError("data argument must be byte-ish (%r)" % type(data)) + await self._writer.send(data, binary=True, compress=compress) + + async def send_json( + self, + data: Any, + compress: Optional[bool] = None, + *, + dumps: JSONEncoder = json.dumps, + ) -> None: + await self.send_str(dumps(data), compress=compress) + + async def write_eof(self) -> None: # type: ignore if self._eof_sent: return if self._payload_writer is None: raise RuntimeError("Response has not been started") - yield from self.close() + await self.close() self._eof_sent = True - @asyncio.coroutine - def close(self, *, code=1000, message=b''): + async def close(self, *, code: int = 1000, message: bytes = b"") -> bool: if self._writer is None: - raise RuntimeError('Call .prepare() first') + raise RuntimeError("Call .prepare() first") self._cancel_heartbeat() + reader = self._reader + assert reader is not None # we need to break `receive()` cycle first, # `close()` may be called from different task if self._waiting is not None and not self._closed: - self._reader.feed_data(WS_CLOSING_MESSAGE, 0) - yield from self._waiting + reader.feed_data(WS_CLOSING_MESSAGE, 0) + await self._waiting if not self._closed: self._closed = True try: - self._writer.close(code, message) - yield from self.drain() + await self._writer.close(code, message) + writer = self._payload_writer + assert writer is not None + await writer.drain() except (asyncio.CancelledError, asyncio.TimeoutError): self._close_code = 1006 raise @@ -225,9 +356,11 @@ def close(self, *, code=1000, message=b''): if self._closing: return True + reader = self._reader + assert reader is not None try: - with Timeout(self._timeout, loop=self._loop): - msg = yield from self._reader.read() + with async_timeout.timeout(self._timeout, loop=self._loop): + msg = await reader.read() except asyncio.CancelledError: self._close_code = 1006 raise @@ -246,102 +379,103 @@ def close(self, *, code=1000, message=b''): else: return False - @asyncio.coroutine - def receive(self, timeout=None): + async def receive(self, timeout: Optional[float] = None) -> WSMessage: if self._reader is None: - raise RuntimeError('Call .prepare() first') + raise RuntimeError("Call .prepare() first") + loop = self._loop + assert loop is not None while True: if self._waiting is not None: - raise RuntimeError( - 'Concurrent call to receive() is not allowed') + raise RuntimeError("Concurrent call to receive() is not allowed") if self._closed: self._conn_lost += 1 if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS: - raise RuntimeError('WebSocket connection is closed.') + raise RuntimeError("WebSocket connection is closed.") return WS_CLOSED_MESSAGE elif self._closing: return WS_CLOSING_MESSAGE try: - self._waiting = create_future(self._loop) + self._waiting = loop.create_future() try: - with Timeout( - timeout or self._receive_timeout, loop=self._loop): - msg = yield from self._reader.read() + with async_timeout.timeout( + timeout or self._receive_timeout, loop=self._loop + ): + msg = await self._reader.read() self._reset_heartbeat() finally: waiter = self._waiting + set_result(waiter, True) self._waiting = None - waiter.set_result(True) - except (asyncio.CancelledError, asyncio.TimeoutError) as exc: + except (asyncio.CancelledError, asyncio.TimeoutError): self._close_code = 1006 raise + except EofStream: + self._close_code = 1000 + await self.close() + return WSMessage(WSMsgType.CLOSED, None, None) except WebSocketError as exc: self._close_code = exc.code - yield from self.close(code=exc.code) + await self.close(code=exc.code) return WSMessage(WSMsgType.ERROR, exc, None) except Exception as exc: self._exception = exc self._closing = True self._close_code = 1006 - yield from self.close() + await self.close() return WSMessage(WSMsgType.ERROR, exc, None) if msg.type == WSMsgType.CLOSE: self._closing = True self._close_code = msg.data if not self._closed and self._autoclose: - yield from self.close() + await self.close() elif msg.type == WSMsgType.CLOSING: self._closing = True elif msg.type == WSMsgType.PING and self._autoping: - self.pong(msg.data) + await self.pong(msg.data) continue elif msg.type == WSMsgType.PONG and self._autoping: continue return msg - @asyncio.coroutine - def receive_str(self, *, timeout=None): - msg = yield from self.receive(timeout) + async def receive_str(self, *, timeout: Optional[float] = None) -> str: + msg = await self.receive(timeout) if msg.type != WSMsgType.TEXT: raise TypeError( "Received message {}:{!r} is not WSMsgType.TEXT".format( - msg.type, msg.data)) + msg.type, msg.data + ) + ) return msg.data - @asyncio.coroutine - def receive_bytes(self, *, timeout=None): - msg = yield from self.receive(timeout) + async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: + msg = await self.receive(timeout) if msg.type != WSMsgType.BINARY: - raise TypeError( - "Received message {}:{!r} is not bytes".format(msg.type, - msg.data)) + raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes") return msg.data - @asyncio.coroutine - def receive_json(self, *, loads=json.loads, timeout=None): - data = yield from self.receive_str(timeout=timeout) + async def receive_json( + self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None + ) -> Any: + data = await self.receive_str(timeout=timeout) return loads(data) - def write(self, data): + async def write(self, data: bytes) -> None: raise RuntimeError("Cannot call .write() for websocket") - if PY_35: - def __aiter__(self): - return self + def __aiter__(self) -> "WebSocketResponse": + return self - if not PY_352: # pragma: no cover - __aiter__ = asyncio.coroutine(__aiter__) + async def __anext__(self) -> WSMessage: + msg = await self.receive() + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): + raise StopAsyncIteration + return msg - @asyncio.coroutine - def __anext__(self): - msg = yield from self.receive() - if msg.type in (WSMsgType.CLOSE, - WSMsgType.CLOSING, - WSMsgType.CLOSED): - raise StopAsyncIteration # NOQA - return msg + def _cancel(self, exc: BaseException) -> None: + if self._reader is not None: + self._reader.set_exception(exc) diff --git a/aiohttp/worker.py b/aiohttp/worker.py index e30790a716c..67b244bbd35 100644 --- a/aiohttp/worker.py +++ b/aiohttp/worker.py @@ -4,16 +4,29 @@ import os import re import signal -import socket -import ssl import sys +from types import FrameType +from typing import Any, Awaitable, Callable, Optional, Union # noqa from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat from gunicorn.workers import base -from .helpers import AccessLogger, ensure_future +from aiohttp import web -__all__ = ('GunicornWebWorker', 'GunicornUVLoopWebWorker') +from .helpers import set_result +from .web_app import Application +from .web_log import AccessLogger + +try: + import ssl + + SSLContext = ssl.SSLContext +except ImportError: # pragma: no cover + ssl = None # type: ignore + SSLContext = object # type: ignore + + +__all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker", "GunicornTokioWebWorker") class GunicornWebWorker(base.Worker): @@ -21,14 +34,14 @@ class GunicornWebWorker(base.Worker): DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default - def __init__(self, *args, **kw): # pragma: no cover + def __init__(self, *args: Any, **kw: Any) -> None: # pragma: no cover super().__init__(*args, **kw) - self.servers = {} + self._task = None # type: Optional[asyncio.Task[None]] self.exit_code = 0 - self._notify_waiter = None + self._notify_waiter = None # type: Optional[asyncio.Future[bool]] - def init_process(self): + def init_process(self) -> None: # create new event_loop after fork asyncio.get_event_loop().close() @@ -37,86 +50,61 @@ def init_process(self): super().init_process() - def run(self): - if hasattr(self.wsgi, 'startup'): - self.loop.run_until_complete(self.wsgi.startup()) - self._runner = ensure_future(self._run(), loop=self.loop) + def run(self) -> None: + self._task = self.loop.create_task(self._run()) - try: - self.loop.run_until_complete(self._runner) - finally: - self.loop.close() + try: # ignore all finalization problems + self.loop.run_until_complete(self._task) + except Exception: + self.log.exception("Exception in gunicorn worker") + if sys.version_info >= (3, 6): + self.loop.run_until_complete(self.loop.shutdown_asyncgens()) + self.loop.close() sys.exit(self.exit_code) - def make_handler(self, app): - if hasattr(self.wsgi, 'make_handler'): - access_log = self.log.access_log if self.cfg.accesslog else None - return app.make_handler( - loop=self.loop, - logger=self.log, - slow_request_timeout=self.cfg.timeout, - keepalive_timeout=self.cfg.keepalive, - access_log=access_log, - access_log_format=self._get_valid_log_format( - self.cfg.access_log_format)) + async def _run(self) -> None: + if isinstance(self.wsgi, Application): + app = self.wsgi + elif asyncio.iscoroutinefunction(self.wsgi): + app = await self.wsgi() else: raise RuntimeError( - "aiohttp.wsgi is not supported anymore, " - "consider to switch to aiohttp.web.Application") - - @asyncio.coroutine - def close(self): - if self.servers: - servers = self.servers - self.servers = None - - # stop accepting connections - for server, handler in servers.items(): - self.log.info("Stopping server: %s, connections: %s", - self.pid, len(handler.connections)) - server.close() - yield from server.wait_closed() - - # send on_shutdown event - if hasattr(self.wsgi, 'shutdown'): - yield from self.wsgi.shutdown() - - # stop alive connections - tasks = [ - handler.shutdown( - timeout=self.cfg.graceful_timeout / 100 * 95) - for handler in servers.values()] - yield from asyncio.gather(*tasks, loop=self.loop) - - # cleanup application - if hasattr(self.wsgi, 'cleanup'): - yield from self.wsgi.cleanup() - - @asyncio.coroutine - def _run(self): + "wsgi app should be either Application or " + "async function returning Application, got {}".format(self.wsgi) + ) + access_log = self.log.access_log if self.cfg.accesslog else None + runner = web.AppRunner( + app, + logger=self.log, + keepalive_timeout=self.cfg.keepalive, + access_log=access_log, + access_log_format=self._get_valid_log_format(self.cfg.access_log_format), + ) + await runner.setup() ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None + runner = runner + assert runner is not None + server = runner.server + assert server is not None for sock in self.sockets: - handler = self.make_handler(self.wsgi) - - if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX: - srv = yield from self.loop.create_unix_server( - handler, sock=sock.sock, ssl=ctx) - else: - srv = yield from self.loop.create_server( - handler, sock=sock.sock, ssl=ctx) - self.servers[srv] = handler + site = web.SockSite( + runner, + sock, + ssl_context=ctx, + shutdown_timeout=self.cfg.graceful_timeout / 100 * 95, + ) + await site.start() # If our parent changed then we shut down. pid = os.getpid() try: - while self.alive: + while self.alive: # type: ignore self.notify() - cnt = sum(handler.requests_count - for handler in self.servers.values()) + cnt = server.requests_count if self.cfg.max_requests and cnt > self.cfg.max_requests: self.alive = False self.log.info("Max requests, shutting down: %s", self) @@ -125,78 +113,89 @@ def _run(self): self.alive = False self.log.info("Parent changed, shutting down: %s", self) else: - yield from self._wait_next_notify() - + await self._wait_next_notify() except BaseException: pass - yield from self.close() + await runner.cleanup() - def _wait_next_notify(self): + def _wait_next_notify(self) -> "asyncio.Future[bool]": self._notify_waiter_done() - self._notify_waiter = waiter = asyncio.Future(loop=self.loop) - self.loop.call_later(1.0, self._notify_waiter_done) + loop = self.loop + assert loop is not None + self._notify_waiter = waiter = loop.create_future() + self.loop.call_later(1.0, self._notify_waiter_done, waiter) return waiter - def _notify_waiter_done(self): - waiter = self._notify_waiter - if waiter is not None and not waiter.done(): - waiter.set_result(True) + def _notify_waiter_done( + self, waiter: Optional["asyncio.Future[bool]"] = None + ) -> None: + if waiter is None: + waiter = self._notify_waiter + if waiter is not None: + set_result(waiter, True) - self._notify_waiter = None + if waiter is self._notify_waiter: + self._notify_waiter = None - def init_signals(self): + def init_signals(self) -> None: # Set up signals through the event loop API. - self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit, - signal.SIGQUIT, None) + self.loop.add_signal_handler( + signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None + ) - self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit, - signal.SIGTERM, None) + self.loop.add_signal_handler( + signal.SIGTERM, self.handle_exit, signal.SIGTERM, None + ) - self.loop.add_signal_handler(signal.SIGINT, self.handle_quit, - signal.SIGINT, None) + self.loop.add_signal_handler( + signal.SIGINT, self.handle_quit, signal.SIGINT, None + ) - self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch, - signal.SIGWINCH, None) + self.loop.add_signal_handler( + signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None + ) - self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1, - signal.SIGUSR1, None) + self.loop.add_signal_handler( + signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None + ) - self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort, - signal.SIGABRT, None) + self.loop.add_signal_handler( + signal.SIGABRT, self.handle_abort, signal.SIGABRT, None + ) # Don't let SIGTERM and SIGUSR1 disturb active requests # by interrupting system calls signal.siginterrupt(signal.SIGTERM, False) signal.siginterrupt(signal.SIGUSR1, False) - def handle_quit(self, sig, frame): + def handle_quit(self, sig: int, frame: FrameType) -> None: self.alive = False # worker_int callback self.cfg.worker_int(self) - # init closing process - self._closing = ensure_future(self.close(), loop=self.loop) - - # close loop - self.loop.call_later(0.1, self._notify_waiter_done) + # wakeup closing process + self._notify_waiter_done() - def handle_abort(self, sig, frame): + def handle_abort(self, sig: int, frame: FrameType) -> None: self.alive = False self.exit_code = 1 self.cfg.worker_abort(self) sys.exit(1) @staticmethod - def _create_ssl_context(cfg): - """ Creates SSLContext instance for usage in asyncio.create_server. + def _create_ssl_context(cfg: Any) -> "SSLContext": + """Creates SSLContext instance for usage in asyncio.create_server. See ssl.SSLSocket.__init__ for more details. """ + if ssl is None: # pragma: no cover + raise RuntimeError("SSL is not supported.") + ctx = ssl.SSLContext(cfg.ssl_version) ctx.load_cert_chain(cfg.certfile, cfg.keyfile) ctx.verify_mode = cfg.cert_reqs @@ -206,15 +205,15 @@ def _create_ssl_context(cfg): ctx.set_ciphers(cfg.ciphers) return ctx - def _get_valid_log_format(self, source_format): + def _get_valid_log_format(self, source_format: str) -> str: if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT: return self.DEFAULT_AIOHTTP_LOG_FORMAT - elif re.search(r'%\([^\)]+\)', source_format): + elif re.search(r"%\([^\)]+\)", source_format): raise ValueError( "Gunicorn's style options in form of `%(name)s` are not " "supported for the log formatting. Please use aiohttp's " "format specification to configure access log formatting: " - "http://aiohttp.readthedocs.io/en/stable/logging.html" + "http://docs.aiohttp.org/en/stable/logging.html" "#format-specification" ) else: @@ -222,8 +221,7 @@ def _get_valid_log_format(self, source_format): class GunicornUVLoopWebWorker(GunicornWebWorker): - - def init_process(self): + def init_process(self) -> None: import uvloop # Close any existing event loop before setting a @@ -236,3 +234,19 @@ def init_process(self): asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) super().init_process() + + +class GunicornTokioWebWorker(GunicornWebWorker): + def init_process(self) -> None: # pragma: no cover + import tokio + + # Close any existing event loop before setting a + # new policy. + asyncio.get_event_loop().close() + + # Setup tokio policy, so that every + # asyncio.get_event_loop() will create an instance + # of tokio event loop. + asyncio.set_event_loop_policy(tokio.EventLoopPolicy()) + + super().init_process() diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index 96e19f09873..00000000000 --- a/appveyor.yml +++ /dev/null @@ -1,41 +0,0 @@ -version: 2.0dev{build} - -environment: - PYPI_PASSWD: - secure: syNUF3e8AEPY327rWBkKag== - matrix: - - PYTHON: "C:\\Python34" - - PYTHON: "C:\\Python34-x64" - DISTUTILS_USE_SDK: "1" - - PYTHON: "C:\\Python35" - - PYTHON: "C:\\Python35-x64" - - PYTHON: "C:\\Python36" - - PYTHON: "C:\\Python36-x64" - -install: - - "build.cmd %PYTHON%\\python.exe -m pip install wheel" - - "build.cmd %PYTHON%\\python.exe -m pip install twine" - - "build.cmd %PYTHON%\\python.exe -m pip install -r requirements-ci.txt" - -build: false - -test_script: - - "build.cmd %PYTHON%\\python.exe setup.py test" - -after_test: - - "build.cmd %PYTHON%\\python.exe setup.py bdist_wheel" - -artifacts: - - path: dist\* - -deploy_script: - - ps: >- - if($env:appveyor_repo_tag -eq 'True') { - Invoke-Expression "$env:PYTHON\\python.exe -m twine upload dist/* --username fafhrd --password $env:PYPI_PASSWD" - } - -#notifications: -# - provider: Webhook -# url: https://ci.appveyor.com/api/github/webhook?id=08c7793w1tp839fl -# on_build_success: false -# on_build_failure: True diff --git a/benchmark/async.py b/benchmark/async.py deleted file mode 100644 index 30ea995f7cc..00000000000 --- a/benchmark/async.py +++ /dev/null @@ -1,331 +0,0 @@ -import argparse -import asyncio -import collections -import cProfile -import gc -import random -import socket -import string -import sys -from multiprocessing import Barrier, Process, set_start_method -from statistics import mean, median, stdev - -import aiohttp - - -def find_port(): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(('localhost', 0)) - host, port = s.getsockname() - s.close() - return host, port - - -profiler = cProfile.Profile() - - -def run_aiohttp(host, port, barrier, profile): - - from aiohttp import web - - @asyncio.coroutine - def test(request): - txt = 'Hello, ' + request.match_info['name'] - return web.Response(text=txt) - - @asyncio.coroutine - def prepare(request): - gc.collect() - return web.Response(text='OK') - - @asyncio.coroutine - def stop(request): - loop.call_later(0.1, loop.stop) - return web.Response(text='OK') - - @asyncio.coroutine - def init(loop): - app = web.Application(loop=loop) - app.router.add_route('GET', '/prepare', prepare) - app.router.add_route('GET', '/stop', stop) - app.router.add_route('GET', '/test/{name}', test) - - handler = app.make_handler(keep_alive=15, timeout=0) - srv = yield from loop.create_server(handler, host, port) - return srv, app, handler - - loop = asyncio.get_event_loop() - srv, app, handler = loop.run_until_complete(init(loop)) - barrier.wait() - - if profile: - profiler.enable() - - loop.run_forever() - srv.close() - loop.run_until_complete(srv.wait_closed()) - loop.run_until_complete(handler.finish_connections()) - loop.close() - - if profile: - profiler.disable() - - -def run_tornado(host, port, barrier, profile): - - import tornado.ioloop - import tornado.web - - class TestHandler(tornado.web.RequestHandler): - - def get(self, name): - txt = 'Hello, ' + name - self.set_header('Content-Type', 'text/plain; charset=utf-8') - self.write(txt) - - class PrepareHandler(tornado.web.RequestHandler): - - def get(self): - gc.collect() - self.write('OK') - - class StopHandler(tornado.web.RequestHandler): - - def get(self): - self.write('OK') - - def on_finish(self): - tornado.ioloop.IOLoop.instance().stop() - - app = tornado.web.Application([ - (r'/prepare', PrepareHandler), - (r'/stop', StopHandler), - (r'/test/(.+)', TestHandler)]) - - app.listen(port, host) - barrier.wait() - tornado.ioloop.IOLoop.instance().start() - - -def run_twisted(host, port, barrier, profile): - - if 'bsd' in sys.platform or sys.platform.startswith('darwin'): - from twisted.internet import kqreactor - kqreactor.install() - elif sys.platform in ['win32']: - from twisted.internet.iocpreactor import reactor as iocpreactor - iocpreactor.install() - elif sys.platform.startswith('linux'): - from twisted.internet import epollreactor - epollreactor.install() - else: - from twisted.internet import default as defaultreactor - defaultreactor.install() - - from twisted.web.server import Site - from twisted.web.resource import Resource - from twisted.internet import reactor - - class TestResource(Resource): - - def __init__(self, name): - super().__init__() - self.name = name - self.isLeaf = name is not None - - def render_GET(self, request): - txt = 'Hello, ' + self.name - request.setHeader(b'Content-Type', b'text/plain; charset=utf-8') - return txt.encode('utf8') - - def getChild(self, name, request): - return TestResource(name=name.decode('utf-8')) - - class PrepareResource(Resource): - - isLeaf = True - - def render_GET(self, request): - gc.collect() - return b'OK' - - class StopResource(Resource): - - isLeaf = True - - def render_GET(self, request): - reactor.callLater(0.1, reactor.stop) - return b'OK' - - root = Resource() - root.putChild(b'test', TestResource(None)) - root.putChild(b'prepare', PrepareResource()) - root.putChild(b'stop', StopResource()) - site = Site(root) - reactor.listenTCP(port, site, interface=host) - barrier.wait() - - reactor.run() - - -@asyncio.coroutine -def attack(count, concurrency, client, loop, url): - - out_times = collections.deque() - processed_count = 0 - - def gen(): - for i in range(count): - rnd = ''.join(random.sample(string.ascii_letters, 16)) - yield rnd - - @asyncio.coroutine - def do_bomb(in_iter): - nonlocal processed_count - for rnd in in_iter: - real_url = url + '/test/' + rnd - try: - t1 = loop.time() - resp = yield from client.get(real_url) - assert resp.status == 200, resp.status - if 'text/plain; charset=utf-8' != resp.headers['Content-Type']: - raise AssertionError('Invalid Content-Type: %r' % - resp.headers) - body = yield from resp.text() - yield from resp.release() - assert body == ('Hello, ' + rnd), rnd - t2 = loop.time() - out_times.append(t2 - t1) - processed_count += 1 - except Exception: - continue - - in_iter = gen() - bombers = [] - for i in range(concurrency): - bomber = asyncio.async(do_bomb(in_iter)) - bombers.append(bomber) - - t1 = loop.time() - yield from asyncio.gather(*bombers) - t2 = loop.time() - rps = processed_count / (t2 - t1) - return rps, out_times - - -@asyncio.coroutine -def run(test, count, concurrency, *, loop, verbose, profile): - if verbose: - print("Prepare") - else: - print('.', end='', flush=True) - host, port = find_port() - barrier = Barrier(2) - server = Process(target=test, args=(host, port, barrier, profile)) - server.start() - barrier.wait() - - url = 'http://{}:{}'.format(host, port) - - connector = aiohttp.TCPConnector(loop=loop) - with aiohttp.ClientSession(connector=connector) as client: - - for i in range(10): - # make server hot - resp = yield from client.get(url+'/prepare') - assert resp.status == 200, resp.status - yield from resp.release() - - if verbose: - test_name = test.__name__ - print("Attack", test_name) - rps, data = yield from attack(count, concurrency, client, loop, url) - if verbose: - print("Done") - - resp = yield from client.get(url+'/stop') - assert resp.status == 200, resp.status - yield from resp.release() - - server.join() - return rps, data - - -def main(argv): - args = ARGS.parse_args() - - count = args.count - concurrency = args.concurrency - verbose = args.verbose - tries = args.tries - - loop = asyncio.get_event_loop() - suite = [run_aiohttp, run_tornado, run_twisted] - - suite *= tries - random.shuffle(suite) - - all_times = collections.defaultdict(list) - all_rps = collections.defaultdict(list) - for test in suite: - test_name = test.__name__ - - rps, times = loop.run_until_complete(run(test, count, concurrency, - loop=loop, verbose=verbose, - profile=args.profile)) - all_times[test_name].extend(times) - all_rps[test_name].append(rps) - - if args.profile: - profiler.dump_stats('out.prof') - - print() - - for test_name in sorted(all_rps): - rps = all_rps[test_name] - times = [t * 1000 for t in all_times[test_name]] - - rps_mean = mean(rps) - times_mean = mean(times) - times_stdev = stdev(times) - times_median = median(times) - print('Results for', test_name) - print('RPS: {:d},\tmean: {:.3f} ms,' - '\tstandard deviation {:.3f} ms\tmedian {:.3f} ms' - .format(int(rps_mean), - times_mean, - times_stdev, - times_median)) - return 0 - - -ARGS = argparse.ArgumentParser(description="Run benchmark.") -ARGS.add_argument( - '-t', '--tries', action="store", - nargs='?', type=int, default=5, - help='count of tries (default: `%(default)s`)') -ARGS.add_argument( - '-n', '--count', action="store", - nargs='?', type=int, default=10000, - help='requests count (default: `%(default)s`)') -ARGS.add_argument( - '-c', '--concurrency', action="store", - nargs='?', type=int, default=500, - help='count of parallel requests (default: `%(default)s`)') -ARGS.add_argument( - '-p', '--plot-file-name', action="store", - type=str, default=None, - dest='plot_file_name', - help='file name for plot (default: `%(default)s`)') -ARGS.add_argument( - '-v', '--verbose', action="count", default=0, - help='verbosity level (default: `%(default)s`)') -ARGS.add_argument( - '--profile', action="store_true", default=False, - help='perform aiohttp test profiling, store result as out.prof ' - '(default: `%(default)s`)') - - -if __name__ == '__main__': - set_start_method('spawn') - sys.exit(main(sys.argv)) diff --git a/benchmark/prof.py b/benchmark/prof.py deleted file mode 100644 index 98207aaf804..00000000000 --- a/benchmark/prof.py +++ /dev/null @@ -1,26 +0,0 @@ -# Run with python3 simple_server.py PORT - -import asyncio -import logging -import sys - -import ujson as json -import uvloop - -from aiohttp import web - -loop = uvloop.new_event_loop() -asyncio.set_event_loop(loop) - -logging.basicConfig(level=logging.DEBUG) - - -async def handle(request): - return web.Response(body=json.dumps({"test": True}).encode('utf-8'), - content_type='application/json') - - -app = web.Application(loop=loop) -app.router.add_route('GET', '/', handle) - -web.run_app(app, port=int(sys.argv[1]), access_log=None) diff --git a/benchmark/requirements.txt b/benchmark/requirements.txt deleted file mode 100644 index 2046c8c06c4..00000000000 --- a/benchmark/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -twisted==17.1.0 -tornado==4.4.2 diff --git a/build-wheels.sh b/build-wheels.sh deleted file mode 100755 index 8033defed1c..00000000000 --- a/build-wheels.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash -PYTHON_VERSIONS="cp34-cp34m cp35-cp35m cp36-cp36m" - -echo "Compile wheels" -for PYTHON in ${PYTHON_VERSIONS}; do - /opt/python/${PYTHON}/bin/pip install -r /io/requirements-wheel.txt - /opt/python/${PYTHON}/bin/pip wheel /io/ -w /io/dist/ -done - -echo "Bundle external shared libraries into the wheels" -for whl in /io/dist/aiohttp*.whl; do - auditwheel repair $whl -w /io/dist/ -done - -echo "Install packages and test" -for PYTHON in ${PYTHON_VERSIONS}; do - /opt/python/${PYTHON}/bin/pip install aiohttp --no-index -f file:///io/dist - rm -rf /io/tests/__pycache__ - rm -rf /io/tests/test_py35/__pycache__ - /opt/python/${PYTHON}/bin/py.test /io/tests - rm -rf /io/tests/__pycache__ - rm -rf /io/tests/test_py35/__pycache__ -done diff --git a/build.cmd b/build.cmd deleted file mode 100644 index 243dc9a1f0f..00000000000 --- a/build.cmd +++ /dev/null @@ -1,21 +0,0 @@ -@echo off -:: To build extensions for 64 bit Python 3, we need to configure environment -:: variables to use the MSVC 2010 C++ compilers from GRMSDKX_EN_DVD.iso of: -:: MS Windows SDK for Windows 7 and .NET Framework 4 -:: -:: More details at: -:: https://github.com/cython/cython/wiki/64BitCythonExtensionsOnWindows - -IF "%DISTUTILS_USE_SDK%"=="1" ( - ECHO Configuring environment to build with MSVC on a 64bit architecture - ECHO Using Windows SDK 7.1 - "C:\Program Files\Microsoft SDKs\Windows\v7.1\Setup\WindowsSdkVer.exe" -q -version:v7.1 - CALL "C:\Program Files\Microsoft SDKs\Windows\v7.1\Bin\SetEnv.cmd" /x64 /release - SET MSSdk=1 - REM Need the following to allow tox to see the SDK compiler - SET TOX_TESTENV_PASSENV=DISTUTILS_USE_SDK MSSdk INCLUDE LIB -) ELSE ( - ECHO Using default MSVC build environment -) - -CALL %* diff --git a/codecov.yml b/codecov.yml deleted file mode 100644 index 543a9f13646..00000000000 --- a/codecov.yml +++ /dev/null @@ -1,2 +0,0 @@ -coverage: - range: "95..100" diff --git a/demos/README.rst b/demos/README.rst deleted file mode 100644 index b1db1c43370..00000000000 --- a/demos/README.rst +++ /dev/null @@ -1,2 +0,0 @@ -aiohttp demos -============= diff --git a/demos/chat/aiohttpdemo_chat/__init__.py b/demos/chat/aiohttpdemo_chat/__init__.py deleted file mode 100644 index b8023d8bc0c..00000000000 --- a/demos/chat/aiohttpdemo_chat/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = '0.0.1' diff --git a/demos/chat/aiohttpdemo_chat/main.py b/demos/chat/aiohttpdemo_chat/main.py deleted file mode 100644 index 3ab668a6f93..00000000000 --- a/demos/chat/aiohttpdemo_chat/main.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -import logging - -import jinja2 - -import aiohttp_jinja2 -from aiohttp import web -from aiohttpdemo_chat.views import setup as setup_routes - - -async def init(loop): - app = web.Application(loop=loop) - app['sockets'] = {} - app.on_shutdown.append(shutdown) - - aiohttp_jinja2.setup( - app, loader=jinja2.PackageLoader('aiohttpdemo_chat', 'templates')) - - setup_routes(app) - - return app - - -async def shutdown(app): - for ws in app['sockets'].values(): - await ws.close() - app['sockets'].clear() - - -def main(): - # init logging - logging.basicConfig(level=logging.DEBUG) - - loop = asyncio.get_event_loop() - app = loop.run_until_complete(init(loop)) - web.run_app(app) - - -if __name__ == '__main__': - main() diff --git a/demos/chat/aiohttpdemo_chat/templates/index.html b/demos/chat/aiohttpdemo_chat/templates/index.html deleted file mode 100644 index 6b51fb63734..00000000000 --- a/demos/chat/aiohttpdemo_chat/templates/index.html +++ /dev/null @@ -1,112 +0,0 @@ - - - - - - - - -

    Chat!

    -
    -  | Status: - UNKNOWN - disconnected -
    -
    -
    -
    - - -
    - - diff --git a/demos/chat/aiohttpdemo_chat/views.py b/demos/chat/aiohttpdemo_chat/views.py deleted file mode 100644 index 15e6ae28a37..00000000000 --- a/demos/chat/aiohttpdemo_chat/views.py +++ /dev/null @@ -1,50 +0,0 @@ -import json -import logging -import random -import string - -import aiohttp_jinja2 -from aiohttp import web - -log = logging.getLogger(__name__) - - -async def index(request): - resp = web.WebSocketResponse() - ok, protocol = resp.can_start(request) - if not ok: - return aiohttp_jinja2.render_template('index.html', request, {}) - - await resp.prepare(request) - name = (random.choice(string.ascii_uppercase) + - ''.join(random.sample(string.ascii_lowercase*10, 10))) - log.info('%s joined.', name) - resp.send_str(json.dumps({'action': 'connect', - 'name': name})) - for ws in request.app['sockets'].values(): - ws.send_str(json.dumps({'action': 'join', - 'name': name})) - request.app['sockets'][name] = resp - - while True: - msg = await resp.receive() - - if msg.type == web.MsgType.text: - for ws in request.app['sockets'].values(): - if ws is not resp: - ws.send_str(json.dumps({'action': 'sent', - 'name': name, - 'text': msg.data})) - else: - break - - del request.app['sockets'][name] - log.info('%s disconnected.', name) - for ws in request.app['sockets'].values(): - ws.send_str(json.dumps({'action': 'disconnect', - 'name': name})) - return resp - - -def setup(app): - app.router.add_get('/', index) diff --git a/demos/chat/setup.py b/demos/chat/setup.py deleted file mode 100644 index 9a927835dab..00000000000 --- a/demos/chat/setup.py +++ /dev/null @@ -1,32 +0,0 @@ -import os -import re - -from setuptools import find_packages, setup - - -def read_version(): - regexp = re.compile(r"^__version__\W*=\W*'([\d.abrc]+)'") - init_py = os.path.join(os.path.dirname(__file__), - 'aiohttpdemo_chat', '__init__.py') - with open(init_py) as f: - for line in f: - match = regexp.match(line) - if match is not None: - return match.group(1) - else: - msg = 'Cannot find version in aiohttpdemo_chat/__init__.py' - raise RuntimeError(msg) - - -install_requires = ['aiohttp', - 'aiohttp_jinja2'] - - -setup(name='aiohttpdemo_chat', - version=read_version(), - description='Chat example from aiohttp', - platforms=['POSIX'], - packages=find_packages(), - include_package_data=True, - install_requires=install_requires, - zip_safe=False) diff --git a/demos/polls/Makefile b/demos/polls/Makefile deleted file mode 100644 index 22a96238dd7..00000000000 --- a/demos/polls/Makefile +++ /dev/null @@ -1,28 +0,0 @@ -# Some simple testing tasks (sorry, UNIX only). - -FLAGS= - - -flake: - pyflakes aiohttpdemo_polls - pep8 aiohttpdemo_polls setup.py - -test: - pytest tests - -clean: - rm -rf `find . -name __pycache__` - rm -f `find . -type f -name '*.py[co]' ` - rm -f `find . -type f -name '*~' ` - rm -f `find . -type f -name '.*~' ` - rm -f `find . -type f -name '@*' ` - rm -f `find . -type f -name '#*#' ` - rm -f `find . -type f -name '*.orig' ` - rm -f `find . -type f -name '*.rej' ` - rm -f .coverage - rm -rf coverage - rm -rf build - rm -rf htmlcov - rm -rf dist - -.PHONY: flake clean test diff --git a/demos/polls/README.rst b/demos/polls/README.rst deleted file mode 100644 index 2650a3d5bae..00000000000 --- a/demos/polls/README.rst +++ /dev/null @@ -1,48 +0,0 @@ -Polls (demo for aiohttp) -======================== - -Example of polls project using aiohttp_, aiopg_ and aiohttp_jinja2_, -similar to django one. - -Installation -============ - -Install the app:: - - $ cd demos/polls - $ pip install -e . - -Create database for your project:: - - bash sql/install.sh - -Run application:: - - $ python -m aiohttpdemo_polls - - -Open browser:: - - http://localhost:8080/ - -.. image:: https://raw.githubusercontent.com/andriisoldatenko/aiohttp_polls/master/images/example.png - :align: center - - -Run integration tests:: - - pip install tox - tox - - -Requirements -============ -* aiohttp_ -* aiopg_ -* aiohttp_jinja2_ - - -.. _Python: https://www.python.org -.. _aiohttp: https://github.com/aio-libs/aiohttp -.. _aiopg: https://github.com/aio-libs/aiopg -.. _aiohttp_jinja2: https://github.com/aio-libs/aiohttp_jinja2 diff --git a/demos/polls/aiohttpdemo_polls/__init__.py b/demos/polls/aiohttpdemo_polls/__init__.py deleted file mode 100644 index b8023d8bc0c..00000000000 --- a/demos/polls/aiohttpdemo_polls/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = '0.0.1' diff --git a/demos/polls/aiohttpdemo_polls/__main__.py b/demos/polls/aiohttpdemo_polls/__main__.py deleted file mode 100644 index 1ea11eb3a16..00000000000 --- a/demos/polls/aiohttpdemo_polls/__main__.py +++ /dev/null @@ -1,4 +0,0 @@ -import sys -from aiohttpdemo_polls.main import main - -main(sys.argv[1:]) diff --git a/demos/polls/aiohttpdemo_polls/db.py b/demos/polls/aiohttpdemo_polls/db.py deleted file mode 100644 index be9fef35351..00000000000 --- a/demos/polls/aiohttpdemo_polls/db.py +++ /dev/null @@ -1,83 +0,0 @@ -import sqlalchemy as sa - -import aiopg.sa - -__all__ = ['question', 'choice'] - -meta = sa.MetaData() - - -question = sa.Table( - 'question', meta, - sa.Column('id', sa.Integer, nullable=False), - sa.Column('question_text', sa.String(200), nullable=False), - sa.Column('pub_date', sa.Date, nullable=False), - - # Indexes # - sa.PrimaryKeyConstraint('id', name='question_id_pkey')) - -choice = sa.Table( - 'choice', meta, - sa.Column('id', sa.Integer, nullable=False), - sa.Column('question_id', sa.Integer, nullable=False), - sa.Column('choice_text', sa.String(200), nullable=False), - sa.Column('votes', sa.Integer, server_default="0", nullable=False), - - # Indexes # - sa.PrimaryKeyConstraint('id', name='choice_id_pkey'), - sa.ForeignKeyConstraint(['question_id'], [question.c.id], - name='choice_question_id_fkey', - ondelete='CASCADE'), -) - - -class RecordNotFound(Exception): - """Requested record in database was not found""" - - -async def init_pg(app): - conf = app['config']['postgres'] - engine = await aiopg.sa.create_engine( - database=conf['database'], - user=conf['user'], - password=conf['password'], - host=conf['host'], - port=conf['port'], - minsize=conf['minsize'], - maxsize=conf['maxsize'], - loop=app.loop) - app['db'] = engine - - -async def close_pg(app): - app['db'].close() - await app['db'].wait_closed() - - -async def get_question(conn, question_id): - result = await conn.execute( - question.select() - .where(question.c.id == question_id)) - question_record = await result.first() - if not question_record: - msg = "Question with id: {} does not exists" - raise RecordNotFound(msg.format(question_id)) - result = await conn.execute( - choice.select() - .where(choice.c.question_id == question_id) - .order_by(choice.c.id)) - choice_recoreds = await result.fetchall() - return question_record, choice_recoreds - - -async def vote(conn, question_id, choice_id): - result = await conn.execute( - choice.update() - .returning(*choice.c) - .where(choice.c.question_id == question_id) - .where(choice.c.id == choice_id) - .values(votes=choice.c.votes+1)) - record = await result.fetchone() - if not record: - msg = "Question with id: {} or choice id: {} does not exists" - raise RecordNotFound(msg.format(question_id, choice_id)) diff --git a/demos/polls/aiohttpdemo_polls/main.py b/demos/polls/aiohttpdemo_polls/main.py deleted file mode 100644 index 58c2e31bfa1..00000000000 --- a/demos/polls/aiohttpdemo_polls/main.py +++ /dev/null @@ -1,64 +0,0 @@ -import argparse -import asyncio -import logging -import sys - -import jinja2 - -from trafaret_config import commandline - - -import aiohttp_jinja2 -from aiohttp import web -from aiohttpdemo_polls.db import close_pg, init_pg -from aiohttpdemo_polls.middlewares import setup_middlewares -from aiohttpdemo_polls.routes import setup_routes -from aiohttpdemo_polls.utils import TRAFARET - - -def init(loop, argv): - ap = argparse.ArgumentParser() - commandline.standard_argparse_options(ap, - default_config='./config/polls.yaml') - # - # define your command-line arguments here - # - options = ap.parse_args(argv) - - config = commandline.config_from_options(options, TRAFARET) - - # setup application and extensions - app = web.Application(loop=loop) - - # load config from yaml file in current dir - app['config'] = config - - # setup Jinja2 template renderer - aiohttp_jinja2.setup( - app, loader=jinja2.PackageLoader('aiohttpdemo_polls', 'templates')) - - # create connection to the database - app.on_startup.append(init_pg) - # shutdown db connection on exit - app.on_cleanup.append(close_pg) - # setup views and routes - setup_routes(app) - setup_middlewares(app) - - return app - - -def main(argv): - # init logging - logging.basicConfig(level=logging.DEBUG) - - loop = asyncio.get_event_loop() - - app = init(loop, argv) - web.run_app(app, - host=app['config']['host'], - port=app['config']['port']) - - -if __name__ == '__main__': - main(sys.argv[1:]) diff --git a/demos/polls/aiohttpdemo_polls/middlewares.py b/demos/polls/aiohttpdemo_polls/middlewares.py deleted file mode 100644 index a7c75406d4f..00000000000 --- a/demos/polls/aiohttpdemo_polls/middlewares.py +++ /dev/null @@ -1,42 +0,0 @@ -import aiohttp_jinja2 -from aiohttp import web - - -async def handle_404(request, response): - response = aiohttp_jinja2.render_template('404.html', - request, - {}) - return response - - -async def handle_500(request, response): - response = aiohttp_jinja2.render_template('500.html', - request, - {}) - return response - - -def error_pages(overrides): - async def middleware(app, handler): - async def middleware_handler(request): - try: - response = await handler(request) - override = overrides.get(response.status) - if override is None: - return response - else: - return await override(request, response) - except web.HTTPException as ex: - override = overrides.get(ex.status) - if override is None: - raise - else: - return await override(request, ex) - return middleware_handler - return middleware - - -def setup_middlewares(app): - error_middleware = error_pages({404: handle_404, - 500: handle_500}) - app.middlewares.append(error_middleware) diff --git a/demos/polls/aiohttpdemo_polls/routes.py b/demos/polls/aiohttpdemo_polls/routes.py deleted file mode 100644 index fc74a766689..00000000000 --- a/demos/polls/aiohttpdemo_polls/routes.py +++ /dev/null @@ -1,16 +0,0 @@ -import pathlib - -from .views import index, poll, results, vote - -PROJECT_ROOT = pathlib.Path(__file__).parent - - -def setup_routes(app): - app.router.add_get('/', index) - app.router.add_get('/poll/{question_id}', poll, name='poll') - app.router.add_get('/poll/{question_id}/results', - results, name='results') - app.router.add_post('/poll/{question_id}/vote', vote, name='vote') - app.router.add_static('/static/', - path=str(PROJECT_ROOT / 'static'), - name='static') diff --git a/demos/polls/aiohttpdemo_polls/static/style.css b/demos/polls/aiohttpdemo_polls/static/style.css deleted file mode 100644 index a9db566399f..00000000000 --- a/demos/polls/aiohttpdemo_polls/static/style.css +++ /dev/null @@ -1,7 +0,0 @@ -li a { - color: green; -} - -body { - background: white url("images/background.gif") no-repeat right bottom; -} diff --git a/demos/polls/aiohttpdemo_polls/templates/404.html b/demos/polls/aiohttpdemo_polls/templates/404.html deleted file mode 100644 index 1d47f08a585..00000000000 --- a/demos/polls/aiohttpdemo_polls/templates/404.html +++ /dev/null @@ -1,3 +0,0 @@ -{% extends "base.html" %} - -{% set title = "Page Not Found" %} diff --git a/demos/polls/aiohttpdemo_polls/templates/500.html b/demos/polls/aiohttpdemo_polls/templates/500.html deleted file mode 100644 index a9201ce52e3..00000000000 --- a/demos/polls/aiohttpdemo_polls/templates/500.html +++ /dev/null @@ -1,3 +0,0 @@ -{% extends "base.html" %} - -{% set title = "Internal Server Error" %} diff --git a/demos/polls/aiohttpdemo_polls/templates/base.html b/demos/polls/aiohttpdemo_polls/templates/base.html deleted file mode 100644 index 0a3f5cf6adb..00000000000 --- a/demos/polls/aiohttpdemo_polls/templates/base.html +++ /dev/null @@ -1,17 +0,0 @@ - - - - {% block head %} - - {{title}} - {% endblock %} - - -

    {{title}}

    -
    - {% block content %} - {% endblock %} -
    - - diff --git a/demos/polls/aiohttpdemo_polls/templates/detail.html b/demos/polls/aiohttpdemo_polls/templates/detail.html deleted file mode 100644 index 72e7d60e462..00000000000 --- a/demos/polls/aiohttpdemo_polls/templates/detail.html +++ /dev/null @@ -1,15 +0,0 @@ -{% extends "base.html" %} - -{% set title = question.question_text %} - -{% block content %} -{% if error_message %}

    {{ error_message }}

    {% endif %} - -
    -{% for choice in choices %} - -
    -{% endfor %} - -
    -{% endblock %} diff --git a/demos/polls/aiohttpdemo_polls/templates/index.html b/demos/polls/aiohttpdemo_polls/templates/index.html deleted file mode 100644 index 3e21f80ba22..00000000000 --- a/demos/polls/aiohttpdemo_polls/templates/index.html +++ /dev/null @@ -1,15 +0,0 @@ -{% extends "base.html" %} - -{% set title = "Main" %} - -{% block content %} -{% if questions %} - -{% else %} -

    No polls are available.

    -{% endif %} -{% endblock %} diff --git a/demos/polls/aiohttpdemo_polls/templates/results.html b/demos/polls/aiohttpdemo_polls/templates/results.html deleted file mode 100644 index 923beae878a..00000000000 --- a/demos/polls/aiohttpdemo_polls/templates/results.html +++ /dev/null @@ -1,13 +0,0 @@ -{% extends "base.html" %} - -{% set title = question.question_text %} - -{% block content %} -
      -{% for choice in choices %} -
    • {{ choice.choice_text }} -- {{ choice.votes }} vote{{ choice.votes }}
    • -{% endfor %} -
    - -Vote again? -{% endblock %} diff --git a/demos/polls/aiohttpdemo_polls/utils.py b/demos/polls/aiohttpdemo_polls/utils.py deleted file mode 100644 index 9283dd6cdd4..00000000000 --- a/demos/polls/aiohttpdemo_polls/utils.py +++ /dev/null @@ -1,18 +0,0 @@ -import trafaret as T - -primitive_ip_regexp = r'^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$' - -TRAFARET = T.Dict({ - T.Key('postgres'): - T.Dict({ - 'database': T.String(), - 'user': T.String(), - 'password': T.String(), - 'host': T.String(), - 'port': T.Int(), - 'minsize': T.Int(), - 'maxsize': T.Int(), - }), - T.Key('host'): T.String(regex=primitive_ip_regexp), - T.Key('port'): T.Int(), -}) diff --git a/demos/polls/aiohttpdemo_polls/views.py b/demos/polls/aiohttpdemo_polls/views.py deleted file mode 100644 index 06a468b0bae..00000000000 --- a/demos/polls/aiohttpdemo_polls/views.py +++ /dev/null @@ -1,63 +0,0 @@ -import aiohttp_jinja2 -from aiohttp import web - -from . import db - - -@aiohttp_jinja2.template('index.html') -async def index(request): - async with request.app['db'].acquire() as conn: - cursor = await conn.execute(db.question.select()) - records = await cursor.fetchall() - questions = [dict(q) for q in records] - return {'questions': questions} - - -@aiohttp_jinja2.template('detail.html') -async def poll(request): - async with request.app['db'].acquire() as conn: - question_id = request.match_info['question_id'] - try: - question, choices = await db.get_question(conn, - question_id) - except db.RecordNotFound as e: - raise web.HTTPNotFound(text=str(e)) - return { - 'question': question, - 'choices': choices - } - - -@aiohttp_jinja2.template('results.html') -async def results(request): - async with request.app['db'].acquire() as conn: - question_id = request.match_info['question_id'] - - try: - question, choices = await db.get_question(conn, - question_id) - except db.RecordNotFound as e: - raise web.HTTPNotFound(text=str(e)) - - return { - 'question': question, - 'choices': choices - } - - -async def vote(request): - async with request.app['db'].acquire() as conn: - question_id = int(request.match_info['question_id']) - data = await request.post() - try: - choice_id = int(data['choice']) - except (KeyError, TypeError, ValueError) as e: - raise web.HTTPBadRequest( - text='You have not specified choice value') from e - try: - await db.vote(conn, question_id, choice_id) - except db.RecordNotFound as e: - raise web.HTTPNotFound(text=str(e)) - router = request.app.router - url = router['results'].url(parts={'question_id': question_id}) - return web.HTTPFound(location=url) diff --git a/demos/polls/config/polls.yaml b/demos/polls/config/polls.yaml deleted file mode 100644 index c790f1c6f0e..00000000000 --- a/demos/polls/config/polls.yaml +++ /dev/null @@ -1,11 +0,0 @@ -postgres: - database: aiohttpdemo_polls - user: aiohttpdemo_user - password: aiohttpdemo_user - host: localhost - port: 5432 - minsize: 1 - maxsize: 5 - -host: 127.0.0.1 -port: 8080 diff --git a/demos/polls/images/example.png b/demos/polls/images/example.png deleted file mode 100644 index 51343c30f78..00000000000 Binary files a/demos/polls/images/example.png and /dev/null differ diff --git a/demos/polls/requirements.txt b/demos/polls/requirements.txt deleted file mode 100644 index 1cc98029514..00000000000 --- a/demos/polls/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ --e . -docker-py==1.10.6 -pytest-aiohttp==0.1.3 diff --git a/demos/polls/setup.py b/demos/polls/setup.py deleted file mode 100644 index a72a5fd9033..00000000000 --- a/demos/polls/setup.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -import re - -from setuptools import find_packages, setup - - -def read_version(): - regexp = re.compile(r"^__version__\W*=\W*'([\d.abrc]+)'") - init_py = os.path.join(os.path.dirname(__file__), - 'aiohttpdemo_polls', '__init__.py') - with open(init_py) as f: - for line in f: - match = regexp.match(line) - if match is not None: - return match.group(1) - else: - msg = 'Cannot find version in aiohttpdemo_polls/__init__.py' - raise RuntimeError(msg) - - -install_requires = ['aiohttp', - 'aiopg[sa]', - 'aiohttp-jinja2', - 'trafaret-config'] - - -setup(name='aiohttpdemo-polls', - version=read_version(), - description='Polls project example from aiohttp', - platforms=['POSIX'], - packages=find_packages(), - package_data={ - '': ['templates/*.html', 'static/*.*'] - }, - include_package_data=True, - install_requires=install_requires, - zip_safe=False) diff --git a/demos/polls/sql/create_tables.sql b/demos/polls/sql/create_tables.sql deleted file mode 100644 index 9bb4769a49c..00000000000 --- a/demos/polls/sql/create_tables.sql +++ /dev/null @@ -1,20 +0,0 @@ -SET ROLE 'aiohttpdemo_user'; - -BEGIN; --- --- Create model Choice --- -CREATE TABLE "choice" ("id" serial NOT NULL PRIMARY KEY, "choice_text" varchar(200) NOT NULL, "votes" integer NOT NULL); --- --- Create model Question --- -CREATE TABLE "question" ("id" serial NOT NULL PRIMARY KEY, "question_text" varchar(200) NOT NULL, "pub_date" timestamp with time zone NOT NULL); --- --- Add field question to choice --- -ALTER TABLE "choice" ADD COLUMN "question_id" integer NOT NULL; -ALTER TABLE "choice" ALTER COLUMN "question_id" DROP DEFAULT; -CREATE INDEX "choice_7aa0f6ee" ON "choice" ("question_id"); -ALTER TABLE "choice" ADD CONSTRAINT "choice_question_id_c5b4b260_fk_question_id" FOREIGN KEY ("question_id") REFERENCES "question" ("id") DEFERRABLE INITIALLY DEFERRED; - -COMMIT; diff --git a/demos/polls/sql/install.sh b/demos/polls/sql/install.sh deleted file mode 100755 index b1207c4eb0c..00000000000 --- a/demos/polls/sql/install.sh +++ /dev/null @@ -1,8 +0,0 @@ -sudo -u postgres psql -c "DROP DATABASE IF EXISTS aiohttpdemo_polls" -sudo -u postgres psql -c "DROP ROLE IF EXISTS aiohttpdemo_user" -sudo -u postgres psql -c "CREATE USER aiohttpdemo_user WITH PASSWORD 'aiohttpdemo_user';" -sudo -u postgres psql -c "CREATE DATABASE aiohttpdemo_polls ENCODING 'UTF8';" -sudo -u postgres psql -c "GRANT ALL PRIVILEGES ON DATABASE aiohttpdemo_polls TO aiohttpdemo_user;" - -cat sql/create_tables.sql | sudo -u postgres psql -d aiohttpdemo_polls -a -cat sql/sample_data.sql | sudo -u postgres psql -d aiohttpdemo_polls -a diff --git a/demos/polls/sql/sample_data.sql b/demos/polls/sql/sample_data.sql deleted file mode 100644 index 4274fa9d25a..00000000000 --- a/demos/polls/sql/sample_data.sql +++ /dev/null @@ -1,22 +0,0 @@ -SET ROLE 'aiohttpdemo_user'; - -INSERT INTO question (id, question_text, pub_date) VALUES (1, 'What''s new?', '2015-12-15 17:17:49.629+02'); - - --- --- Name: question_id_seq; Type: SEQUENCE SET; Schema: public; Owner: polls --- - -SELECT pg_catalog.setval('question_id_seq', 1, true); - - -INSERT INTO choice (id, choice_text, votes, question_id) VALUES (1, 'Not much', 0, 1); -INSERT INTO choice (id, choice_text, votes, question_id) VALUES (2, 'The sky', 0, 1); -INSERT INTO choice (id, choice_text, votes, question_id) VALUES (3, 'Just hacking again', 0, 1); - - --- --- Name: choice_id_seq; Type: SEQUENCE SET; Schema: public; Owner: polls --- - -SELECT pg_catalog.setval('choice_id_seq', 3, true); diff --git a/demos/polls/tests/conftest.py b/demos/polls/tests/conftest.py deleted file mode 100644 index f0d66a13e72..00000000000 --- a/demos/polls/tests/conftest.py +++ /dev/null @@ -1,29 +0,0 @@ -import pathlib -import subprocess - -import pytest - -from aiohttpdemo_polls.main import init - -BASE_DIR = pathlib.Path(__file__).parent.parent - - -@pytest.fixture -def config_path(): - path = BASE_DIR / 'config' / 'polls.yaml' - return path.as_posix() - - -@pytest.fixture -def cli(loop, test_client, config_path): - app = init(loop, ['-c', config_path]) - return loop.run_until_complete(test_client(app)) - - -@pytest.fixture -def app_db(): - subprocess.call( - [(BASE_DIR / 'sql' / 'install.sh').as_posix()], - shell=True, - cwd=BASE_DIR.as_posix() - ) diff --git a/demos/polls/tests/test_integration.py b/demos/polls/tests/test_integration.py deleted file mode 100644 index 5a9f174f025..00000000000 --- a/demos/polls/tests/test_integration.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Integration tests. They need a running database. - -Beware, they destroy your db using sudo. -""" - - -async def test_index(cli, app_db): - response = await cli.get('/poll/1') - assert response.status == 200 - assert 'What\'s new?' in await response.text() - - -async def test_results(cli, app_db): - response = await cli.get('/poll/1/results') - assert response.status == 200 - assert 'Just hacking again' in await response.text() diff --git a/demos/polls/tox.ini b/demos/polls/tox.ini deleted file mode 100644 index 70e4fd07894..00000000000 --- a/demos/polls/tox.ini +++ /dev/null @@ -1,9 +0,0 @@ -[tox] -envlist = py35 - -[testenv] -deps = - pytest - pytest-aiohttp -usedevelop = True -commands=py.test tests -s diff --git a/docs/_static/aiohttp-icon-128x128.png b/docs/_static/aiohttp-icon-128x128.png index 1a3c9498119..e486a04e36e 100644 Binary files a/docs/_static/aiohttp-icon-128x128.png and b/docs/_static/aiohttp-icon-128x128.png differ diff --git a/docs/_static/aiohttp-icon-32x32.png b/docs/_static/aiohttp-icon-32x32.png deleted file mode 100644 index c8f7862339f..00000000000 Binary files a/docs/_static/aiohttp-icon-32x32.png and /dev/null differ diff --git a/docs/_static/aiohttp-icon-64x64.png b/docs/_static/aiohttp-icon-64x64.png deleted file mode 100644 index f7768b9c46d..00000000000 Binary files a/docs/_static/aiohttp-icon-64x64.png and /dev/null differ diff --git a/docs/_static/aiohttp-icon-96x96.png b/docs/_static/aiohttp-icon-96x96.png deleted file mode 100644 index d9dd02e1776..00000000000 Binary files a/docs/_static/aiohttp-icon-96x96.png and /dev/null differ diff --git a/docs/abc.rst b/docs/abc.rst index 7caae117d9b..7930b2850e8 100644 --- a/docs/abc.rst +++ b/docs/abc.rst @@ -3,9 +3,7 @@ Abstract Base Classes ===================== -.. module:: aiohttp - -.. currentmodule:: aiohttp +.. module:: aiohttp.abc Abstract routing ---------------- @@ -18,7 +16,7 @@ but few of them are. aiohttp.web is built on top of few concepts: *application*, *router*, *request* and *response*. -*router* is a *pluggable* part: a library user may build a *router* +*router* is a *plugable* part: a library user may build a *router* from scratch, all other parts should work with new router seamlessly. :class:`AbstractRouter` has the only mandatory method: @@ -35,7 +33,7 @@ Not Allowed*. :meth:`AbstractMatchInfo.handler` raises :attr:`~AbstractMatchInfo.http_exception` on call. -.. class:: AbstractRouter +.. class:: aiohttp.abc.AbstractRouter Abstract router, :class:`aiohttp.web.Application` accepts it as *router* parameter and returns as @@ -54,7 +52,7 @@ Not Allowed*. :meth:`AbstractMatchInfo.handler` raises :return: :class:`AbstractMatchInfo` instance. -.. class:: AbstractMatchInfo +.. class:: aiohttp.abc.AbstractMatchInfo Abstract *match info*, returned by :meth:`AbstractRouter.resolve` call. @@ -102,7 +100,7 @@ attribute. Abstract Cookie Jar ------------------- -.. class:: AbstractCookieJar +.. class:: aiohttp.abc.AbstractCookieJar The cookie jar instance is available as :attr:`ClientSession.cookie_jar`. @@ -147,6 +145,20 @@ Abstract Cookie Jar :return: :class:`http.cookies.SimpleCookie` with filtered cookies for given URL. +Abstract Abstract Access Logger +------------------------------- + +.. class:: aiohttp.abc.AbstractAccessLogger + + An abstract class, base for all :class:`RequestHandler` + ``access_logger`` implementations + + Method ``log`` should be overridden. + + .. method:: log(request, response, time) + + :param request: :class:`aiohttp.web.Request` object. + + :param response: :class:`aiohttp.web.Response` object. -.. disqus:: - :title: aiohttp abstact base classes + :param float time: Time taken to serve the request. diff --git a/docs/aiohttp-icon.ico b/docs/aiohttp-icon.ico deleted file mode 100644 index 56b6e563e13..00000000000 Binary files a/docs/aiohttp-icon.ico and /dev/null differ diff --git a/docs/aiohttp-icon.svg b/docs/aiohttp-icon.svg index 2b87a55c5c0..9356d47aaa4 100644 --- a/docs/aiohttp-icon.svg +++ b/docs/aiohttp-icon.svg @@ -1,487 +1,62 @@ - - - aiohttp-icon - Created with Sketch. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file + + + + + image/svg+xml + + + + + + + + + + + diff --git a/docs/aiohttp-plain.svg b/docs/aiohttp-plain.svg new file mode 100644 index 00000000000..f45ccd0f92f --- /dev/null +++ b/docs/aiohttp-plain.svg @@ -0,0 +1,62 @@ + + + + + + image/svg+xml + + + + + + + + + + + diff --git a/docs/built_with.rst b/docs/built_with.rst new file mode 100644 index 00000000000..12fa64c55db --- /dev/null +++ b/docs/built_with.rst @@ -0,0 +1,27 @@ +.. _aiohttp-built-with: + +Built with aiohttp +================== + +aiohttp is used to build useful libraries built on top of it, +and there's a page dedicated to list them: :ref:`aiohttp-3rd-party`. + +There are also projects that leverage the power of aiohttp to +provide end-user tools, like command lines or software with +full user interfaces. + +This page aims to list those projects. If you are using aiohttp +in your software and if it's playing a central role, you +can add it here in this list. + +You can also add a **Built with aiohttp** link somewhere in your +project, pointing to ``_. + + +* `Molotov `_ Load testing tool. +* `Arsenic `_ Async WebDriver. +* `Home Assistant `_ Home Automation Platform. +* `Backend.AI `_ Code execution API service. +* `doh-proxy `_ DNS Over HTTPS Proxy. +* `Mariner `_ Command-line torrent searcher. +* `DEEPaaS API `_ REST API for Machine learning, Deep learning and artificial intelligence applications. diff --git a/docs/changes.rst b/docs/changes.rst index 005bd1fdc38..0ecf1d76af8 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -3,7 +3,3 @@ .. include:: ../CHANGES.rst .. include:: ../HISTORY.rst - - -.. disqus:: - :title: aiohttp changelog diff --git a/docs/client.rst b/docs/client.rst index c94b8c477ef..0c57de57472 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -3,695 +3,16 @@ Client ====== -.. module:: aiohttp - .. currentmodule:: aiohttp +The page contains all information about aiohttp Client API: -Make a Request --------------- - -Begin by importing the aiohttp module:: - - import aiohttp - -Now, let's try to get a web-page. For example let's get GitHub's public -time-line:: - - async with aiohttp.ClientSession() as session: - async with session.get('https://api.github.com/events') as resp: - print(resp.status) - print(await resp.text()) - -Now, we have a :class:`ClientSession` called ``session`` and -a :class:`ClientResponse` object called ``resp``. We can get all the -information we need from the response. The mandatory parameter of -:meth:`ClientSession.get` coroutine is an HTTP url. - -In order to make an HTTP POST request use :meth:`ClientSession.post` coroutine:: - - session.post('http://httpbin.org/post', data=b'data') - -Other HTTP methods are available as well:: - - session.put('http://httpbin.org/put', data=b'data') - session.delete('http://httpbin.org/delete') - session.head('http://httpbin.org/get') - session.options('http://httpbin.org/get') - session.patch('http://httpbin.org/patch', data=b'data') - -.. note:: - - Don't create a session per request. Most likely you need a session - per application which performs all requests altogether. - - A session contains a connection pool inside, connection reusage and - keep-alives (both are on by default) may speed up total performance. - - -Passing Parameters In URLs --------------------------- - -You often want to send some sort of data in the URL's query string. If -you were constructing the URL by hand, this data would be given as key/value -pairs in the URL after a question mark, e.g. ``httpbin.org/get?key=val``. -Requests allows you to provide these arguments as a :class:`dict`, using the -``params`` keyword argument. As an example, if you wanted to pass -``key1=value1`` and ``key2=value2`` to ``httpbin.org/get``, you would use the -following code:: - - params = {'key1': 'value1', 'key2': 'value2'} - async with session.get('http://httpbin.org/get', - params=params) as resp: - assert resp.url == 'http://httpbin.org/get?key2=value2&key1=value1' - -You can see that the URL has been correctly encoded by printing the URL. - -For sending data with multiple values for the same key -:class:`MultiDict` may be used as well. - - -It is also possible to pass a list of 2 item tuples as parameters, in -that case you can specify multiple values for each key:: - - params = [('key', 'value1'), ('key', 'value2')] - async with session.get('http://httpbin.org/get', - params=params) as r: - assert r.url == 'http://httpbin.org/get?key=value2&key=value1' - -You can also pass :class:`str` content as param, but beware -- content -is not encoded by library. Note that ``+`` is not encoded:: - - async with session.get('http://httpbin.org/get', - params='key=value+1') as r: - assert r.url == 'http://httpbin.org/get?key=value+1' - -Response Content ----------------- - -We can read the content of the server's response. Consider the GitHub time-line -again:: - - async with session.get('https://api.github.com/events') as resp: - print(await resp.text()) - -will printout something like:: - - '[{"created_at":"2015-06-12T14:06:22Z","public":true,"actor":{... - -``aiohttp`` will automatically decode the content from the server. You can -specify custom encoding for the :meth:`~ClientResponse.text` method:: - - await resp.text(encoding='windows-1251') - - -Binary Response Content ------------------------ - -You can also access the response body as bytes, for non-text requests:: - - print(await resp.read()) - -:: - - b'[{"created_at":"2015-06-12T14:06:22Z","public":true,"actor":{... - -The ``gzip`` and ``deflate`` transfer-encodings are automatically -decoded for you. - -.. note:: - - This methods reads whole response body into memory. If you are planing - planing to read a lot of data consider to use streaming response. - - -JSON Response Content ---------------------- - -There's also a built-in JSON decoder, in case you're dealing with JSON data:: - - async with session.get('https://api.github.com/events') as resp: - print(await resp.json()) - -In case that JSON decoding fails, :meth:`~ClientResponse.json` will -raise an exception. It is possible to specify custom encoding and -decoder functions for the :meth:`~ClientResponse.json` call. - - -Streaming Response Content --------------------------- - -While methods :meth:`~ClientResponse.read`, -:meth:`~ClientResponse.json` and :meth:`~ClientResponse.text` are very -convenient you should use them carefully. All these methods load the -whole response in memory. For example if you want to download several -gigabyte sized files, these methods will load all the data in -memory. Instead you can use the :attr:`~ClientResponse.content` -attribute. It is an instance of the :class:`aiohttp.StreamReader` -class. The ``gzip`` and ``deflate`` transfer-encodings are -automatically decoded for you:: - - async with session.get('https://api.github.com/events') as resp: - await resp.content.read(10) - -In general, however, you should use a pattern like this to save what is being -streamed to a file:: - - with open(filename, 'wb') as fd: - while True: - chunk = await resp.content.read(chunk_size) - if not chunk: - break - fd.write(chunk) - -It is not possible to use :meth:`~ClientResponse.read`, -:meth:`~ClientResponse.json` and :meth:`~ClientResponse.text` after -explicit reading from :attr:`~ClientResponse.content`. - - -Custom Headers --------------- - -If you need to add HTTP headers to a request, pass them in a -:class:`dict` to the *headers* parameter. - -For example, if you want to specify the content-type for the previous -example:: - - import json - url = 'https://api.github.com/some/endpoint' - payload = {'some': 'data'} - headers = {'content-type': 'application/json'} - - await session.post(url, - data=json.dumps(payload), - headers=headers) - - -Custom Cookies --------------- - -To send your own cookies to the server, you can use the *cookies* -parameter of :class:`ClientSession` constructor:: - - url = 'http://httpbin.org/cookies' - cookies = {'cookies_are': 'working'} - async with ClientSession(cookies=cookies) as session: - async with session.get(url) as resp: - assert await resp.json() == { - "cookies": {"cookies_are": "working"}} - -.. note:: - ``httpbin.org/cookies`` endpoint returns request cookies - in JSON-encoded body. - To access session cookies see :attr:`ClientSession.cookie_jar`. - - -More complicated POST requests ------------------------------- - -Typically, you want to send some form-encoded data -- much like an HTML form. -To do this, simply pass a dictionary to the *data* argument. Your -dictionary of data will automatically be form-encoded when the request is made:: - - payload = {'key1': 'value1', 'key2': 'value2'} - async with session.post('http://httpbin.org/post', - data=payload) as resp: - print(await resp.text()) - -:: - - { - ... - "form": { - "key2": "value2", - "key1": "value1" - }, - ... - } - -If you want to send data that is not form-encoded you can do it by -passing a :class:`str` instead of a :class:`dict`. This data will be -posted directly. - -For example, the GitHub API v3 accepts JSON-Encoded POST/PATCH data:: - - import json - url = 'https://api.github.com/some/endpoint' - payload = {'some': 'data'} - - async with session.post(url, data=json.dumps(payload)) as resp: - ... - - -POST a Multipart-Encoded File ------------------------------ - -To upload Multipart-encoded files:: - - url = 'http://httpbin.org/post' - files = {'file': open('report.xls', 'rb')} - - await session.post(url, data=files) - -You can set the filename, content_type explicitly:: - - url = 'http://httpbin.org/post' - data = FormData() - data.add_field('file', - open('report.xls', 'rb'), - filename='report.xls', - content_type='application/vnd.ms-excel') - - await session.post(url, data=data) - -If you pass a file object as data parameter, aiohttp will stream it to -the server automatically. Check :class:`~aiohttp.streams.StreamReader` -for supported format information. - -.. seealso:: :ref:`aiohttp-multipart` - - -Streaming uploads ------------------ - -:mod:`aiohttp` supports multiple types of streaming uploads, which allows you to -send large files without reading them into memory. - -As a simple case, simply provide a file-like object for your body:: - - with open('massive-body', 'rb') as f: - await session.post('http://some.url/streamed', data=f) - - -Or you can use `aiohttp.streamer` object:: - - @aiohttp.streamer - def file_sender(writer, file_name=None): - with open(file_name, 'rb') as f: - chunk = f.read(2**16) - while chunk: - yield from writer.write(chunk) - chunk = f.read(2**16) - - # Then you can use `file_sender` as a data provider: - - async with session.post('http://httpbin.org/post', - data=file_sender(file_name='huge_file')) as resp: - print(await resp.text()) - -Also it is possible to use a :class:`~aiohttp.streams.StreamReader` -object. Lets say we want to upload a file from another request and -calculate the file SHA1 hash:: - - async def feed_stream(resp, stream): - h = hashlib.sha256() - - while True: - chunk = await resp.content.readany() - if not chunk: - break - h.update(chunk) - stream.feed_data(chunk) - - return h.hexdigest() - - resp = session.get('http://httpbin.org/post') - stream = StreamReader() - loop.create_task(session.post('http://httpbin.org/post', data=stream)) - - file_hash = await feed_stream(resp, stream) - - -Because the response content attribute is a -:class:`~aiohttp.streams.StreamReader`, you can chain get and post -requests together:: - - r = await session.get('http://python.org') - await session.post('http://httpbin.org/post', - data=r.content) - - -Uploading pre-compressed data ------------------------------ - -To upload data that is already compressed before passing it to aiohttp, call -the request function with the used compression algorithm name (usually deflate or zlib) -as the value of the ``Content-Encoding`` header:: - - async def my_coroutine(session, headers, my_data): - data = zlib.compress(my_data) - headers = {'Content-Encoding': 'deflate'} - async with session.post('http://httpbin.org/post', - data=data, - headers=headers) - pass - - -.. _aiohttp-client-session: - -Keep-Alive, connection pooling and cookie sharing -------------------------------------------------- - -:class:`~aiohttp.ClientSession` may be used for sharing cookies -between multiple requests:: - - async with aiohttp.ClientSession() as session: - await session.get( - 'http://httpbin.org/cookies/set?my_cookie=my_value') - filtered = session.cookie_jar.filter_cookies('http://httpbin.org') - assert filtered['my_cookie'].value == 'my_value' - async with session.get('http://httpbin.org/cookies') as r: - json_body = await r.json() - assert json_body['cookies']['my_cookie'] == 'my_value' - -You also can set default headers for all session requests:: - - async with aiohttp.ClientSession( - headers={"Authorization": "Basic bG9naW46cGFzcw=="}) as session: - async with session.get("http://httpbin.org/headers") as r: - json_body = await r.json() - assert json_body['headers']['Authorization'] == \ - 'Basic bG9naW46cGFzcw==' - -:class:`~aiohttp.ClientSession` supports keep-alive requests -and connection pooling out-of-the-box. - -.. _aiohttp-client-cookie-safety: - -Cookie safety -------------- - -By default :class:`~aiohttp.ClientSession` uses strict version of -:class:`aiohttp.CookieJar`. :rfc:`2109` explicitly forbids cookie -accepting from URLs with IP address instead of DNS name -(e.g. `http://127.0.0.1:80/cookie`). - -It's good but sometimes for testing we need to enable support for such -cookies. It should be done by passing `unsafe=True` to -:class:`aiohttp.CookieJar` constructor:: - - - jar = aiohttp.CookieJar(unsafe=True) - session = aiohttp.ClientSession(cookie_jar=jar) - - -Connectors ----------- - -To tweak or change *transport* layer of requests you can pass a custom -*connector* to :class:`~aiohttp.ClientSession` and family. For example:: - - conn = aiohttp.TCPConnector() - session = aiohttp.ClientSession(connector=conn) - -.. note:: - - You can not re-use custom *connector*, *session* object takes ownership - of the *connector*. - -.. seealso:: :ref:`aiohttp-client-reference-connectors` section for - more information about different connector types and - configuration options. - - -Limiting connection pool size ------------------------------ - -To limit amount of simultaneously opened connections you can pass *limit* -parameter to *connector*:: - - conn = aiohttp.TCPConnector(limit=30) - -The example limits total amount of parallel connections to `30`. - -The default is `100`. - -If you explicitly want not to have limits, pass `0`. For example:: - - conn = aiohttp.TCPConnector(limit=0) - -To limit amount of simultaneously opened connection to the same -endpoint (``(host, port, is_ssl)`` triple) you can pass *limit_per_host* -parameter to *connector*:: - - conn = aiohttp.TCPConnector(limit_per_host=30) - -The example limits amount of parallel connections to the same to `30`. - -The default is `0` (no limit on per host bases). - - -Resolving using custom nameservers ----------------------------------- - -In order to specify the nameservers to when resolving the hostnames, -:term:`aiodns` is required:: - - from aiohttp.resolver import AsyncResolver - - resolver = AsyncResolver(nameservers=["8.8.8.8", "8.8.4.4"]) - conn = aiohttp.TCPConnector(resolver=resolver) - - -SSL control for TCP sockets ---------------------------- - -:class:`~aiohttp.TCPConnector` constructor accepts mutually -exclusive *verify_ssl* and *ssl_context* params. - -By default it uses strict checks for HTTPS protocol. Certification -checks can be relaxed by passing ``verify_ssl=False``:: - - conn = aiohttp.TCPConnector(verify_ssl=False) - session = aiohttp.ClientSession(connector=conn) - r = await session.get('https://example.com') - - -If you need to setup custom ssl parameters (use own certification -files for example) you can create a :class:`ssl.SSLContext` instance and -pass it into the connector:: - - sslcontext = ssl.create_default_context( - cafile='/path/to/ca-bundle.crt') - conn = aiohttp.TCPConnector(ssl_context=sslcontext) - session = aiohttp.ClientSession(connector=conn) - r = await session.get('https://example.com') - -You may also verify certificates via MD5, SHA1, or SHA256 fingerprint:: - - # Attempt to connect to https://www.python.org - # with a pin to a bogus certificate: - bad_md5 = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=' - conn = aiohttp.TCPConnector(fingerprint=bad_md5) - session = aiohttp.ClientSession(connector=conn) - exc = None - try: - r = yield from session.get('https://www.python.org') - except FingerprintMismatch as e: - exc = e - assert exc is not None - assert exc.expected == bad_md5 - - # www.python.org cert's actual md5 - assert exc.got == b'\xca;I\x9cuv\x8es\x138N$?\x15\xca\xcb' - -Note that this is the fingerprint of the DER-encoded certificate. -If you have the certificate in PEM format, you can convert it to -DER with e.g. ``openssl x509 -in crt.pem -inform PEM -outform DER > crt.der``. - -Tip: to convert from a hexadecimal digest to a binary byte-string, you can use -:attr:`binascii.unhexlify`:: - - md5_hex = 'ca3b499c75768e7313384e243f15cacb' - from binascii import unhexlify - assert unhexlify(md5_hex) == b'\xca;I\x9cuv\x8es\x138N$?\x15\xca\xcb' - -Unix domain sockets -------------------- - -If your HTTP server uses UNIX domain sockets you can use -:class:`~aiohttp.UnixConnector`:: - - conn = aiohttp.UnixConnector(path='/path/to/socket') - session = aiohttp.ClientSession(connector=conn) - - -Proxy support -------------- - -aiohttp supports proxy. You have to use -:attr:`proxy`:: - - async with aiohttp.ClientSession() as session: - async with session.get("http://python.org", - proxy="http://some.proxy.com") as resp: - print(resp.status) - -it also supports proxy authorization:: - - async with aiohttp.ClientSession() as session: - proxy_auth = aiohttp.BasicAuth('user', 'pass') - async with session.get("http://python.org", - proxy="http://some.proxy.com", - proxy_auth=proxy_auth) as resp: - print(resp.status) - -Authentication credentials can be passed in proxy URL:: - - session.get("http://python.org", - proxy="http://user:pass@some.proxy.com") - - -Response Status Codes ---------------------- - -We can check the response status code:: - - async with session.get('http://httpbin.org/get') as resp: - assert resp.status == 200 - - -Response Headers ----------------- - -We can view the server's response :attr:`ClientResponse.headers` using -a :class:`CIMultiDictProxy`:: - - >>> resp.headers - {'ACCESS-CONTROL-ALLOW-ORIGIN': '*', - 'CONTENT-TYPE': 'application/json', - 'DATE': 'Tue, 15 Jul 2014 16:49:51 GMT', - 'SERVER': 'gunicorn/18.0', - 'CONTENT-LENGTH': '331', - 'CONNECTION': 'keep-alive'} - -The dictionary is special, though: it's made just for HTTP -headers. According to `RFC 7230 -`_, HTTP Header names -are case-insensitive. It also supports multiple values for the same -key as HTTP protocol does. - -So, we can access the headers using any capitalization we want:: - - >>> resp.headers['Content-Type'] - 'application/json' - - >>> resp.headers.get('content-type') - 'application/json' - -All headers converted from binary data using UTF-8 with -``surrogateescape`` option. That works fine on most cases but -sometimes unconverted data is needed if a server uses nonstandard -encoding. While these headers are malformed from :rfc:`7230` -perspective they are may be retrieved by using -:attr:`ClientResponse.raw_headers` property:: - - >>> resp.raw_headers - ((b'SERVER', b'nginx'), - (b'DATE', b'Sat, 09 Jan 2016 20:28:40 GMT'), - (b'CONTENT-TYPE', b'text/html; charset=utf-8'), - (b'CONTENT-LENGTH', b'12150'), - (b'CONNECTION', b'keep-alive')) - - -Response Cookies ----------------- - -If a response contains some Cookies, you can quickly access them:: - - url = 'http://example.com/some/cookie/setting/url' - async with session.get(url) as resp: - print(resp.cookies['example_cookie_name']) - -.. note:: - - Response cookies contain only values, that were in ``Set-Cookie`` headers - of the **last** request in redirection chain. To gather cookies between all - redirection requests please use :ref:`aiohttp.ClientSession - ` object. - - -Response History ----------------- - -If a request was redirected, it is possible to view previous responses using -the :attr:`~ClientResponse.history` attribute:: - - >>> resp = await session.get('http://example.com/some/redirect/') - >>> resp - - >>> resp.history - (,) - -If no redirects occurred or ``allow_redirects`` is set to ``False``, -history will be an empty sequence. - - -.. _aiohttp-client-websockets: - - -WebSockets ----------- - -:mod:`aiohttp` works with client websockets out-of-the-box. - -You have to use the :meth:`aiohttp.ClientSession.ws_connect` coroutine -for client websocket connection. It accepts a *url* as a first -parameter and returns :class:`ClientWebSocketResponse`, with that -object you can communicate with websocket server using response's -methods:: - - session = aiohttp.ClientSession() - async with session.ws_connect('http://example.org/websocket') as ws: - - async for msg in ws: - if msg.type == aiohttp.WSMsgType.TEXT: - if msg.data == 'close cmd': - await ws.close() - break - else: - ws.send_str(msg.data + '/answer') - elif msg.type == aiohttp.WSMsgType.CLOSED: - break - elif msg.type == aiohttp.WSMsgType.ERROR: - break - - -You **must** use the only websocket task for both reading (e.g. ``await -ws.receive()`` or ``async for msg in ws:``) and writing but may have -multiple writer tasks which can only send data asynchronously (by -``ws.send_str('data')`` for example). - - -Timeouts --------- - -By default all IO operations have 5min timeout. The timeout may be -overridden by passing ``timeout`` parameter into -:meth:`ClientSession.get` and family:: - - async with session.get('https://github.com', timeout=60) as r: - ... - -``None`` or ``0`` disables timeout check. - -The example wraps a client call in :func:`async_timeout.timeout` context -manager, adding timeout for both connecting and response body -reading procedures:: - - import async_timeout - - with async_timeout.timeout(0.001, loop=session.loop): - async with session.get('https://github.com') as r: - await r.text() - - -.. note:: - - Timeout is cumulative time, it includes all operations like sending request, - redirects, response parsing, consuming response, etc. +.. toctree:: + :name: client -.. disqus:: - :title: aiohttp client usage + Quickstart + Advanced Usage + Reference + Tracing Reference + The aiohttp Request Lifecycle diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst new file mode 100644 index 00000000000..e4e0919c7f0 --- /dev/null +++ b/docs/client_advanced.rst @@ -0,0 +1,605 @@ +.. _aiohttp-client-advanced: + +Advanced Client Usage +===================== + +.. currentmodule:: aiohttp + +.. _aiohttp-client-session: + +Client Session +-------------- + +:class:`ClientSession` is the heart and the main entry point for all +client API operations. + +Create the session first, use the instance for performing HTTP +requests and initiating WebSocket connections. + +The session contains a cookie storage and connection pool, thus +cookies and connections are shared between HTTP requests sent by the +same session. + +Custom Request Headers +---------------------- + +If you need to add HTTP headers to a request, pass them in a +:class:`dict` to the *headers* parameter. + +For example, if you want to specify the content-type directly:: + + url = 'http://example.com/image' + payload = b'GIF89a\x01\x00\x01\x00\x00\xff\x00,\x00\x00' + b'\x00\x00\x01\x00\x01\x00\x00\x02\x00;' + headers = {'content-type': 'image/gif'} + + await session.post(url, + data=payload, + headers=headers) + +You also can set default headers for all session requests:: + + headers={"Authorization": "Basic bG9naW46cGFzcw=="} + async with aiohttp.ClientSession(headers=headers) as session: + async with session.get("http://httpbin.org/headers") as r: + json_body = await r.json() + assert json_body['headers']['Authorization'] == \ + 'Basic bG9naW46cGFzcw==' + +Typical use case is sending JSON body. You can specify content type +directly as shown above, but it is more convenient to use special keyword +``json``:: + + await session.post(url, json={'example': 'text'}) + +For *text/plain* :: + + await session.post(url, data='Привет, Мир!') + +Custom Cookies +-------------- + +To send your own cookies to the server, you can use the *cookies* +parameter of :class:`ClientSession` constructor:: + + url = 'http://httpbin.org/cookies' + cookies = {'cookies_are': 'working'} + async with ClientSession(cookies=cookies) as session: + async with session.get(url) as resp: + assert await resp.json() == { + "cookies": {"cookies_are": "working"}} + +.. note:: + ``httpbin.org/cookies`` endpoint returns request cookies + in JSON-encoded body. + To access session cookies see :attr:`ClientSession.cookie_jar`. + +:class:`~aiohttp.ClientSession` may be used for sharing cookies +between multiple requests:: + + async with aiohttp.ClientSession() as session: + await session.get( + 'http://httpbin.org/cookies/set?my_cookie=my_value') + filtered = session.cookie_jar.filter_cookies( + 'http://httpbin.org') + assert filtered['my_cookie'].value == 'my_value' + async with session.get('http://httpbin.org/cookies') as r: + json_body = await r.json() + assert json_body['cookies']['my_cookie'] == 'my_value' + +Response Headers and Cookies +---------------------------- + +We can view the server's response :attr:`ClientResponse.headers` using +a :class:`~multidict.CIMultiDictProxy`:: + + assert resp.headers == { + 'ACCESS-CONTROL-ALLOW-ORIGIN': '*', + 'CONTENT-TYPE': 'application/json', + 'DATE': 'Tue, 15 Jul 2014 16:49:51 GMT', + 'SERVER': 'gunicorn/18.0', + 'CONTENT-LENGTH': '331', + 'CONNECTION': 'keep-alive'} + +The dictionary is special, though: it's made just for HTTP +headers. According to `RFC 7230 +`_, HTTP Header names +are case-insensitive. It also supports multiple values for the same +key as HTTP protocol does. + +So, we can access the headers using any capitalization we want:: + + assert resp.headers['Content-Type'] == 'application/json' + + assert resp.headers.get('content-type') == 'application/json' + +All headers are converted from binary data using UTF-8 with +``surrogateescape`` option. That works fine on most cases but +sometimes unconverted data is needed if a server uses nonstandard +encoding. While these headers are malformed from :rfc:`7230` +perspective they may be retrieved by using +:attr:`ClientResponse.raw_headers` property:: + + assert resp.raw_headers == ( + (b'SERVER', b'nginx'), + (b'DATE', b'Sat, 09 Jan 2016 20:28:40 GMT'), + (b'CONTENT-TYPE', b'text/html; charset=utf-8'), + (b'CONTENT-LENGTH', b'12150'), + (b'CONNECTION', b'keep-alive')) + + +If a response contains some *HTTP Cookies*, you can quickly access them:: + + url = 'http://example.com/some/cookie/setting/url' + async with session.get(url) as resp: + print(resp.cookies['example_cookie_name']) + +.. note:: + + Response cookies contain only values, that were in ``Set-Cookie`` headers + of the **last** request in redirection chain. To gather cookies between all + redirection requests please use :ref:`aiohttp.ClientSession + ` object. + + +Redirection History +------------------- + +If a request was redirected, it is possible to view previous responses using +the :attr:`~ClientResponse.history` attribute:: + + resp = await session.get('http://example.com/some/redirect/') + assert resp.status == 200 + assert resp.url = URL('http://example.com/some/other/url/') + assert len(resp.history) == 1 + assert resp.history[0].status == 301 + assert resp.history[0].url = URL( + 'http://example.com/some/redirect/') + +If no redirects occurred or ``allow_redirects`` is set to ``False``, +history will be an empty sequence. + + +Cookie Jar +---------- + +.. _aiohttp-client-cookie-safety: + +Cookie Safety +^^^^^^^^^^^^^ + +By default :class:`~aiohttp.ClientSession` uses strict version of +:class:`aiohttp.CookieJar`. :rfc:`2109` explicitly forbids cookie +accepting from URLs with IP address instead of DNS name +(e.g. ``http://127.0.0.1:80/cookie``). + +It's good but sometimes for testing we need to enable support for such +cookies. It should be done by passing ``unsafe=True`` to +:class:`aiohttp.CookieJar` constructor:: + + + jar = aiohttp.CookieJar(unsafe=True) + session = aiohttp.ClientSession(cookie_jar=jar) + + +.. _aiohttp-client-cookie-quoting-routine: + +Cookie Quoting Routine +^^^^^^^^^^^^^^^^^^^^^^ + +The client uses the :class:`~aiohttp.SimpleCookie` quoting routines +conform to the :rfc:`2109`, which in turn references the character definitions +from :rfc:`2068`. They provide a two-way quoting algorithm where any non-text +character is translated into a 4 character sequence: a forward-slash +followed by the three-digit octal equivalent of the character. +Any ``\`` or ``"`` is quoted with a preceding ``\`` slash. +Because of the way browsers really handle cookies (as opposed to what the RFC +says) we also encode ``,`` and ``;``. + +Some backend systems does not support quoted cookies. You can skip this +quotation routine by passing ``quote_cookie=False`` to the +:class:`~aiohttp.CookieJar` constructor:: + + jar = aiohttp.CookieJar(quote_cookie=False) + session = aiohttp.ClientSession(cookie_jar=jar) + + +.. _aiohttp-client-dummy-cookie-jar: + +Dummy Cookie Jar +^^^^^^^^^^^^^^^^ + +Sometimes cookie processing is not desirable. For this purpose it's +possible to pass :class:`aiohttp.DummyCookieJar` instance into client +session:: + + jar = aiohttp.DummyCookieJar() + session = aiohttp.ClientSession(cookie_jar=jar) + + +Uploading pre-compressed data +----------------------------- + +To upload data that is already compressed before passing it to +aiohttp, call the request function with the used compression algorithm +name (usually ``deflate`` or ``gzip``) as the value of the +``Content-Encoding`` header:: + + async def my_coroutine(session, headers, my_data): + data = zlib.compress(my_data) + headers = {'Content-Encoding': 'deflate'} + async with session.post('http://httpbin.org/post', + data=data, + headers=headers) + pass + +Disabling content type validation for JSON responses +---------------------------------------------------- + +The standard explicitly restricts JSON ``Content-Type`` HTTP header to +``application/json`` or any extended form, e.g. ``application/vnd.custom-type+json``. +Unfortunately, some servers send a wrong type, like ``text/html``. + +This can be worked around in two ways: + +1. Pass the expected type explicitly (in this case checking will be strict, without the extended form support, + so ``custom/xxx+type`` won't be accepted): + + ``await resp.json(content_type='custom/type')``. +2. Disable the check entirely: + + ``await resp.json(content_type=None)``. + +.. _aiohttp-client-tracing: + +Client Tracing +-------------- + +The execution flow of a specific request can be followed attaching +listeners coroutines to the signals provided by the +:class:`TraceConfig` instance, this instance will be used as a +parameter for the :class:`ClientSession` constructor having as a +result a client that triggers the different signals supported by the +:class:`TraceConfig`. By default any instance of +:class:`ClientSession` class comes with the signals ability +disabled. The following snippet shows how the start and the end +signals of a request flow can be followed:: + + async def on_request_start( + session, trace_config_ctx, params): + print("Starting request") + + async def on_request_end(session, trace_config_ctx, params): + print("Ending request") + + trace_config = aiohttp.TraceConfig() + trace_config.on_request_start.append(on_request_start) + trace_config.on_request_end.append(on_request_end) + async with aiohttp.ClientSession( + trace_configs=[trace_config]) as client: + client.get('http://example.com/some/redirect/') + +The ``trace_configs`` is a list that can contain instances of +:class:`TraceConfig` class that allow run the signals handlers coming +from different :class:`TraceConfig` instances. The following example +shows how two different :class:`TraceConfig` that have a different +nature are installed to perform their job in each signal handle:: + + from mylib.traceconfig import AuditRequest + from mylib.traceconfig import XRay + + async with aiohttp.ClientSession( + trace_configs=[AuditRequest(), XRay()]) as client: + client.get('http://example.com/some/redirect/') + + +All signals take as a parameters first, the :class:`ClientSession` +instance used by the specific request related to that signals and +second, a :class:`SimpleNamespace` instance called +``trace_config_ctx``. The ``trace_config_ctx`` object can be used to +share the state through to the different signals that belong to the +same request and to the same :class:`TraceConfig` class, perhaps:: + + async def on_request_start( + session, trace_config_ctx, params): + trace_config_ctx.start = asyncio.get_event_loop().time() + + async def on_request_end(session, trace_config_ctx, params): + elapsed = asyncio.get_event_loop().time() - trace_config_ctx.start + print("Request took {}".format(elapsed)) + + +The ``trace_config_ctx`` param is by default a +:class:`SimpleNampespace` that is initialized at the beginning of the +request flow. However, the factory used to create this object can be +overwritten using the ``trace_config_ctx_factory`` constructor param of +the :class:`TraceConfig` class. + +The ``trace_request_ctx`` param can given at the beginning of the +request execution, accepted by all of the HTTP verbs, and will be +passed as a keyword argument for the ``trace_config_ctx_factory`` +factory. This param is useful to pass data that is only available at +request time, perhaps:: + + async def on_request_start( + session, trace_config_ctx, params): + print(trace_config_ctx.trace_request_ctx) + + + session.get('http://example.com/some/redirect/', + trace_request_ctx={'foo': 'bar'}) + + +.. seealso:: :ref:`aiohttp-client-tracing-reference` section for + more information about the different signals supported. + +Connectors +---------- + +To tweak or change *transport* layer of requests you can pass a custom +*connector* to :class:`~aiohttp.ClientSession` and family. For example:: + + conn = aiohttp.TCPConnector() + session = aiohttp.ClientSession(connector=conn) + +.. note:: + + By default *session* object takes the ownership of the connector, among + other things closing the connections once the *session* is closed. If + you are keen on share the same *connector* through different *session* + instances you must give the *connector_owner* parameter as **False** + for each *session* instance. + +.. seealso:: :ref:`aiohttp-client-reference-connectors` section for + more information about different connector types and + configuration options. + + +Limiting connection pool size +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To limit amount of simultaneously opened connections you can pass *limit* +parameter to *connector*:: + + conn = aiohttp.TCPConnector(limit=30) + +The example limits total amount of parallel connections to `30`. + +The default is `100`. + +If you explicitly want not to have limits, pass `0`. For example:: + + conn = aiohttp.TCPConnector(limit=0) + +To limit amount of simultaneously opened connection to the same +endpoint (``(host, port, is_ssl)`` triple) you can pass *limit_per_host* +parameter to *connector*:: + + conn = aiohttp.TCPConnector(limit_per_host=30) + +The example limits amount of parallel connections to the same to `30`. + +The default is `0` (no limit on per host bases). + +Tuning the DNS cache +^^^^^^^^^^^^^^^^^^^^ + +By default :class:`~aiohttp.TCPConnector` comes with the DNS cache +table enabled, and resolutions will be cached by default for `10` seconds. +This behavior can be changed either to change of the TTL for a resolution, +as can be seen in the following example:: + + conn = aiohttp.TCPConnector(ttl_dns_cache=300) + +or disabling the use of the DNS cache table, meaning that all requests will +end up making a DNS resolution, as the following example shows:: + + conn = aiohttp.TCPConnector(use_dns_cache=False) + + +Resolving using custom nameservers +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to specify the nameservers to when resolving the hostnames, +:term:`aiodns` is required:: + + from aiohttp.resolver import AsyncResolver + + resolver = AsyncResolver(nameservers=["8.8.8.8", "8.8.4.4"]) + conn = aiohttp.TCPConnector(resolver=resolver) + + +Unix domain sockets +^^^^^^^^^^^^^^^^^^^ + +If your HTTP server uses UNIX domain sockets you can use +:class:`~aiohttp.UnixConnector`:: + + conn = aiohttp.UnixConnector(path='/path/to/socket') + session = aiohttp.ClientSession(connector=conn) + + +Named pipes in Windows +^^^^^^^^^^^^^^^^^^^^^^ + +If your HTTP server uses Named pipes you can use +:class:`~aiohttp.NamedPipeConnector`:: + + conn = aiohttp.NamedPipeConnector(path=r'\\.\pipe\') + session = aiohttp.ClientSession(connector=conn) + +It will only work with the ProactorEventLoop + +SSL control for TCP sockets +--------------------------- + +By default *aiohttp* uses strict checks for HTTPS protocol. Certification +checks can be relaxed by setting *ssl* to ``False``:: + + r = await session.get('https://example.com', ssl=False) + + +If you need to setup custom ssl parameters (use own certification +files for example) you can create a :class:`ssl.SSLContext` instance and +pass it into the proper :class:`ClientSession` method:: + + sslcontext = ssl.create_default_context( + cafile='/path/to/ca-bundle.crt') + r = await session.get('https://example.com', ssl=sslcontext) + +If you need to verify *self-signed* certificates, you can do the +same thing as the previous example, but add another call to +:meth:`ssl.SSLContext.load_cert_chain` with the key pair:: + + sslcontext = ssl.create_default_context( + cafile='/path/to/ca-bundle.crt') + sslcontext.load_cert_chain('/path/to/client/public/device.pem', + '/path/to/client/private/device.key') + r = await session.get('https://example.com', ssl=sslcontext) + +There is explicit errors when ssl verification fails + +:class:`aiohttp.ClientConnectorSSLError`:: + + try: + await session.get('https://expired.badssl.com/') + except aiohttp.ClientConnectorSSLError as e: + assert isinstance(e, ssl.SSLError) + +:class:`aiohttp.ClientConnectorCertificateError`:: + + try: + await session.get('https://wrong.host.badssl.com/') + except aiohttp.ClientConnectorCertificateError as e: + assert isinstance(e, ssl.CertificateError) + +If you need to skip both ssl related errors + +:class:`aiohttp.ClientSSLError`:: + + try: + await session.get('https://expired.badssl.com/') + except aiohttp.ClientSSLError as e: + assert isinstance(e, ssl.SSLError) + + try: + await session.get('https://wrong.host.badssl.com/') + except aiohttp.ClientSSLError as e: + assert isinstance(e, ssl.CertificateError) + +You may also verify certificates via *SHA256* fingerprint:: + + # Attempt to connect to https://www.python.org + # with a pin to a bogus certificate: + bad_fp = b'0'*64 + exc = None + try: + r = await session.get('https://www.python.org', + ssl=aiohttp.Fingerprint(bad_fp)) + except aiohttp.FingerprintMismatch as e: + exc = e + assert exc is not None + assert exc.expected == bad_fp + + # www.python.org cert's actual fingerprint + assert exc.got == b'...' + +Note that this is the fingerprint of the DER-encoded certificate. +If you have the certificate in PEM format, you can convert it to +DER with e.g:: + + openssl x509 -in crt.pem -inform PEM -outform DER > crt.der + +.. note:: + + Tip: to convert from a hexadecimal digest to a binary byte-string, + you can use :func:`binascii.unhexlify`. + + *ssl* parameter could be passed + to :class:`TCPConnector` as default, the value from + :meth:`ClientSession.get` and others override default. + +Proxy support +------------- + +aiohttp supports plain HTTP proxies and HTTP proxies that can be upgraded to HTTPS +via the HTTP CONNECT method. aiohttp does not support proxies that must be +connected to via ``https://``. To connect, use the *proxy* parameter:: + + async with aiohttp.ClientSession() as session: + async with session.get("http://python.org", + proxy="http://proxy.com") as resp: + print(resp.status) + +It also supports proxy authorization:: + + async with aiohttp.ClientSession() as session: + proxy_auth = aiohttp.BasicAuth('user', 'pass') + async with session.get("http://python.org", + proxy="http://proxy.com", + proxy_auth=proxy_auth) as resp: + print(resp.status) + +Authentication credentials can be passed in proxy URL:: + + session.get("http://python.org", + proxy="http://user:pass@some.proxy.com") + +Contrary to the ``requests`` library, it won't read environment +variables by default. But you can do so by passing +``trust_env=True`` into :class:`aiohttp.ClientSession` +constructor for extracting proxy configuration from +*HTTP_PROXY* or *HTTPS_PROXY* *environment variables* (both are case +insensitive):: + + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get("http://python.org") as resp: + print(resp.status) + +Proxy credentials are given from ``~/.netrc`` file if present (see +:class:`aiohttp.ClientSession` for more details). + +Graceful Shutdown +----------------- + +When :class:`ClientSession` closes at the end of an ``async with`` +block (or through a direct :meth:`ClientSession.close()` call), the +underlying connection remains open due to asyncio internal details. In +practice, the underlying connection will close after a short +while. However, if the event loop is stopped before the underlying +connection is closed, a ``ResourceWarning: unclosed transport`` +warning is emitted (when warnings are enabled). + +To avoid this situation, a small delay must be added before closing +the event loop to allow any open underlying connections to close. + +For a :class:`ClientSession` without SSL, a simple zero-sleep (``await +asyncio.sleep(0)``) will suffice:: + + async def read_website(): + async with aiohttp.ClientSession() as session: + async with session.get('http://example.org/') as resp: + await resp.read() + + loop = asyncio.get_event_loop() + loop.run_until_complete(read_website()) + # Zero-sleep to allow underlying connections to close + loop.run_until_complete(asyncio.sleep(0)) + loop.close() + +For a :class:`ClientSession` with SSL, the application must wait a +short duration before closing:: + + ... + # Wait 250 ms for the underlying SSL connections to close + loop.run_until_complete(asyncio.sleep(0.250)) + loop.close() + +Note that the appropriate amount of time to wait will vary from +application to application. + +All if this will eventually become obsolete when the asyncio internals +are changed so that aiohttp itself can wait on the underlying +connection to close. Please follow issue `#1925 +`_ for the progress +on this. diff --git a/docs/client_quickstart.rst b/docs/client_quickstart.rst new file mode 100644 index 00000000000..fe770243ec8 --- /dev/null +++ b/docs/client_quickstart.rst @@ -0,0 +1,472 @@ +.. _aiohttp-client-quickstart: + +=================== + Client Quickstart +=================== + +.. currentmodule:: aiohttp + +Eager to get started? This page gives a good introduction in how to +get started with aiohttp client API. + +First, make sure that aiohttp is :ref:`installed +` and *up-to-date* + +Let's get started with some simple examples. + + + +Make a Request +============== + +Begin by importing the aiohttp module, and asyncio:: + + import aiohttp + import asyncio + +Now, let's try to get a web-page. For example let's query +``http://httpbin.org/get``:: + + async def main(): + async with aiohttp.ClientSession() as session: + async with session.get('http://httpbin.org/get') as resp: + print(resp.status) + print(await resp.text()) + + + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) + +Now, we have a :class:`ClientSession` called ``session`` and a +:class:`ClientResponse` object called ``resp``. We can get all the +information we need from the response. The mandatory parameter of +:meth:`ClientSession.get` coroutine is an HTTP *url* (:class:`str` or +class:`yarl.URL` instance). + +In order to make an HTTP POST request use :meth:`ClientSession.post` coroutine:: + + session.post('http://httpbin.org/post', data=b'data') + +Other HTTP methods are available as well:: + + session.put('http://httpbin.org/put', data=b'data') + session.delete('http://httpbin.org/delete') + session.head('http://httpbin.org/get') + session.options('http://httpbin.org/get') + session.patch('http://httpbin.org/patch', data=b'data') + +.. note:: + + Don't create a session per request. Most likely you need a session + per application which performs all requests altogether. + + More complex cases may require a session per site, e.g. one for + Github and other one for Facebook APIs. Anyway making a session for + every request is a **very bad** idea. + + A session contains a connection pool inside. Connection reusage and + keep-alives (both are on by default) may speed up total performance. + +A session context manager usage is not mandatory +but ``await session.close()`` method +should be called in this case, e.g.:: + + session = aiohttp.ClientSession() + async with session.get('...'): + # ... + await session.close() + + +Passing Parameters In URLs +========================== + +You often want to send some sort of data in the URL's query string. If +you were constructing the URL by hand, this data would be given as key/value +pairs in the URL after a question mark, e.g. ``httpbin.org/get?key=val``. +Requests allows you to provide these arguments as a :class:`dict`, using the +``params`` keyword argument. As an example, if you wanted to pass +``key1=value1`` and ``key2=value2`` to ``httpbin.org/get``, you would use the +following code:: + + params = {'key1': 'value1', 'key2': 'value2'} + async with session.get('http://httpbin.org/get', + params=params) as resp: + expect = 'http://httpbin.org/get?key1=value1&key2=value2' + assert str(resp.url) == expect + +You can see that the URL has been correctly encoded by printing the URL. + +For sending data with multiple values for the same key :class:`MultiDict` may be +used; the library support nested lists (``{'key': ['value1', 'value2']}``) +alternative as well. + +It is also possible to pass a list of 2 item tuples as parameters, in +that case you can specify multiple values for each key:: + + params = [('key', 'value1'), ('key', 'value2')] + async with session.get('http://httpbin.org/get', + params=params) as r: + expect = 'http://httpbin.org/get?key=value2&key=value1' + assert str(r.url) == expect + +You can also pass :class:`str` content as param, but beware -- content +is not encoded by library. Note that ``+`` is not encoded:: + + async with session.get('http://httpbin.org/get', + params='key=value+1') as r: + assert str(r.url) == 'http://httpbin.org/get?key=value+1' + +.. note:: + + *aiohttp* internally performs URL canonicalization before sending request. + + Canonicalization encodes *host* part by :term:`IDNA` codec and applies + :term:`requoting` to *path* and *query* parts. + + For example ``URL('http://example.com/путь/%30?a=%31')`` is converted to + ``URL('http://example.com/%D0%BF%D1%83%D1%82%D1%8C/0?a=1')``. + + Sometimes canonicalization is not desirable if server accepts exact + representation and does not requote URL itself. + + To disable canonicalization use ``encoded=True`` parameter for URL construction:: + + await session.get( + URL('http://example.com/%30', encoded=True)) + +.. warning:: + + Passing *params* overrides ``encoded=True``, never use both options. + +Response Content and Status Code +================================ + +We can read the content of the server's response and its status +code. Consider the GitHub time-line again:: + + async with session.get('https://api.github.com/events') as resp: + print(resp.status) + print(await resp.text()) + +prints out something like:: + + 200 + '[{"created_at":"2015-06-12T14:06:22Z","public":true,"actor":{... + +``aiohttp`` automatically decodes the content from the server. You can +specify custom encoding for the :meth:`~ClientResponse.text` method:: + + await resp.text(encoding='windows-1251') + + +Binary Response Content +======================= + +You can also access the response body as bytes, for non-text requests:: + + print(await resp.read()) + +:: + + b'[{"created_at":"2015-06-12T14:06:22Z","public":true,"actor":{... + +The ``gzip`` and ``deflate`` transfer-encodings are automatically +decoded for you. + +You can enable ``brotli`` transfer-encodings support, +just install `brotlipy `_. + +JSON Request +============ + +Any of session's request methods like :func:`request`, +:meth:`ClientSession.get`, :meth:`ClientSesssion.post` etc. accept +`json` parameter:: + + async with aiohttp.ClientSession() as session: + async with session.post(url, json={'test': 'object'}) + + +By default session uses python's standard :mod:`json` module for +serialization. But it is possible to use different +``serializer``. :class:`ClientSession` accepts ``json_serialize`` +parameter:: + + import ujson + + async with aiohttp.ClientSession( + json_serialize=ujson.dumps) as session: + await session.post(url, json={'test': 'object'}) + +.. note:: + + ``ujson`` library is faster than standard :mod:`json` but slightly + incompatible. + +JSON Response Content +===================== + +There's also a built-in JSON decoder, in case you're dealing with JSON data:: + + async with session.get('https://api.github.com/events') as resp: + print(await resp.json()) + +In case that JSON decoding fails, :meth:`~ClientResponse.json` will +raise an exception. It is possible to specify custom encoding and +decoder functions for the :meth:`~ClientResponse.json` call. + +.. note:: + + The methods above reads the whole response body into memory. If you are + planning on reading lots of data, consider using the streaming response + method documented below. + + +Streaming Response Content +========================== + +While methods :meth:`~ClientResponse.read`, +:meth:`~ClientResponse.json` and :meth:`~ClientResponse.text` are very +convenient you should use them carefully. All these methods load the +whole response in memory. For example if you want to download several +gigabyte sized files, these methods will load all the data in +memory. Instead you can use the :attr:`~ClientResponse.content` +attribute. It is an instance of the :class:`aiohttp.StreamReader` +class. The ``gzip`` and ``deflate`` transfer-encodings are +automatically decoded for you:: + + async with session.get('https://api.github.com/events') as resp: + await resp.content.read(10) + +In general, however, you should use a pattern like this to save what is being +streamed to a file:: + + with open(filename, 'wb') as fd: + while True: + chunk = await resp.content.read(chunk_size) + if not chunk: + break + fd.write(chunk) + +It is not possible to use :meth:`~ClientResponse.read`, +:meth:`~ClientResponse.json` and :meth:`~ClientResponse.text` after +explicit reading from :attr:`~ClientResponse.content`. + +More complicated POST requests +============================== + +Typically, you want to send some form-encoded data -- much like an HTML form. +To do this, simply pass a dictionary to the *data* argument. Your +dictionary of data will automatically be form-encoded when the request is made:: + + payload = {'key1': 'value1', 'key2': 'value2'} + async with session.post('http://httpbin.org/post', + data=payload) as resp: + print(await resp.text()) + +:: + + { + ... + "form": { + "key2": "value2", + "key1": "value1" + }, + ... + } + +If you want to send data that is not form-encoded you can do it by +passing a :class:`bytes` instead of a :class:`dict`. This data will be +posted directly and content-type set to 'application/octet-stream' by +default:: + + async with session.post(url, data=b'\x00Binary-data\x00') as resp: + ... + +If you want to send JSON data:: + + async with session.post(url, json={'example': 'test'}) as resp: + ... + +To send text with appropriate content-type just use ``data`` argument:: + + async with session.post(url, data='Тест') as resp: + ... + +POST a Multipart-Encoded File +============================= + +To upload Multipart-encoded files:: + + url = 'http://httpbin.org/post' + files = {'file': open('report.xls', 'rb')} + + await session.post(url, data=files) + +You can set the ``filename`` and ``content_type`` explicitly:: + + url = 'http://httpbin.org/post' + data = FormData() + data.add_field('file', + open('report.xls', 'rb'), + filename='report.xls', + content_type='application/vnd.ms-excel') + + await session.post(url, data=data) + +If you pass a file object as data parameter, aiohttp will stream it to +the server automatically. Check :class:`~aiohttp.streams.StreamReader` +for supported format information. + +.. seealso:: :ref:`aiohttp-multipart` + + +Streaming uploads +================= + +:mod:`aiohttp` supports multiple types of streaming uploads, which allows you to +send large files without reading them into memory. + +As a simple case, simply provide a file-like object for your body:: + + with open('massive-body', 'rb') as f: + await session.post('http://httpbin.org/post', data=f) + + +Or you can use *asynchronous generator*:: + + async def file_sender(file_name=None): + async with aiofiles.open(file_name, 'rb') as f: + chunk = await f.read(64*1024) + while chunk: + yield chunk + chunk = await f.read(64*1024) + + # Then you can use file_sender as a data provider: + + async with session.post('http://httpbin.org/post', + data=file_sender(file_name='huge_file')) as resp: + print(await resp.text()) + + +Because the :attr:`~aiohttp.ClientResponse.content` attribute is a +:class:`~aiohttp.StreamReader` (provides async iterator protocol), you +can chain get and post requests together:: + + resp = await session.get('http://python.org') + await session.post('http://httpbin.org/post', + data=resp.content) + +.. note:: + + Python 3.5 has no native support for asynchronous generators, use + ``async_generator`` library as workaround. + +.. deprecated:: 3.1 + + ``aiohttp`` still supports ``aiohttp.streamer`` decorator but this + approach is deprecated in favor of *asynchronous generators* as + shown above. + + +.. _aiohttp-client-websockets: + + +WebSockets +========== + +:mod:`aiohttp` works with client websockets out-of-the-box. + +You have to use the :meth:`aiohttp.ClientSession.ws_connect` coroutine +for client websocket connection. It accepts a *url* as a first +parameter and returns :class:`ClientWebSocketResponse`, with that +object you can communicate with websocket server using response's +methods:: + + async with session.ws_connect('http://example.org/ws') as ws: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == 'close cmd': + await ws.close() + break + else: + await ws.send_str(msg.data + '/answer') + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + +You **must** use the only websocket task for both reading (e.g. ``await +ws.receive()`` or ``async for msg in ws:``) and writing but may have +multiple writer tasks which can only send data asynchronously (by +``await ws.send_str('data')`` for example). + + +.. _aiohttp-client-timeouts: + +Timeouts +======== + +Timeout settings are stored in :class:`ClientTimeout` data structure. + +By default *aiohttp* uses a *total* 300 seconds (5min) timeout, it means that the +whole operation should finish in 5 minutes. + +The value could be overridden by *timeout* parameter for the session (specified in seconds):: + + timeout = aiohttp.ClientTimeout(total=60) + async with aiohttp.ClientSession(timeout=timeout) as session: + ... + +Timeout could be overridden for a request like :meth:`ClientSession.get`:: + + async with session.get(url, timeout=timeout) as resp: + ... + +Supported :class:`ClientTimeout` fields are: + + ``total`` + + The maximal number of seconds for the whole operation including connection + establishment, request sending and response reading. + + ``connect`` + + The maximal number of seconds for + connection establishment of a new connection or + for waiting for a free connection from a pool if pool connection + limits are exceeded. + + ``sock_connect`` + + The maximal number of seconds for connecting to a peer for a new connection, not + given from a pool. + + ``sock_read`` + + The maximal number of seconds allowed for period between reading a new + data portion from a peer. + +All fields are floats, ``None`` or ``0`` disables a particular timeout check, see the +:class:`ClientTimeout` reference for defaults and additional details. + +Thus the default timeout is:: + + aiohttp.ClientTimeout(total=5*60, connect=None, + sock_connect=None, sock_read=None) + +.. note:: + + *aiohttp* **ceils** timeout if the value is equal or greater than 5 + seconds. The timeout expires at the next integer second greater than + ``current_time + timeout``. + + The ceiling is done for the sake of optimization, when many concurrent tasks + are scheduled to wake-up at the almost same but different absolute times. It + leads to very many event loop wakeups, which kills performance. + + The optimization shifts absolute wakeup times by scheduling them to exactly + the same time as other neighbors, the loop wakes up once-per-second for + timeout expiration. + + Smaller timeouts are not rounded to help testing; in the real life network + timeouts usually greater than tens of seconds. diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 25052b1cfb9..5a420e0142d 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -3,7 +3,6 @@ Client Reference ================ -.. module:: aiohttp .. currentmodule:: aiohttp @@ -28,30 +27,36 @@ Usage example:: assert resp.status == 200 return await resp.text() - async def main(loop): - async with aiohttp.ClientSession(loop=loop) as client: + async def main(): + async with aiohttp.ClientSession() as client: html = await fetch(client) print(html) loop = asyncio.get_event_loop() - loop.run_until_complete(main(loop)) + loop.run_until_complete(main()) -.. versionadded:: 0.17 - The client session supports the context manager protocol for self closing. .. class:: ClientSession(*, connector=None, loop=None, cookies=None, \ headers=None, skip_auto_headers=None, \ - auth=None, json_serialize=func:`json.dumps`, \ + auth=None, json_serialize=json.dumps, \ version=aiohttp.HttpVersion11, \ - cookie_jar=None, read_timeout=None, conn_timeout=None, \ - raise_for_status=False) + cookie_jar=None, read_timeout=None, \ + conn_timeout=None, \ + timeout=sentinel, \ + raise_for_status=False, \ + connector_owner=True, \ + auto_decompress=True, \ + read_bufsize=2**16, \ + requote_redirect_url=False, \ + trust_env=False, \ + trace_configs=None) The class for creating client sessions and making requests. - :param aiohttp.connector.BaseConnector connector: BaseConnector + :param aiohttp.BaseConnector connector: BaseConnector sub-class instance to support connection pooling. :param loop: :ref:`event loop` used for @@ -63,6 +68,8 @@ The client session supports the context manager protocol for self closing. :func:`asyncio.get_event_loop` is used for getting default event loop otherwise. + .. deprecated:: 2.0 + :param dict cookies: Cookies to send with the request (optional) :param headers: HTTP Headers to send with every request (optional). @@ -70,7 +77,7 @@ The client session supports the context manager protocol for self closing. May be either *iterable of key-value pairs* or :class:`~collections.abc.Mapping` (e.g. :class:`dict`, - :class:`~aiohttp.CIMultiDict`). + :class:`~multidict.CIMultiDict`). :param skip_auto_headers: set of headers for which autogeneration should be skipped. @@ -81,15 +88,13 @@ The client session supports the context manager protocol for self closing. that generation. Note that ``Content-Length`` autogeneration can't be skipped. - Iterable of :class:`str` or :class:`~aiohttp.upstr` (optional) + Iterable of :class:`str` or :class:`~aiohttp.istr` (optional) :param aiohttp.BasicAuth auth: an object that represents HTTP Basic Authorization (optional) :param version: supported HTTP version, ``HTTP 1.1`` by default. - .. versionadded:: 0.21 - :param cookie_jar: Cookie Jar, :class:`AbstractCookieJar` instance. By default every session instance has own private cookie jar for @@ -99,26 +104,96 @@ The client session supports the context manager protocol for self closing. One example is not processing cookies at all when working in proxy mode. - .. versionadded:: 0.22 + If no cookie processing is needed, a + :class:`aiohttp.DummyCookieJar` instance can be + provided. + + :param callable json_serialize: Json *serializer* callable. - :param callable json_serialize: Json `serializer` function. (:func:`json.dumps` by default) + By default :func:`json.dumps` function. - :param bool raise_for_status: Automatically call `raise_for_status()` for each response. - (default is False) + :param bool raise_for_status: - .. versionadded:: 2.0 + Automatically call :meth:`ClientResponse.raise_for_status()` for + each response, ``False`` by default. + + This parameter can be overridden when you making a request, e.g.:: + + client_session = aiohttp.ClientSession(raise_for_status=True) + resp = await client_session.get(url, raise_for_status=False) + async with resp: + assert resp.status == 200 - :param float read_timeout: Request operations timeout. read_timeout is + Set the parameter to ``True`` if you need ``raise_for_status`` + for most of cases but override ``raise_for_status`` for those + requests where you need to handle responses with status 400 or + higher. + + :param timeout: a :class:`ClientTimeout` settings structure, 300 seconds (5min) + total timeout by default. + + .. versionadded:: 3.3 + + :param float read_timeout: Request operations timeout. ``read_timeout`` is cumulative for all request operations (request, redirects, responses, - data consuming) + data consuming). By default, the read timeout is 5*60 seconds. + Use ``None`` or ``0`` to disable timeout checks. + + .. deprecated:: 3.3 + + Use ``timeout`` parameter instead. :param float conn_timeout: timeout for connection establishing (optional). Values ``0`` or ``None`` mean no timeout. - .. versionchanged:: 1.0 + .. deprecated:: 3.3 + + Use ``timeout`` parameter instead. + + :param bool connector_owner: + + Close connector instance on session closing. + + Setting the parameter to ``False`` allows to share + connection pool between sessions without sharing session state: + cookies etc. + + :param bool auto_decompress: Automatically decompress response body, + ``True`` by default + + .. versionadded:: 2.3 + + :param int read_bufsize: Size of the read buffer (:attr:`ClientResponse.content`). + 64 KiB by default. + + .. versionadded:: 3.7 + + :param bool trust_env: Get proxies information from *HTTP_PROXY* / + *HTTPS_PROXY* environment variables if the parameter is ``True`` + (``False`` by default). + + Get proxy credentials from ``~/.netrc`` file if present. + + .. seealso:: + + ``.netrc`` documentation: https://www.gnu.org/software/inetutils/manual/html_node/The-_002enetrc-file.html + + .. versionadded:: 2.3 + + .. versionchanged:: 3.0 - ``.cookies`` attribute was dropped. Use :attr:`cookie_jar` - instead. + Added support for ``~/.netrc`` file. + + :param bool requote_redirect_url: Apply *URL requoting* for redirection URLs if + automatic redirection is enabled (``True`` by + default). + + .. versionadded:: 3.5 + + :param trace_configs: A list of :class:`TraceConfig` instances used for client + tracing. ``None`` (default) is used for request tracing + disabling. See :ref:`aiohttp-client-tracing-reference` for + more information. .. attribute:: closed @@ -128,7 +203,7 @@ The client session supports the context manager protocol for self closing. .. attribute:: connector - :class:`aiohttp.connector.BaseConnector` derived instance used + :class:`aiohttp.BaseConnector` derived instance used for the session. A read-only property. @@ -141,7 +216,19 @@ The client session supports the context manager protocol for self closing. A read-only property. - .. versionadded:: 1.0 + .. attribute:: requote_redirect_url + + aiohttp re quote's redirect urls by default, but some servers + require exact url from location header. To disable *re-quote* system + set :attr:`requote_redirect_url` attribute to ``False``. + + .. versionadded:: 2.1 + + .. note:: This parameter affects all subsequent requests. + + .. deprecated:: 3.5 + + The attribute modification is deprecated. .. attribute:: loop @@ -149,15 +236,106 @@ The client session supports the context manager protocol for self closing. A read-only property. + .. deprecated:: 3.5 + + .. attribute:: timeout + + Default client timeouts, :class:`ClientTimeout` instance. The value can + be tuned by passing *timeout* parameter to :class:`ClientSession` + constructor. + + .. versionadded:: 3.7 + + .. attribute:: headers + + HTTP Headers that sent with every request + + May be either *iterable of key-value pairs* or + :class:`~collections.abc.Mapping` + (e.g. :class:`dict`, + :class:`~multidict.CIMultiDict`). + + .. versionadded:: 3.7 + + .. attribute:: skip_auto_headers + + Set of headers for which autogeneration skipped. + + :class:`frozenset` of :class:`str` or :class:`~aiohttp.istr` (optional) + + .. versionadded:: 3.7 + + .. attribute:: auth + + An object that represents HTTP Basic Authorization. + + :class:`~aiohttp.BasicAuth` (optional) + + .. versionadded:: 3.7 + + .. attribute:: json_serialize + + Json serializer callable. + + By default :func:`json.dumps` function. + + .. versionadded:: 3.7 + + .. attribute:: connector_owner + + Should connector be closed on session closing + + :class:`bool` (optional) + + .. versionadded:: 3.7 + + .. attribute:: raise_for_status + + Should :meth:`ClientResponse.raise_for_status()` be called for each response + + Either :class:`bool` or :class:`callable` + + .. versionadded:: 3.7 + + .. attribute:: auto_decompress + + Should the body response be automatically decompressed + + :class:`bool` default is ``True`` + + .. versionadded:: 3.7 + + .. attribute:: trust_env + + Should get proxies information from HTTP_PROXY / HTTPS_PROXY environment + variables or ~/.netrc file if present + + :class:`bool` default is ``False`` + + .. versionadded:: 3.7 + + .. attribute:: trace_config + + A list of :class:`TraceConfig` instances used for client + tracing. ``None`` (default) is used for request tracing + disabling. See :ref:`aiohttp-client-tracing-reference` for more information. + + .. versionadded:: 3.7 + .. comethod:: request(method, url, *, params=None, data=None, json=None,\ - headers=None, skip_auto_headers=None, \ + cookies=None, headers=None, skip_auto_headers=None, \ auth=None, allow_redirects=True,\ - max_redirects=10, version=HttpVersion(major=1, minor=1),\ - compress=None, chunked=None, expect100=False,\ - read_until_eof=True, proxy=None, proxy_auth=None,\ - timeout=5*60) + max_redirects=10,\ + compress=None, chunked=None, expect100=False, raise_for_status=None,\ + read_until_eof=True, \ + read_bufsize=None, \ + proxy=None, proxy_auth=None,\ + timeout=sentinel, ssl=None, \ + verify_ssl=None, fingerprint=None, \ + ssl_context=None, proxy_headers=None) :async-with: :coroutine: + :noindex: Performs an asynchronous HTTP request. Returns a response object. @@ -180,11 +358,22 @@ The client session supports the context manager protocol for self closing. - :class:`str` with preferably url-encoded content (**Warning:** content will not be encoded by *aiohttp*) - :param data: Dictionary, bytes, or file-like object to - send in the body of the request (optional) + :param data: The data to send in the body of the request. This can be a + :class:`FormData` object or anything that can be passed into + :class:`FormData`, e.g. a dictionary, bytes, or file-like object. + (optional) + + :param json: Any json compatible python object + (optional). *json* and *data* parameters could not + be used at the same time. + + :param dict cookies: HTTP Cookies to send with + the request (optional) - :param json: Any json compatible python object (optional). `json` and `data` - parameters could not be used at the same time. + Global session cookies and the explicitly set cookies will be merged + when sending the request. + + .. versionadded:: 3.5 :param dict headers: HTTP Headers to send with the request (optional) @@ -197,7 +386,7 @@ The client session supports the context manager protocol for self closing. passed. Using ``skip_auto_headers`` parameter allows to skip that generation. - Iterable of :class:`str` or :class:`~aiohttp.upstr` + Iterable of :class:`str` or :class:`~aiohttp.istr` (optional) :param aiohttp.BasicAuth auth: an object that represents HTTP @@ -206,15 +395,16 @@ The client session supports the context manager protocol for self closing. :param bool allow_redirects: If set to ``False``, do not follow redirects. ``True`` by default (optional). - :param aiohttp.protocol.HttpVersion version: Request HTTP version - (optional) + :param int max_redirects: Maximum number of redirects to follow. + ``10`` by default. :param bool compress: Set to ``True`` if request has to be compressed with deflate encoding. If `compress` can not be combined with a *Content-Encoding* and *Content-Length* headers. ``None`` by default (optional). - :param int chunked: Enable chunked transfer encoding. It is up to the developer + :param int chunked: Enable chunked transfer encoding. + It is up to the developer to decide how to chunk data streams. If chunking is enabled, aiohttp encodes the provided chunks in the "Transfer-encoding: chunked" format. If *chunked* is set, then the *Transfer-encoding* and *content-length* @@ -223,29 +413,99 @@ The client session supports the context manager protocol for self closing. :param bool expect100: Expect 100-continue response from server. ``False`` by default (optional). + :param bool raise_for_status: Automatically call :meth:`ClientResponse.raise_for_status()` for + response if set to ``True``. + If set to ``None`` value from ``ClientSession`` will be used. + ``None`` by default (optional). + + .. versionadded:: 3.4 + :param bool read_until_eof: Read response until EOF if response does not have Content-Length header. ``True`` by default (optional). + :param int read_bufsize: Size of the read buffer (:attr:`ClientResponse.content`). + ``None`` by default, + it means that the session global value is used. + + .. versionadded:: 3.7 + :param proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP Basic Authorization (optional) - :param int timeout: a timeout for IO operations, 5min by default. - Use ``None`` or ``0`` to disable timeout checks. + :param int timeout: override the session's timeout. + + .. versionchanged:: 3.3 + + The parameter is :class:`ClientTimeout` instance, + :class:`float` is still supported for sake of backward + compatibility. + + If :class:`float` is passed it is a *total* timeout (in seconds). + + :param ssl: SSL validation mode. ``None`` for default SSL check + (:func:`ssl.create_default_context` is used), + ``False`` for skip SSL certificate validation, + :class:`aiohttp.Fingerprint` for fingerprint + validation, :class:`ssl.SSLContext` for custom SSL + certificate validation. + + Supersedes *verify_ssl*, *ssl_context* and + *fingerprint* parameters. + + .. versionadded:: 3.0 + + :param bool verify_ssl: Perform SSL certificate validation for + *HTTPS* requests (enabled by default). May be disabled to + skip validation for sites with invalid certificates. + + .. versionadded:: 2.3 - :return ClientResponse: a :class:`client response ` object. + .. deprecated:: 3.0 - .. versionadded:: 1.0 + Use ``ssl=False`` - Added ``proxy`` and ``proxy_auth`` parameters. + :param bytes fingerprint: Pass the SHA256 digest of the expected + certificate in DER format to verify that the certificate the + server presents matches. Useful for `certificate pinning + `_. - Added ``timeout`` parameter. + Warning: use of MD5 or SHA1 digests is insecure and removed. - .. versionchanged:: 1.1 + .. versionadded:: 2.3 - URLs may be either :class:`str` or :class:`~yarl.URL` + .. deprecated:: 3.0 + + Use ``ssl=aiohttp.Fingerprint(digest)`` + + :param ssl.SSLContext ssl_context: ssl context used for processing + *HTTPS* requests (optional). + + *ssl_context* may be used for configuring certification + authority channel, supported SSL options etc. + + .. versionadded:: 2.3 + + .. deprecated:: 3.0 + + Use ``ssl=ssl_context`` + + :param abc.Mapping proxy_headers: HTTP headers to send to the proxy if the + parameter proxy has been provided. + + .. versionadded:: 2.3 + + :param trace_request_ctx: Object used to give as a kw param for each new + :class:`TraceConfig` object instantiated, + used to give information to the + tracers that is only available at request time. + + .. versionadded:: 3.0 + + :return ClientResponse: a :class:`client response ` + object. .. comethod:: get(url, *, allow_redirects=True, **kwargs) :async-with: @@ -254,7 +514,7 @@ The client session supports the context manager protocol for self closing. Perform a ``GET`` request. In order to modify inner - :meth:`request` + :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` @@ -265,10 +525,6 @@ The client session supports the context manager protocol for self closing. :return ClientResponse: a :class:`client response ` object. - .. versionchanged:: 1.1 - - URLs may be either :class:`str` or :class:`~yarl.URL` - .. comethod:: post(url, *, data=None, **kwargs) :async-with: :coroutine: @@ -276,22 +532,19 @@ The client session supports the context manager protocol for self closing. Perform a ``POST`` request. In order to modify inner - :meth:`request` + :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` - :param data: Dictionary, bytes, or file-like object to - send in the body of the request (optional) + :param data: Data to send in the body of the request; see + :meth:`request` + for details (optional) :return ClientResponse: a :class:`client response ` object. - .. versionchanged:: 1.1 - - URLs may be either :class:`str` or :class:`~yarl.URL` - .. comethod:: put(url, *, data=None, **kwargs) :async-with: :coroutine: @@ -299,22 +552,19 @@ The client session supports the context manager protocol for self closing. Perform a ``PUT`` request. In order to modify inner - :meth:`request` + :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` - :param data: Dictionary, bytes, or file-like object to - send in the body of the request (optional) + :param data: Data to send in the body of the request; see + :meth:`request` + for details (optional) :return ClientResponse: a :class:`client response ` object. - .. versionchanged:: 1.1 - - URLs may be either :class:`str` or :class:`~yarl.URL` - .. comethod:: delete(url, **kwargs) :async-with: :coroutine: @@ -322,7 +572,7 @@ The client session supports the context manager protocol for self closing. Perform a ``DELETE`` request. In order to modify inner - :meth:`request` + :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` @@ -330,10 +580,6 @@ The client session supports the context manager protocol for self closing. :return ClientResponse: a :class:`client response ` object. - .. versionchanged:: 1.1 - - URLs may be either :class:`str` or :class:`~yarl.URL` - .. comethod:: head(url, *, allow_redirects=False, **kwargs) :async-with: :coroutine: @@ -341,7 +587,7 @@ The client session supports the context manager protocol for self closing. Perform a ``HEAD`` request. In order to modify inner - :meth:`request` + :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` @@ -352,10 +598,6 @@ The client session supports the context manager protocol for self closing. :return ClientResponse: a :class:`client response ` object. - .. versionchanged:: 1.1 - - URLs may be either :class:`str` or :class:`~yarl.URL` - .. comethod:: options(url, *, allow_redirects=True, **kwargs) :async-with: :coroutine: @@ -363,7 +605,7 @@ The client session supports the context manager protocol for self closing. Perform an ``OPTIONS`` request. In order to modify inner - :meth:`request` + :meth:`request` parameters, provide `kwargs`. @@ -375,10 +617,6 @@ The client session supports the context manager protocol for self closing. :return ClientResponse: a :class:`client response ` object. - .. versionchanged:: 1.1 - - URLs may be either :class:`str` or :class:`~yarl.URL` - .. comethod:: patch(url, *, data=None, **kwargs) :async-with: :coroutine: @@ -386,30 +624,31 @@ The client session supports the context manager protocol for self closing. Perform a ``PATCH`` request. In order to modify inner - :meth:`request` + :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` - :param data: Dictionary, bytes, or file-like object to - send in the body of the request (optional) - + :param data: Data to send in the body of the request; see + :meth:`request` + for details (optional) :return ClientResponse: a :class:`client response ` object. - .. versionchanged:: 1.1 - - URLs may be either :class:`str` or :class:`~yarl.URL` - - .. comethod:: ws_connect(url, *, protocols=(), timeout=10.0,\ + .. comethod:: ws_connect(url, *, method='GET', \ + protocols=(), timeout=10.0,\ receive_timeout=None,\ auth=None,\ autoclose=True,\ autoping=True,\ heartbeat=None,\ origin=None, \ - proxy=None, proxy_auth=None) + headers=None, \ + proxy=None, proxy_auth=None, ssl=None, \ + verify_ssl=None, fingerprint=None, \ + ssl_context=None, proxy_headers=None, \ + compress=0, max_msg_size=4194304) :async-with: :coroutine: @@ -420,51 +659,109 @@ The client session supports the context manager protocol for self closing. :param tuple protocols: Websocket protocols - :param float timeout: Timeout for websocket to close. 10 seconds by default + :param float timeout: Timeout for websocket to close. ``10`` seconds + by default - :param float receive_timeout: Timeout for websocket to receive complete message. - None(unlimited) seconds by default + :param float receive_timeout: Timeout for websocket to receive + complete message. ``None`` (unlimited) + seconds by default :param aiohttp.BasicAuth auth: an object that represents HTTP Basic Authorization (optional) :param bool autoclose: Automatically close websocket connection on close - message from server. If `autoclose` is False - them close procedure has to be handled manually + message from server. If *autoclose* is False + then close procedure has to be handled manually. + ``True`` by default + + :param bool autoping: automatically send *pong* on *ping* + message from server. ``True`` by default - :param bool autoping: automatically send `pong` on `ping` - message from server + :param float heartbeat: Send *ping* message every *heartbeat* + seconds and wait *pong* response, if + *pong* response is not received then + close connection. The timer is reset on any data + reception.(optional) - :param float heartbeat: Send `ping` message every `heartbeat` seconds - and wait `pong` response, if `pong` response is not received - then close connection. + :param str origin: Origin header to send to server(optional) - :param str origin: Origin header to send to server + :param dict headers: HTTP Headers to send with + the request (optional) :param str proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP Basic Authorization (optional) - .. versionadded:: 0.16 + :param ssl: SSL validation mode. ``None`` for default SSL check + (:func:`ssl.create_default_context` is used), + ``False`` for skip SSL certificate validation, + :class:`aiohttp.Fingerprint` for fingerprint + validation, :class:`ssl.SSLContext` for custom SSL + certificate validation. + + Supersedes *verify_ssl*, *ssl_context* and + *fingerprint* parameters. + + .. versionadded:: 3.0 + + :param bool verify_ssl: Perform SSL certificate validation for + *HTTPS* requests (enabled by default). May be disabled to + skip validation for sites with invalid certificates. + + .. versionadded:: 2.3 + + .. deprecated:: 3.0 + + Use ``ssl=False`` + + :param bytes fingerprint: Pass the SHA256 digest of the expected + certificate in DER format to verify that the certificate the + server presents matches. Useful for `certificate pinning + `_. + + Note: use of MD5 or SHA1 digests is insecure and deprecated. + + .. versionadded:: 2.3 - Add :meth:`ws_connect`. + .. deprecated:: 3.0 - .. versionadded:: 0.18 + Use ``ssl=aiohttp.Fingerprint(digest)`` - Add *auth* parameter. + :param ssl.SSLContext ssl_context: ssl context used for processing + *HTTPS* requests (optional). - .. versionadded:: 0.19 + *ssl_context* may be used for configuring certification + authority channel, supported SSL options etc. - Add *origin* parameter. + .. versionadded:: 2.3 - .. versionadded:: 1.0 + .. deprecated:: 3.0 - Added ``proxy`` and ``proxy_auth`` parameters. + Use ``ssl=ssl_context`` - .. versionchanged:: 1.1 + :param dict proxy_headers: HTTP headers to send to the proxy if the + parameter proxy has been provided. + + .. versionadded:: 2.3 + + :param int compress: Enable Per-Message Compress Extension support. + 0 for disable, 9 to 15 for window bit support. + Default value is 0. + + .. versionadded:: 2.3 + + :param int max_msg_size: maximum size of read websocket message, + 4 MB by default. To disable the size + limit use ``0``. + + .. versionadded:: 3.3 + + :param str method: HTTP method to establish WebSocket connection, + ``'GET'`` by default. + + .. versionadded:: 3.5 - URLs may be either :class:`str` or :class:`~yarl.URL` .. comethod:: close() @@ -479,6 +776,7 @@ The client session supports the context manager protocol for self closing. Session is switched to closed state anyway. + Basic API --------- @@ -490,17 +788,20 @@ keepaliving, cookies and complex connection stuff like properly configured SSL certification chaining. -.. coroutinefunction:: request(method, url, *, params=None, data=None, json=None,\ - headers=None, cookies=None, auth=None, \ - allow_redirects=True, max_redirects=10, \ - encoding='utf-8', \ - version=HttpVersion(major=1, minor=1), \ - compress=None, chunked=None, expect100=False, \ - connector=None, loop=None,\ - read_until_eof=True) +.. cofunction:: request(method, url, *, params=None, data=None, \ + json=None,\ + headers=None, cookies=None, auth=None, \ + allow_redirects=True, max_redirects=10, \ + encoding='utf-8', \ + version=HttpVersion(major=1, minor=1), \ + compress=None, chunked=None, expect100=False, raise_for_status=False, \ + read_bufsize=None, \ + connector=None, loop=None,\ + read_until_eof=True, timeout=sentinel) + :async-with: - Perform an asynchronous HTTP request. Return a response object - (:class:`ClientResponse` or derived from). + Asynchronous context manager for performing an asynchronous HTTP + request. Returns a :class:`ClientResponse` response object. :param str method: HTTP method @@ -509,10 +810,12 @@ certification chaining. :param dict params: Parameters to be sent in the query string of the new request (optional) - :param data: Dictionary, bytes, or file-like object to - send in the body of the request (optional) + :param data: The data to send in the body of the request. This can be a + :class:`FormData` object or anything that can be passed into + :class:`FormData`, e.g. a dictionary, bytes, or file-like object. + (optional) - :param json: Any json compatible python object (optional). `json` and `data` + :param json: Any json compatible python object (optional). *json* and *data* parameters could not be used at the same time. :param dict headers: HTTP Headers to send with the request (optional) @@ -538,20 +841,37 @@ certification chaining. :param bool expect100: Expect 100-continue response from server. ``False`` by default (optional). - :param aiohttp.connector.BaseConnector connector: BaseConnector sub-class + :param bool raise_for_status: Automatically call + :meth:`ClientResponse.raise_for_status()` + for response if set to ``True``. If + set to ``None`` value from + ``ClientSession`` will be used. + ``None`` by default (optional). + + .. versionadded:: 3.4 + + :param aiohttp.BaseConnector connector: BaseConnector sub-class instance to support connection pooling. :param bool read_until_eof: Read response until EOF if response does not have Content-Length header. ``True`` by default (optional). + :param int read_bufsize: Size of the read buffer (:attr:`ClientResponse.content`). + ``None`` by default, + it means that the session global value is used. + + .. versionadded:: 3.7 + + :param timeout: a :class:`ClientTimeout` settings structure, 300 seconds (5min) + total timeout by default. + :param loop: :ref:`event loop` used for processing HTTP requests. If param is ``None``, :func:`asyncio.get_event_loop` - is used for getting default event loop, but we strongly - recommend to use explicit loops everywhere. - (optional) + is used for getting default event loop. + .. deprecated:: 2.0 :return ClientResponse: a :class:`client response ` object. @@ -560,14 +880,11 @@ certification chaining. import aiohttp async def fetch(): - async with aiohttp.request('GET', 'http://python.org/') as resp: + async with aiohttp.request('GET', + 'http://python.org/') as resp: assert resp.status == 200 print(await resp.text()) - .. versionchanged:: 1.1 - - URLs may be either :class:`str` or :class:`~yarl.URL` - .. _aiohttp-client-reference-connectors: @@ -585,16 +902,16 @@ There are standard connectors: All connector classes should be derived from :class:`BaseConnector`. -By default all *connectors* support *keep-alive connections* (behavior is controlled by -*force_close* constructor's parameter). +By default all *connectors* support *keep-alive connections* (behavior +is controlled by *force_close* constructor's parameter). BaseConnector ^^^^^^^^^^^^^ -.. class:: BaseConnector(*, keepalive_timeout=30, \ - limit=100, limit_per_host=None, \ - force_close=False, loop=None) +.. class:: BaseConnector(*, keepalive_timeout=15, \ + force_close=False, limit=100, limit_per_host=0, \ + enable_cleanup_closed=False, loop=None) Base class for all connectors. @@ -604,23 +921,29 @@ BaseConnector feature use ``force_close=True`` flag. - :param int limit: Total number simultaneous connections. If *limit* is + :param int limit: total number simultaneous connections. If *limit* is ``None`` the connector has no limit (default: 100). - :param int limit_by_host: limit for simultaneous connections to the same + :param int limit_per_host: limit simultaneous connections to the same endpoint. Endpoints are the same if they are have equal ``(host, port, is_ssl)`` triple. - If *limit* is ``None`` the connector has no limit (default: None). + If *limit* is ``0`` the connector has no limit (default: 0). - :param bool force_close: do close underlying sockets after + :param bool force_close: close underlying sockets after connection releasing (optional). + :param bool enable_cleanup_closed: some SSL servers do not properly complete + SSL shutdown process, in that case asyncio leaks ssl connections. + If this parameter is set to True, aiohttp additionally aborts underlining + transport after 2 seconds. It is off by default. + + :param loop: :ref:`event loop` used for handling connections. If param is ``None``, :func:`asyncio.get_event_loop` - is used for getting default event loop, but we strongly - recommend to use explicit loops everywhere. - (optional) + is used for getting default event loop. + + .. deprecated:: 2.0 .. attribute:: closed @@ -631,8 +954,6 @@ BaseConnector Read-only property, ``True`` if connector should ultimately close connections on releasing. - .. versionadded:: 0.16 - .. attribute:: limit The total number for simultaneous connections. @@ -662,7 +983,7 @@ BaseConnector The call may be paused if :attr:`limit` is exhausted until used connections returns to pool. - :param aiohttp.client.ClientRequest request: request object + :param aiohttp.ClientRequest request: request object which is connection initiator. @@ -679,12 +1000,12 @@ BaseConnector TCPConnector ^^^^^^^^^^^^ -.. class:: TCPConnector(*, verify_ssl=True, fingerprint=None,\ - use_dns_cache=True, \ - family=0, ssl_context=None, conn_timeout=None, \ - keepalive_timeout=30, limit=None, \ - force_close=False, loop=None, local_addr=None, \ - disable_cleanup_closed=True) +.. class:: TCPConnector(*, ssl=None, verify_ssl=True, fingerprint=None, \ + use_dns_cache=True, ttl_dns_cache=10, \ + family=0, ssl_context=None, local_addr=None, \ + resolver=None, keepalive_timeout=sentinel, \ + force_close=False, limit=100, limit_per_host=0, \ + enable_cleanup_closed=False, loop=None) Connector for working with *HTTP* and *HTTPS* via *TCP* sockets. @@ -696,18 +1017,36 @@ TCPConnector Constructor accepts all parameters suitable for :class:`BaseConnector` plus several TCP-specific ones: - :param bool verify_ssl: Perform SSL certificate validation for + :param ssl: SSL validation mode. ``None`` for default SSL check + (:func:`ssl.create_default_context` is used), + ``False`` for skip SSL certificate validation, + :class:`aiohttp.Fingerprint` for fingerprint + validation, :class:`ssl.SSLContext` for custom SSL + certificate validation. + + Supersedes *verify_ssl*, *ssl_context* and + *fingerprint* parameters. + + .. versionadded:: 3.0 + + :param bool verify_ssl: perform SSL certificate validation for *HTTPS* requests (enabled by default). May be disabled to skip validation for sites with invalid certificates. - :param bytes fingerprint: Pass the SHA256 digest of the expected - certificate in DER format to verify that the certificate the - server presents matches. Useful for `certificate pinning - `_. + .. deprecated:: 2.3 + + Pass *verify_ssl* to ``ClientSession.get()`` etc. + + :param bytes fingerprint: pass the SHA256 digest of the expected + certificate in DER format to verify that the certificate the + server presents matches. Useful for `certificate pinning + `_. + + Note: use of MD5 or SHA1 digests is insecure and deprecated. - Note: use of MD5 or SHA1 digests is insecure and deprecated. + .. deprecated:: 2.3 - .. versionadded:: 0.16 + Pass *verify_ssl* to ``ClientSession.get()`` etc. :param bool use_dns_cache: use internal cache for DNS lookups, ``True`` by default. @@ -716,38 +1055,43 @@ TCPConnector establishing a bit but may introduce some *side effects* also. - .. versionadded:: 0.17 + :param int ttl_dns_cache: expire after some seconds the DNS entries, ``None`` + means cached forever. By default 10 seconds (optional). - .. versionchanged:: 1.0 + In some environments the IP addresses related to a specific HOST can + change after a specific time. Use this option to keep the DNS cache + updated refreshing each entry after N seconds. - The default is changed to ``True`` + :param int limit: total number simultaneous connections. If *limit* is + ``None`` the connector has no limit (default: 100). - :param aiohttp.abc.AbstractResolver resolver: Custom resolver + :param int limit_per_host: limit simultaneous connections to the same + endpoint. Endpoints are the same if they are + have equal ``(host, port, is_ssl)`` triple. + If *limit* is ``0`` the connector has no limit (default: 0). + + :param aiohttp.abc.AbstractResolver resolver: custom resolver instance to use. ``aiohttp.DefaultResolver`` by default (asynchronous if ``aiodns>=1.1`` is installed). Custom resolvers allow to resolve hostnames differently than the way the host is configured. - .. versionadded:: 0.22 - - .. versionchanged:: 1.0 - - The resolver is ``aiohttp.AsyncResolver`` now if - :term:`aiodns` is installed. + The resolver is ``aiohttp.ThreadedResolver`` by default, + asynchronous version is pretty robust but might fail in + very rare cases. :param int family: TCP socket family, both IPv4 and IPv6 by default. For *IPv4* only use :const:`socket.AF_INET`, for *IPv6* only -- :const:`socket.AF_INET6`. - .. versionchanged:: 0.18 + *family* is ``0`` by default, that means both + IPv4 and IPv6 are accepted. To specify only + concrete version please pass + :const:`socket.AF_INET` or + :const:`socket.AF_INET6` explicitly. - *family* is `0` by default, that means both IPv4 and IPv6 are - accepted. To specify only concrete version please pass - :const:`socket.AF_INET` or :const:`socket.AF_INET6` - explicitly. - - :param ssl.SSLContext ssl_context: ssl context used for processing + :param ssl.SSLContext ssl_context: SSL context used for processing *HTTPS* requests (optional). *ssl_context* may be used for configuring certification @@ -756,23 +1100,14 @@ TCPConnector :param tuple local_addr: tuple of ``(local_host, local_port)`` used to bind socket locally if specified. - .. versionadded:: 0.21 + :param bool force_close: close underlying sockets after + connection releasing (optional). - :param tuple enable_cleanup_closed: Some ssl servers do not properly complete - ssl shutdown process, in that case asyncio leaks ssl connections. + :param bool enable_cleanup_closed: Some ssl servers do not properly complete + SSL shutdown process, in that case asyncio leaks SSL connections. If this parameter is set to True, aiohttp additionally aborts underlining transport after 2 seconds. It is off by default. - .. attribute:: verify_ssl - - Check *ssl certifications* if ``True``. - - Read-only :class:`bool` property. - - .. attribute:: ssl_context - - :class:`ssl.SSLContext` instance for *https* requests, read-only property. - .. attribute:: family *TCP* socket family e.g. :const:`socket.AF_INET` or @@ -786,26 +1121,12 @@ TCPConnector Read-only :class:`bool` property. - .. versionadded:: 0.17 - .. attribute:: cached_hosts The cache of resolved hosts if :attr:`dns_cache` is enabled. Read-only :class:`types.MappingProxyType` property. - .. versionadded:: 0.17 - - .. attribute:: fingerprint - - MD5, SHA1, or SHA256 hash of the expected certificate in DER - format, or ``None`` if no certificate fingerprint check - required. - - Read-only :class:`bytes` property. - - .. versionadded:: 0.16 - .. method:: clear_dns_cache(self, host=None, port=None) Clear internal *DNS* cache. @@ -813,14 +1134,12 @@ TCPConnector Remove specific entry if both *host* and *port* are specified, clear all cache otherwise. - .. versionadded:: 0.17 - UnixConnector ^^^^^^^^^^^^^ .. class:: UnixConnector(path, *, conn_timeout=None, \ - keepalive_timeout=30, limit=None, \ + keepalive_timeout=30, limit=100, \ force_close=False, loop=None) Unix socket connector. @@ -870,6 +1189,8 @@ Connection Event loop used for connection + .. deprecated:: 3.5 + .. attribute:: transport Connection transport @@ -886,20 +1207,13 @@ Connection later if timeout (30 seconds by default) for connection was not expired. - .. method:: detach() - - Detach underlying socket from connection. - - Underlying socket is not closed, next :meth:`close` or - :meth:`release` calls don't return socket to free pool. - Response object --------------- .. class:: ClientResponse - Client response returned be :meth:`ClientSession.request` and family. + Client response returned by :meth:`ClientSession.request` and family. User never creates the instance of ClientResponse class but gets it from API calls. @@ -913,10 +1227,6 @@ Response object After exiting from ``async with`` block response object will be *released* (see :meth:`release` coroutine). - .. versionadded:: 0.18 - - Support for ``async with``. - .. attribute:: version Response's version, :class:`HttpVersion` instance. @@ -929,6 +1239,11 @@ Response object HTTP status reason of response (:class:`str`), e.g. ``"OK"``. + .. attribute:: ok + + Boolean representation of HTTP status code (:class:`bool`). + ``True`` if ``status`` is less than ``400``; otherwise, ``False``. + .. attribute:: method Request's method (:class:`str`). @@ -937,18 +1252,28 @@ Response object URL of request (:class:`~yarl.URL`). + .. attribute:: real_url + + Unmodified URL of request with URL fragment unstripped (:class:`~yarl.URL`). + + .. versionadded:: 3.2 + .. attribute:: connection :class:`Connection` used for handling response. .. attribute:: content - Payload stream, contains response's BODY (:class:`StreamReader`). + Payload stream, which contains response's BODY (:class:`StreamReader`). + It supports various reading methods depending on the expected format. + When chunked transfer encoding is used by the server, allows retrieving + the actual http chunks. Reading from the stream may raise :exc:`aiohttp.ClientPayloadError` if the response object is - closed before response receives all data or in case if any transfer encoding - related errors like mis-formed chunked encoding of broken compression data. + closed before response receives all data or in case if any + transfer encoding related errors like misformed chunked + encoding of broken compression data. .. attribute:: cookies @@ -958,13 +1283,23 @@ Response object .. attribute:: headers A case-insensitive multidict proxy with HTTP headers of - response, :class:`CIMultiDictProxy`. + response, :class:`~multidict.CIMultiDictProxy`. .. attribute:: raw_headers Unmodified HTTP headers of response as unconverted bytes, a sequence of ``(key, value)`` pairs. + .. attribute:: links + + Link HTTP header parsed into a :class:`~multidict.MultiDictProxy`. + + For each link, key is link param `rel` when it exists, or link url as + :class:`str` otherwise, and value is :class:`~multidict.MultiDictProxy` + of link params and url at key `url` as :class:`~yarl.URL` instance. + + .. versionadded:: 3.2 + .. attribute:: content_type Read-only property with *content* part of *Content-Type* header. @@ -986,6 +1321,13 @@ Response object Returns :class:`str` like ``'utf-8'`` or ``None`` if no *Content-Type* header present in HTTP headers or it has no charset information. + .. attribute:: content_disposition + + Read-only property that specified the *Content-Disposition* HTTP header. + + Instance of :class:`ContentDisposition` or ``None`` if no *Content-Disposition* + header present in HTTP headers. + .. attribute:: history A :class:`~collections.abc.Sequence` of :class:`ClientResponse` @@ -1005,19 +1347,25 @@ Response object Close underlying connection if data reading gets an error, release connection otherwise. + Raise an :exc:`aiohttp.ClientResponseError` if the data can't + be read. + :return bytes: read *BODY*. .. seealso:: :meth:`close`, :meth:`release`. .. comethod:: release() - It is not required to call `release` on the response object. When the - client fully receives the payload, the underlying connection automatically - returns back to pool. If the payload is not fully read, the connection is closed + It is not required to call `release` on the response + object. When the client fully receives the payload, the + underlying connection automatically returns back to pool. If the + payload is not fully read, the connection is closed .. method:: raise_for_status() - Raise an :exc:`aiohttp.ClientResponseError` if the response status is 400 or higher. + Raise an :exc:`aiohttp.ClientResponseError` if the response + status is 400 or higher. + Do nothing for success responses (less than 400). .. comethod:: text(encoding=None) @@ -1026,7 +1374,10 @@ Response object specified *encoding* parameter. If *encoding* is ``None`` content encoding is autocalculated - using :term:`cchardet` or :term:`chardet` as fallback if + using ``Content-Type`` HTTP header and *chardet* tool if the + header is not provided by server. + + :term:`cchardet` is used with fallback to :term:`chardet` if *cchardet* is not available. Close underlying connection if data reading gets an error, @@ -1038,23 +1389,45 @@ Response object :return str: decoded *BODY* - .. comethod:: json(encoding=None, loads=json.loads, content_type='application/json') + :raise LookupError: if the encoding detected by chardet or cchardet is + unknown by Python (e.g. VISCII). + + .. note:: + + If response has no ``charset`` info in ``Content-Type`` HTTP + header :term:`cchardet` / :term:`chardet` is used for content + encoding autodetection. + + It may hurt performance. If page encoding is known passing + explicit *encoding* parameter might help:: + + await resp.text('ISO-8859-1') + + .. comethod:: json(*, encoding=None, loads=json.loads, \ + content_type='application/json') Read response's body as *JSON*, return :class:`dict` using - specified *encoding* and *loader*. + specified *encoding* and *loader*. If data is not still available + a ``read`` call will be done, If *encoding* is ``None`` content encoding is autocalculated using :term:`cchardet` or :term:`chardet` as fallback if *cchardet* is not available. if response's `content-type` does not match `content_type` parameter - :exc:`aiohttp.ClientResponseError` get raised. To disable content type - check pass ``None`` value. + :exc:`aiohttp.ContentTypeError` get raised. + To disable content type check pass ``None`` value. :param str encoding: text encoding used for *BODY* decoding, or ``None`` for encoding autodetection (default). + By the standard JSON encoding should be + ``UTF-8`` but practice beats purity: some + servers return non-UTF + responses. Autodetection works pretty fine + anyway. + :param callable loads: :func:`callable` used for loading *JSON* data, :func:`json.loads` by default. @@ -1066,6 +1439,27 @@ Response object :return: *BODY* as *JSON* data parsed by *loads* parameter or ``None`` if *BODY* is empty or contains white-spaces only. + .. attribute:: request_info + + A namedtuple with request URL and headers from :class:`ClientRequest` + object, :class:`aiohttp.RequestInfo` instance. + + .. method:: get_encoding() + + Automatically detect content encoding using ``charset`` info in + ``Content-Type`` HTTP header. If this info is not exists or there + are no appropriate codecs for encoding then :term:`cchardet` / + :term:`chardet` is used. + + Beware that it is not always safe to use the result of this function to + decode a response. Some encodings detected by cchardet are not known by + Python (e.g. VISCII). + + :raise RuntimeError: if called before the body has been read, + for :term:`cchardet` usage + + .. versionadded:: 3.0 + ClientWebSocketResponse ----------------------- @@ -1081,7 +1475,7 @@ manually. .. attribute:: closed - Read-only property, ``True`` if :meth:`close` has been called of + Read-only property, ``True`` if :meth:`close` has been called or :const:`~aiohttp.WSMsgType.CLOSE` message has been received from peer. .. attribute:: protocol @@ -1099,7 +1493,7 @@ manually. Returns exception if any occurs or returns None. - .. method:: ping(message=b'') + .. comethod:: ping(message=b'') Send :const:`~aiohttp.WSMsgType.PING` to peer. @@ -1107,29 +1501,67 @@ manually. :class:`str` (converted to *UTF-8* encoded bytes) or :class:`bytes`. - .. comethod:: send_str(data) + .. versionchanged:: 3.0 + + The method is converted into :term:`coroutine` + + .. comethod:: pong(message=b'') + + Send :const:`~aiohttp.WSMsgType.PONG` to peer. + + :param message: optional payload of *pong* message, + :class:`str` (converted to *UTF-8* encoded bytes) + or :class:`bytes`. + + .. versionchanged:: 3.0 + + The method is converted into :term:`coroutine` + + .. comethod:: send_str(data, compress=None) Send *data* to peer as :const:`~aiohttp.WSMsgType.TEXT` message. :param str data: data to send. + :param int compress: sets specific level of compression for + single message, + ``None`` for not overriding per-socket setting. + :raise TypeError: if data is not :class:`str` - .. comethod:: send_bytes(data) + .. versionchanged:: 3.0 + + The method is converted into :term:`coroutine`, + *compress* parameter added. + + .. comethod:: send_bytes(data, compress=None) Send *data* to peer as :const:`~aiohttp.WSMsgType.BINARY` message. :param data: data to send. + :param int compress: sets specific level of compression for + single message, + ``None`` for not overriding per-socket setting. + :raise TypeError: if data is not :class:`bytes`, :class:`bytearray` or :class:`memoryview`. - .. comethod:: send_json(data, *, dumps=json.loads) + .. versionchanged:: 3.0 + + The method is converted into :term:`coroutine`, + *compress* parameter added. + + .. comethod:: send_json(data, compress=None, *, dumps=json.dumps) Send *data* to peer as JSON string. :param data: data to send. + :param int compress: sets specific level of compression for + single message, + ``None`` for not overriding per-socket setting. + :param callable dumps: any :term:`callable` that accepts an object and returns a JSON string (:func:`json.dumps` by default). @@ -1141,16 +1573,21 @@ manually. :raise TypeError: if value returned by ``dumps(data)`` is not :class:`str` + .. versionchanged:: 3.0 + + The method is converted into :term:`coroutine`, + *compress* parameter added. + .. comethod:: close(*, code=1000, message=b'') A :ref:`coroutine` that initiates closing handshake by sending :const:`~aiohttp.WSMsgType.CLOSE` message. It waits for - close response from server. It add timeout to `close()` call just wrap - call with `asyncio.wait()` or `asyncio.wait_for()`. + close response from server. To add a timeout to `close()` call + just wrap the call with `asyncio.wait()` or `asyncio.wait_for()`. :param int code: closing code - :param message: optional payload of *pong* message, + :param message: optional payload of *close* message, :class:`str` (converted to *UTF-8* encoded bytes) or :class:`bytes`. .. comethod:: receive() @@ -1166,8 +1603,7 @@ manually. It process *ping-pong game* and performs *closing handshake* internally. - :return: :class:`~aiohttp.WSMessage`, `tp` is a type from - :class:`~aiohttp.WSMsgType` enumeration. + :return: :class:`~aiohttp.WSMessage` .. coroutinemethod:: receive_str() @@ -1209,6 +1645,76 @@ Utilities --------- +ClientTimeout +^^^^^^^^^^^^^ + +.. class:: ClientTimeout(*, total=None, connect=None, \ + sock_connect=None, sock_read=None) + + A data class for client timeout settings. + + See :ref:`aiohttp-client-timeouts` for usage examples. + + .. attribute:: total + + Total number of seconds for the whole request. + + :class:`float`, ``None`` by default. + + .. attribute:: connect + + Maximal number of seconds for acquiring a connection from pool. The time + consists connection establishment for a new connection or + waiting for a free connection from a pool if pool connection + limits are exceeded. + + For pure socket connection establishment time use + :attr:`sock_connect`. + + :class:`float`, ``None`` by default. + + .. attribute:: sock_connect + + Maximal number of seconds for connecting to a peer for a new connection, not + given from a pool. See also :attr:`connect`. + + :class:`float`, ``None`` by default. + + .. attribute:: sock_read + + Maximal number of seconds for reading a portion of data from a peer. + + :class:`float`, ``None`` by default. + + .. versionadded:: 3.3 + +RequestInfo +^^^^^^^^^^^ + +.. class:: RequestInfo() + + A data class with request URL and headers from :class:`ClientRequest` + object, available as :attr:`ClientResponse.request_info` attribute. + + .. attribute:: url + + Requested *url*, :class:`yarl.URL` instance. + + .. attribute:: method + + Request HTTP method like ``'GET'`` or ``'POST'``, :class:`str`. + + .. attribute:: headers + + HTTP headers for request, :class:`multidict.CIMultiDict` instance. + + .. attribute:: real_url + + Requested *url* with URL fragment unstripped, :class:`yarl.URL` instance. + + .. versionadded:: 3.2 + + BasicAuth ^^^^^^^^^ @@ -1218,7 +1724,7 @@ BasicAuth :param str login: login :param str password: password - :param str encoding: encoding (`'latin1'` by default) + :param str encoding: encoding (``'latin1'`` by default) Should be used for specifying authorization data in client API, @@ -1234,6 +1740,15 @@ BasicAuth :return: decoded authentication data, :class:`BasicAuth`. + .. classmethod:: from_url(url) + + Constructed credentials info from url's *user* and *password* + parts. + + :return: credentials data, :class:`BasicAuth` or ``None`` is + credentials are not provided. + + .. versionadded:: 2.3 .. method:: encode() @@ -1246,7 +1761,7 @@ BasicAuth CookieJar ^^^^^^^^^ -.. class:: CookieJar(unsafe=False, loop=None) +.. class:: CookieJar(*, unsafe=False, quote_cookie=True, loop=None) The cookie jar instance is available as :attr:`ClientSession.cookie_jar`. @@ -1271,9 +1786,19 @@ CookieJar :param bool unsafe: (optional) Whether to accept cookies from IPs. + :param bool quote_cookie: (optional) Whether to quote cookies according to + :rfc:`2109`. Some backend systems + (not compatible with RFC mentioned above) + does not support quoted cookies. + + .. versionadded:: 3.7 + :param bool loop: an :ref:`event loop` instance. See :class:`aiohttp.abc.AbstractCookieJar` + .. deprecated:: 2.0 + + .. method:: update_cookies(cookies, response_url=None) Update cookies returned by server in ``Set-Cookie`` header. @@ -1315,50 +1840,344 @@ CookieJar imported, :class:`str` or :class:`pathlib.Path` instance. +.. class:: DummyCookieJar(*, loop=None) + + Dummy cookie jar which does not store cookies but ignores them. + + Could be useful e.g. for web crawlers to iterate over Internet + without blowing up with saved cookies information. + + To install dummy cookie jar pass it into session instance:: + + jar = aiohttp.DummyCookieJar() + session = aiohttp.ClientSession(cookie_jar=DummyCookieJar()) + + +.. class:: Fingerprint(digest) + + Fingerprint helper for checking SSL certificates by *SHA256* digest. + + :param bytes digest: *SHA256* digest for certificate in DER-encoded + binary form (see + :meth:`ssl.SSLSocket.getpeercert`). + + To check fingerprint pass the object into :meth:`ClientSession.get` + call, e.g.:: + + import hashlib + + with open(path_to_cert, 'rb') as f: + digest = hashlib.sha256(f.read()).digest() + + await session.get(url, ssl=aiohttp.Fingerprint(digest)) + + .. versionadded:: 3.0 + +FormData +^^^^^^^^ + +A :class:`FormData` object contains the form data and also handles +encoding it into a body that is either ``multipart/form-data`` or +``application/x-www-form-urlencoded``. ``multipart/form-data`` is +used if at least one field is an :class:`io.IOBase` object or was +added with at least one optional argument to :meth:`add_field` +(``content_type``, ``filename``, or ``content_transfer_encoding``). +Otherwise, ``application/x-www-form-urlencoded`` is used. + +:class:`FormData` instances are callable and return a :class:`Payload` +on being called. + +.. class:: FormData(fields, quote_fields=True, charset=None) + + Helper class for multipart/form-data and application/x-www-form-urlencoded body generation. + + :param fields: A container for the key/value pairs of this form. + + Possible types are: + + - :class:`dict` + - :class:`tuple` or :class:`list` + - :class:`io.IOBase`, e.g. a file-like object + - :class:`multidict.MultiDict` or :class:`multidict.MultiDictProxy` + + If it is a :class:`tuple` or :class:`list`, it must be a valid argument + for :meth:`add_fields`. + + For :class:`dict`, :class:`multidict.MultiDict`, and :class:`multidict.MultiDictProxy`, + the keys and values must be valid `name` and `value` arguments to + :meth:`add_field`, respectively. + + .. method:: add_field(name, value, content_type=None, filename=None,\ + content_transfer_encoding=None) + + Add a field to the form. + + :param str name: Name of the field + + :param value: Value of the field + + Possible types are: + + - :class:`str` + - :class:`bytes`, :class:`bytesarray`, or :class:`memoryview` + - :class:`io.IOBase`, e.g. a file-like object + + :param str content_type: The field's content-type header (optional) + + :param str filename: The field's filename (optional) + + If this is not set and ``value`` is a :class:`bytes`, :class:`bytesarray`, + or :class:`memoryview` object, the `name` argument is used as the filename + unless ``content_transfer_encoding`` is specified. + + If ``filename`` is not set and ``value`` is an :class:`io.IOBase` + object, the filename is extracted from the object if possible. + + :param str content_transfer_encoding: The field's content-transfer-encoding + header (optional) + + .. method:: add_fields(fields) + + Add one or more fields to the form. + + :param fields: An iterable containing: + + - :class:`io.IOBase`, e.g. a file-like object + - :class:`multidict.MultiDict` or :class:`multidict.MultiDictProxy` + - :class:`tuple` or :class:`list` of length two, containing a name-value pair + Client exceptions +----------------- + +Exception hierarchy has been significantly modified in version +2.0. aiohttp defines only exceptions that covers connection handling +and server response misbehaviors. For developer specific mistakes, +aiohttp uses python standard exceptions like :exc:`ValueError` or +:exc:`TypeError`. + +Reading a response content may raise a :exc:`ClientPayloadError` +exception. This exception indicates errors specific to the payload +encoding. Such as invalid compressed data, malformed chunked-encoded +chunks or not enough data that satisfy the content-length header. + +All exceptions are available as members of *aiohttp* module. + +.. exception:: ClientError + + Base class for all client specific exceptions. + + Derived from :exc:`Exception` + + +.. class:: ClientPayloadError + + This exception can only be raised while reading the response + payload if one of these errors occurs: + + 1. invalid compression + 2. malformed chunked encoding + 3. not enough data that satisfy ``Content-Length`` HTTP header. + + Derived from :exc:`ClientError` + +.. exception:: InvalidURL + + URL used for fetching is malformed, e.g. it does not contain host + part. + + Derived from :exc:`ClientError` and :exc:`ValueError` + + .. attribute:: url + + Invalid URL, :class:`yarl.URL` instance. + +.. class:: ContentDisposition + + Represent Content-Disposition header + + .. attribute:: value + + A :class:`str` instance. Value of Content-Disposition header + itself, e.g. ``attachment``. + + .. attribute:: filename + + A :class:`str` instance. Content filename extracted from + parameters. May be ``None``. + + .. attribute:: parameters + + Read-only mapping contains all parameters. + +Response errors +^^^^^^^^^^^^^^^ + +.. exception:: ClientResponseError + + These exceptions could happen after we get response from server. + + Derived from :exc:`ClientError` + + .. attribute:: request_info + + Instance of :class:`RequestInfo` object, contains information + about request. + + .. attribute:: status + + HTTP status code of response (:class:`int`), e.g. ``400``. + + .. attribute:: message + + Message of response (:class:`str`), e.g. ``"OK"``. + + .. attribute:: headers + + Headers in response, a list of pairs. + + .. attribute:: history + + History from failed response, if available, else empty tuple. + + A :class:`tuple` of :class:`ClientResponse` objects used for + handle redirection responses. + + .. attribute:: code + + HTTP status code of response (:class:`int`), e.g. ``400``. + + .. deprecated:: 3.1 + + +.. class:: WSServerHandshakeError + + Web socket server response error. + + Derived from :exc:`ClientResponseError` + + +.. class:: ContentTypeError + + Invalid content type. + + Derived from :exc:`ClientResponseError` + + .. versionadded:: 2.3 + + +.. class:: TooManyRedirects + + Client was redirected too many times. + + Maximum number of redirects can be configured by using + parameter ``max_redirects`` in :meth:`request`. + + Derived from :exc:`ClientResponseError` + + .. versionadded:: 3.2 + +Connection errors ^^^^^^^^^^^^^^^^^ -Exception hierarchy has been significantly modified in version 2.0. aiohttp defines only -exceptions that covers connection handling and server response misbehaviors. -For developer specific mistakes, aiohttp uses python standard exceptions -like `ValueError` or `TypeError`. +.. class:: ClientConnectionError + + These exceptions related to low-level connection problems. + + Derived from :exc:`ClientError` + +.. class:: ClientOSError + + Subset of connection errors that are initiated by an :exc:`OSError` + exception. + + Derived from :exc:`ClientConnectionError` and :exc:`OSError` + +.. class:: ClientConnectorError + + Connector related exceptions. + + Derived from :exc:`ClientOSError` + +.. class:: ClientProxyConnectionError + + Derived from :exc:`ClientConnectorError` + +.. class:: ServerConnectionError + + Derived from :exc:`ClientConnectionError` + +.. class:: ClientSSLError + + Derived from :exc:`ClientConnectorError` + +.. class:: ClientConnectorSSLError + + Response ssl error. + + Derived from :exc:`ClientSSLError` and :exc:`ssl.SSLError` + +.. class:: ClientConnectorCertificateError + + Response certificate error. + + Derived from :exc:`ClientSSLError` and :exc:`ssl.CertificateError` + +.. class:: ServerDisconnectedError + + Server disconnected. + + Derived from :exc:`ServerDisconnectionError` + + .. attribute:: message + + Partially parsed HTTP message (optional). + + +.. class:: ServerTimeoutError + + Server operation timeout: read timeout, etc. + + Derived from :exc:`ServerConnectionError` and :exc:`asyncio.TimeoutError` + +.. class:: ServerFingerprintMismatch + + Server fingerprint mismatch. + + Derived from :exc:`ServerConnectionError` -Reading a response content may raise a :exc:`ClientPayloadError` exception. This exception -indicates errors specific to the payload encoding. Such as invalid compressed data, -malformed chunked-encoded chunks or not enough data that satisfy the content-length header. -All exceptions are available as attributes in `aiohttp` module. +Hierarchy of exceptions +^^^^^^^^^^^^^^^^^^^^^^^ -Hierarchy of exceptions: +* :exc:`ClientError` -* `aiohttp.ClientError` - Base class for all client specific exceptions + * :exc:`ClientResponseError` - - `aiohttp.ClientResponseError` - exceptions that could happen after we get response from server. + * :exc:`ContentTypeError` + * :exc:`WSServerHandshakeError` + * :exc:`ClientHttpProxyError` - - `aiohttp.WSServerHandshakeError` - web socket server response error + * :exc:`ClientConnectionError` - - `aiohttp.ClientHttpProxyError` - proxy response + * :exc:`ClientOSError` - - `aiohttp.ClientConnectionError` - exceptions related to low-level connection problems + * :exc:`ClientConnectorError` - - `aiohttp.ClientOSError` - subset of connection errors that are initiated by an OSError exception + * :exc:`ClientSSLError` - - `aiohttp.ClientConnectorError` - connector related exceptions - - - `aiohttp.ClientProxyConnectionError` - proxy connection initialization error + * :exc:`ClientConnectorCertificateError` - - `aiohttp.ServerConnectionError` - server connection related errors + * :exc:`ClientConnectorSSLError` - - `aiohttp.ServerDisconnectedError` - server disconnected + * :exc:`ClientProxyConnectionError` - - `aiohttp.ServerTimeoutError` - server operation timeout, (read timeout, etc) + * :exc:`ServerConnectionError` - - `aiohttp.ServerFingerprintMismatch` - server fingerprint mismatch + * :exc:`ServerDisconnectedError` + * :exc:`ServerTimeoutError` - - `aiohttp.ClientPayloadError` - This exception can only be raised while reading the response - payload if one of these errors occurs: invalid compression, malformed chunked encoding or - not enough data that satisfy content-length header. + * :exc:`ServerFingerprintMismatch` + * :exc:`ClientPayloadError` -.. disqus:: - :title: aiohttp client reference + * :exc:`InvalidURL` diff --git a/docs/conf.py b/docs/conf.py index ed5da4b166d..6532648d399 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # # aiohttp documentation build configuration file, created by # sphinx-quickstart on Wed Mar 5 12:35:35 2014. @@ -13,34 +12,29 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys +import io import os -import codecs import re _docs_path = os.path.dirname(__file__) -_version_path = os.path.abspath(os.path.join(_docs_path, - '..', 'aiohttp', '__init__.py')) -with codecs.open(_version_path, 'r', 'latin1') as fp: +_version_path = os.path.abspath( + os.path.join(_docs_path, "..", "aiohttp", "__init__.py") +) +with open(_version_path, encoding="latin1") as fp: try: - _version_info = re.search(r"^__version__ = '" - r"(?P\d+)" - r"\.(?P\d+)" - r"\.(?P\d+)" - r"(?P.*)?'$", - fp.read(), re.M).groupdict() + _version_info = re.search( + r'^__version__ = "' + r"(?P\d+)" + r"\.(?P\d+)" + r"\.(?P\d+)" + r'(?P.*)?"$', + fp.read(), + re.M, + ).groupdict() except IndexError: - raise RuntimeError('Unable to determine version.') + raise RuntimeError("Unable to determine version.") -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('..')) -sys.path.insert(0, os.path.abspath('.')) - -# import alabaster - # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. @@ -50,57 +44,55 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.viewcode', - 'sphinx.ext.intersphinx', - 'alabaster', - 'sphinxcontrib.asyncio', - 'sphinxcontrib.newsfeed', + "sphinx.ext.viewcode", + "sphinx.ext.intersphinx", + "sphinxcontrib.asyncio", + "sphinxcontrib.blockdiag", ] try: import sphinxcontrib.spelling # noqa - extensions.append('sphinxcontrib.spelling') + + extensions.append("sphinxcontrib.spelling") except ImportError: pass intersphinx_mapping = { - 'python': ('http://docs.python.org/3', None), - 'multidict': - ('https://multidict.readthedocs.io/en/stable/', None), - 'yarl': - ('https://yarl.readthedocs.io/en/stable/', None), - 'aiohttpjinja2': - ('https://aiohttp-jinja2.readthedocs.io/en/stable/', None), - 'aiohttpsession': - ('https://aiohttp-session.readthedocs.io/en/stable/', None)} + "python": ("http://docs.python.org/3", None), + "multidict": ("https://multidict.readthedocs.io/en/stable/", None), + "yarl": ("https://yarl.readthedocs.io/en/stable/", None), + "aiohttpjinja2": ("https://aiohttp-jinja2.readthedocs.io/en/stable/", None), + "aiohttpremotes": ("https://aiohttp-remotes.readthedocs.io/en/stable/", None), + "aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None), + "aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None), +} # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. # source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'aiohttp' -copyright = '2013-2017, Aiohttp contributors' +project = "aiohttp" +copyright = "2013-2020, aiohttp maintainers" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '{major}.{minor}'.format(**_version_info) +version = "{major}.{minor}".format(**_version_info) # The full version, including alpha/beta/rc tags. -release = '{major}.{minor}.{patch}-{tag}'.format(**_version_info) +release = "{major}.{minor}.{patch}{tag}".format(**_version_info) # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -114,7 +106,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. @@ -132,10 +124,10 @@ # show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +# pygments_style = 'sphinx' # The default language to highlight source code in. -highlight_language = 'python3' +highlight_language = "python3" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] @@ -148,27 +140,52 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'alabaster' +html_theme = "aiohttp_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { - 'logo': 'aiohttp-icon-128x128.png', - 'description': 'http client/server for asyncio', - 'github_user': 'aio-libs', - 'github_repo': 'aiohttp', - 'github_button': True, - 'github_type': 'star', - 'github_banner': True, - 'travis_button': True, - 'codecov_button': True, - 'pre_bg': '#FFF6E5', - 'note_bg': '#E5ECD1', - 'note_border': '#BFCF8C', - 'body_text': '#482C0A', - 'sidebar_text': '#49443E', - 'sidebar_header': '#4B4032', + "logo": "aiohttp-icon-128x128.png", + "description": "Async HTTP client/server for asyncio and Python", + "canonical_url": "http://docs.aiohttp.org/en/stable/", + "github_user": "aio-libs", + "github_repo": "aiohttp", + "github_button": True, + "github_type": "star", + "github_banner": True, + "badges": [ + { + "image": "https://github.com/aio-libs/aiohttp/workflows/CI/badge.svg", + "target": "https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI", + "height": "20", + "alt": "Azure Pipelines CI status", + }, + { + "image": "https://codecov.io/github/aio-libs/aiohttp/coverage.svg?branch=master", + "target": "https://codecov.io/github/aio-libs/aiohttp", + "height": "20", + "alt": "Code coverage status", + }, + { + "image": "https://badge.fury.io/py/aiohttp.svg", + "target": "https://badge.fury.io/py/aiohttp", + "height": "20", + "alt": "Latest PyPI package version", + }, + { + "image": "https://img.shields.io/discourse/status?server=https%3A%2F%2Faio-libs.discourse.group", + "target": "https://aio-libs.discourse.group", + "height": "20", + "alt": "Discourse status", + }, + { + "image": "https://badges.gitter.im/Join%20Chat.svg", + "target": "https://gitter.im/aio-libs/Lobby", + "height": "20", + "alt": "Chat on Gitter", + }, + ], } # Add any paths that contain custom themes here, relative to this directory. @@ -188,12 +205,12 @@ # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -html_favicon = 'aiohttp-icon.ico' +html_favicon = "favicon.ico" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied @@ -210,8 +227,10 @@ # Custom sidebar templates, maps document names to template names. html_sidebars = { - '**': [ - 'about.html', 'navigation.html', 'searchbox.html', + "**": [ + "about.html", + "navigation.html", + "searchbox.html", ] } @@ -246,7 +265,7 @@ # html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'aiohttpdoc' +htmlhelp_basename = "aiohttpdoc" # -- Options for LaTeX output --------------------------------------------- @@ -254,10 +273,8 @@ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # 'preamble': '', } @@ -266,8 +283,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ('index', 'aiohttp.tex', 'aiohttp Documentation', - 'aiohttp contributors', 'manual'), + ("index", "aiohttp.tex", "aiohttp Documentation", "aiohttp contributors", "manual"), ] # The name of an image file (relative to this directory) to place at the top of @@ -295,10 +311,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'aiohttp', 'aiohttp Documentation', - ['aiohttp'], 1) -] +man_pages = [("index", "aiohttp", "aiohttp Documentation", ["aiohttp"], 1)] # If true, show URL addresses after external links. # man_show_urls = False @@ -310,9 +323,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'aiohttp', 'aiohttp Documentation', - 'Aiohttp contributors', 'aiohttp', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "aiohttp", + "aiohttp Documentation", + "Aiohttp contributors", + "aiohttp", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. @@ -326,6 +345,3 @@ # If true, do not generate a @detailmenu in the "Top" node's menu. # texinfo_no_detailmenu = False - - -disqus_shortname = 'aiohttp' diff --git a/docs/contributing.rst b/docs/contributing.rst index 2f4d72b12c4..5ecb4454a72 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -1,7 +1,325 @@ .. _aiohttp-contributing: -.. include:: ../CONTRIBUTING.rst +Contributing +============ +Instructions for contributors +----------------------------- -.. disqus:: - :title: instructions for aiohttp contributors +In order to make a clone of the GitHub_ repo: open the link and press the "Fork" button on the upper-right menu of the web page. + +I hope everybody knows how to work with git and github nowadays :) + +Workflow is pretty straightforward: + + 0. Make sure you are reading the latest version of this document. + It can be found in the GitHub_ repo in the ``docs`` subdirectory. + + 1. Clone the GitHub_ repo using the ``--recurse-submodules`` argument + + 2. Setup your machine with the required dev environment + + 3. Make a change + + 4. Make sure all tests passed + + 5. Add a file into the ``CHANGES`` folder (see `Changelog update`_ for how). + + 6. Commit changes to your own aiohttp clone + + 7. Make a pull request from the github page of your clone against the master branch + + 8. Optionally make backport Pull Request(s) for landing a bug fix into released aiohttp versions. + +.. note:: + + The project uses *Squash-and-Merge* strategy for *GitHub Merge* button. + + Basically it means that there is **no need to rebase** a Pull Request against + *master* branch. Just ``git merge`` *master* into your working copy (a fork) if + needed. The Pull Request is automatically squashed into the single commit + once the PR is accepted. + +.. note:: + + GitHub issue and pull request threads are automatically locked when there has + not been any recent activity for one year. Please open a `new issue + `_ for related bugs. + + If you feel like there are important points in the locked discussions, + please include those excerpts into that new issue. + + +Preconditions for running aiohttp test suite +-------------------------------------------- + +We expect you to use a python virtual environment to run our tests. + +There are several ways to make a virtual environment. + +If you like to use *virtualenv* please run: + +.. code-block:: shell + + $ cd aiohttp + $ virtualenv --python=`which python3` venv + $ . venv/bin/activate + +For standard python *venv*: + +.. code-block:: shell + + $ cd aiohttp + $ python3 -m venv venv + $ . venv/bin/activate + +For *virtualenvwrapper*: + +.. code-block:: shell + + $ cd aiohttp + $ mkvirtualenv --python=`which python3` aiohttp + +There are other tools like *pyvenv* but you know the rule of thumb now: create a python3 virtual environment and activate it. + +After that please install libraries required for development: + +.. code-block:: shell + + $ pip install -r requirements/dev.txt + +.. note:: + + For now, the development tooling depends on ``make`` and assumes an Unix OS If you wish to contribute to aiohttp from a Windows machine, the easiest way is probably to `configure the WSL `_ so you can use the same instructions. If it's not possible for you or if it doesn't work, please contact us so we can find a solution together. + +Install pre-commit hooks: + +.. code-block:: shell + + $ pre-commit install + +.. warning:: + + If you plan to use temporary ``print()``, ``pdb`` or ``ipdb`` within the test suite, execute it with ``-s``: + + .. code-block:: shell + + $ pytest tests -s + + in order to run the tests without output capturing. + +Congratulations, you are ready to run the test suite! + + +Run autoformatter +----------------- + +The project uses black_ + isort_ formatters to keep the source code style. +Please run `make fmt` after every change before starting tests. + + .. code-block:: shell + + $ make fmt + + +Run aiohttp test suite +---------------------- + +After all the preconditions are met you can run tests typing the next +command: + +.. code-block:: shell + + $ make test + +The command at first will run the *linters* (sorry, we don't accept +pull requests with pyflakes, black, isort, or mypy errors). + +On *lint* success the tests will be run. + +Please take a look on the produced output. + +Any extra texts (print statements and so on) should be removed. + + +Tests coverage +-------------- + +We are trying hard to have good test coverage; please don't make it worse. + +Use: + +.. code-block:: shell + + $ make cov-dev + +to run test suite and collect coverage information. Once the command +has finished check your coverage at the file that appears in the last +line of the output: +``open file:///.../aiohttp/htmlcov/index.html`` + +Please go to the link and make sure that your code change is covered. + + +The project uses *codecov.io* for storing coverage results. Visit +https://codecov.io/gh/aio-libs/aiohttp for looking on coverage of +master branch, history, pull requests etc. + +The browser extension https://docs.codecov.io/docs/browser-extension +is highly recommended for analyzing the coverage just in *Files +Changed* tab on *GitHub Pull Request* review page. + +Documentation +------------- + +We encourage documentation improvements. + +Please before making a Pull Request about documentation changes run: + +.. code-block:: shell + + $ make doc + +Once it finishes it will output the index html page +``open file:///.../aiohttp/docs/_build/html/index.html``. + +Go to the link and make sure your doc changes looks good. + +Spell checking +-------------- + +We use ``pyenchant`` and ``sphinxcontrib-spelling`` for running spell +checker for documentation: + +.. code-block:: shell + + $ make doc-spelling + +Unfortunately there are problems with running spell checker on MacOS X. + +To run spell checker on Linux box you should install it first: + +.. code-block:: shell + + $ sudo apt-get install enchant + $ pip install sphinxcontrib-spelling + +Changelog update +---------------- + +The ``CHANGES.rst`` file is managed using `towncrier +`_ tool and all non trivial +changes must be accompanied by a news entry. + +To add an entry to the news file, first you need to have created an +issue describing the change you want to make. A Pull Request itself +*may* function as such, but it is preferred to have a dedicated issue +(for example, in case the PR ends up rejected due to code quality +reasons). + +Once you have an issue or pull request, you take the number and you +create a file inside of the ``CHANGES/`` directory named after that +issue number with an extension of ``.removal``, ``.feature``, +``.bugfix``, or ``.doc``. Thus if your issue or PR number is ``1234`` and +this change is fixing a bug, then you would create a file +``CHANGES/1234.bugfix``. PRs can span multiple categories by creating +multiple files (for instance, if you added a feature and +deprecated/removed the old feature at the same time, you would create +``CHANGES/NNNN.feature`` and ``CHANGES/NNNN.removal``). Likewise if a PR touches +multiple issues/PRs you may create a file for each of them with the +exact same contents and *Towncrier* will deduplicate them. + +The contents of this file are *reStructuredText* formatted text that +will be used as the content of the news file entry. You do not need to +reference the issue or PR numbers here as *towncrier* will automatically +add a reference to all of the affected issues when rendering the news +file. + + + +Making a Pull Request +--------------------- + +After finishing all steps make a GitHub_ Pull Request with *master* base branch. + + +Backporting +----------- + +All Pull Requests are created against *master* git branch. + +If the Pull Request is not a new functionality but bug fixing +*backport* to maintenance branch would be desirable. + +*aiohttp* project committer may ask for making a *backport* of the PR +into maintained branch(es), in this case he or she adds a github label +like *needs backport to 3.1*. + +*Backporting* is performed *after* main PR merging into master. + Please do the following steps: + +1. Find *Pull Request's commit* for cherry-picking. + + *aiohttp* does *squashing* PRs on merging, so open your PR page on + github and scroll down to message like ``asvetlov merged commit + f7b8921 into master 9 days ago``. ``f7b8921`` is the required commit number. + +2. Run `cherry_picker + `_ + tool for making backport PR (the tool is already pre-installed from + ``./requirements/dev.txt``), e.g. ``cherry_picker f7b8921 3.1``. + +3. In case of conflicts fix them and continue cherry-picking by + ``cherry_picker --continue``. + + ``cherry_picker --abort`` stops the process. + + ``cherry_picker --status`` shows current cherry-picking status + (like ``git status``) + +4. After all conflicts are done the tool opens a New Pull Request page + in a browser with pre-filed information. Create a backport Pull + Request and wait for review/merging. + +5. *aiohttp* *committer* should remove *backport Git label* after + merging the backport. + +How to become an aiohttp committer +---------------------------------- + +Contribute! + +The easiest way is providing Pull Requests for issues in our bug +tracker. But if you have a great idea for the library improvement +-- please make an issue and Pull Request. + + + +The rules for committers are simple: + +1. No wild commits! Everything should go through PRs. +2. Take a part in reviews. It's very important part of maintainer's activity. +3. Pickup issues created by others, especially if they are simple. +4. Keep test suite comprehensive. In practice it means leveling up + coverage. 97% is not bad but we wish to have 100% someday. Well, 99% + is good target too. +5. Don't hesitate to improve our docs. Documentation is very important + thing, it's the key for project success. The documentation should + not only cover our public API but help newbies to start using the + project and shed a light on non-obvious gotchas. + + + +After positive answer aiohttp committer creates an issue on github +with the proposal for nomination. If the proposal will collect only +positive votes and no strong objection -- you'll be a new member in +our team. + + +.. _GitHub: https://github.com/aio-libs/aiohttp + +.. _ipdb: https://pypi.python.org/pypi/ipdb + +.. _black: https://pypi.python.org/pypi/black + +.. _isort: https://pypi.python.org/pypi/isort diff --git a/docs/deployment.rst b/docs/deployment.rst index 49e18a16842..e542a3409e2 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -1,6 +1,8 @@ -========================= -aiohttp server deployment -========================= +.. _aiohttp-deployment: + +================= +Server Deployment +================= There are several options for aiohttp server deployment: @@ -24,7 +26,7 @@ Just call :func:`aiohttp.web.run_app` function passing The method is very simple and could be the best solution in some -trivial cases. But it doesn't utilize all CPU cores. +trivial cases. But it does not utilize all CPU cores. For running multiple aiohttp server instances use *reverse proxies*. @@ -50,7 +52,7 @@ Nginx configuration -------------------- Here is short extraction about writing Nginx configuration file. -It doesn't cover all available Nginx options. +It does not cover all available Nginx options. For full reference read `Nginx tutorial `_ and `official Nginx @@ -64,9 +66,9 @@ First configure HTTP server itself: http { server { listen 80; - client_max_body_size 4G; + client_max_body_size 4G; - server example.com; + server_name example.com; location / { proxy_set_header Host $http_host; @@ -145,10 +147,10 @@ Here we'll use `Supervisord `_ for example: process_name = example_%(process_num)s ; Unix socket paths are specified by command line. - cmd=/path/to/aiohttp_example.py --path=/tmp/example%(process_num)s.sock + command=/path/to/aiohttp_example.py --path=/tmp/example_%(process_num)s.sock ; We can just as easily pass TCP port numbers: - ; cmd=/path/to/aiohttp_example.py --port=808%(process_num)s + ; command=/path/to/aiohttp_example.py --port=808%(process_num)s user=nobody autostart=true @@ -195,7 +197,7 @@ pre-fork worker model. Gunicorn launches your app as worker processes for handling incoming requests. In opposite to deployment with :ref:`bare Nginx -` the solution doesn't need to +` the solution does not need to manually run several aiohttp processes and use tool like supervisord for monitoring it. But nothing is for free: running aiohttp application under gunicorn is slightly slower. @@ -205,27 +207,23 @@ Prepare environment ------------------- You firstly need to setup your deployment environment. This example is -based on `Ubuntu` 14.04. +based on `Ubuntu `_ 16.04. Create a directory for your application:: >> mkdir myapp >> cd myapp -`Ubuntu` has a bug in pyenv, so to create virtualenv you need to do some -extra manipulation:: +Create Python virtual environment:: - >> pyvenv-3.4 --without-pip venv - >> source venv/bin/activate - >> curl https://bootstrap.pypa.io/get-pip.py | python - >> deactivate + >> python3 -m venv venv >> source venv/bin/activate Now that the virtual environment is ready, we'll proceed to install aiohttp and gunicorn:: >> pip install gunicorn - >> pip install -e git+https://github.com/aio-libs/aiohttp.git#egg=aiohttp + >> pip install aiohttp Application @@ -236,7 +234,7 @@ name this file *my_app_module.py*:: from aiohttp import web - def index(request): + async def index(request): return web.Response(text="Welcome home!") @@ -244,36 +242,53 @@ name this file *my_app_module.py*:: my_web_app.router.add_get('/', index) +Application factory +------------------- + +As an option an entry point could be a coroutine that accepts no +parameters and returns an application instance:: + + from aiohttp import web + + async def index(request): + return web.Response(text="Welcome home!") + + + async def my_web_app(): + app = web.Application() + app.router.add_get('/', index) + return app + + Start Gunicorn -------------- When `Running Gunicorn `_, you provide the name -of the module, i.e. *my_app_module*, and the name of the app, -i.e. *my_web_app*, along with other `Gunicorn Settings -`_ provided as -command line flags or in your config file. +of the module, i.e. *my_app_module*, and the name of the app or +application factory, i.e. *my_web_app*, along with other `Gunicorn +Settings `_ provided +as command line flags or in your config file. In this case, we will use: -* the *'--bind'* flag to set the server's socket address; -* the *'--worker-class'* flag to tell Gunicorn that we want to use a +* the ``--bind`` flag to set the server's socket address; +* the ``--worker-class`` flag to tell Gunicorn that we want to use a custom worker subclass instead of one of the Gunicorn default worker types; -* you may also want to use the *'--workers'* flag to tell Gunicorn how +* you may also want to use the ``--workers`` flag to tell Gunicorn how many worker processes to use for handling requests. (See the documentation for recommendations on `How Many Workers? `_) +* you may also want to use the ``--accesslog`` flag to enable the access + log to be populated. (See :ref:`logging ` for more information.) -The custom worker subclass is defined in -*aiohttp.GunicornWebWorker* and should be used instead of the -*gaiohttp* worker provided by Gunicorn, which supports only -aiohttp.wsgi applications:: +The custom worker subclass is defined in ``aiohttp.GunicornWebWorker``:: >> gunicorn my_app_module:my_web_app --bind localhost:8080 --worker-class aiohttp.GunicornWebWorker - [2015-03-11 18:27:21 +0000] [1249] [INFO] Starting gunicorn 19.3.0 - [2015-03-11 18:27:21 +0000] [1249] [INFO] Listening at: http://127.0.0.1:8080 (1249) - [2015-03-11 18:27:21 +0000] [1249] [INFO] Using worker: aiohttp.worker.GunicornWebWorker + [2017-03-11 18:27:21 +0000] [1249] [INFO] Starting gunicorn 19.7.1 + [2017-03-11 18:27:21 +0000] [1249] [INFO] Listening at: http://127.0.0.1:8080 (1249) + [2017-03-11 18:27:21 +0000] [1249] [INFO] Using worker: aiohttp.worker.GunicornWebWorker [2015-03-11 18:27:21 +0000] [1253] [INFO] Booting worker with pid: 1253 Gunicorn is now running and ready to serve requests to your app's @@ -285,15 +300,99 @@ worker processes. `uvloop `_, you can use the ``aiohttp.GunicornUVLoopWebWorker`` worker class. +Proxy through NGINX +---------------------- + +We can proxy our gunicorn workers through NGINX with a configuration like this: + +.. code-block:: nginx + + worker_processes 1; + user nobody nogroup; + events { + worker_connections 1024; + } + http { + ## Main Server Block + server { + ## Open by default. + listen 80 default_server; + server_name main; + client_max_body_size 200M; + + ## Main site location. + location / { + proxy_pass http://127.0.0.1:8080; + proxy_set_header Host $host; + proxy_set_header X-Forwarded-Host $server_name; + proxy_set_header X-Real-IP $remote_addr; + } + } + } + +Since gunicorn listens for requests at our localhost address on port 8080, we can +use the `proxy_pass `_ +directive to send web traffic to our workers. If everything is configured correctly, +we should reach our application at the ip address of our web server. + +Proxy through NGINX + SSL +---------------------------- + +Here is an example NGINX configuration setup to accept SSL connections: + +.. code-block:: nginx + + worker_processes 1; + user nobody nogroup; + events { + worker_connections 1024; + } + http { + ## SSL Redirect + server { + listen 80 default; + return 301 https://$host$request_uri; + } + + ## Main Server Block + server { + # Open by default. + listen 443 ssl default_server; + listen [::]:443 ssl default_server; + server_name main; + client_max_body_size 200M; + + ssl_certificate /etc/secrets/cert.pem; + ssl_certificate_key /etc/secrets/key.pem; + + ## Main site location. + location / { + proxy_pass http://127.0.0.1:8080; + proxy_set_header Host $host; + proxy_set_header X-Forwarded-Host $server_name; + proxy_set_header X-Real-IP $remote_addr; + } + } + } + + +The first server block accepts regular http connections on port 80 and redirects +them to our secure SSL connection. The second block matches our previous example +except we need to change our open port to https and specify where our SSL +certificates are being stored with the ``ssl_certificate`` and ``ssl_certificate_key`` +directives. + +During development, you may want to `create your own self-signed certificates for testing purposes `_ +and use another service like `Let's Encrypt `_ when you +are ready to move to production. More information ---------------- -The Gunicorn documentation recommends deploying Gunicorn behind an -Nginx proxy server. See the `official documentation +See the `official documentation `_ for more -information about suggested nginx configuration. - +information about suggested nginx configuration. You can also find out more about +`configuring for secure https connections as well. `_ Logging configuration --------------------- @@ -302,10 +401,13 @@ Logging configuration By default aiohttp uses own defaults:: - '%a %l %u %t "%r" %s %b "%{Referrer}i" "%{User-Agent}i"' + '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"' -For more information please read :ref:`Format Specification for Accees +For more information please read :ref:`Format Specification for Access Log `. -.. disqus:: - :title: aiohttp deployment with gunicorn + +Proxy through Apache at your own risk +------------------------------------- +Issues have been reported using Apache2 in front of aiohttp server: +`#2687 Intermittent 502 proxy errors when running behind Apache `. diff --git a/docs/essays.rst b/docs/essays.rst index 7329ff6ba28..df83cd1915a 100644 --- a/docs/essays.rst +++ b/docs/essays.rst @@ -6,3 +6,5 @@ Essays new_router whats_new_1_1 + migration_to_2xx + whats_new_3_0 diff --git a/docs/external.rst b/docs/external.rst new file mode 100644 index 00000000000..55892dd8e77 --- /dev/null +++ b/docs/external.rst @@ -0,0 +1,18 @@ +Who uses aiohttp? +================= + +The list of *aiohttp* users: both libraries, big projects and web sites. + +Please don't hesitate to add your awesome project to the list by +making a Pull Request on GitHub_. + +If you like the project -- please go to GitHub_ and press *Star* button! + + +.. toctree:: + + third_party + built_with + powered_by + +.. _GitHub: https://github.com/aio-libs/aiohttp diff --git a/docs/faq.rst b/docs/faq.rst index 063ef278a65..4e1c30b7683 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -1,50 +1,65 @@ -Frequently Asked Questions -========================== +FAQ +=== + .. contents:: :local: -Are there any plans for @app.route decorator like in Flask? ------------------------------------------------------------ -There are couple issues here: +Are there plans for an @app.route decorator like in Flask? +---------------------------------------------------------- -* This adds huge problem name "configuration as side effect of importing". -* Route matching is order specific, it is very hard to maintain import order. -* In semi large application better to have routes table defined in one place. +As of aiohttp 2.3, :class:`~aiohttp.web.RouteTableDef` provides an API +similar to Flask's ``@app.route``. See +:ref:`aiohttp-web-alternative-routes-definition`. -For this reason feature will not be implemented. But if you really want to -use decorators just derive from web.Application and add desired method. +Unlike Flask's ``@app.route``, :class:`~aiohttp.web.RouteTableDef` +does not require an ``app`` in the module namespace (which often leads +to circular imports). +Instead, a :class:`~aiohttp.web.RouteTableDef` is decoupled from an application instance:: -Has aiohttp the Flask Blueprint or Django App concept? ------------------------------------------------------- + routes = web.RouteTableDef() -If you're planing to write big applications, maybe you must consider -use nested applications. They acts as a Flask Blueprint or like the -Django application concept. + @routes.get('/get') + async def handle_get(request): + ... -Using nested application you can add sub-applications to the main application. -see: :ref:`aiohttp-web-nested-applications`. + @routes.post('/post') + async def handle_post(request): + ... + app.router.add_routes(routes) -How to create route that catches urls with given prefix? ---------------------------------------------------------- -Try something like:: - app.router.add_route('*', '/path/to/{tail:.+}', sink_handler) +Does aiohttp have a concept like Flask's "blueprint" or Django's "app"? +----------------------------------------------------------------------- + +If you're writing a large application, you may want to consider +using :ref:`nested applications `, which +are similar to Flask's "blueprints" or Django's "apps". -Where first argument, star, means catch any possible method -(*GET, POST, OPTIONS*, etc), second matching ``url`` with desired prefix, -third -- handler. +See: :ref:`aiohttp-web-nested-applications`. -Where to put my database connection so handlers can access it? +How do I create a route that matches urls with a given prefix? -------------------------------------------------------------- -:class:`aiohttp.web.Application` object supports :class:`dict` -interface, and right place to store your database connections or any -other resource you want to share between handlers. Take a look on -following example:: +You can do something like the following: :: + + app.router.add_route('*', '/path/to/{tail:.+}', sink_handler) + +The first argument, ``*``, matches any HTTP method +(*GET, POST, OPTIONS*, etc). The second argument matches URLS with the desired prefix. +The third argument is the handler function. + + +Where do I put my database connection so handlers can access it? +---------------------------------------------------------------- + +:class:`aiohttp.web.Application` object supports the :class:`dict` +interface and provides a place to store your database connections or any +other resource you want to share between handlers. +:: async def go(request): db = request.app['db'] @@ -62,65 +77,45 @@ following example:: return app -Why the minimal supported version is Python 3.4.2 --------------------------------------------------- - -As of aiohttp **v0.18.0** we dropped support for Python 3.3 up to -3.4.1. The main reason for that is the :meth:`object.__del__` method, -which is fully working since Python 3.4.1 and we need it for proper -resource closing. - -The last Python 3.3, 3.4.0 compatible version of aiohttp is -**v0.17.4**. - -This should not be an issue for most aiohttp users (for example `Ubuntu` -14.04.3 LTS provides python upgraded to 3.4.3), however libraries -depending on aiohttp should consider this and either freeze aiohttp -version or drop Python 3.3 support as well. - -As of aiohttp **v1.0.0** we dropped support for Python 3.4.1 up to -3.4.2+ also. The reason is: `loop.is_closed` appears in 3.4.2+ - -Again, it should be not an issue at 2016 Summer because all major -distributions are switched to Python 3.5 now. - - -How a middleware may store a data for using by web-handler later? ------------------------------------------------------------------ +How can middleware store data for web handlers to use? +------------------------------------------------------ -:class:`aiohttp.web.Request` supports :class:`dict` interface as well -as :class:`aiohttp.web.Application`. +Both :class:`aiohttp.web.Request` and :class:`aiohttp.web.Application` +support the :class:`dict` interface. -Just put data inside *request*:: +Therefore, data may be stored inside a request object. :: async def handler(request): request['unique_key'] = data -See https://github.com/aio-libs/aiohttp_session code for inspiration, -``aiohttp_session.get_session(request)`` method uses ``SESSION_KEY`` -for saving request specific session info. +See https://github.com/aio-libs/aiohttp_session code for an example. +The ``aiohttp_session.get_session(request)`` method uses ``SESSION_KEY`` +for saving request-specific session information. + +As of aiohttp 3.0, all response objects are dict-like structures as +well. .. _aiohttp_faq_parallel_event_sources: -How to receive an incoming events from different sources in parallel? ---------------------------------------------------------------------- +Can a handler receive incoming events from different sources in parallel? +------------------------------------------------------------------------- -For example we have two event sources: +Yes. - 1. WebSocket for event from end user +As an example, we may have two event sources: - 2. Redis PubSub from receiving events from other parts of app for - sending them to user via websocket. + 1. WebSocket for events from an end user -The most native way to perform it is creation of separate task for -pubsub handling. + 2. Redis PubSub for events from other parts of the application -Parallel :meth:`aiohttp.web.WebSocketResponse.receive` calls are forbidden, only -the single task should perform websocket reading. +The most native way to handle this is to create a separate task for +PubSub handling. -But other tasks may use the same websocket object for sending data to -peer:: +Parallel :meth:`aiohttp.web.WebSocketResponse.receive` calls are forbidden; +a single task should perform WebSocket reading. +However, other tasks may use the same WebSocket object for sending data to +peers. :: async def handler(request): @@ -143,225 +138,267 @@ peer:: try: async for msg in channel.iter(): - answer = process message(msg) - ws.send_str(answer) + answer = process_the_message(msg) # your function here + await ws.send_str(answer) finally: await redis.unsubscribe('channel:1') .. _aiohttp_faq_terminating_websockets: -How to programmatically close websocket server-side? ----------------------------------------------------- - +How do I programmatically close a WebSocket server-side? +-------------------------------------------------------- -For example we have an application with two endpoints: +Let's say we have an application with two endpoints: - 1. ``/echo`` a websocket echo server that authenticates the user somehow - 2. ``/logout_user`` that when invoked needs to close all open - websockets for that user. + 1. ``/echo`` a WebSocket echo server that authenticates the user + 2. ``/logout_user`` that, when invoked, closes all open + WebSockets for that user. -Keep in mind that you can only ``.close()`` a websocket from inside -the handler task, and since the handler task is busy reading from the -websocket, it can't react to other events. - -One simple solution is keeping a shared registry of websocket handler -tasks for a user in the :class:`aiohttp.web.Application` instance and -``cancel()`` them in ``/logout_user`` handler:: +One simple solution is to keep a shared registry of WebSocket +responses for a user in the :class:`aiohttp.web.Application` instance +and call :meth:`aiohttp.web.WebSocketResponse.close` on all of them in +``/logout_user`` handler:: async def echo_handler(request): ws = web.WebSocketResponse() user_id = authenticate_user(request) await ws.prepare(request) - request.app['websockets'][user_id].add(asyncio.Task.current_task()) - + request.app['websockets'][user_id].add(ws) try: async for msg in ws: - # handle incoming messages - ... - - except asyncio.CancelledError: - print('websocket cancelled') + ws.send_str(msg.data) finally: - request.app['websockets'][user_id].remove(asyncio.Task.current_task()) - await ws.close() + request.app['websockets'][user_id].remove(ws) + return ws + async def logout_handler(request): user_id = authenticate_user(request) - for task in request.app['websockets'][user_id]: - task.cancel() + ws_closers = [ws.close() + for ws in request.app['websockets'][user_id] + if not ws.closed] + + # Watch out, this will keep us from returing the response + # until all are closed + ws_closers and await asyncio.gather(*ws_closers) + + return web.Response(text='OK') - # return response - ... def main(): loop = asyncio.get_event_loop() - app = aiohttp.web.Application(loop=loop) + app = web.Application(loop=loop) app.router.add_route('GET', '/echo', echo_handler) app.router.add_route('POST', '/logout', logout_handler) - app['handlers'] = defaultdict(set) - aiohttp.web.run_app(app, host='localhost', port=8080) + app['websockets'] = defaultdict(set) + web.run_app(app, host='localhost', port=8080) -How to make request from a specific IP address? ------------------------------------------------ +How do I make a request from a specific IP address? +--------------------------------------------------- -If your system has several IP interfaces you may choose one which will -be used used to bind socket locally:: +If your system has several IP interfaces, you may choose one which will +be used used to bind a socket locally:: - conn = aiohttp.TCPConnector(local_addr=('127.0.0.1, 0), loop=loop) + conn = aiohttp.TCPConnector(local_addr=('127.0.0.1', 0), loop=loop) async with aiohttp.ClientSession(connector=conn) as session: ... .. seealso:: :class:`aiohttp.TCPConnector` and ``local_addr`` parameter. -.. _aiohttp_faq_tests_and_implicit_loop: +What is the API stability and deprecation policy? +------------------------------------------------- +*aiohttp* follows strong `Semantic Versioning `_ (SemVer). -How to use aiohttp test features with code which works with implicit loop? --------------------------------------------------------------------------- +Obsolete attributes and methods are marked as *deprecated* in the +documentation and raise :class:`DeprecationWarning` upon usage. -Passing explicit loop everywhere is the recommended way. But -sometimes, in case you have many nested non well-written services, -this is impossible. +Assume aiohttp ``X.Y.Z`` where ``X`` is major version, +``Y`` is minor version and ``Z`` is bugfix number. -There is a technique based on monkey-patching your low level service -that depends on aioes, to inject the loop at that level. This way, you -just need your ``AioESService`` with the loop in its signature. An -example would be the following:: +For example, if the latest released version is ``aiohttp==3.0.6``: - import pytest +``3.0.7`` fixes some bugs but have no new features. - from unittest.mock import patch, MagicMock +``3.1.0`` introduces new features and can deprecate some API but never +remove it, also all bug fixes from previous release are merged. - from main import AioESService, create_app +``4.0.0`` removes all deprecations collected from ``3.Y`` versions +**except** deprecations from the **last** ``3.Y`` release. These +deprecations will be removed by ``5.0.0``. - class TestAcceptance: +Unfortunately we may have to break these rules when a **security +vulnerability** is found. +If a security problem cannot be fixed without breaking backward +compatibility, a bugfix release may break compatibility. This is unlikely, but +possible. - async def test_get(self, test_client, loop): - with patch("main.AioESService", MagicMock( - side_effect=lambda *args, **kwargs: AioESService(*args, - **kwargs, - loop=loop))): - client = await test_client(create_app) - resp = await client.get("/") - assert resp.status == 200 +All backward incompatible changes are explicitly marked in +:ref:`the changelog `. -Note how we are patching the ``AioESService`` with and instance of itself but -adding the explicit loop as an extra (you need to load the loop fixture in your -test signature). -The final code to test all this (you will need a local instance of -elasticsearch running):: +How do I enable gzip compression globally for my entire application? +-------------------------------------------------------------------- - import asyncio +It's impossible. Choosing what to compress and what not to compress is +is a tricky matter. - from aioes import Elasticsearch - from aiohttp import web +If you need global compression, write a custom middleware. Or +enable compression in NGINX (you are deploying aiohttp behind reverse +proxy, right?). - class AioESService: +How do I manage a ClientSession within a web server? +---------------------------------------------------- - def __init__(self, loop=None): - self.es = Elasticsearch(["127.0.0.1:9200"], loop=loop) +:class:`aiohttp.ClientSession` should be created once for the lifetime +of the server in order to benefit from connection pooling. - async def get_info(self): - return await self.es.info() +Sessions save cookies internally. If you don't need cookie processing, +use :class:`aiohttp.DummyCookieJar`. If you need separate cookies +for different http calls but process them in logical chains, use a single +:class:`aiohttp.TCPConnector` with separate +client sessions and ``connector_owner=False``. - class MyService: +How do I access database connections from a subapplication? +----------------------------------------------------------- - def __init__(self): - self.aioes_service = AioESService() +Restricting access from subapplication to main (or outer) app is a +deliberate choice. - async def get_es_info(self): - return await self.aioes_service.get_info() +A subapplication is an isolated unit by design. If you need to share a +database object, do it explicitly:: + subapp['db'] = mainapp['db'] + mainapp.add_subapp('/prefix', subapp) - async def hello_aioes(request): - my_service = MyService() - cluster_info = await my_service.get_es_info() - return web.Response(text="{}".format(cluster_info)) +How do I perform operations in a request handler after sending the response? +---------------------------------------------------------------------------- - def create_app(loop=None): +Middlewares can be written to handle post-response operations, but +they run after every request. You can explicitly send the response by +calling :meth:`aiohttp.web.Response.write_eof`, which starts sending +before the handler returns, giving you a chance to execute follow-up +operations:: - app = web.Application(loop=loop) - app.router.add_route('GET', '/', hello_aioes) - return app + def ping_handler(request): + """Send PONG and increase DB counter.""" + # explicitly send the response + resp = web.json_response({'message': 'PONG'}) + await resp.prepare(request) + await resp.write_eof() - if __name__ == "__main__": - web.run_app(create_app()) + # increase the pong count + APP['db'].inc_pong() + return resp -And the full tests file:: +A :class:`aiohttp.web.Response` object must be returned. This is +required by aiohttp web contracts, even though the response +already been sent. - from unittest.mock import patch, MagicMock +How do I make sure my custom middleware response will behave correctly? +------------------------------------------------------------------------ - from main import AioESService, create_app +Sometimes your middleware handlers might need to send a custom response. +This is just fine as long as you always create a new +:class:`aiohttp.web.Response` object when required. +The response object is a Finite State Machine. Once it has been dispatched +by the server, it will reach its final state and cannot be used again. - class TestAioESService: +The following middleware will make the server hang, once it serves the second +response:: - async def test_get_info(self, loop): - cluster_info = await AioESService("random_arg", loop=loop).get_info() - assert isinstance(cluster_info, dict) + from aiohttp import web + def misbehaved_middleware(): + # don't do this! + cached = web.Response(status=200, text='Hi, I am cached!') - class TestAcceptance: + @web.middleware + async def middleware(request, handler): + # ignoring response for the sake of this example + _res = handler(request) + return cached - async def test_get(self, test_client, loop): - with patch("main.AioESService", MagicMock( - side_effect=lambda *args, **kwargs: AioESService(*args, - **kwargs, - loop=loop))): - client = await test_client(create_app) - resp = await client.get("/") - assert resp.status == 200 + return middleware -Note how we are using the ``side_effect`` feature for injecting the loop to the -``AioESService.__init__`` call. The use of ``**args, **kwargs`` is mandatory -in order to propagate the arguments being used by the caller. +The rule of thumb is *one request, one response*. -API stability and deprecation policy ------------------------------------- +Why is creating a ClientSession outside of an event loop dangerous? +------------------------------------------------------------------- -aiohttp tries to not break existing users code. +Short answer is: life-cycle of all asyncio objects should be shorter +than life-cycle of event loop. -Obsolete attributes and methods are marked as *deprecated* in -documentation and raises :class:`DeprecationWarning` on usage. +Full explanation is longer. All asyncio object should be correctly +finished/disconnected/closed before event loop shutdown. Otherwise +user can get unexpected behavior. In the best case it is a warning +about unclosed resource, in the worst case the program just hangs, +awaiting for coroutine is never resumed etc. -Deprecation period is usually a year and half. +Consider the following code from ``mod.py``:: -After the period is passed out deprecated code is be removed. + import aiohttp -Unfortunately we should break own rules if new functionality or bug -fixing forces us to do it (for example proper cookies support on -client side forced us to break backward compatibility twice). + session = aiohttp.ClientSession() -All *backward incompatible* changes are explicitly marked in -:ref:`CHANGES ` chapter. + async def fetch(url): + async with session.get(url) as resp: + return await resp.text() +The session grabs current event loop instance and stores it in a +private variable. -How to enable gzip compression globally for the whole application? ------------------------------------------------------------------- +The main module imports the module and installs ``uvloop`` (an +alternative fast event loop implementation). -It's impossible. Choosing what to compress and where don't apply such -time consuming operation is very tricky matter. +``main.py``:: -If you need global compression -- write own custom middleware. Or -enable compression in NGINX (you are deploying aiohttp behind reverse -proxy, isn't it). + import asyncio + import uvloop + import mod + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + asyncio.run(main()) + +The code is broken: ``session`` is bound to default ``asyncio`` loop +on import time but the loop is changed **after the import** by +``set_event_loop()``. As result ``fetch()`` call hangs. + + +To avoid import dependency hell *aiohttp* encourages creation of +``ClientSession`` from async function. The same policy works for +``web.Application`` too. + +Another use case is unit test writing. Very many test libraries +(*aiohttp test tools* first) creates a new loop instance for every +test function execution. It's done for sake of tests isolation. +Otherwise pending activity (timers, network packets etc.) from +previous test may interfere with current one producing very cryptic +and unstable test failure. + +Note: *class variables* are hidden globals actually. The following +code has the same problem as ``mod.py`` example, ``session`` variable +is the hidden global object:: + class A: + session = aiohttp.ClientSession() -.. disqus:: - :title: aiohttp FAQ + async def fetch(self, url): + async with session.get(url) as resp: + return await resp.text() diff --git a/docs/favicon.ico b/docs/favicon.ico new file mode 100644 index 00000000000..666937af428 Binary files /dev/null and b/docs/favicon.ico differ diff --git a/docs/glossary.rst b/docs/glossary.rst index 0288a848e8c..bc5e1169c33 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -52,6 +52,18 @@ http://gunicorn.org/ + IDNA + + An Internationalized Domain Name in Applications (IDNA) is an + industry standard for encoding Internet Domain Names that contain in + whole or in part, in a language-specific script or alphabet, + such as Arabic, Chinese, Cyrillic, Tamil, Hebrew or the Latin + alphabet-based characters with diacritics or ligatures, such as + French. These writing systems are encoded by computers in + multi-byte Unicode. Internationalized domain names are stored + in the Domain Name System as ASCII strings using Punycode + transcription. + keep-alive A technique for communicating between HTTP client and server @@ -68,6 +80,35 @@ https://nginx.org/en/ + percent-encoding + + A mechanism for encoding information in a Uniform Resource + Locator (URL) if URL parts don't fit in safe characters space. + + requests + + Currently the most popular synchronous library to make + HTTP requests in Python. + + https://requests.readthedocs.io + + requoting + + Applying :term:`percent-encoding` to non-safe symbols and decode + percent encoded safe symbols back. + + According to :rfc:`3986` allowed path symbols are:: + + allowed = unreserved / pct-encoded / sub-delims + / ":" / "@" / "/" + + pct-encoded = "%" HEXDIG HEXDIG + + unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + + sub-delims = "!" / "$" / "&" / "'" / "(" / ")" + / "*" / "+" / "," / ";" / "=" + resource A concept reflects the HTTP **path**, every resource corresponds @@ -96,8 +137,3 @@ A library for operating with URL objects. https://pypi.python.org/pypi/yarl - - - -.. disqus:: - :title: aiohttp glossary diff --git a/docs/http_request_lifecycle.rst b/docs/http_request_lifecycle.rst new file mode 100644 index 00000000000..e14fb03de5f --- /dev/null +++ b/docs/http_request_lifecycle.rst @@ -0,0 +1,110 @@ + + +.. _aiohttp-request-lifecycle: + + +The aiohttp Request Lifecycle +============================= + + +Why is aiohttp client API that way? +-------------------------------------- + + +The first time you use aiohttp, you'll notice that a simple HTTP request is performed not with one, but with up to three steps: + + +.. code-block:: python + + + async with aiohttp.ClientSession() as session: + async with session.get('http://python.org') as response: + print(await response.text()) + + +It's especially unexpected when coming from other libraries such as the very popular :term:`requests`, where the "hello world" looks like this: + + +.. code-block:: python + + + response = requests.get('http://python.org') + print(response.text) + + +So why is the aiohttp snippet so verbose? + + +Because aiohttp is asynchronous, its API is designed to make the most out of non-blocking network operations. In code like this, requests will block three times, and does it transparently, while aiohttp gives the event loop three opportunities to switch context: + + +- When doing the ``.get()``, both libraries send a GET request to the remote server. For aiohttp, this means asynchronous I/O, which is marked here with an ``async with`` that gives you the guarantee that not only it doesn't block, but that it's cleanly finalized. +- When doing ``response.text`` in requests, you just read an attribute. The call to ``.get()`` already preloaded and decoded the entire response payload, in a blocking manner. aiohttp loads only the headers when ``.get()`` is executed, letting you decide to pay the cost of loading the body afterward, in a second asynchronous operation. Hence the ``await response.text()``. +- ``async with aiohttp.ClientSession()`` does not perform I/O when entering the block, but at the end of it, it will ensure all remaining resources are closed correctly. Again, this is done asynchronously and must be marked as such. The session is also a performance tool, as it manages a pool of connections for you, allowing you to reuse them instead of opening and closing a new one at each request. You can even `manage the pool size by passing a connector object `_. + +Using a session as a best practice +----------------------------------- + +The requests library does in fact also provides a session system. Indeed, it lets you do: + +.. code-block:: python + + with requests.Session() as session: + response = session.get('http://python.org') + print(response.text) + +It's just not the default behavior, nor is it advertised early in the documentation. Because of this, most users take a hit in performance, but can quickly start hacking. And for requests, it's an understandable trade-off, since its goal is to be "HTTP for humans" and simplicity has always been more important than performance in this context. + +However, if one uses aiohttp, one chooses asynchronous programming, a paradigm that makes the opposite trade-off: more verbosity for better performance. And so the library default behavior reflects this, encouraging you to use performant best practices from the start. + +How to use the ClientSession ? +------------------------------- + +By default the :class:`aiohttp.ClientSession` object will hold a connector with a maximum of 100 connections, putting the rest in a queue. This is quite a big number, this means you must be connected to a hundred different servers (not pages!) concurrently before even having to consider if your task needs resource adjustment. + +In fact, you can picture the session object as a user starting and closing a browser: it wouldn't make sense to do that every time you want to load a new tab. + +So you are expected to reuse a session object and make many requests from it. For most scripts and average-sized software, this means you can create a single session, and reuse it for the entire execution of the program. You can even pass the session around as a parameter in functions. For example, the typical "hello world": + +.. code-block:: python + + import aiohttp + import asyncio + + async def main(): + async with aiohttp.ClientSession() as session: + async with session.get('http://python.org') as response: + html = await response.text() + print(html) + + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) + + +Can become this: + + +.. code-block:: python + + import aiohttp + import asyncio + + async def fetch(session, url): + async with session.get(url) as response: + return await response.text() + + async def main(): + async with aiohttp.ClientSession() as session: + html = await fetch(session, 'http://python.org') + print(html) + + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) + +On more complex code bases, you can even create a central registry to hold the session object from anywhere in the code, or a higher level ``Client`` class that holds a reference to it. + +When to create more than one session object then? It arises when you want more granularity with your resources management: + +- you want to group connections by a common configuration. e.g: sessions can set cookies, headers, timeout values, etc. that are shared for all connections they hold. +- you need several threads and want to avoid sharing a mutable object between them. +- you want several connection pools to benefit from different queues and assign priorities. e.g: one session never uses the queue and is for high priority requests, the other one has a small concurrency limit and a very long queue, for non important requests. diff --git a/docs/index.rst b/docs/index.rst index 0008d5e3003..13fe723b412 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,26 +3,31 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -aiohttp: Asynchronous HTTP Client/Server -======================================== +================== +Welcome to AIOHTTP +================== + +Asynchronous HTTP Client/Server for :term:`asyncio` and Python. -HTTP client/server for :term:`asyncio` (:pep:`3156`). +Current version is |release|. .. _GitHub: https://github.com/aio-libs/aiohttp -.. _Freenode: http://freenode.net -Features --------- +Key Features +============ - Supports both :ref:`aiohttp-client` and :ref:`HTTP Server `. - - Supports both :ref:`Server WebSockets ` and - :ref:`Client WebSockets ` out-of-the-box. +- Supports both :ref:`Server WebSockets ` and + :ref:`Client WebSockets ` out-of-the-box + without the Callback Hell. - Web-server has :ref:`aiohttp-web-middlewares`, - :ref:`aiohttp-web-signals` and pluggable routing. + :ref:`aiohttp-web-signals` and plugable routing. + +.. _aiohttp-installation: Library Installation --------------------- +==================== .. code-block:: bash @@ -43,29 +48,56 @@ This option is highly recommended: $ pip install aiodns +Installing speedups altogether +------------------------------ + +The following will get you ``aiohttp`` along with :term:`chardet`, +:term:`aiodns` and ``brotlipy`` in one bundle. No need to type +separate commands anymore! + +.. code-block:: bash + + $ pip install aiohttp[speedups] + Getting Started ---------------- +=============== + +Client example +-------------- + +.. code-block:: python + + import aiohttp + import asyncio + + async def main(): + + async with aiohttp.ClientSession() as session: + async with session.get('http://python.org') as response: -Client example:: + print("Status:", response.status) + print("Content-type:", response.headers['content-type']) - import aiohttp - import asyncio - import async_timeout + html = await response.text() + print("Body:", html[:15], "...") - async def fetch(session, url): - with async_timeout.timeout(10): - async with session.get(url) as response: - return await response.text() + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) - async def main(loop): - async with aiohttp.ClientSession(loop=loop) as session: - html = await fetch(session, 'http://python.org') - print(html) +This prints: - loop = asyncio.get_event_loop() - loop.run_until_complete(main(loop)) +.. code-block:: text -Server example:: + Status: 200 + Content-type: text/html; charset=utf-8 + Body: ... + +Coming from :term:`requests` ? Read :ref:`why we need so many lines `. + +Server example: +---------------- + +.. code-block:: python from aiohttp import web @@ -75,38 +107,31 @@ Server example:: return web.Response(text=text) app = web.Application() - app.router.add_get('/', handle) - app.router.add_get('/{name}', handle) - - web.run_app(app) + app.add_routes([web.get('/', handle), + web.get('/{name}', handle)]) -.. note:: + if __name__ == '__main__': + web.run_app(app) - Throughout this documentation, examples utilize the `async/await` syntax - introduced by :pep:`492` that is only valid for Python 3.5+. - If you are using Python 3.4, please replace ``await`` with - ``yield from`` and ``async def`` with a ``@coroutine`` decorator. - For example, this:: +For more information please visit :ref:`aiohttp-client` and +:ref:`aiohttp-web` pages. - async def coro(...): - ret = await f() +What's new in aiohttp 3? +======================== - should be replaced by:: - - @asyncio.coroutine - def coro(...): - ret = yield from f() +Go to :ref:`aiohttp_whats_new_3_0` page for aiohttp 3.0 major release +changes. Tutorial --------- +======== -:ref:`Polls tutorial ` +:ref:`Polls tutorial ` Source code ------------ +=========== The project is hosted on GitHub_ @@ -114,17 +139,18 @@ Please feel free to file an issue on the `bug tracker `_ if you have found a bug or have some suggestion in order to improve the library. -The library uses `Travis `_ for +The library uses `Azure Pipelines `_ for Continuous Integration. Dependencies ------------- +============ -- Python 3.4.2+ +- Python 3.6+ +- *async_timeout* +- *attrs* - *chardet* - *multidict* -- *async_timeout* - *yarl* - *Optional* :term:`cchardet` as faster replacement for :term:`chardet`. @@ -143,22 +169,28 @@ Dependencies $ pip install aiodns -Discussion list ---------------- +Communication channels +====================== -*aio-libs* google group: https://groups.google.com/forum/#!forum/aio-libs +*aio-libs discourse group*: https://aio-libs.discourse.group Feel free to post your questions and ideas here. +*gitter chat* https://gitter.im/aio-libs/Lobby + +We support `Stack Overflow +`_. +Please add *aiohttp* tag to your question there. + Contributing ------------- +============ Please read the :ref:`instructions for contributors` before making a Pull Request. Authors and License -------------------- +=================== The ``aiohttp`` package is written mostly by Nikolay Kim and Andrew Svetlov. @@ -170,7 +202,7 @@ Feel free to improve this package and send a pull request to GitHub_. .. _aiohttp-backward-compatibility-policy: Policy for Backward Incompatible Changes ----------------------------------------- +======================================== *aiohttp* keeps backward compatibility. @@ -188,39 +220,17 @@ solved without major API change, but we are working hard for keeping these changes as rare as possible. -Contents --------- +Table Of Contents +================= .. toctree:: + :name: mastertoc + :maxdepth: 2 - migration client - client_reference - tutorial web - web_reference - web_lowlevel - abc - multipart - streams - api - logging - testing - deployment + utilities faq - third_party - essays + misc + external contributing - changes - glossary - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` - - -.. disqus:: - :title: aiohttp documentation diff --git a/docs/logging.rst b/docs/logging.rst index b6ad5ed1f51..916a7feff67 100644 --- a/docs/logging.rst +++ b/docs/logging.rst @@ -23,27 +23,40 @@ page does not provide instructions for logging subscribing while the most friendly method is :func:`logging.config.dictConfig` for configuring whole loggers in your application. +Logging does not work out of the box. It requires at least minimal ``'logging'`` +configuration. +Example of minimal working logger setup:: + import logging + from aiohttp import web + + app = web.Application() + logging.basicConfig(level=logging.DEBUG) + web.run_app(app, port=5000) + +.. versionadded:: 4.0.0 Access logs ----------- -Access log by default is switched on and uses ``'aiohttp.access'`` -logger name. +Access logs are enabled by default. If the `debug` flag is set, and the default +logger ``'aiohttp.access'`` is used, access logs will be output to +:obj:`~sys.stderr` if no handlers are attached. +Furthermore, if the default logger has no log level set, the log level will be +set to :obj:`logging.DEBUG`. -The log may be controlled by :meth:`aiohttp.web.Application.make_handler` call. +This logging may be controlled by :meth:`aiohttp.web.AppRunner` and +:func:`aiohttp.web.run_app`. -Pass *access_log* parameter with value of :class:`logging.Logger` -instance to override default logger. +To override the default logger, pass an instance of :class:`logging.Logger` to +override the default logger. .. note:: - Use ``app.make_handler(access_log=None)`` for disabling access logs. - + Use ``web.run_app(app, access_log=None)`` to disable access logs. -Other parameter called *access_log_format* may be used for specifying log -format (see below). +In addition, *access_log_format* may be used to specify the log format. .. _aiohttp-logging-access-log-format-spec: @@ -69,9 +82,7 @@ request and response: +--------------+---------------------------------------------------------+ | ``%s`` | Response status code | +--------------+---------------------------------------------------------+ -| ``%b`` | Size of response in bytes, excluding HTTP headers | -+--------------+---------------------------------------------------------+ -| ``%O`` | Bytes sent, including headers | +| ``%b`` | Size of response in bytes, including HTTP headers | +--------------+---------------------------------------------------------+ | ``%T`` | The time taken to serve the request, in seconds | +--------------+---------------------------------------------------------+ @@ -84,40 +95,62 @@ request and response: +--------------+---------------------------------------------------------+ | ``%{FOO}o`` | ``response.headers['FOO']`` | +--------------+---------------------------------------------------------+ -| ``%{FOO}e`` | ``os.environ['FOO']`` | -+--------------+---------------------------------------------------------+ -Default access log format is:: +The default access log format is:: - '%a %l %u %t "%r" %s %b "%{Referrer}i" "%{User-Agent}i"' + '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"' +.. versionadded:: 2.3.0 -.. note:: +*access_log_class* introduced. + +Example of a drop-in replacement for the default access logger:: + + from aiohttp.abc import AbstractAccessLogger + + class AccessLogger(AbstractAccessLogger): + + def log(self, request, response, time): + self.logger.info(f'{request.remote} ' + f'"{request.method} {request.path} ' + f'done in {time}s: {response.status}') + + +.. _gunicorn-accesslog: + +Gunicorn access logs +^^^^^^^^^^^^^^^^^^^^ +When `Gunicorn `_ is used for +:ref:`deployment `, its default access log format +will be automatically replaced with the default aiohttp's access log format. + +If Gunicorn's option access_logformat_ is +specified explicitly, it should use aiohttp's format specification. + +Gunicorn's access log works only if accesslog_ is specified explicitly in your +config or as a command line option. +This configuration can be either a path or ``'-'``. If the application uses +a custom logging setup intercepting the ``'gunicorn.access'`` logger, +accesslog_ should be set to ``'-'`` to prevent Gunicorn to create an empty +access log file upon every startup. - When `Gunicorn `_ is used for - :ref:`deployment ` its default access log format - will be automatically replaced with the default aiohttp's access log format. - If Gunicorn's option access_logformat_ is - specified explicitly it should use aiohttp's format specification. Error logs ---------- -*aiohttp.web* uses logger named ``'aiohttp.server'`` to store errors +:mod:`aiohttp.web` uses a logger named ``'aiohttp.server'`` to store errors given on web requests handling. -The log is enabled by default. +This log is enabled by default. -To use different logger name please specify *logger* parameter -(:class:`logging.Logger` instance) on performing -:meth:`aiohttp.web.Application.make_handler` call. +To use a different logger name, pass *logger* (:class:`logging.Logger` +instance) to the :meth:`aiohttp.web.AppRunner` constructor. .. _access_logformat: http://docs.gunicorn.org/en/stable/settings.html#access-log-format - -.. disqus:: - :title: aiohttp logging +.. _accesslog: + http://docs.gunicorn.org/en/stable/settings.html#accesslog diff --git a/docs/migration.rst b/docs/migration_to_2xx.rst similarity index 71% rename from docs/migration.rst rename to docs/migration_to_2xx.rst index 56ac96408f1..2f3c57ce24a 100644 --- a/docs/migration.rst +++ b/docs/migration_to_2xx.rst @@ -13,9 +13,10 @@ aiohttp does not support custom chunking sizes. It is up to the developer to decide how to chunk data streams. If chunking is enabled, aiohttp encodes the provided chunks in the "Transfer-encoding: chunked" format. -aiohttp does not enable chunked encoding automatically even if a *transfer-encoding* -header is supplied: *chunked* has to be set explicitly. If *chunked* is set, -then the *Transfer-encoding* and *content-length* headers are disallowed. +aiohttp does not enable chunked encoding automatically even if a +*transfer-encoding* header is supplied: *chunked* has to be set +explicitly. If *chunked* is set, then the *Transfer-encoding* and +*content-length* headers are disallowed. compression ^^^^^^^^^^^ @@ -29,21 +30,24 @@ Compression can not be combined with a *Content-Length* header. Client Connector ^^^^^^^^^^^^^^^^ -1. By default a connector object manages a total number of concurrent connections. - This limit was a per host rule in version 1.x. In 2.x, the `limit` parameter - defines how many concurrent connection connector can open and a new `limit_per_host` - parameter defines the limit per host. By default there is no per-host limit. -2. BaseConnector.close is now a normal function as opposed to coroutine in version 1.x +1. By default a connector object manages a total number of concurrent + connections. This limit was a per host rule in version 1.x. In + 2.x, the `limit` parameter defines how many concurrent connection + connector can open and a new `limit_per_host` parameter defines the + limit per host. By default there is no per-host limit. +2. BaseConnector.close is now a normal function as opposed to + coroutine in version 1.x 3. BaseConnector.conn_timeout was moved to ClientSession ClientResponse.release ^^^^^^^^^^^^^^^^^^^^^^ -Internal implementation was significantly redesigned. It is not required -to call `release` on the response object. When the client fully receives the payload, -the underlying connection automatically returns back to pool. If the payload is not -fully read, the connection is closed +Internal implementation was significantly redesigned. It is not +required to call `release` on the response object. When the client +fully receives the payload, the underlying connection automatically +returns back to pool. If the payload is not fully read, the connection +is closed Client exceptions @@ -54,25 +58,30 @@ exceptions that covers connection handling and server response misbehaviors. For developer specific mistakes, aiohttp uses python standard exceptions like ValueError or TypeError. -Reading a response content may raise a ClientPayloadError exception. This exception -indicates errors specific to the payload encoding. Such as invalid compressed data, -malformed chunked-encoded chunks or not enough data that satisfy the content-length header. +Reading a response content may raise a ClientPayloadError +exception. This exception indicates errors specific to the payload +encoding. Such as invalid compressed data, malformed chunked-encoded +chunks or not enough data that satisfy the content-length header. -All exceptions are moved from `aiohttp.errors` module to top level `aiohttp` module. +All exceptions are moved from `aiohttp.errors` module to top level +`aiohttp` module. New hierarchy of exceptions: * `ClientError` - Base class for all client specific exceptions - - `ClientResponseError` - exceptions that could happen after we get response from server + - `ClientResponseError` - exceptions that could happen after we get + response from server * `WSServerHandshakeError` - web socket server response error - `ClientHttpProxyError` - proxy response - - `ClientConnectionError` - exceptions related to low-level connection problems + - `ClientConnectionError` - exceptions related to low-level + connection problems - * `ClientOSError` - subset of connection errors that are initiated by an OSError exception + * `ClientOSError` - subset of connection errors that are initiated + by an OSError exception - `ClientConnectorError` - connector related exceptions @@ -86,24 +95,26 @@ New hierarchy of exceptions: * `ServerFingerprintMismatch` - server fingerprint mismatch - - `ClientPayloadError` - This exception can only be raised while reading the response - payload if one of these errors occurs: invalid compression, malformed chunked encoding or - not enough data that satisfy content-length header. + - `ClientPayloadError` - This exception can only be raised while + reading the response payload if one of these errors occurs: + invalid compression, malformed chunked encoding or not enough data + that satisfy content-length header. Client payload (form-data) ^^^^^^^^^^^^^^^^^^^^^^^^^^ -To unify form-data/payload handling a new `Payload` system was introduced. It handles -customized handling of existing types and provide implementation for user-defined types. +To unify form-data/payload handling a new `Payload` system was +introduced. It handles customized handling of existing types and +provide implementation for user-defined types. 1. FormData.__call__ does not take an encoding arg anymore and its return value changes from an iterator or bytes to a Payload instance. aiohttp provides payload adapters for some standard types like `str`, `byte`, `io.IOBase`, `StreamReader` or `DataQueue`. -2. a generator is not supported as data provider anymore, `streamer` can be used instead. - For example, to upload data from file:: +2. a generator is not supported as data provider anymore, `streamer` + can be used instead. For example, to upload data from file:: @aiohttp.streamer def file_sender(writer, file_name=None): @@ -132,15 +143,17 @@ Various 3. `aiohttp.MsgType` dropped, use `aiohttp.WSMsgType` instead. -4. `ClientResponse.url` is an instance of `yarl.URL` class (`url_obj` is deprecated) +4. `ClientResponse.url` is an instance of `yarl.URL` class (`url_obj` + is deprecated) -5. `ClientResponse.raise_for_status()` raises :exc:`aiohttp.ClientResponseError` exception +5. `ClientResponse.raise_for_status()` raises + :exc:`aiohttp.ClientResponseError` exception -6. `ClientResponse.json()` is strict about response's content type. if content type - does not match, it raises :exc:`aiohttp.ClientResponseError` exception. - To disable content type check you can pass ``None`` as `content_type` parameter. +6. `ClientResponse.json()` is strict about response's content type. if + content type does not match, it raises + :exc:`aiohttp.ClientResponseError` exception. To disable content + type check you can pass ``None`` as `content_type` parameter. -7. `ClientSession.close()` is a regular function returning None, not a coroutine. @@ -181,12 +194,13 @@ WebRequest and WebResponse 4. `FileSender` api is dropped, it is replaced with more general `FileResponse` class:: async def handle(request): - return web.FileResponse('path-to-file.txt) + return web.FileResponse('path-to-file.txt') 5. `WebSocketResponse.protocol` is renamed to `WebSocketResponse.ws_protocol`. `WebSocketResponse.protocol` is instance of `RequestHandler` class. + RequestPayloadError ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/misc.rst b/docs/misc.rst new file mode 100644 index 00000000000..dd9df05e9f5 --- /dev/null +++ b/docs/misc.rst @@ -0,0 +1,24 @@ +.. _aiohttp-misc: + +Miscellaneous +============= + +Helpful pages. + +.. toctree:: + :name: misc + + essays + glossary + +.. toctree:: + :titlesonly: + + changes + +Indices and tables +------------------ + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/multipart.rst b/docs/multipart.rst index 0958459f591..b6ecc639c51 100644 --- a/docs/multipart.rst +++ b/docs/multipart.rst @@ -1,4 +1,4 @@ -.. module:: aiohttp.multipart +.. currentmodule:: aiohttp .. _aiohttp-multipart: @@ -6,7 +6,7 @@ Working with Multipart ====================== ``aiohttp`` supports a full featured multipart reader and writer. Both -are designed with steaming processing in mind to avoid unwanted +are designed with streaming processing in mind to avoid unwanted footprint which may be significant if you're dealing with large payloads, but this also means that most I/O operation are only possible to be executed a single time. @@ -55,7 +55,7 @@ body part headers: this allows you to filter parts by their attributes:: metadata = await part.json() continue -Nor :class:`BodyPartReader` or :class:`MultipartReader` instances doesn't +Nor :class:`BodyPartReader` or :class:`MultipartReader` instances does not read the whole body part data without explicitly asking for. :class:`BodyPartReader` provides a set of helpers methods to fetch popular content types in friendly way: @@ -80,10 +80,10 @@ from it:: if part.filename != 'secret.txt': continue -If current body part doesn't matches your expectation and you want to skip it +If current body part does not matches your expectation and you want to skip it - just continue a loop to start a next iteration of it. Here is where magic happens. Before fetching the next body part ``await reader.next()`` it -ensures that the previous one was read completely. If it wasn't, all its content +ensures that the previous one was read completely. If it was not, all its content sends to the void in term to fetch the next part. So you don't have to care about cleanup routines while you're within a loop. @@ -146,15 +146,15 @@ the second argument:: {'CONTENT-TYPE': 'image/gif'}) For file objects `Content-Type` will be determined by using Python's -`mimetypes`_ module and additionally `Content-Disposition` header will include -the file's basename:: +mod:`mimetypes` module and additionally `Content-Disposition` header +will include the file's basename:: part = root.append(open(__file__, 'rb')) If you want to send a file with a different name, just handle the -:class:`BodyPartWriter` instance which :meth:`MultipartWriter.append` will +:class:`Payload` instance which :meth:`MultipartWriter.append` will always return and set `Content-Disposition` explicitly by using -the :meth:`BodyPartWriter.set_content_disposition` helper:: +the :meth:`Payload.set_content_disposition` helper:: part.set_content_disposition('attachment', filename='secret.txt') @@ -175,40 +175,69 @@ and form urlencoded data, so you don't have to encode it every time manually:: mpwriter.append_form([('key', 'value')]) When it's done, to make a request just pass a root :class:`MultipartWriter` -instance as :func:`aiohttp.client.request` `data` argument:: +instance as :meth:`aiohttp.ClientSession.request` ``data`` argument:: - await aiohttp.post('http://example.com', data=mpwriter) + await session.post('http://example.com', data=mpwriter) -Behind the scenes :meth:`MultipartWriter.serialize` will yield chunks of every +Behind the scenes :meth:`MultipartWriter.write` will yield chunks of every part and if body part has `Content-Encoding` or `Content-Transfer-Encoding` they will be applied on streaming content. -Please note, that on :meth:`MultipartWriter.serialize` all the file objects +Please note, that on :meth:`MultipartWriter.write` all the file objects will be read until the end and there is no way to repeat a request without rewinding their pointers to the start. +Example MJPEG Streaming ``multipart/x-mixed-replace``. By default +:meth:`MultipartWriter.write` appends closing ``--boundary--`` and breaks your +content. Providing `close_boundary = False` prevents this.:: + + my_boundary = 'some-boundary' + response = web.StreamResponse( + status=200, + reason='OK', + headers={ + 'Content-Type': 'multipart/x-mixed-replace;boundary={}'.format(my_boundary) + } + ) + while True: + frame = get_jpeg_frame() + with MultipartWriter('image/jpeg', boundary=my_boundary) as mpwriter: + mpwriter.append(frame, { + 'Content-Type': 'image/jpeg' + }) + await mpwriter.write(response, close_boundary=False) + await response.drain() + Hacking Multipart ----------------- The Internet is full of terror and sometimes you may find a server which implements multipart support in strange ways when an oblivious solution -doesn't work. +does not work. -For instance, is server used `cgi.FieldStorage`_ then you have to ensure that -no body part contains a `Content-Length` header:: +For instance, is server used :class:`cgi.FieldStorage` then you have +to ensure that no body part contains a `Content-Length` header:: for part in mpwriter: part.headers.pop(aiohttp.hdrs.CONTENT_LENGTH, None) On the other hand, some server may require to specify `Content-Length` for the -whole multipart request. `aiohttp` doesn't do that since it sends multipart +whole multipart request. `aiohttp` does not do that since it sends multipart using chunked transfer encoding by default. To overcome this issue, you have to serialize a :class:`MultipartWriter` by our own in the way to calculate its size:: - body = b''.join(mpwriter.serialize()) + class Writer: + def __init__(self): + self.buffer = bytearray() + + async def write(self, data): + self.buffer.extend(data) + + writer = Writer() + await mpwriter.write(writer) await aiohttp.post('http://example.com', - data=body, headers=mpwriter.headers) + data=writer.buffer, headers=mpwriter.headers) Sometimes the server response may not be well formed: it may or may not contains nested parts. For instance, we request a resource which returns @@ -324,12 +353,4 @@ And this gives us a more cleaner solution:: result.append((doc, files)) -.. seealso:: Multipart API in :ref:`aiohttp-api` section. - - -.. _cgi.FieldStorage: https://docs.python.org/3.4/library/cgi.html -.. _mimetypes: https://docs.python.org/3.4/library/mimetypes.html - - -.. disqus:: - :title: aiohttp suppport for multipart encoding +.. seealso:: :ref:`aiohttp-multipart-reference` diff --git a/docs/multipart_reference.rst b/docs/multipart_reference.rst new file mode 100644 index 00000000000..032ecc8b7aa --- /dev/null +++ b/docs/multipart_reference.rst @@ -0,0 +1,204 @@ +.. currentmodule:: aiohttp + +.. _aiohttp-multipart-reference: + +Multipart reference +=================== + +.. class:: MultipartResponseWrapper(resp, stream) + + Wrapper around the :class:`MultipartBodyReader` to take care about + underlying connection and close it when it needs in. + + + .. method:: at_eof() + + Returns ``True`` when all response data had been read. + + :rtype: bool + + .. comethod:: next() + + Emits next multipart reader object. + + .. comethod:: release() + + Releases the connection gracefully, reading all the content + to the void. + + +.. class:: BodyPartReader(boundary, headers, content) + + Multipart reader for single body part. + + .. comethod:: read(*, decode=False) + + Reads body part data. + + :param bool decode: Decodes data following by encoding method + from ``Content-Encoding`` header. If it + missed data remains untouched + + :rtype: bytearray + + .. comethod:: read_chunk(size=chunk_size) + + Reads body part content chunk of the specified size. + + :param int size: chunk size + + :rtype: bytearray + + .. comethod:: readline() + + Reads body part by line by line. + + :rtype: bytearray + + .. comethod:: release() + + Like :meth:`read`, but reads all the data to the void. + + :rtype: None + + .. comethod:: text(*, encoding=None) + + Like :meth:`read`, but assumes that body part contains text data. + + :param str encoding: Custom text encoding. Overrides specified + in charset param of ``Content-Type`` header + + :rtype: str + + .. comethod:: json(*, encoding=None) + + Like :meth:`read`, but assumes that body parts contains JSON data. + + :param str encoding: Custom JSON encoding. Overrides specified + in charset param of ``Content-Type`` header + + .. comethod:: form(*, encoding=None) + + Like :meth:`read`, but assumes that body parts contains form + urlencoded data. + + :param str encoding: Custom form encoding. Overrides specified + in charset param of ``Content-Type`` header + + .. method:: at_eof() + + Returns ``True`` if the boundary was reached or ``False`` otherwise. + + :rtype: bool + + .. method:: decode(data) + + Decodes data according the specified ``Content-Encoding`` + or ``Content-Transfer-Encoding`` headers value. + + Supports ``gzip``, ``deflate`` and ``identity`` encodings for + ``Content-Encoding`` header. + + Supports ``base64``, ``quoted-printable``, ``binary`` encodings for + ``Content-Transfer-Encoding`` header. + + :param bytearray data: Data to decode. + + :raises: :exc:`RuntimeError` - if encoding is unknown. + + :rtype: bytes + + .. method:: get_charset(default=None) + + Returns charset parameter from ``Content-Type`` header or default. + + .. attribute:: name + + A field *name* specified in ``Content-Disposition`` header or ``None`` + if missed or header is malformed. + + Readonly :class:`str` property. + + .. attribute:: filename + + A field *filename* specified in ``Content-Disposition`` header or ``None`` + if missed or header is malformed. + + Readonly :class:`str` property. + + +.. class:: MultipartReader(headers, content) + + Multipart body reader. + + .. classmethod:: from_response(cls, response) + + Constructs reader instance from HTTP response. + + :param response: :class:`~aiohttp.client.ClientResponse` instance + + .. method:: at_eof() + + Returns ``True`` if the final boundary was reached or + ``False`` otherwise. + + :rtype: bool + + .. comethod:: next() + + Emits the next multipart body part. + + .. comethod:: release() + + Reads all the body parts to the void till the final boundary. + + .. comethod:: fetch_next_part() + + Returns the next body part reader. + + +.. class:: MultipartWriter(subtype='mixed', boundary=None, close_boundary=True) + + Multipart body writer. + + ``boundary`` may be an ASCII-only string. + + .. attribute:: boundary + + The string (:class:`str`) representation of the boundary. + + .. versionchanged:: 3.0 + + Property type was changed from :class:`bytes` to :class:`str`. + + .. method:: append(obj, headers=None) + + Append an object to writer. + + .. method:: append_payload(payload) + + Adds a new body part to multipart writer. + + .. method:: append_json(obj, headers=None) + + Helper to append JSON part. + + .. method:: append_form(obj, headers=None) + + Helper to append form urlencoded part. + + .. attribute:: size + + Size of the payload. + + .. comethod:: write(writer, close_boundary=True) + + Write body. + + :param bool close_boundary: The (:class:`bool`) that will emit + boundary closing. You may want to disable + when streaming (``multipart/x-mixed-replace``) + + .. versionadded:: 3.4 + + Support ``close_boundary`` argument. diff --git a/docs/new_router.rst b/docs/new_router.rst index 7dd05384017..a88b20838aa 100644 --- a/docs/new_router.rst +++ b/docs/new_router.rst @@ -82,7 +82,3 @@ shortcut for:: ``app.router.register_route(...)`` is still supported, it creates :class:`aiohttp.web.ResourceAdapter` for every call (but it's deprecated now). - - -.. disqus:: - :title: aiohttp router refactoring notes diff --git a/docs/old-logo.png b/docs/old-logo.png new file mode 100644 index 00000000000..eac760bd8c9 Binary files /dev/null and b/docs/old-logo.png differ diff --git a/docs/old-logo.svg b/docs/old-logo.svg new file mode 100644 index 00000000000..4d7ac2d278a --- /dev/null +++ b/docs/old-logo.svg @@ -0,0 +1,487 @@ + + + + aiohttp-icon + Created with Sketch. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/powered_by.rst b/docs/powered_by.rst new file mode 100644 index 00000000000..c6e497134ff --- /dev/null +++ b/docs/powered_by.rst @@ -0,0 +1,38 @@ +.. _aiohttp-powered-by: + +Powered by aiohttp +================== + +Web sites powered by aiohttp. + +Feel free to fork documentation on github, add a link to your site and +make a Pull Request! + +* `Farmer Business Network `_ +* `Home Assistant `_ +* `KeepSafe `_ +* `Skyscanner Hotels `_ +* `Ocean S.A. `_ +* `GNS3 `_ +* `TutorCruncher socket + `_ +* `Morpheus messaging microservice `_ +* `Eyepea - Custom telephony solutions `_ +* `ALLOcloud - Telephony in the cloud `_ +* `helpmanual - comprehensive help and man page database + `_ +* `bedevere `_ - CPython's GitHub + bot, helps maintain and identify issues with a CPython pull request. +* `miss-islington `_ - + CPython's GitHub bot, backports and merge CPython's pull requests +* `noa technologies - Bike-sharing management platform + `_ - SSE endpoint, pushes real time updates of + bikes location. +* `Wargaming: World of Tanks `_ +* `Yandex `_ +* `Rambler `_ +* `Escargot `_ - Chat server +* `Prom.ua `_ - Online trading platform +* `globo.com `_ - (some parts) Brazilian largest media portal +* `Glose `_ - Social reader for E-Books +* `Emoji Generator `_ - Text icon generator diff --git a/docs/signals.rst b/docs/signals.rst new file mode 100644 index 00000000000..1126c2ffc99 --- /dev/null +++ b/docs/signals.rst @@ -0,0 +1,45 @@ +Signals +======= + +.. currentmodule:: aiohttp + +Signal is a list of registered asynchronous callbacks. + +The signal's life-cycle has two stages: after creation its content +could be filled by using standard list operations: ``sig.append()`` +etc. + +After ``sig.freeze()`` call the signal is *frozen*: adding, removing +and dropping callbacks are forbidden. + +The only available operation is calling previously registered +callbacks by ``await sig.send(data)``. + +For concrete usage examples see :ref:`signals in aiohttp.web +` chapter. + +.. versionchanged:: 3.0 + + ``sig.send()`` call is forbidden for non-frozen signal. + + Support for regular (non-async) callbacks is dropped. All callbacks + should be async functions. + + +.. class:: Signal + + The signal, implements :class:`collections.abc.MutableSequence` + interface. + + .. comethod:: send(*args, **kwargs) + + Call all registered callbacks one by one starting from the begin + of list. + + .. attribute:: frozen + + ``True`` if :meth:`freeze` was called, read-only property. + + .. method:: freeze() + + Freeze the list. After the call any content modification is forbidden. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 04cebf56095..ebf58fdfd66 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1,231 +1,340 @@ +# de-facto: +Arsenic +Backport +Backporting +BaseEventLoop +BasicAuth +BodyPartReader +Bugfixes +BytesIO +brotli +brotlipy +pydantic +CIMultiDict +CPython +Changelog +ClientSession +Codings +Config +CookieJar +Coroutine +Ctrl +Cython +Cythonize +DER +DNSResolver +Dev +Dict +Discord +Django +Dup +Facebook +HTTPException +HttpProcessingError +IP +IPv +Indices +Jinja +KiB +Locator +Mako +Mixcloud +Mongo +Mongo +MsgType +Multidicts +Multipart +Nagle +Nagle’s +Nginx +Nikolay +OAuth +Online +Overridable +PRs +Paolini +Postgres +Punycode +Pytest +Quickstart +Redis +RequestContextManager +Request’s +Runit +SSLContext +Satisfiable +Skyscanner +SocketSocketTransport +Supervisord +Svetlov +Systemd +TCP +TLS +Teardown +TestClient +Testsuite +Tf +UI +Unittest +WSMessage +WSMsgType +Websockets +Workflow abc aiodns aioes aiohttp aiohttpdemo +aiohttp’s aiopg alives api +api’s app +apps +app’s arg async asyncio auth autocalculated autodetection +autoformatter +autoformatters autogenerates autogeneration awaitable -BaseEventLoop backend backends +backport +backports basename -BasicAuth -BodyPartReader +boolean +botocore +bugfix builtin -BytesIO -cchardet cChardet +cancelled +canonicalization +canonicalize +cchardet +ceil charset charsetdetect chunked -CIMultiDict -ClientSession +chunking cls cmd -Codings +codec +committer +committers config -Config configs +conjunction contextmanager -CookieJar coroutine -Coroutine coroutines +cpu css ctor -Ctrl +cython cythonized -DER +de +deduplicate +deprecations +dev dict -Dict django -Django dns -DNSResolver -Dup +docstring elasticsearch encodings env environ eof epoll +facto fallback +fallbacks filename finalizers +formatters frontend getall gethostbyname github google gunicorn -Gunicorn +gunicorn’s gzipped hackish highlevel hostnames -HTTPException -HttpProcessingError +httpretty https -incapsulates impl -Indices +incapsulates infos +initializer inline +intaking io ip -IP ipdb -IPv ish iterable iterables -Jinja javascript json keepalive keepalived keepalives keepaliving +kib kwarg latin +lifecycle linux localhost login lookup lookups +lossless manylinux metadata +microservice middleware middlewares miltidict -Mongo +misbehaviors +misformed msg -MsgType +multi multidict multidicts -Multidicts +multidict’s multipart -Multipart -Nagle +mypy namedtuple nameservers namespace +netrc nginx -Nginx -Nikolay +noop nowait optimizations os outcoming -Overridable -Paolini param params +parsers pathlib +peername +performant +pickleable ping -pluggable +pipelining +plugable plugin -Postgres poller pong pre +preloaded +proactor programmatically proxied pubsub py pyenv +pyflakes pytest -Pytest +quote’s readonly readpayload rebase -Redis -Refactor +redirections refactor refactored refactoring regex regexps regexs +reloader renderer renderers repo repr -RequestContextManager +repr’s +request’s requote +requoting +resolvehost resolvers reusage +reuseconn +runtime sa schemas sendfile serializable +serializer shourtcuts skipuntil -SocketSocketTransport +softwares ssl -SSLContext +startup +subapplication subclasses +subdirectory submodules subpackage subprotocol subprotocols subtype -Supervisord supervisord -Systemd -Runit -Svetlov symlink symlinks syscall syscalls -TCP +tarball teardown -Teardown -TestClient -Testsuite -Tf timestamps toolbar toplevel +towncrier tp tuples -UI +uWSGI un +unawaited +unclosed +unhandled unicode unittest -Unittest unix unsets +unstripped upstr url urldispatcher urlencoded urls +url’s utf utils uvloop vcvarsall waituntil +wakeup +wakeups +webapp websocket websockets -Websockets +websocket’s wildcard -Workflow ws wsgi -WSMessage -WSMsgType wss www +xxx +yarl diff --git a/docs/streams.rst b/docs/streams.rst index 0d195fea0a2..8356c390772 100644 --- a/docs/streams.rst +++ b/docs/streams.rst @@ -1,11 +1,8 @@ -.. module:: aiohttp.streams - .. _aiohttp-streams: Streaming API ============= -.. module:: aiohttp .. currentmodule:: aiohttp @@ -74,6 +71,21 @@ Reading Methods :return bytes: the given line +.. comethod:: StreamReader.readchunk() + + Read a chunk of data as it was received by the server. + + Returns a tuple of (data, end_of_HTTP_chunk). + + When chunked transfer encoding is used, end_of_HTTP_chunk is a :class:`bool` + indicating if the end of the data corresponds to the end of a HTTP chunk, + otherwise it is always ``False``. + + :return tuple[bytes, bool]: a chunk of data and a :class:`bool` that is ``True`` + when the end of the returned chunk corresponds + to the end of a HTTP chunk. + + Asynchronous Iteration Support ------------------------------ @@ -96,7 +108,7 @@ size limit and over any available data. async for data in response.content.iter_chunked(1024): print(data) -.. comethod:: StreamReader.iter_any(n) +.. comethod:: StreamReader.iter_any() :async-for: Iterates over data chunks in order of intaking them into the stream:: @@ -104,6 +116,25 @@ size limit and over any available data. async for data in response.content.iter_any(): print(data) +.. comethod:: StreamReader.iter_chunks() + :async-for: + + Iterates over data chunks as received from the server:: + + async for data, _ in response.content.iter_chunks(): + print(data) + + If chunked transfer encoding is used, the original http chunks formatting + can be retrieved by reading the second element of returned tuples:: + + buffer = b"" + + async for data, end_of_http_chunk in response.content.iter_chunks(): + buffer += data + if end_of_http_chunk: + print(buffer) + buffer = b"" + Helpers ------- @@ -144,7 +175,7 @@ Helpers :param bytes data: data to push back into the stream. - .. warning:: The method doesn't wake up waiters. + .. warning:: The method does not wake up waiters. E.g. :meth:`~StreamReader.read()` will not be resumed. @@ -152,7 +183,3 @@ Helpers .. comethod:: wait_eof() Wait for EOF. The given data may be accessible by upcoming read calls. - - -.. disqus:: - :title: aiohttp streaming api diff --git a/docs/structures.rst b/docs/structures.rst new file mode 100644 index 00000000000..a47bdc09578 --- /dev/null +++ b/docs/structures.rst @@ -0,0 +1,55 @@ +.. _aiohttp-structures: + + +Common data structures +====================== + +.. module:: aiohttp + +.. currentmodule:: aiohttp + + +Common data structures used by *aiohttp* internally. + + +FrozenList +---------- + +A list-like structure which implements +:class:`collections.abc.MutableSequence`. + +The list is *mutable* unless :meth:`FrozenList.freeze` is called, +after that the list modification raises :exc:`RuntimeError`. + + +.. class:: FrozenList(items) + + Construct a new *non-frozen* list from *items* iterable. + + The list implements all :class:`collections.abc.MutableSequence` + methods plus two additional APIs. + + .. attribute:: frozen + + A read-only property, ``True`` is the list is *frozen* + (modifications are forbidden). + + .. method:: freeze() + + Freeze the list. There is no way to *thaw* it back. + + +ChainMapProxy +------------- + +An *immutable* version of :class:`collections.ChainMap`. Internally +the proxy is a list of mappings (dictionaries), if the requested key +is not present in the first mapping the second is looked up and so on. + +The class supports :class:`collections.abc.Mapping` interface. + +.. class:: ChainMapProxy(maps) + + Create a new chained mapping proxy from a list of mappings (*maps*). + + .. versionadded:: 3.2 diff --git a/docs/testing.rst b/docs/testing.rst index 9827ba29ccf..d722f3aef39 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -8,7 +8,7 @@ Testing Testing aiohttp web servers --------------------------- -aiohttp provides plugin for pytest_ making writing web server tests +aiohttp provides plugin for *pytest* making writing web server tests extremely easy, it also provides :ref:`test framework agnostic utilities ` for testing with other frameworks such as :ref:`unittest @@ -43,7 +43,7 @@ The module is a **provisional**. But for :mod:`aiohttp.test_tools` the deprecation period could be reduced. Moreover we may break *backward compatibility* without *deprecation -peroid* for some very strong reason. +period* for some very strong reason. The Test Client and Servers @@ -71,7 +71,9 @@ proxy methods to the client for common operations such as Pytest ~~~~~~ -The :data:`test_client` fixture available from pytest-aiohttp_ plugin +.. currentmodule:: pytest_aiohttp + +The :data:`aiohttp_client` fixture available from pytest-aiohttp_ plugin allows you to create a client to make requests to test your app. A simple would be:: @@ -81,10 +83,10 @@ A simple would be:: async def hello(request): return web.Response(text='Hello, world') - async def test_hello(test_client, loop): + async def test_hello(aiohttp_client, loop): app = web.Application() app.router.add_get('/', hello) - client = await test_client(app) + client = await aiohttp_client(app) resp = await client.get('/') assert resp.status == 200 text = await resp.text() @@ -107,11 +109,11 @@ app test client:: body='value: {}'.format(request.app['value']).encode('utf-8')) @pytest.fixture - def cli(loop, test_client): + def cli(loop, aiohttp_client): app = web.Application() app.router.add_get('/', previous) app.router.add_post('/', previous) - return loop.run_until_complete(test_client(app)) + return loop.run_until_complete(aiohttp_client(app)) async def test_set_value(cli): resp = await cli.post('/', data={'value': 'foo'}) @@ -128,38 +130,48 @@ app test client:: Pytest tooling has the following fixtures: -.. data:: test_server(app, **kwargs) +.. data:: aiohttp_server(app, *, port=None, **kwargs) A fixture factory that creates :class:`~aiohttp.test_utils.TestServer`:: - async def test_f(test_server): + async def test_f(aiohttp_server): app = web.Application() # fill route table - server = await test_server(app) + server = await aiohttp_server(app) The server will be destroyed on exit from test function. *app* is the :class:`aiohttp.web.Application` used to start server. + *port* optional, port the server is run at, if + not provided a random unused port is used. + + .. versionadded:: 3.0 + *kwargs* are parameters passed to :meth:`aiohttp.web.Application.make_handler` + .. versionchanged:: 3.0 + .. deprecated:: 3.2 -.. data:: test_client(app, **kwargs) - test_client(server, **kwargs) - test_client(raw_server, **kwargs) + The fixture was renamed from ``test_server`` to ``aiohttp_server``. + + +.. data:: aiohttp_client(app, server_kwargs=None, **kwargs) + aiohttp_client(server, **kwargs) + aiohttp_client(raw_server, **kwargs) A fixture factory that creates :class:`~aiohttp.test_utils.TestClient` for access to tested server:: - async def test_f(test_client): + async def test_f(aiohttp_client): app = web.Application() # fill route table - client = await test_client(app) + client = await aiohttp_client(app) resp = await client.get('/') *client* and responses are cleaned up after test function finishing. @@ -168,27 +180,56 @@ Pytest tooling has the following fixtures: :class:`aiohttp.test_utils.TestServer` or :class:`aiohttp.test_utils.RawTestServer` instance. + *server_kwargs* are parameters passed to the test server if an app + is passed, else ignored. + *kwargs* are parameters passed to :class:`aiohttp.test_utils.TestClient` constructor. -.. data:: raw_test_server(handler, **kwargs) + .. versionchanged:: 3.0 + + The fixture was renamed from ``test_client`` to ``aiohttp_client``. + +.. data:: aiohttp_raw_server(handler, *, port=None, **kwargs) A fixture factory that creates :class:`~aiohttp.test_utils.RawTestServer` instance from given web - handler. + handler.:: - *handler* should be a coroutine which accepts a request and returns - response, e.g.:: - - async def test_f(raw_test_server, test_client): + async def test_f(aiohttp_raw_server, aiohttp_client): async def handler(request): return web.Response(text="OK") - raw_server = await raw_test_server(handler) - client = await test_client(raw_server) + raw_server = await aiohttp_raw_server(handler) + client = await aiohttp_client(raw_server) resp = await client.get('/') + *handler* should be a coroutine which accepts a request and returns + response, e.g. + + *port* optional, port the server is run at, if + not provided a random unused port is used. + + .. versionadded:: 3.0 + +.. data:: aiohttp_unused_port() + + Function to return an unused port number for IPv4 TCP protocol:: + + async def test_f(aiohttp_client, aiohttp_unused_port): + port = aiohttp_unused_port() + app = web.Application() + # fill route table + + client = await aiohttp_client(app, server_kwargs={'port': port}) + ... + + .. versionchanged:: 3.0 + + The fixture was renamed from ``unused_port`` to ``aiohttp_unused_port``. + + .. _aiohttp-testing-unittest-example: .. _aiohttp-testing-unittest-style: @@ -196,6 +237,9 @@ Pytest tooling has the following fixtures: Unittest ~~~~~~~~ +.. currentmodule:: aiohttp.test_utils + + To test applications with the standard library's unittest or unittest-based functionality, the AioHTTPTestCase is provided:: @@ -208,23 +252,28 @@ functionality, the AioHTTPTestCase is provided:: """ Override the get_app method to return your application. """ - return web.Application() + async def hello(request): + return web.Response(text='Hello, world') + + app = web.Application() + app.router.add_get('/', hello) + return app # the unittest_run_loop decorator can be used in tandem with # the AioHTTPTestCase to simplify running # tests that are asynchronous @unittest_run_loop async def test_example(self): - request = await self.client.request("GET", "/") - assert request.status == 200 - text = await request.text() + resp = await self.client.request("GET", "/") + assert resp.status == 200 + text = await resp.text() assert "Hello, world" in text # a vanilla example - def test_example(self): + def test_example_vanilla(self): async def test_get_route(): - url = root + "/" - resp = await self.client.request("GET", url, loop=loop) + url = "/" + resp = await self.client.request("GET", url) assert resp.status == 200 text = await resp.text() assert "Hello, world" in text @@ -243,15 +292,41 @@ functionality, the AioHTTPTestCase is provided:: an aiohttp test client, :class:`TestClient` instance. + .. attribute:: server + + an aiohttp test server, :class:`TestServer` instance. + + .. versionadded:: 2.3 + .. attribute:: loop The event loop in which the application and server are running. + .. deprecated:: 3.5 + .. attribute:: app The application returned by :meth:`get_app` (:class:`aiohttp.web.Application` instance). + .. comethod:: get_client() + + This async method can be overridden to return the :class:`TestClient` + object used in the test. + + :return: :class:`TestClient` instance. + + .. versionadded:: 2.3 + + .. comethod:: get_server() + + This async method can be overridden to return the :class:`TestServer` + object used in the test. + + :return: :class:`TestServer` instance. + + .. versionadded:: 2.3 + .. comethod:: get_application() This async method should be overridden @@ -260,6 +335,20 @@ functionality, the AioHTTPTestCase is provided:: :return: :class:`aiohttp.web.Application` instance. + .. comethod:: setUpAsync() + + This async method do nothing by default and can be overridden to execute + asynchronous code during the ``setUp`` stage of the ``TestCase``. + + .. versionadded:: 2.3 + + .. comethod:: tearDownAsync() + + This async method do nothing by default and can be overridden to execute + asynchronous code during the ``tearDown`` stage of the ``TestCase``. + + .. versionadded:: 2.3 + .. method:: setUp() Standard test initialization method. @@ -331,12 +420,13 @@ conditions that hard to reproduce on real server:: version=HttpVersion(1, 1), \ closing=False, \ app=None, \ + match_info=sentinel, \ reader=sentinel, \ writer=sentinel, \ transport=sentinel, \ payload=sentinel, \ sslcontext=None, \ - secure_proxy_ssl_header=None) + loop=...) Creates mocked web.Request testing purposes. @@ -353,6 +443,9 @@ conditions that hard to reproduce on real server:: by the multidict.CIMultiDict constructor. :type headers: dict, multidict.CIMultiDict, list of pairs + :param match_info: mapping containing the info to match with url parameters. + :type match_info: dict + :param version: namedtuple with encoded HTTP version :type version: aiohttp.protocol.HttpVersion @@ -364,23 +457,24 @@ conditions that hard to reproduce on real server:: :type app: aiohttp.web.Application :param writer: object for managing outcoming data - :type wirter: aiohttp.streams.StreamWriter + :type writer: aiohttp.StreamWriter :param transport: asyncio transport instance :type transport: asyncio.transports.Transport :param payload: raw payload reader object - :type payload: aiohttp.streams.FlowControlStreamReader + :type payload: aiohttp.StreamReader :param sslcontext: ssl.SSLContext object, for HTTPS connection :type sslcontext: ssl.SSLContext - :param secure_proxy_ssl_header: A tuple representing a HTTP header/value - combination that signifies a request is secure. - :type secure_proxy_ssl_header: tuple + :param loop: An event loop instance, mocked loop by default. + :type loop: :class:`asyncio.AbstractEventLoop` :return: :class:`aiohttp.web.Request` object. + .. versionadded:: 2.3 + *match_info* parameter. .. _aiohttp-testing-writing-testable-services: @@ -392,14 +486,14 @@ Framework Agnostic Utilities High level test creation:: - from aiohttp.test_utils import TestClient, loop_context + from aiohttp.test_utils import TestClient, TestServer, loop_context from aiohttp import request # loop_context is provided as a utility. You can use any - # asyncio.BaseEventLoop class in it's place. + # asyncio.BaseEventLoop class in its place. with loop_context() as loop: app = _create_example_app() - with TestClient(app, loop=loop) as client: + with TestClient(TestServer(app), loop=loop) as client: async def test_get_route(): nonlocal client @@ -414,11 +508,11 @@ High level test creation:: If it's preferred to handle the creation / teardown on a more granular basis, the TestClient object can be used directly:: - from aiohttp.test_utils import TestClient + from aiohttp.test_utils import TestClient, TestServer with loop_context() as loop: app = _create_example_app() - client = TestClient(app, loop=loop) + client = TestClient(TestServer(app), loop=loop) loop.run_until_complete(client.start_server()) root = "http://127.0.0.1:{}".format(port) @@ -435,73 +529,6 @@ basis, the TestClient object can be used directly:: A full list of the utilities provided can be found at the :data:`api reference ` -Writing testable services -------------------------- - -Some libraries like motor, aioes and others depend on the asyncio loop for -executing the code. When running your normal program, these libraries pick -the main event loop by doing ``asyncio.get_event_loop``. The problem during -testing is that there is no main loop assigned because an independent -loop for each test is created without assigning it as the main one. - -This raises a problem when those libraries try to find it. Luckily, the ones -that are well written, allow passing the loop explicitly. Let's have a look -at the aioes client signature:: - - def __init__(self, endpoints, *, loop=None, **kwargs) - -As you can see, there is an optional ``loop`` kwarg. Of course, we are not -going to test directly the aioes client but our service that depends on it -will. So, if we want our ``AioESService`` to be easily testable, we should -define it as follows:: - - import asyncio - - from aioes import Elasticsearch - - - class AioESService: - - def __init__(self, loop=None): - self.es = Elasticsearch(["127.0.0.1:9200"], loop=loop) - - async def get_info(self): - cluster_info = await self.es.info() - print(cluster_info) - - if __name__ == "__main__": - client = AioESService() - loop = asyncio.get_event_loop() - loop.run_until_complete(client.get_info()) - - -Note that it is accepting an optional ``loop`` kwarg. For the normal flow of -execution it won't affect because we can still call the service without passing -the loop explicitly having a main loop available. The problem comes when you -try to do a test like:: - - import pytest - - from main import AioESService - - - class TestAioESService: - - async def test_get_info(self): - cluster_info = await AioESService().get_info() - assert isinstance(cluster_info, dict) - -If you try to run the test, it will fail with a similar error:: - - ... - RuntimeError: There is no current event loop in thread 'MainThread'. - - -If you check the stack trace, you will see aioes is complaining that there is -no current event loop in the main thread. Pass explicit loop to solve it. - -If you rely on code which works with *implicit* loops only you may try -to use hackish approach from :ref:`FAQ `. Testing API Reference --------------------- @@ -520,7 +547,7 @@ Test server usually works in conjunction with :class:`aiohttp.test_utils.TestClient` which provides handy client methods for accessing to the server. -.. class:: BaseTestServer(*, scheme='http', host='127.0.0.1') +.. class:: BaseTestServer(*, scheme='http', host='127.0.0.1', port=None) Base class for test servers. @@ -529,11 +556,15 @@ for accessing to the server. :param str host: a host for TCP socket, IPv4 *local host* (``'127.0.0.1'``) by default. + :param int port: optional port for TCP socket, if not provided a + random unused port is used. + + .. versionadded:: 3.0 .. attribute:: scheme A *scheme* for tested application, ``'http'`` for non-protected - run and ``'htttps'`` for TLS encrypted server. + run and ``'https'`` for TLS encrypted server. .. attribute:: host @@ -541,7 +572,7 @@ for accessing to the server. .. attribute:: port - A random *port* used to start a server. + *port* used to start the test server. .. attribute:: handler @@ -587,6 +618,11 @@ for accessing to the server. :param str host: a host for TCP socket, IPv4 *local host* (``'127.0.0.1'``) by default. + :param int port: optional port for TCP socket, if not provided a + random unused port is used. + + .. versionadded:: 3.0 + .. class:: TestServer(app, *, scheme="http", host='127.0.0.1') @@ -600,6 +636,10 @@ for accessing to the server. :param str host: a host for TCP socket, IPv4 *local host* (``'127.0.0.1'``) by default. + :param int port: optional port for TCP socket, if not provided a + random unused port is used. + + .. versionadded:: 3.0 .. attribute:: app @@ -618,10 +658,9 @@ Test Client :param app_or_server: :class:`BaseTestServer` instance for making client requests to it. - If the parameter is - :class:`aiohttp.web.Application` the tool - creates :class:`TestServer` implicitly for - serving the application. + In order to pass a :class:`aiohttp.web.Application` + you need to convert it first to :class:`TestServer` + first with ``TestServer(app)``. :param cookie_jar: an optional :class:`aiohttp.CookieJar` instance, may be useful with ``CookieJar(unsafe=True)`` @@ -637,7 +676,7 @@ Test Client .. attribute:: scheme A *scheme* for tested application, ``'http'`` for non-protected - run and ``'htttps'`` for TLS encrypted server. + run and ``'https'`` for TLS encrypted server. .. attribute:: host @@ -645,13 +684,19 @@ Test Client .. attribute:: port - A random *port* used to start a server. + *port* used to start the server .. attribute:: server :class:`BaseTestServer` test server instance used in conjunction with client. + .. attribute:: app + + An alias for :attr:`self.server.app`. return ``None`` if + ``self.server`` is not :class:`TestServer` + instance(e.g. :class:`RawTestServer` instance for test low-level server). + .. attribute:: session An internal :class:`aiohttp.ClientSession`. @@ -677,7 +722,7 @@ Test Client Routes a request to tested http server. The interface is identical to - :meth:`asyncio.ClientSession.request`, except the loop kwarg is + :meth:`aiohttp.ClientSession.request`, except the loop kwarg is overridden by the instance used by the test server. .. comethod:: get(path, *args, **kwargs) @@ -756,6 +801,20 @@ Utilities The caller should also call teardown_test_loop, once they are done with the loop. + .. note:: + + As side effect the function changes asyncio *default loop* by + :func:`asyncio.set_event_loop` call. + + Previous default loop is not restored. + + It should not be a problem for test suite: every test expects a + new test loop instance anyway. + + .. versionchanged:: 3.1 + + The function installs a created event loop as *default*. + .. function:: teardown_test_loop(loop) Teardown and cleanup an event_loop created by setup_test_loop. @@ -767,7 +826,3 @@ Utilities .. _pytest: http://pytest.org/latest/ .. _pytest-aiohttp: https://pypi.python.org/pypi/pytest-aiohttp - - -.. disqus:: - :title: aiohttp testing diff --git a/docs/third_party.rst b/docs/third_party.rst index fb158e9520d..d5bcb3df86a 100644 --- a/docs/third_party.rst +++ b/docs/third_party.rst @@ -1,3 +1,5 @@ +.. _aiohttp-3rd-party: + Third-Party libraries ===================== @@ -12,9 +14,9 @@ This page is a list of these tools. Please feel free to add your open sourced library if it's not enlisted yet by making Pull Request to https://github.com/aio-libs/aiohttp/ -- Q. Why do you might want to include your awesome library into the list? +* Why do you might want to include your awesome library into the list? -- A. Just because the list increases your library visibility. People +* Just because the list increases your library visibility. People will have an easy way to find it. @@ -32,11 +34,31 @@ aiohttp extensions provides sessions for :mod:`aiohttp.web`. - `aiohttp-debugtoolbar `_ - is a library for *debug toolbar* support for :mod:`aiohttp.web`. + is a library for *debug toolbar* support for :mod:`aiohttp.web`. - `aiohttp-security `_ - auth and permissions for :mod:`aiohttp.web`. + auth and permissions for :mod:`aiohttp.web`. + +- `aiohttp-devtools `_ + provides development tools for :mod:`aiohttp.web` applications. + +- `aiohttp-cors `_ CORS + support for aiohttp. + +- `aiohttp-sse `_ Server-sent + events support for aiohttp. + +- `pytest-aiohttp `_ + pytest plugin for aiohttp support. + +- `aiohttp-mako `_ Mako + template renderer for aiohttp.web. + +- `aiohttp-jinja2 `_ Jinja2 + template renderer for aiohttp.web. +- `aiozipkin `_ distributed + tracing instrumentation for `aiohttp` client and server. Database drivers ^^^^^^^^^^^^^^^^ @@ -47,6 +69,15 @@ Database drivers - `aioredis `_ Redis async driver. +Other tools +^^^^^^^^^^^ + +- `aiodocker `_ Python Docker + API client based on asyncio and aiohttp. + +- `aiobotocore `_ asyncio + support for botocore library using aiohttp. + Approved third-party libraries ------------------------------ @@ -71,17 +102,163 @@ Database drivers Others ------ -The list of libs which are exists but not enlisted in former categories. +The list of libraries which are exists but not enlisted in former categories. -They are may be perfect or not -- we don't know. +They may be perfect or not -- we don't know. Please add your library reference here first and after some time -period ask to raise he status. +period ask to raise the status. - `aiohttp-cache `_ A cache system for aiohttp server. + - `aiocache `_ Caching for asyncio with multiple backends (framework agnostic) -- `aiohttp-devtools `_ - provides development tools for :mod:`aiohttp.web` applications +- `gain `_ Web crawling framework + based on asyncio for everyone. + +- `aiohttp-swagger `_ + Swagger API Documentation builder for aiohttp server. + +- `aiohttp-swaggerify `_ + Library to automatically generate swagger2.0 definition for aiohttp endpoints. + +- `aiohttp-validate `_ + Simple library that helps you validate your API endpoints requests/responses with json schema. + +- `aiohttp-pydantic `_ + An ``aiohttp.View`` to validate the HTTP request's body, query-string, and headers regarding function annotations and generate Open API doc. Python 3.8+ required. + +- `raven-aiohttp `_ An + aiohttp transport for raven-python (Sentry client). + +- `webargs `_ A friendly library + for parsing HTTP request arguments, with built-in support for + popular web frameworks, including Flask, Django, Bottle, Tornado, + Pyramid, webapp2, Falcon, and aiohttp. + +- `aioauth-client `_ OAuth + client for aiohttp. + +- `aiohttpretty + `_ A simple + asyncio compatible httpretty mock using aiohttp. + +- `aioresponses `_ a + helper for mock/fake web requests in python aiohttp package. + +- `aiohttp-transmute + `_ A transmute + implementation for aiohttp. + +- `aiohttp_apiset `_ + Package to build routes using swagger specification. + +- `aiohttp-login `_ + Registration and authorization (including social) for aiohttp + applications. + +- `aiohttp_utils `_ Handy + utilities for building aiohttp.web applications. + +- `aiohttpproxy `_ Simple + aiohttp HTTP proxy. + +- `aiohttp_traversal `_ + Traversal based router for aiohttp.web. + +- `aiohttp_autoreload + `_ Makes aiohttp + server auto-reload on source code change. + +- `gidgethub `_ An async + GitHub API library for Python. + +- `aiohttp_jrpc `_ aiohttp + JSON-RPC service. + +- `fbemissary `_ A bot + framework for the Facebook Messenger platform, built on asyncio and + aiohttp. + +- `aioslacker `_ slacker + wrapper for asyncio. + +- `aioreloader `_ Port of + tornado reloader to asyncio. + +- `aiohttp_babel `_ Babel + localization support for aiohttp. + +- `python-mocket `_ a + socket mock framework - for all kinds of socket animals, web-clients + included. + +- `aioraft `_ asyncio RAFT + algorithm based on aiohttp. + +- `home-assistant `_ + Open-source home automation platform running on Python 3. + +- `discord.py `_ Discord client library. + +- `aiogram `_ + A fully asynchronous library for Telegram Bot API written with asyncio and aiohttp. + +- `vk.py `_ + Extremely-fast Python 3.6+ toolkit for create applications work`s with VKAPI. + +- `aiohttp-graphql `_ + GraphQL and GraphIQL interface for aiohttp. + +- `aiohttp-sentry `_ + An aiohttp middleware for reporting errors to Sentry. Python 3.5+ is required. + +- `aiohttp-datadog `_ + An aiohttp middleware for reporting metrics to DataDog. Python 3.5+ is required. + +- `async-v20 `_ + Asynchronous FOREX client for OANDA's v20 API. Python 3.6+ + +- `aiohttp-jwt `_ + An aiohttp middleware for JWT(JSON Web Token) support. Python 3.5+ is required. + +- `AWS Xray Python SDK `_ + Native tracing support for Aiohttp applications. + +- `GINO `_ + An asyncio ORM on top of SQLAlchemy core, delivered with an aiohttp extension. + +- `aiohttp-apispec `_ + Build and document REST APIs with ``aiohttp`` and ``apispec``. + +- `eider-py `_ Python implementation of + the `Eider RPC protocol `_. + +- `asynapplicationinsights `_ + A client for `Azure Application Insights + `_ implemented using + ``aiohttp`` client, including a middleware for ``aiohttp`` servers to collect web apps + telemetry. + +- `aiogmaps `_ + Asynchronous client for Google Maps API Web Services. Python 3.6+ required. + +- `DBGR `_ + Terminal based tool to test and debug HTTP APIs with ``aiohttp``. + +- `rororo `_ + Implement ``aiohtp.web`` OpenAPI 3 server applications with schema first + approach. Python 3.6+ required. + +- `aiohttp-middlewares `_ + Collection of useful middlewares for ``aiohttp.web`` applications. Python + 3.6+ required. + +- `aiohttp-tus `_ + `tus.io `_ protocol implementation for ``aiohttp.web`` + applications. Python 3.6+ required. + +- `aiohttp-sse-client `_ + A Server-Sent Event python client base on aiohttp. Python 3.6+ required. diff --git a/docs/tracing_reference.rst b/docs/tracing_reference.rst new file mode 100644 index 00000000000..772b485ddcb --- /dev/null +++ b/docs/tracing_reference.rst @@ -0,0 +1,494 @@ +.. _aiohttp-client-tracing-reference: + +Tracing Reference +================= + +.. currentmodule:: aiohttp + +.. versionadded:: 3.0 + +A reference for client tracing API. + +.. seealso:: :ref:`aiohttp-client-tracing` for tracing usage instructions. + + +Request life cycle +------------------ + +A request goes through the following stages and corresponding fallbacks. + + +Overview +^^^^^^^^ + +.. blockdiag:: + :desctable: + + + blockdiag { + orientation = portrait; + + start[shape=beginpoint, description="on_request_start"]; + redirect[description="on_request_redirect"]; + end[shape=endpoint, description="on_request_end"]; + exception[shape=flowchart.terminator, description="on_request_exception"]; + + acquire_connection[description="Connection acquiring"]; + headers_received; + headers_sent; + chunk_sent[description="on_request_chunk_sent"]; + chunk_received[description="on_response_chunk_received"]; + + start -> acquire_connection; + acquire_connection -> headers_sent; + headers_sent -> headers_received; + headers_sent -> chunk_sent; + chunk_sent -> chunk_sent; + chunk_sent -> headers_received; + headers_received -> chunk_received; + chunk_received -> chunk_received; + chunk_received -> end; + headers_received -> redirect; + headers_received -> end; + redirect -> headers_sent; + chunk_received -> exception; + chunk_sent -> exception; + headers_sent -> exception; + + } + + +Connection acquiring +^^^^^^^^^^^^^^^^^^^^ + +.. blockdiag:: + :desctable: + + blockdiag { + orientation = portrait; + + begin[shape=beginpoint]; + end[shape=endpoint]; + exception[shape=flowchart.terminator, description="Exception raised"]; + + queued_start[description="on_connection_queued_start"]; + queued_end[description="on_connection_queued_end"]; + create_start[description="on_connection_create_start"]; + create_end[description="on_connection_create_end"]; + reuseconn[description="on_connection_reuseconn"]; + + resolve_dns[description="DNS resolving"]; + sock_connect[description="Connection establishment"]; + + begin -> reuseconn; + begin -> create_start; + create_start -> resolve_dns; + resolve_dns -> exception; + resolve_dns -> sock_connect; + sock_connect -> exception; + sock_connect -> create_end -> end; + begin -> queued_start; + queued_start -> queued_end; + queued_end -> reuseconn; + queued_end -> create_start; + reuseconn -> end; + + } + +DNS resolving +^^^^^^^^^^^^^ + +.. blockdiag:: + :desctable: + + blockdiag { + orientation = portrait; + + begin[shape=beginpoint]; + end[shape=endpoint]; + exception[shape=flowchart.terminator, description="Exception raised"]; + + resolve_start[description="on_dns_resolvehost_start"]; + resolve_end[description="on_dns_resolvehost_end"]; + cache_hit[description="on_dns_cache_hit"]; + cache_miss[description="on_dns_cache_miss"]; + + begin -> cache_hit -> end; + begin -> cache_miss -> resolve_start; + resolve_start -> resolve_end -> end; + resolve_start -> exception; + + } + + +TraceConfig +----------- + + +.. class:: TraceConfig(trace_config_ctx_factory=SimpleNamespace) + + Trace config is the configuration object used to trace requests + launched by a :class:`ClientSession` object using different events + related to different parts of the request flow. + + :param trace_config_ctx_factory: factory used to create trace contexts, + default class used :class:`types.SimpleNamespace` + + .. method:: trace_config_ctx(trace_request_ctx=None) + + :param trace_request_ctx: Will be used to pass as a kw for the + ``trace_config_ctx_factory``. + + Build a new trace context from the config. + + Every signal handler should have the following signature:: + + async def on_signal(session, context, params): ... + + where ``session`` is :class:`ClientSession` instance, ``context`` is an + object returned by :meth:`trace_config_ctx` call and ``params`` is a + data class with signal parameters. The type of ``params`` depends on + subscribed signal and described below. + + .. attribute:: on_request_start + + Property that gives access to the signals that will be executed + when a request starts. + + ``params`` is :class:`aiohttp.TraceRequestStartParams` instance. + + .. attribute:: on_request_chunk_sent + + + Property that gives access to the signals that will be executed + when a chunk of request body is sent. + + ``params`` is :class:`aiohttp.TraceRequestChunkSentParams` instance. + + .. versionadded:: 3.1 + + .. attribute:: on_response_chunk_received + + + Property that gives access to the signals that will be executed + when a chunk of response body is received. + + ``params`` is :class:`aiohttp.TraceResponseChunkReceivedParams` instance. + + .. versionadded:: 3.1 + + .. attribute:: on_request_redirect + + Property that gives access to the signals that will be executed when a + redirect happens during a request flow. + + ``params`` is :class:`aiohttp.TraceRequestRedirectParams` instance. + + .. attribute:: on_request_end + + Property that gives access to the signals that will be executed when a + request ends. + + ``params`` is :class:`aiohttp.TraceRequestEndParams` instance. + + .. attribute:: on_request_exception + + Property that gives access to the signals that will be executed when a + request finishes with an exception. + + ``params`` is :class:`aiohttp.TraceRequestExceptionParams` instance. + + .. attribute:: on_connection_queued_start + + Property that gives access to the signals that will be executed when a + request has been queued waiting for an available connection. + + ``params`` is :class:`aiohttp.TraceConnectionQueuedStartParams` + instance. + + .. attribute:: on_connection_queued_end + + Property that gives access to the signals that will be executed when a + request that was queued already has an available connection. + + ``params`` is :class:`aiohttp.TraceConnectionQueuedEndParams` + instance. + + .. attribute:: on_connection_create_start + + Property that gives access to the signals that will be executed when a + request creates a new connection. + + ``params`` is :class:`aiohttp.TraceConnectionCreateStartParams` + instance. + + .. attribute:: on_connection_create_end + + Property that gives access to the signals that will be executed when a + request that created a new connection finishes its creation. + + ``params`` is :class:`aiohttp.TraceConnectionCreateEndParams` + instance. + + .. attribute:: on_connection_reuseconn + + Property that gives access to the signals that will be executed when a + request reuses a connection. + + ``params`` is :class:`aiohttp.TraceConnectionReuseconnParams` + instance. + + .. attribute:: on_dns_resolvehost_start + + Property that gives access to the signals that will be executed when a + request starts to resolve the domain related with the request. + + ``params`` is :class:`aiohttp.TraceDnsResolveHostStartParams` + instance. + + .. attribute:: on_dns_resolvehost_end + + Property that gives access to the signals that will be executed when a + request finishes to resolve the domain related with the request. + + ``params`` is :class:`aiohttp.TraceDnsResolveHostEndParams` instance. + + .. attribute:: on_dns_cache_hit + + Property that gives access to the signals that will be executed when a + request was able to use a cached DNS resolution for the domain related + with the request. + + ``params`` is :class:`aiohttp.TraceDnsCacheHitParams` instance. + + .. attribute:: on_dns_cache_miss + + Property that gives access to the signals that will be executed when a + request was not able to use a cached DNS resolution for the domain related + with the request. + + ``params`` is :class:`aiohttp.TraceDnsCacheMissParams` instance. + + +TraceRequestStartParams +----------------------- + +.. class:: TraceRequestStartParams + + See :attr:`TraceConfig.on_request_start` for details. + + .. attribute:: method + + Method that will be used to make the request. + + .. attribute:: url + + URL that will be used for the request. + + .. attribute:: headers + + Headers that will be used for the request, can be mutated. + + +TraceRequestChunkSentParams +--------------------------- + +.. class:: TraceRequestChunkSentParams + + .. versionadded:: 3.1 + + See :attr:`TraceConfig.on_request_chunk_sent` for details. + + .. attribute:: method + + Method that will be used to make the request. + + .. attribute:: url + + URL that will be used for the request. + + .. attribute:: chunk + + Bytes of chunk sent + + +TraceResponseChunkReceivedParams +-------------------------------- + +.. class:: TraceResponseChunkReceivedParams + + .. versionadded:: 3.1 + + See :attr:`TraceConfig.on_response_chunk_received` for details. + + .. attribute:: method + + Method that will be used to make the request. + + .. attribute:: url + + URL that will be used for the request. + + .. attribute:: chunk + + Bytes of chunk received + + +TraceRequestEndParams +--------------------- + +.. class:: TraceRequestEndParams + + See :attr:`TraceConfig.on_request_end` for details. + + .. attribute:: method + + Method used to make the request. + + .. attribute:: url + + URL used for the request. + + .. attribute:: headers + + Headers used for the request. + + .. attribute:: response + + Response :class:`ClientResponse`. + + +TraceRequestExceptionParams +--------------------------- + +.. class:: TraceRequestExceptionParams + + See :attr:`TraceConfig.on_request_exception` for details. + + .. attribute:: method + + Method used to make the request. + + .. attribute:: url + + URL used for the request. + + .. attribute:: headers + + Headers used for the request. + + .. attribute:: exception + + Exception raised during the request. + +TraceRequestRedirectParams +-------------------------- + +.. class:: TraceRequestRedirectParams + + See :attr:`TraceConfig.on_request_redirect` for details. + + .. attribute:: method + + Method used to get this redirect request. + + .. attribute:: url + + URL used for this redirect request. + + .. attribute:: headers + + Headers used for this redirect. + + .. attribute:: response + + Response :class:`ClientResponse` got from the redirect. + +TraceConnectionQueuedStartParams +-------------------------------- + +.. class:: TraceConnectionQueuedStartParams + + See :attr:`TraceConfig.on_connection_queued_start` for details. + + There are no attributes right now. + +TraceConnectionQueuedEndParams +------------------------------ + +.. class:: TraceConnectionQueuedEndParams + + See :attr:`TraceConfig.on_connection_queued_end` for details. + + There are no attributes right now. + +TraceConnectionCreateStartParams +-------------------------------- + +.. class:: TraceConnectionCreateStartParams + + See :attr:`TraceConfig.on_connection_create_start` for details. + + There are no attributes right now. + +TraceConnectionCreateEndParams +------------------------------ + +.. class:: TraceConnectionCreateEndParams + + See :attr:`TraceConfig.on_connection_create_end` for details. + + There are no attributes right now. + +TraceConnectionReuseconnParams +------------------------------ + +.. class:: TraceConnectionReuseconnParams + + See :attr:`TraceConfig.on_connection_reuseconn` for details. + + There are no attributes right now. + +TraceDnsResolveHostStartParams +------------------------------ + +.. class:: TraceDnsResolveHostStartParams + + See :attr:`TraceConfig.on_dns_resolvehost_start` for details. + + .. attribute:: host + + Host that will be resolved. + +TraceDnsResolveHostEndParams +---------------------------- + +.. class:: TraceDnsResolveHostEndParams + + See :attr:`TraceConfig.on_dns_resolvehost_end` for details. + + .. attribute:: host + + Host that has been resolved. + +TraceDnsCacheHitParams +---------------------- + +.. class:: TraceDnsCacheHitParams + + See :attr:`TraceConfig.on_dns_cache_hit` for details. + + .. attribute:: host + + Host found in the cache. + +TraceDnsCacheMissParams +----------------------- + +.. class:: TraceDnsCacheMissParams + + See :attr:`TraceConfig.on_dns_cache_miss` for details. + + .. attribute:: host + + Host didn't find the cache. diff --git a/docs/tutorial.rst b/docs/tutorial.rst deleted file mode 100644 index fb1e6459fb7..00000000000 --- a/docs/tutorial.rst +++ /dev/null @@ -1,467 +0,0 @@ -.. _aiohttp-tutorial: - -Server Tutorial -=============== - -Are you going to learn *aiohttp* but don't where to start? We have -example for you. Polls application is a great example for getting -started with aiohttp. - -If you want the full source code in advance or for comparison, check out -the `demo source`_. - -.. _demo source: - https://github.com/aio-libs/aiohttp/tree/master/demos/polls/ - - -.. _aiohttp-tutorial-setup: - -Setup your environment ----------------------- - -First of all check you python version: - -.. code-block:: shell - - $ python -V - Python 3.5.0 - -Tutorial requires Python 3.5.0 or newer. - -We’ll assume that you have already installed *aiohttp* library. You can check -aiohttp is installed and which version by running the following -command: - -.. code-block:: shell - - $ python3 -c 'import aiohttp; print(aiohttp.__version__)' - 2.0.5 - -Project structure looks very similar to other python based web projects: - -.. code-block:: none - - . - ├── README.rst - └── polls - ├── Makefile - ├── README.rst - ├── aiohttpdemo_polls - │ ├── __init__.py - │ ├── __main__.py - │ ├── db.py - │ ├── main.py - │ ├── routes.py - │ ├── templates - │ ├── utils.py - │ └── views.py - ├── config - │ └── polls.yaml - ├── images - │ └── example.png - ├── setup.py - ├── sql - │ ├── create_tables.sql - │ ├── install.sh - │ └── sample_data.sql - └── static - └── style.css - - -.. _aiohttp-tutorial-introduction: - -Getting started with aiohttp first app --------------------------------------- - -This tutorial based on Django polls tutorial. - - -Application ------------ - -All aiohttp server is built around :class:`aiohttp.web.Application` instance. -It is used for registering *startup*/*cleanup* signals, connecting routes etc. - -The following code creates an application:: - - from aiohttp import web - - - app = web.Application() - web.run_app(app, host='127.0.0.1', port=8080) - -Save it under ``aiohttpdemo_polls/main.py`` and start the server: - -.. code-block:: shell - - $ python3 main.py - -You'll see the following output on the command line: - -.. code-block:: shell - - ======== Running on http://127.0.0.1:8080 ======== - (Press CTRL+C to quit) - -Open ``http://127.0.0.1:8080`` in browser or do - -.. code-block:: shell - - $ curl -X GET localhost:8080 - -Alas, for now both return only ``404: Not Found``. -To show something more meaningful let's create a route and a view. - -.. _aiohttp-tutorial-views: - -Views ------ - -Let's start from first views. Create the file ``aiohttpdemo_polls/views.py`` with the following:: - - from aiohttp import web - - - async def index(request): - return web.Response(text='Hello Aiohttp!') - -This is the simplest view possible in Aiohttp. -Now we should create a route for this ``index`` view. Put this into ``aiohttpdemo_polls/routes.py`` (it is a good practice to separate views, routes, models etc. You'll have more of each, and it is nice to have them in different places):: - - from views import index - - - def setup_routes(app): - app.router.add_get('/', index) - - -Also, we should call ``setup_routes`` function somewhere, and the best place is in the ``main.py`` :: - - from aiohttp import web - from routes import setup_routes - - - app = web.Application() - setup_routes(app) - web.run_app(app, host='127.0.0.1', port=8080) - -Start server again. Now if we open browser we can see: - -.. code-block:: shell - - $ curl -X GET localhost:8080 - Hello Aiohttp! - -Success! For now your working directory should look like this: - -.. code-block:: none - - . - ├── .. - └── polls - ├── aiohttpdemo_polls - │ ├── main.py - │ ├── routes.py - │ └── views.py - -.. _aiohttp-tutorial-config: - -Configuration files -------------------- - -aiohttp is configuration agnostic. It means the library doesn't -require any configuration approach and doesn't have builtin support -for any config schema. - -But please take into account these facts: - - 1. 99% of servers have configuration files. - - 2. Every product (except Python-based solutions like Django and - Flask) doesn't store config files as part as source code. - - For example Nginx has own configuration files stored by default - under ``/etc/nginx`` folder. - - Mongo pushes config as ``/etc/mongodb.conf``. - - 3. Config files validation is good idea, strong checks may prevent - silly errors during product deployment. - -Thus we **suggest** to use the following approach: - - 1. Pushing configs as ``yaml`` files (``json`` or ``ini`` is also - good but ``yaml`` is the best). - - 2. Loading ``yaml`` config from a list of predefined locations, - e.g. ``./config/app_cfg.yaml``, ``/etc/app_cfg.yaml``. - - 3. Keeping ability to override config file by command line - parameter, e.g. ``./run_app --config=/opt/config/app_cfg.yaml``. - - 4. Applying strict validation checks to loaded dict. `trafaret - `_, `colander - `_ - or `JSON schema - `_ are good - candidates for such job. - - -Load config and push into application:: - - # load config from yaml file in current dir - conf = load_config(str(pathlib.Path('.') / 'config' / 'polls.yaml')) - app['config'] = conf - -.. _aiohttp-tutorial-database: - -Database --------- - -Setup -^^^^^ - -In this tutorial we will use the latest PostgreSQL database. You can install -PostgreSQL using this instruction http://www.postgresql.org/download/ - -Database schema -^^^^^^^^^^^^^^^ - -We use SQLAlchemy to describe database schemas. -For this tutorial we can use two simple models ``question`` and ``choice``:: - - import sqlalchemy as sa - - meta = sa.MetaData() - - question = sa.Table( - 'question', meta, - sa.Column('id', sa.Integer, nullable=False), - sa.Column('question_text', sa.String(200), nullable=False), - sa.Column('pub_date', sa.Date, nullable=False), - - # Indexes # - sa.PrimaryKeyConstraint('id', name='question_id_pkey')) - - choice = sa.Table( - 'choice', meta, - sa.Column('id', sa.Integer, nullable=False), - sa.Column('question_id', sa.Integer, nullable=False), - sa.Column('choice_text', sa.String(200), nullable=False), - sa.Column('votes', sa.Integer, server_default="0", nullable=False), - - # Indexes # - sa.PrimaryKeyConstraint('id', name='choice_id_pkey'), - sa.ForeignKeyConstraint(['question_id'], [question.c.id], - name='choice_question_id_fkey', - ondelete='CASCADE'), - ) - - - -You can find below description of tables in database: - -First table is question: - -+---------------+ -| question | -+===============+ -| id | -+---------------+ -| question_text | -+---------------+ -| pub_date | -+---------------+ - -and second table is choice table: - -+---------------+ -| choice | -+===============+ -| id | -+---------------+ -| choice_text | -+---------------+ -| votes | -+---------------+ -| question_id | -+---------------+ - -Creating connection engine -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For making DB queries we need an engine instance. Assuming ``conf`` is -a :class:`dict` with configuration info Postgres connection could be -done by the following coroutine:: - - async def init_pg(app): - conf = app['config'] - engine = await aiopg.sa.create_engine( - database=conf['database'], - user=conf['user'], - password=conf['password'], - host=conf['host'], - port=conf['port'], - minsize=conf['minsize'], - maxsize=conf['maxsize'], - loop=app.loop) - app['db'] = engine - -The best place for connecting to DB is -:attr:`~aiohtp.web.Application.on_startup` signal:: - - app.on_startup.append(init_pg) - - -Graceful shutdown -^^^^^^^^^^^^^^^^^ - -There is a good practice to close all resources on program exit. - -Let's close DB connection in :attr:`~aiohtp.web.Application.on_cleanup` signal:: - - async def close_pg(app): - app['db'].close() - await app['db'].wait_closed() - - - app.on_cleanup.append(close_pg) - - - -.. _aiohttp-tutorial-templates: - -Templates ---------- - -Let's add more useful views:: - - @aiohttp_jinja2.template('detail.html') - async def poll(request): - async with request['db'].acquire() as conn: - question_id = request.match_info['question_id'] - try: - question, choices = await db.get_question(conn, - question_id) - except db.RecordNotFound as e: - raise web.HTTPNotFound(text=str(e)) - return { - 'question': question, - 'choices': choices - } - -Templates are very convenient way for web page writing. We return a -dict with page content, ``aiohttp_jinja2.template`` decorator -processes it by jinja2 template renderer. - -For setting up template engine we need to install ``aiohttp_jinja2`` -library first: - -.. code-block:: shell - - $ pip install aiohttp_jinja2 - -After installing we need to setup the library:: - - import aiohttp_jinja2 - import jinja2 - - aiohttp_jinja2.setup( - app, loader=jinja2.PackageLoader('aiohttpdemo_polls', 'templates')) - - -In the tutorial we push template files under -``polls/aiohttpdemo_polls/templates`` folder. - - -.. _aiohttp-tutorial-static: - -Static files ------------- - -Any web site has static files: images, JavaScript sources, CSS files etc. - -The best way to handle static in production is setting up reverse -proxy like NGINX or using CDN services. - -But for development handling static files by aiohttp server is very convenient. - -Fortunately it can be done easy by single call:: - - app.router.add_static('/static/', - path=str(project_root / 'static'), - name='static') - - -where ``project_root`` is the path to root folder. - - -.. _aiohttp-tutorial-middlewares: - -Middlewares ------------ - -Middlewares are stacked around every web-handler. They are called -*before* handler for pre-processing request and *after* getting -response back for post-processing given response. - -Here we'll add a simple middleware for displaying pretty looking pages -for *404 Not Found* and *500 Internal Error*. - -Middlewares could be registered in ``app`` by adding new middleware to -``app.middlewares`` list:: - - def setup_middlewares(app): - error_middleware = error_pages({404: handle_404, - 500: handle_500}) - app.middlewares.append(error_middleware) - -Middleware itself is a factory which accepts *application* and *next -handler* (the following middleware or *web-handler* in case of the -latest middleware in the list). - -The factory returns *middleware handler* which has the same signature -as regular *web-handler* -- it accepts *request* and returns -*response*. - -Middleware for processing HTTP exceptions:: - - def error_pages(overrides): - async def middleware(app, handler): - async def middleware_handler(request): - try: - response = await handler(request) - override = overrides.get(response.status) - if override is None: - return response - else: - return await override(request, response) - except web.HTTPException as ex: - override = overrides.get(ex.status) - if override is None: - raise - else: - return await override(request, ex) - return middleware_handler - return middleware - -Registered overrides are trivial Jinja2 template renderers:: - - - async def handle_404(request, response): - response = aiohttp_jinja2.render_template('404.html', - request, - {}) - return response - - - async def handle_500(request, response): - response = aiohttp_jinja2.render_template('500.html', - request, - {}) - return response - -.. seealso:: :ref:`aiohttp-web-middlewares` - -.. disqus:: - :title: aiohttp server tutorial diff --git a/docs/utilities.rst b/docs/utilities.rst new file mode 100644 index 00000000000..c328244224e --- /dev/null +++ b/docs/utilities.rst @@ -0,0 +1,20 @@ +.. _aiohttp-utilities: + +Utilities +========= + +Miscellaneous API Shared between Client And Server. + +.. currentmodule:: aiohttp + + +.. toctree:: + :name: utilities + + abc + multipart + multipart_reference + streams + signals + structures + websocket_utilities diff --git a/docs/web.rst b/docs/web.rst index c6c7c4d03b6..4fab23d0067 100644 --- a/docs/web.rst +++ b/docs/web.rst @@ -1,1327 +1,21 @@ .. _aiohttp-web: -Server Usage -============ +Server +====== -.. currentmodule:: aiohttp.web +.. module:: aiohttp.web +The page contains all information about aiohttp Server API: -Run a Simple Web Server ------------------------ -In order to implement a web server, first create a -:ref:`request handler `. +.. toctree:: + :name: server -A request handler is a :ref:`coroutine ` or regular function that -accepts a :class:`Request` instance as its only parameter and returns a -:class:`Response` instance:: - - from aiohttp import web - - async def hello(request): - return web.Response(text="Hello, world") - -Next, create an :class:`Application` instance and register the -request handler with the application's :class:`router ` on a -particular *HTTP method* and *path*:: - - app = web.Application() - app.router.add_get('/', hello) - -After that, run the application by :func:`run_app` call:: - - web.run_app(app) - -That's it. Now, head over to ``http://localhost:8080/`` to see the results. - -.. seealso:: - - :ref:`aiohttp-web-graceful-shutdown` section explains what :func:`run_app` - does and how to implement complex server initialization/finalization - from scratch. - - -.. _aiohttp-web-cli: - -Command Line Interface (CLI) ----------------------------- -:mod:`aiohttp.web` implements a basic CLI for quickly serving an -:class:`Application` in *development* over TCP/IP: - -.. code-block:: shell - - $ python -m aiohttp.web -H localhost -P 8080 package.module:init_func - -``package.module:init_func`` should be an importable :term:`callable` that -accepts a list of any non-parsed command-line arguments and returns an -:class:`Application` instance after setting it up:: - - def init_func(argv): - app = web.Application() - app.router.add_get("/", index_handler) - return app - - -.. _aiohttp-web-handler: - -Handler -------- - -A request handler can be any :term:`callable` that accepts a :class:`Request` -instance as its only argument and returns a :class:`StreamResponse` derived -(e.g. :class:`Response`) instance:: - - def handler(request): - return web.Response() - -A handler **may** also be a :ref:`coroutine`, in which case -:mod:`aiohttp.web` will ``await`` the handler:: - - async def handler(request): - return web.Response() - -Handlers are setup to handle requests by registering them with the -:attr:`Application.router` on a particular route (*HTTP method* and -*path* pair) using methods like :class:`UrlDispatcher.add_get` and -:class:`UrlDispatcher.add_post`:: - - app.router.add_get('/', handler) - app.router.add_post('/post', post_handler) - app.router.add_put('/put', put_handler) - -:meth:`~UrlDispatcher.add_route` also supports the wildcard *HTTP method*, -allowing a handler to serve incoming requests on a *path* having **any** -*HTTP method*:: - - app.router.add_route('*', '/path', all_handler) - -The *HTTP method* can be queried later in the request handler using the -:attr:`Request.method` property. - -By default endpoints added with :meth:`~UrlDispatcher.add_get` will accept -``HEAD`` requests and return the same response headers as they would -for a ``GET`` request. You can also deny ``HEAD`` requests on a route:: - - app.router.add_get('/', handler, allow_head=False) - -Here ``handler`` won't be called and the server will response with ``405``. - -.. note:: - - This is a change as of **aiohttp v2.0** to act in accordance with - `RFC 7231 `_. - - Previous version always returned ``405`` for ``HEAD`` requests - to routes added with :meth:`~UrlDispatcher.add_get`. - -If you have handlers which perform lots of processing to write the response -body you may wish to improve performance by skipping that processing -in the case of ``HEAD`` requests while still taking care to respond with -the same headers as with ``GET`` requests. - -.. _aiohttp-web-resource-and-route: - -Resources and Routes --------------------- - -Internally *router* is a list of *resources*. - -Resource is an entry in *route table* which corresponds to requested URL. - -Resource in turn has at least one *route*. - -Route corresponds to handling *HTTP method* by calling *web handler*. - -:meth:`UrlDispatcher.add_get` / :meth:`UrlDispatcher.add_post` and -family are plain shortcuts for :meth:`UrlDispatcher.add_route`. - -:meth:`UrlDispatcher.add_route` in turn is just a shortcut for pair of -:meth:`UrlDispatcher.add_resource` and :meth:`Resource.add_route`:: - - resource = app.router.add_resource(path, name=name) - route = resource.add_route(method, handler) - return route - -.. seealso:: - - :ref:`aiohttp-router-refactoring-021` for more details - -.. versionadded:: 0.21.0 - - Introduce resources. - - -.. _aiohttp-web-custom-resource: - -Custom resource implementation ------------------------------- - -To register custom resource use :meth:`UrlDispatcher.register_resource`. -Resource instance must implement `AbstractResource` interface. - -.. versionadded:: 1.2.1 - - -.. _aiohttp-web-variable-handler: - -Variable Resources -^^^^^^^^^^^^^^^^^^ - -Resource may have *variable path* also. For instance, a resource with -the path ``'/a/{name}/c'`` would match all incoming requests with -paths such as ``'/a/b/c'``, ``'/a/1/c'``, and ``'/a/etc/c'``. - -A variable *part* is specified in the form ``{identifier}``, where the -``identifier`` can be used later in a -:ref:`request handler ` to access the matched value for -that *part*. This is done by looking up the ``identifier`` in the -:attr:`Request.match_info` mapping:: - - async def variable_handler(request): - return web.Response( - text="Hello, {}".format(request.match_info['name'])) - - resource = app.router.add_resource('/{name}') - resource.add_route('GET', variable_handler) - -By default, each *part* matches the regular expression ``[^{}/]+``. - -You can also specify a custom regex in the form ``{identifier:regex}``:: - - resource = app.router.add_resource(r'/{name:\d+}') - -.. note:: - - Regex should match against *percent encoded* URL - (``request.rel_url_raw_path``). E.g. *space character* is encoded - as ``%20``. - - According to - `RFC 3986 `_ - allowed in path symbols are:: - - allowed = unreserved / pct-encoded / sub-delims - / ":" / "@" / "/" - - pct-encoded = "%" HEXDIG HEXDIG - - unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" - - sub-delims = "!" / "$" / "&" / "'" / "(" / ")" - / "*" / "+" / "," / ";" / "=" - -.. _aiohttp-web-named-routes: - -Reverse URL Constructing using Named Resources -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Routes can also be given a *name*:: - - resource = app.router.add_resource('/root', name='root') - -Which can then be used to access and build a *URL* for that resource later (e.g. -in a :ref:`request handler `):: - - >>> request.app.router['root'].url_for().with_query({"a": "b", "c": "d"}) - URL('/root?a=b&c=d') - -A more interesting example is building *URLs* for :ref:`variable -resources `:: - - app.router.add_resource(r'/{user}/info', name='user-info') - - -In this case you can also pass in the *parts* of the route:: - - >>> request.app.router['user-info'].url_for(user='john_doe')\ - ... .with_query("a=b") - '/john_doe/info?a=b' - - -Organizing Handlers in Classes -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -As discussed above, :ref:`handlers ` can be first-class -functions or coroutines:: - - async def hello(request): - return web.Response(text="Hello, world") - - app.router.add_get('/', hello) - -But sometimes it's convenient to group logically similar handlers into a Python -*class*. - -Since :mod:`aiohttp.web` does not dictate any implementation details, -application developers can organize handlers in classes if they so wish:: - - class Handler: - - def __init__(self): - pass - - def handle_intro(self, request): - return web.Response(text="Hello, world") - - async def handle_greeting(self, request): - name = request.match_info.get('name', "Anonymous") - txt = "Hello, {}".format(name) - return web.Response(text=txt) - - handler = Handler() - app.router.add_get('/intro', handler.handle_intro) - app.router.add_get('/greet/{name}', handler.handle_greeting) - - -.. _aiohttp-web-class-based-views: - -Class Based Views -^^^^^^^^^^^^^^^^^ - -:mod:`aiohttp.web` has support for django-style class based views. - -You can derive from :class:`View` and define methods for handling http -requests:: - - class MyView(web.View): - async def get(self): - return await get_resp(self.request) - - async def post(self): - return await post_resp(self.request) - -Handlers should be coroutines accepting self only and returning -response object as regular :term:`web-handler`. Request object can be -retrieved by :attr:`View.request` property. - -After implementing the view (``MyView`` from example above) should be -registered in application's router:: - - app.router.add_route('*', '/path/to', MyView) - -Example will process GET and POST requests for */path/to* but raise -*405 Method not allowed* exception for unimplemented HTTP methods. - -Resource Views -^^^^^^^^^^^^^^ - -*All* registered resources in a router can be viewed using the -:meth:`UrlDispatcher.resources` method:: - - for resource in app.router.resources(): - print(resource) - -Similarly, a *subset* of the resources that were registered with a *name* can be -viewed using the :meth:`UrlDispatcher.named_resources` method:: - - for name, resource in app.router.named_resources().items(): - print(name, resource) - - - -.. versionadded:: 0.18 - :meth:`UrlDispatcher.routes` - -.. versionadded:: 0.19 - :meth:`UrlDispatcher.named_routes` - -.. deprecated:: 0.21 - - Use :meth:`UrlDispatcher.named_resources` / - :meth:`UrlDispatcher.resources` instead of - :meth:`UrlDispatcher.named_routes` / :meth:`UrlDispatcher.routes`. - -Custom Routing Criteria ------------------------ - -Sometimes you need to register :ref:`handlers ` on -more complex criteria than simply a *HTTP method* and *path* pair. - -Although :class:`UrlDispatcher` does not support any extra criteria, routing -based on custom conditions can be accomplished by implementing a second layer -of routing in your application. - -The following example shows custom routing based on the *HTTP Accept* header:: - - class AcceptChooser: - - def __init__(self): - self._accepts = {} - - async def do_route(self, request): - for accept in request.headers.getall('ACCEPT', []): - acceptor = self._accepts.get(accept) - if acceptor is not None: - return (await acceptor(request)) - raise HTTPNotAcceptable() - - def reg_acceptor(self, accept, handler): - self._accepts[accept] = handler - - - async def handle_json(request): - # do json handling - - async def handle_xml(request): - # do xml handling - - chooser = AcceptChooser() - app.router.add_get('/', chooser.do_route) - - chooser.reg_acceptor('application/json', handle_json) - chooser.reg_acceptor('application/xml', handle_xml) - -.. _aiohttp-web-static-file-handling: - -Static file handling --------------------- - -The best way to handle static files (images, JavaScripts, CSS files -etc.) is using `Reverse Proxy`_ like `nginx`_ or `CDN`_ services. - -.. _Reverse Proxy: https://en.wikipedia.org/wiki/Reverse_proxy -.. _nginx: https://nginx.org/ -.. _CDN: https://en.wikipedia.org/wiki/Content_delivery_network - -But for development it's very convenient to handle static files by -aiohttp server itself. - -To do it just register a new static route by -:meth:`UrlDispatcher.add_static` call:: - - app.router.add_static('/prefix', path_to_static_folder) - -When a directory is accessed within a static route then the server responses -to client with ``HTTP/403 Forbidden`` by default. Displaying folder index -instead could be enabled with ``show_index`` parameter set to ``True``:: - - app.router.add_static('/prefix', path_to_static_folder, show_index=True) - -When a symlink from the static directory is accessed, the server responses to -client with ``HTTP/404 Not Found`` by default. To allow the server to follow -symlinks, parameter ``follow_symlinks`` should be set to ``True``:: - - app.router.add_static('/prefix', path_to_static_folder, follow_symlinks=True) - -Template Rendering ------------------- - -:mod:`aiohttp.web` does not support template rendering out-of-the-box. - -However, there is a third-party library, :mod:`aiohttp_jinja2`, which is -supported by the *aiohttp* authors. - -Using it is rather simple. First, setup a *jinja2 environment* with a call -to :func:`aiohttp_jinja2.setup`:: - - app = web.Application(loop=self.loop) - aiohttp_jinja2.setup(app, - loader=jinja2.FileSystemLoader('/path/to/templates/folder')) - -After that you may use the template engine in your -:ref:`handlers `. The most convenient way is to simply -wrap your handlers with the :func:`aiohttp_jinja2.template` decorator:: - - @aiohttp_jinja2.template('tmpl.jinja2') - def handler(request): - return {'name': 'Andrew', 'surname': 'Svetlov'} - -If you prefer the `Mako`_ template engine, please take a look at the -`aiohttp_mako`_ library. - -.. _Mako: http://www.makotemplates.org/ - -.. _aiohttp_mako: https://github.com/aio-libs/aiohttp_mako - - -JSON Response -------------- - -It is a common case to return JSON data in response, :mod:`aiohttp.web` -provides a shortcut for returning JSON -- :func:`aiohttp.web.json_response`:: - - def handler(request): - data = {'some': 'data'} - return web.json_response(data) - -The shortcut method returns :class:`aiohttp.web.Response` instance -so you can for example set cookies before returning it from handler. - - -User Sessions -------------- - -Often you need a container for storing user data across requests. The concept -is usually called a *session*. - -:mod:`aiohttp.web` has no built-in concept of a *session*, however, there is a -third-party library, :mod:`aiohttp_session`, that adds *session* support:: - - import asyncio - import time - import base64 - from cryptography import fernet - from aiohttp import web - from aiohttp_session import setup, get_session, session_middleware - from aiohttp_session.cookie_storage import EncryptedCookieStorage - - async def handler(request): - session = await get_session(request) - last_visit = session['last_visit'] if 'last_visit' in session else None - text = 'Last visited: {}'.format(last_visit) - return web.Response(text=text) - - def make_app(): - app = web.Application() - # secret_key must be 32 url-safe base64-encoded bytes - fernet_key = fernet.Fernet.generate_key() - secret_key = base64.urlsafe_b64decode(fernet_key) - setup(app, EncryptedCookieStorage(secret_key)) - app.router.add_route('GET', '/', handler) - return app - - web.run_app(make_app()) - - -.. _aiohttp-web-expect-header: - -*Expect* Header ---------------- - -:mod:`aiohttp.web` supports *Expect* header. By default it sends -``HTTP/1.1 100 Continue`` line to client, or raises -:exc:`HTTPExpectationFailed` if header value is not equal to -"100-continue". It is possible to specify custom *Expect* header -handler on per route basis. This handler gets called if *Expect* -header exist in request after receiving all headers and before -processing application's :ref:`aiohttp-web-middlewares` and -route handler. Handler can return *None*, in that case the request -processing continues as usual. If handler returns an instance of class -:class:`StreamResponse`, *request handler* uses it as response. Also -handler can raise a subclass of :exc:`HTTPException`. In this case all -further processing will not happen and client will receive appropriate -http response. - -.. note:: - A server that does not understand or is unable to comply with any of the - expectation values in the Expect field of a request MUST respond with - appropriate error status. The server MUST respond with a 417 - (Expectation Failed) status if any of the expectations cannot be met or, - if there are other problems with the request, some other 4xx status. - - http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.20 - -If all checks pass, the custom handler *must* write a *HTTP/1.1 100 Continue* -status code before returning. - -The following example shows how to setup a custom handler for the *Expect* -header:: - - async def check_auth(request): - if request.version != aiohttp.HttpVersion11: - return - - if request.headers.get('EXPECT') != '100-continue': - raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect) - - if request.headers.get('AUTHORIZATION') is None: - raise HTTPForbidden() - - request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") - - async def hello(request): - return web.Response(body=b"Hello, world") - - app = web.Application() - app.router.add_get('/', hello, expect_handler=check_auth) - -.. _aiohttp-web-forms: - -HTTP Forms ----------- - -HTTP Forms are supported out of the box. - -If form's method is ``"GET"`` (``
    ``) use -:attr:`Request.query` for getting form data. - -To access form data with ``"POST"`` method use -:meth:`Request.post` or :meth:`Request.multipart`. - -:meth:`Request.post` accepts both -``'application/x-www-form-urlencoded'`` and ``'multipart/form-data'`` -form's data encoding (e.g. ````). -It stores files data in temporary directory. If `client_max_size` is -specified `post` raises `ValueError` exception. -For efficiency use :meth:`Request.multipart`, It is especially effective -for uploading large files (:ref:`aiohttp-web-file-upload`). - -Values submitted by the following form: - -.. code-block:: html - - - - - - - - - -
    - -could be accessed as:: - - async def do_login(request): - data = await request.post() - login = data['login'] - password = data['password'] - - -.. _aiohttp-web-file-upload: - -File Uploads ------------- - -:mod:`aiohttp.web` has built-in support for handling files uploaded from the -browser. - -First, make sure that the HTML ``
    `` element has its *enctype* attribute -set to ``enctype="multipart/form-data"``. As an example, here is a form that -accepts an MP3 file: - -.. code-block:: html - - - - - - - -
    - -Then, in the :ref:`request handler ` you can access the -file input field as a :class:`FileField` instance. :class:`FileField` is simply -a container for the file as well as some of its metadata:: - - async def store_mp3_handler(request): - - # WARNING: don't do that if you plan to receive large files! - data = await request.post() - - mp3 = data['mp3'] - - # .filename contains the name of the file in string format. - filename = mp3.filename - - # .file contains the actual file data that needs to be stored somewhere. - mp3_file = data['mp3'].file - - content = mp3_file.read() - - return web.Response(body=content, - headers=MultiDict( - {'CONTENT-DISPOSITION': mp3_file}) - - -You might be noticed a big warning in example above. The general issue is that -:meth:`Request.post` reads whole payload in memory. That's may hurt with -:abbr:`OOM (Out Of Memory)` error. To avoid this, for multipart uploads, you -should use :meth:`Request.multipart` which returns :ref:`multipart reader -` back:: - - async def store_mp3_handler(request): - - reader = await request.multipart() - - # /!\ Don't forget to validate your inputs /!\ - - mp3 = await reader.next() - - filename = mp3.filename - - # You cannot rely on Content-Length if transfer is chunked. - size = 0 - with open(os.path.join('/spool/yarrr-media/mp3/', filename), 'wb') as f: - while True: - chunk = await mp3.read_chunk() # 8192 bytes by default. - if not chunk: - break - size += len(chunk) - f.write(chunk) - - return web.Response(text='{} sized of {} successfully stored' - ''.format(filename, size)) - -.. _aiohttp-web-websockets: - -WebSockets ----------- - -:mod:`aiohttp.web` supports *WebSockets* out-of-the-box. - -To setup a *WebSocket*, create a :class:`WebSocketResponse` in a -:ref:`request handler ` and then use it to communicate -with the peer:: - - async def websocket_handler(request): - - ws = web.WebSocketResponse() - await ws.prepare(request) - - async for msg in ws: - if msg.type == aiohttp.WSMsgType.TEXT: - if msg.data == 'close': - await ws.close() - else: - ws.send_str(msg.data + '/answer') - elif msg.type == aiohttp.WSMsgType.ERROR: - print('ws connection closed with exception %s' % - ws.exception()) - - print('websocket connection closed') - - return ws - -.. _aiohttp-web-websocket-read-same-task: - -Reading from the *WebSocket* (``await ws.receive()``) **must only** be done -inside the request handler *task*; however, writing (``ws.send_str(...)``) to the -*WebSocket*, closing (``await ws.close()``) and canceling the handler -task may be delegated to other tasks. See also :ref:`FAQ section -`. - -*aiohttp.web* creates an implicit :class:`asyncio.Task` for handling every -incoming request. - -.. note:: - - While :mod:`aiohttp.web` itself only supports *WebSockets* without - downgrading to *LONG-POLLING*, etc., our team supports SockJS_, an - aiohttp-based library for implementing SockJS-compatible server - code. - -.. _SockJS: https://github.com/aio-libs/sockjs - - -.. warning:: - - Parallel reads from websocket are forbidden, there is no - possibility to call :meth:`aiohttp.web.WebSocketResponse.receive` - from two tasks. - - See :ref:`FAQ section ` for - instructions how to solve the problem. - - -.. _aiohttp-web-exceptions: - -Exceptions ----------- - -:mod:`aiohttp.web` defines a set of exceptions for every *HTTP status code*. - -Each exception is a subclass of :class:`~HTTPException` and relates to a single -HTTP status code. - -The exceptions are also a subclass of :class:`Response`, allowing you to either -``raise`` or ``return`` them in a -:ref:`request handler ` for the same effect. - -The following snippets are the same:: - - async def handler(request): - return aiohttp.web.HTTPFound('/redirect') - -and:: - - async def handler(request): - raise aiohttp.web.HTTPFound('/redirect') - - -Each exception class has a status code according to :rfc:`2068`: -codes with 100-300 are not really errors; 400s are client errors, -and 500s are server errors. - -HTTP Exception hierarchy chart:: - - Exception - HTTPException - HTTPSuccessful - * 200 - HTTPOk - * 201 - HTTPCreated - * 202 - HTTPAccepted - * 203 - HTTPNonAuthoritativeInformation - * 204 - HTTPNoContent - * 205 - HTTPResetContent - * 206 - HTTPPartialContent - HTTPRedirection - * 300 - HTTPMultipleChoices - * 301 - HTTPMovedPermanently - * 302 - HTTPFound - * 303 - HTTPSeeOther - * 304 - HTTPNotModified - * 305 - HTTPUseProxy - * 307 - HTTPTemporaryRedirect - * 308 - HTTPPermanentRedirect - HTTPError - HTTPClientError - * 400 - HTTPBadRequest - * 401 - HTTPUnauthorized - * 402 - HTTPPaymentRequired - * 403 - HTTPForbidden - * 404 - HTTPNotFound - * 405 - HTTPMethodNotAllowed - * 406 - HTTPNotAcceptable - * 407 - HTTPProxyAuthenticationRequired - * 408 - HTTPRequestTimeout - * 409 - HTTPConflict - * 410 - HTTPGone - * 411 - HTTPLengthRequired - * 412 - HTTPPreconditionFailed - * 413 - HTTPRequestEntityTooLarge - * 414 - HTTPRequestURITooLong - * 415 - HTTPUnsupportedMediaType - * 416 - HTTPRequestRangeNotSatisfiable - * 417 - HTTPExpectationFailed - * 421 - HTTPMisdirectedRequest - * 426 - HTTPUpgradeRequired - * 428 - HTTPPreconditionRequired - * 429 - HTTPTooManyRequests - * 431 - HTTPRequestHeaderFieldsTooLarge - * 451 - HTTPUnavailableForLegalReasons - HTTPServerError - * 500 - HTTPInternalServerError - * 501 - HTTPNotImplemented - * 502 - HTTPBadGateway - * 503 - HTTPServiceUnavailable - * 504 - HTTPGatewayTimeout - * 505 - HTTPVersionNotSupported - * 506 - HTTPVariantAlsoNegotiates - * 510 - HTTPNotExtended - * 511 - HTTPNetworkAuthenticationRequired - -All HTTP exceptions have the same constructor signature:: - - HTTPNotFound(*, headers=None, reason=None, - body=None, text=None, content_type=None) - -If not directly specified, *headers* will be added to the *default -response headers*. - -Classes :class:`HTTPMultipleChoices`, :class:`HTTPMovedPermanently`, -:class:`HTTPFound`, :class:`HTTPSeeOther`, :class:`HTTPUseProxy`, -:class:`HTTPTemporaryRedirect` have the following constructor signature:: - - HTTPFound(location, *, headers=None, reason=None, - body=None, text=None, content_type=None) - -where *location* is value for *Location HTTP header*. - -:class:`HTTPMethodNotAllowed` is constructed by providing the incoming -unsupported method and list of allowed methods:: - - HTTPMethodNotAllowed(method, allowed_methods, *, - headers=None, reason=None, - body=None, text=None, content_type=None) - - -.. _aiohttp-web-data-sharing: - -Data Sharing aka No Singletons Please -------------------------------------- - -:mod:`aiohttp.web` discourages the use of *global variables*, aka *singletons*. -Every variable should have its own context that is *not global*. - -So, :class:`aiohttp.web.Application` and :class:`aiohttp.web.Request` -support a :class:`collections.abc.MutableMapping` interface (i.e. they are -dict-like objects), allowing them to be used as data stores. - -For storing *global-like* variables, feel free to save them in an -:class:`~.Application` instance:: - - app['my_private_key'] = data - -and get it back in the :term:`web-handler`:: - - async def handler(request): - data = request.app['my_private_key'] - -Variables that are only needed for the lifetime of a :class:`~.Request`, can be -stored in a :class:`~.Request`:: - - async def handler(request): - request['my_private_key'] = "data" - ... - -This is mostly useful for :ref:`aiohttp-web-middlewares` and -:ref:`aiohttp-web-signals` handlers to store data for further processing by the -next handlers in the chain. - -To avoid clashing with other *aiohttp* users and third-party libraries, please -choose a unique key name for storing data. - -If your code is published on PyPI, then the project name is most likely unique -and safe to use as the key. -Otherwise, something based on your company name/url would be satisfactory (i.e. -``org.company.app``). - - -.. _aiohttp-web-middlewares: - -Middlewares ------------ - -:mod:`aiohttp.web` provides a powerful mechanism for customizing -:ref:`request handlers` via *middlewares*. - -*Middlewares* are setup by providing a sequence of *middleware factories* to -the keyword-only ``middlewares`` parameter when creating an -:class:`Application`:: - - app = web.Application(middlewares=[middleware_factory_1, - middleware_factory_2]) - -A *middleware factory* is simply a coroutine that implements the logic of a -*middleware*. For example, here's a trivial *middleware factory*:: - - async def middleware_factory(app, handler): - async def middleware_handler(request): - return await handler(request) - return middleware_handler - -Every *middleware factory* should accept two parameters, an -:class:`app ` instance and a *handler*, and return a new handler. - -The *handler* passed in to a *middleware factory* is the handler returned by -the **next** *middleware factory*. The last *middleware factory* always receives -the :ref:`request handler ` selected by the router itself -(by :meth:`UrlDispatcher.resolve`). - -*Middleware factories* should return a new handler that has the same signature -as a :ref:`request handler `. That is, it should accept a -single :class:`Request` instance and return a :class:`Response`, or raise an -exception. - -Internally, a single :ref:`request handler ` is constructed -by applying the middleware chain to the original handler in reverse order, -and is called by the :class:`RequestHandler` as a regular *handler*. - -Since *middleware factories* are themselves coroutines, they may perform extra -``await`` calls when creating a new handler, e.g. call database etc. - -*Middlewares* usually call the inner handler, but they may choose to ignore it, -e.g. displaying *403 Forbidden page* or raising :exc:`HTTPForbidden` exception -if user has no permissions to access the underlying resource. -They may also render errors raised by the handler, perform some pre- or -post-processing like handling *CORS* and so on. - - -Example -^^^^^^^ - -A common use of middlewares is to implement custom error pages. The following -example will render 404 errors using a JSON response, as might be appropriate -a JSON REST service:: - - import json - from aiohttp import web - - def json_error(message): - return web.Response( - body=json.dumps({'error': message}).encode('utf-8'), - content_type='application/json') - - async def error_middleware(app, handler): - async def middleware_handler(request): - try: - response = await handler(request) - if response.status == 404: - return json_error(response.message) - return response - except web.HTTPException as ex: - if ex.status == 404: - return json_error(ex.reason) - raise - return middleware_handler - - app = web.Application(middlewares=[error_middleware]) - -.. _aiohttp-web-signals: - -Signals -------- - -.. versionadded:: 0.18 - -Although :ref:`middlewares ` can customize -:ref:`request handlers` before or after a :class:`Response` -has been prepared, they can't customize a :class:`Response` **while** it's -being prepared. For this :mod:`aiohttp.web` provides *signals*. - -For example, a middleware can only change HTTP headers for *unprepared* -responses (see :meth:`~aiohttp.web.StreamResponse.prepare`), but sometimes we -need a hook for changing HTTP headers for streamed responses and WebSockets. -This can be accomplished by subscribing to the -:attr:`~aiohttp.web.Application.on_response_prepare` signal:: - - async def on_prepare(request, response): - response.headers['My-Header'] = 'value' - - app.on_response_prepare.append(on_prepare) - - -Signal handlers should not return a value but may modify incoming mutable -parameters. - - -.. warning:: - - Signals API has provisional status, meaning it may be changed in future - releases. - - Signal subscription and sending will most likely be the same, but signal - object creation is subject to change. As long as you are not creating new - signals, but simply reusing existing ones, you will not be affected. - -.. _aiohttp-web-nested-applications: - -Nested applications -------------------- - -Sub applications are designed for solving the problem of the big -monolithic code base. -Let's assume we have a project with own business logic and tools like -administration panel and debug toolbar. - -Administration panel is a separate application by its own nature but all -toolbar URLs are served by prefix like ``/admin``. - -Thus we'll create a totally separate application named ``admin`` and -connect it to main app with prefix by -:meth:`~aiohttp.web.Application.add_subapp`:: - - admin = web.Application() - # setup admin routes, signals and middlewares - - app.add_subapp('/admin/', admin) - -Middlewares and signals from ``app`` and ``admin`` are chained. - -It means that if URL is ``'/admin/something'`` middlewares from -``app`` are applied first and ``admin.middlewares`` are the next in -the call chain. - -The same is going for -:attr:`~aiohttp.web.Application.on_response_prepare` signal -- the -signal is delivered to both top level ``app`` and ``admin`` if -processing URL is routed to ``admin`` sub-application. - -Common signals like :attr:`~aiohttp.web.Application.on_startup`, -:attr:`~aiohttp.web.Application.on_shutdown` and -:attr:`~aiohttp.web.Application.on_cleanup` are delivered to all -registered sub-applications. The passed parameter is sub-application -instance, not top-level application. - - -Third level sub-applications can be nested into second level ones -- -there are no limitation for nesting level. - -Url reversing for sub-applications should generate urls with proper prefix. - -But for getting URL sub-application's router should be used:: - - admin = web.Application() - admin.router.add_get('/resource', handler, name='name') - - app.add_subapp('/admin/', admin) - - url = admin.router['name'].url_for() - -The generated ``url`` from example will have a value -``URL('/admin/resource')``. - -If main application should do URL reversing for sub-application it could -use the following explicit technique:: - - admin = web.Application() - admin.router.add_get('/resource', handler, name='name') - - app.add_subapp('/admin/', admin) - app['admin'] = admin - - async def handler(request): # main application's handler - admin = request.app['admin'] - url = admin.router['name'].url_for() - -.. _aiohttp-web-flow-control: - -Flow control ------------- - -:mod:`aiohttp.web` has sophisticated flow control for underlying TCP -sockets write buffer. - -The problem is: by default TCP sockets use `Nagle's algorithm -`_ for output -buffer which is not optimal for streaming data protocols like HTTP. - -Web server response may have one of the following states: - -1. **CORK** (:attr:`~StreamResponse.tcp_cork` is ``True``). - Don't send out partial TCP/IP frames. All queued partial frames - are sent when the option is cleared again. Optimal for sending big - portion of data since data will be sent using minimum - frames count. - - If OS doesn't support **CORK** mode (neither ``socket.TCP_CORK`` - nor ``socket.TCP_NOPUSH`` exists) the mode is equal to *Nagle's - enabled* one. The most widespread OS without **CORK** support is - *Windows*. - -2. **NODELAY** (:attr:`~StreamResponse.tcp_nodelay` is - ``True``). Disable the Nagle algorithm. This means that small - data pieces are always sent as soon as possible, even if there is - only a small amount of data. Optimal for transmitting short messages. - -3. Nagle's algorithm enabled (both - :attr:`~StreamResponse.tcp_cork` and - :attr:`~StreamResponse.tcp_nodelay` are ``False``). - Data is buffered until there is a sufficient amount to send out. - Avoid using this mode for sending HTTP data until you have no doubts. - -By default streaming data (:class:`StreamResponse`), regular responses -(:class:`Response` and http exceptions derived from it) and websockets -(:class:`WebSocketResponse`) use **NODELAY** mode, static file -handlers work in **CORK** mode. - -To manual mode switch :meth:`~StreamResponse.set_tcp_cork` and -:meth:`~StreamResponse.set_tcp_nodelay` methods can be used. It may -be helpful for better streaming control for example. - - -.. _aiohttp-web-graceful-shutdown: - -Graceful shutdown ------------------- - -Stopping *aiohttp web server* by just closing all connections is not -always satisfactory. - -The problem is: if application supports :term:`websocket`\s or *data -streaming* it most likely has open connections at server -shutdown time. - -The *library* has no knowledge how to close them gracefully but -developer can help by registering :attr:`Application.on_shutdown` -signal handler and call the signal on *web server* closing. - -Developer should keep a list of opened connections -(:class:`Application` is a good candidate). - -The following :term:`websocket` snippet shows an example for websocket -handler:: - - app = web.Application() - app['websockets'] = [] - - async def websocket_handler(request): - ws = web.WebSocketResponse() - await ws.prepare(request) - - request.app['websockets'].append(ws) - try: - async for msg in ws: - ... - finally: - request.app['websockets'].remove(ws) - - return ws - -Signal handler may look like:: - - async def on_shutdown(app): - for ws in app['websockets']: - await ws.close(code=WSCloseCode.GOING_AWAY, - message='Server shutdown') - - app.on_shutdown.append(on_shutdown) - -Proper finalization procedure has three steps: - - 1. Stop accepting new client connections by - :meth:`asyncio.Server.close` and - :meth:`asyncio.Server.wait_closed` calls. - - 2. Fire :meth:`Application.shutdown` event. - - 3. Close accepted connections from clients by - :meth:`Server.shutdown` call with - reasonable small delay. - - 4. Call registered application finalizers by :meth:`Application.cleanup`. - -The following code snippet performs proper application start, run and -finalizing. It's pretty close to :func:`run_app` utility function:: - - loop = asyncio.get_event_loop() - handler = app.make_handler() - f = loop.create_server(handler, '0.0.0.0', 8080) - srv = loop.run_until_complete(f) - print('serving on', srv.sockets[0].getsockname()) - try: - loop.run_forever() - except KeyboardInterrupt: - pass - finally: - srv.close() - loop.run_until_complete(srv.wait_closed()) - loop.run_until_complete(app.shutdown()) - loop.run_until_complete(handler.shutdown(60.0)) - loop.run_until_complete(app.cleanup()) - loop.close() - -.. _aiohttp-web-background-tasks: - -Background tasks ------------------ - -Sometimes there's a need to perform some asynchronous operations just -after application start-up. - -Even more, in some sophisticated systems there could be a need to run some -background tasks in the event loop along with the application's request -handler. Such as listening to message queue or other network message/event -sources (e.g. ZeroMQ, Redis Pub/Sub, AMQP, etc.) to react to received messages -within the application. - -For example the background task could listen to ZeroMQ on :data:`zmq.SUB` socket, -process and forward retrieved messages to clients connected via WebSocket -that are stored somewhere in the application -(e.g. in the :obj:`application['websockets']` list). - -To run such short and long running background tasks aiohttp provides an -ability to register :attr:`Application.on_startup` signal handler(s) that -will run along with the application's request handler. - -For example there's a need to run one quick task and two long running -tasks that will live till the application is alive. The appropriate -background tasks could be registered as an :attr:`Application.on_startup` -signal handlers as shown in the example below:: - - - async def listen_to_redis(app): - try: - sub = await aioredis.create_redis(('localhost', 6379), loop=app.loop) - ch, *_ = await sub.subscribe('news') - async for msg in ch.iter(encoding='utf-8'): - # Forward message to all connected websockets: - for ws in app['websockets']: - ws.send_str('{}: {}'.format(ch.name, msg)) - except asyncio.CancelledError: - pass - finally: - await sub.unsubscribe(ch.name) - await sub.quit() - - - async def start_background_tasks(app): - app['redis_listener'] = app.loop.create_task(listen_to_redis(app)) - - - async def cleanup_background_tasks(app): - app['redis_listener'].cancel() - await app['redis_listener'] - - - app = web.Application() - app.on_startup.append(start_background_tasks) - app.on_cleanup.append(cleanup_background_tasks) - web.run_app(app) - - -The task :func:`listen_to_redis` will run forever. -To shut it down correctly :attr:`Application.on_cleanup` signal handler -may be used to send a cancellation to it. - - -Handling error pages --------------------- - -Pages like *404 Not Found* and *500 Internal Error* could be handled -by custom middleware, see :ref:`aiohttp-tutorial-middlewares` for -details. - -Swagger support ---------------- - -`aiohttp-swagger `_ is a -library that allow to add Swagger documentation and embed the -Swagger-UI into your :mod:`aiohttp.web` project. - -CORS support ------------- - -:mod:`aiohttp.web` itself does not support `Cross-Origin Resource -Sharing `_, but -there is an aiohttp plugin for it: -`aiohttp_cors `_. - - -Debug Toolbar -------------- - -aiohttp_debugtoolbar_ is a very useful library that provides a debugging toolbar -while you're developing an :mod:`aiohttp.web` application. - -Install it via ``pip``: - -.. code-block:: shell - - $ pip install aiohttp_debugtoolbar - - -After that attach the :mod:`aiohttp_debugtoolbar` middleware to your -:class:`aiohttp.web.Application` and call :func:`aiohttp_debugtoolbar.setup`:: - - import aiohttp_debugtoolbar - from aiohttp_debugtoolbar import toolbar_middleware_factory - - app = web.Application(loop=loop, - middlewares=[toolbar_middleware_factory]) - aiohttp_debugtoolbar.setup(app) - -The toolbar is ready to use. Enjoy!!! - -.. _aiohttp_debugtoolbar: https://github.com/aio-libs/aiohttp_debugtoolbar - - -Dev Tools ---------- - -aiohttp-devtools_ provides a couple of tools to simplify development of -:mod:`aiohttp.web` applications. - - -Install via ``pip``: - -.. code-block:: shell - - $ pip install aiohttp-devtools - - * ``runserver`` provides a development server with auto-reload, live-reload, static file serving and - aiohttp_debugtoolbar_ integration. - * ``start`` is a `cookiecutter command which does the donkey work of creating new :mod:`aiohttp.web` - Applications. - -Documentation and a complete tutorial of creating and running an app locally are available at -aiohttp-devtools_. - -.. _aiohttp-devtools: https://github.com/samuelcolvin/aiohttp-devtools - - -.. disqus:: - :title: aiohttp server usage + Tutorial + Quickstart + Advanced Usage + Low Level + Reference + Logging + Testing + Deployment diff --git a/docs/web_advanced.rst b/docs/web_advanced.rst new file mode 100644 index 00000000000..01a33410825 --- /dev/null +++ b/docs/web_advanced.rst @@ -0,0 +1,1006 @@ +.. _aiohttp-web-advanced: + +Web Server Advanced +=================== + +.. currentmodule:: aiohttp.web + + +Unicode support +--------------- + +*aiohttp* does :term:`requoting` of incoming request path. + +Unicode (non-ASCII) symbols are processed transparently on both *route +adding* and *resolving* (internally everything is converted to +:term:`percent-encoding` form by :term:`yarl` library). + +But in case of custom regular expressions for +:ref:`aiohttp-web-variable-handler` please take care that URL is +*percent encoded*: if you pass Unicode patterns they don't match to +*requoted* path. + +Peer disconnection +------------------ + +When a client peer is gone a subsequent reading or writing raises :exc:`OSError` +or more specific exception like :exc:`ConnectionResetError`. + +The reason for disconnection is vary; it can be a network issue or explicit +socket closing on the peer side without reading the whole server response. + +*aiohttp* handles disconnection properly but you can handle it explicitly, e.g.:: + + async def handler(request): + try: + text = await request.text() + except OSError: + # disconnected + +Passing a coroutine into run_app and Gunicorn +--------------------------------------------- + +:func:`run_app` accepts either application instance or a coroutine for +making an application. The coroutine based approach allows to perform +async IO before making an app:: + + async def app_factory(): + await pre_init() + app = web.Application() + app.router.add_get(...) + return app + + web.run_app(app_factory()) + +Gunicorn worker supports a factory as well. For Gunicorn the factory +should accept zero parameters:: + + async def my_web_app(): + app = web.Application() + app.router.add_get(...) + return app + +Start gunicorn: + +.. code-block:: shell + + $ gunicorn my_app_module:my_web_app --bind localhost:8080 --worker-class aiohttp.GunicornWebWorker + +.. versionadded:: 3.1 + +Custom Routing Criteria +----------------------- + +Sometimes you need to register :ref:`handlers ` on +more complex criteria than simply a *HTTP method* and *path* pair. + +Although :class:`UrlDispatcher` does not support any extra criteria, routing +based on custom conditions can be accomplished by implementing a second layer +of routing in your application. + +The following example shows custom routing based on the *HTTP Accept* header:: + + class AcceptChooser: + + def __init__(self): + self._accepts = {} + + async def do_route(self, request): + for accept in request.headers.getall('ACCEPT', []): + acceptor = self._accepts.get(accept) + if acceptor is not None: + return (await acceptor(request)) + raise HTTPNotAcceptable() + + def reg_acceptor(self, accept, handler): + self._accepts[accept] = handler + + + async def handle_json(request): + # do json handling + + async def handle_xml(request): + # do xml handling + + chooser = AcceptChooser() + app.add_routes([web.get('/', chooser.do_route)]) + + chooser.reg_acceptor('application/json', handle_json) + chooser.reg_acceptor('application/xml', handle_xml) + +.. _aiohttp-web-static-file-handling: + +Static file handling +-------------------- + +The best way to handle static files (images, JavaScripts, CSS files +etc.) is using `Reverse Proxy`_ like `nginx`_ or `CDN`_ services. + +.. _Reverse Proxy: https://en.wikipedia.org/wiki/Reverse_proxy +.. _nginx: https://nginx.org/ +.. _CDN: https://en.wikipedia.org/wiki/Content_delivery_network + +But for development it's very convenient to handle static files by +aiohttp server itself. + +To do it just register a new static route by +:meth:`RouteTableDef.static` or :func:`static` calls:: + + app.add_routes([web.static('/prefix', path_to_static_folder)]) + + routes.static('/prefix', path_to_static_folder) + +When a directory is accessed within a static route then the server responses +to client with ``HTTP/403 Forbidden`` by default. Displaying folder index +instead could be enabled with ``show_index`` parameter set to ``True``:: + + web.static('/prefix', path_to_static_folder, show_index=True) + +When a symlink from the static directory is accessed, the server responses to +client with ``HTTP/404 Not Found`` by default. To allow the server to follow +symlinks, parameter ``follow_symlinks`` should be set to ``True``:: + + web.static('/prefix', path_to_static_folder, follow_symlinks=True) + +When you want to enable cache busting, +parameter ``append_version`` can be set to ``True`` + +Cache busting is the process of appending some form of file version hash +to the filename of resources like JavaScript and CSS files. +The performance advantage of doing this is that we can tell the browser +to cache these files indefinitely without worrying about the client not getting +the latest version when the file changes:: + + web.static('/prefix', path_to_static_folder, append_version=True) + +Template Rendering +------------------ + +:mod:`aiohttp.web` does not support template rendering out-of-the-box. + +However, there is a third-party library, :mod:`aiohttp_jinja2`, which is +supported by the *aiohttp* authors. + +Using it is rather simple. First, setup a *jinja2 environment* with a call +to :func:`aiohttp_jinja2.setup`:: + + app = web.Application() + aiohttp_jinja2.setup(app, + loader=jinja2.FileSystemLoader('/path/to/templates/folder')) + +After that you may use the template engine in your +:ref:`handlers `. The most convenient way is to simply +wrap your handlers with the :func:`aiohttp_jinja2.template` decorator:: + + @aiohttp_jinja2.template('tmpl.jinja2') + async def handler(request): + return {'name': 'Andrew', 'surname': 'Svetlov'} + +If you prefer the `Mako`_ template engine, please take a look at the +`aiohttp_mako`_ library. + +.. warning:: + + :func:`aiohttp_jinja2.template` should be applied **before** + :meth:`RouteTableDef.get` decorator and family, e.g. it must be + the *first* (most *down* decorator in the chain):: + + + @routes.get('/path') + @aiohttp_jinja2.template('tmpl.jinja2') + async def handler(request): + return {'name': 'Andrew', 'surname': 'Svetlov'} + + +.. _Mako: http://www.makotemplates.org/ + +.. _aiohttp_mako: https://github.com/aio-libs/aiohttp_mako + + +.. _aiohttp-web-websocket-read-same-task: + +Reading from the same task in WebSockets +---------------------------------------- + +Reading from the *WebSocket* (``await ws.receive()``) **must only** be +done inside the request handler *task*; however, writing +(``ws.send_str(...)``) to the *WebSocket*, closing (``await +ws.close()``) and canceling the handler task may be delegated to other +tasks. See also :ref:`FAQ section +`. + +:mod:`aiohttp.web` creates an implicit :class:`asyncio.Task` for +handling every incoming request. + +.. note:: + + While :mod:`aiohttp.web` itself only supports *WebSockets* without + downgrading to *LONG-POLLING*, etc., our team supports SockJS_, an + aiohttp-based library for implementing SockJS-compatible server + code. + +.. _SockJS: https://github.com/aio-libs/sockjs + + +.. warning:: + + Parallel reads from websocket are forbidden, there is no + possibility to call :meth:`WebSocketResponse.receive` + from two tasks. + + See :ref:`FAQ section ` for + instructions how to solve the problem. + + +.. _aiohttp-web-data-sharing: + +Data Sharing aka No Singletons Please +------------------------------------- + +:mod:`aiohttp.web` discourages the use of *global variables*, aka *singletons*. +Every variable should have its own context that is *not global*. + +So, :class:`Application` and :class:`Request` +support a :class:`collections.abc.MutableMapping` interface (i.e. they are +dict-like objects), allowing them to be used as data stores. + + +.. _aiohttp-web-data-sharing-app-config: + +Application's config +^^^^^^^^^^^^^^^^^^^^ + +For storing *global-like* variables, feel free to save them in an +:class:`Application` instance:: + + app['my_private_key'] = data + +and get it back in the :term:`web-handler`:: + + async def handler(request): + data = request.app['my_private_key'] + +In case of :ref:`nested applications +` the desired lookup strategy could +be the following: + +1. Search the key in the current nested application. +2. If the key is not found continue searching in the parent application(s). + +For this please use :attr:`Request.config_dict` read-only property:: + + async def handler(request): + data = request.config_dict['my_private_key'] + + +Request's storage +^^^^^^^^^^^^^^^^^ + +Variables that are only needed for the lifetime of a :class:`Request`, can be +stored in a :class:`Request`:: + + async def handler(request): + request['my_private_key'] = "data" + ... + +This is mostly useful for :ref:`aiohttp-web-middlewares` and +:ref:`aiohttp-web-signals` handlers to store data for further processing by the +next handlers in the chain. + +Response's storage +^^^^^^^^^^^^^^^^^^ + +:class:`StreamResponse` and :class:`Response` objects +also support :class:`collections.abc.MutableMapping` interface. This is useful +when you want to share data with signals and middlewares once all the work in +the handler is done:: + + async def handler(request): + [ do all the work ] + response['my_metric'] = 123 + return response + + +Naming hint +^^^^^^^^^^^ + +To avoid clashing with other *aiohttp* users and third-party libraries, please +choose a unique key name for storing data. + +If your code is published on PyPI, then the project name is most likely unique +and safe to use as the key. +Otherwise, something based on your company name/url would be satisfactory (i.e. +``org.company.app``). + + +.. _aiohttp-web-contextvars: + + +ContextVars support +------------------- + +Starting from Python 3.7 asyncio has :mod:`Context Variables ` as a +context-local storage (a generalization of thread-local concept that works with asyncio +tasks also). + + +*aiohttp* server supports it in the following way: + +* A server inherits the current task's context used when creating it. + :func:`aiohttp.web.run_app()` runs a task for handling all underlying jobs running + the app, but alternatively :ref:`aiohttp-web-app-runners` can be used. + +* Application initialization / finalization events (:attr:`Application.cleanup_ctx`, + :attr:`Application.on_startup` and :attr:`Application.on_shutdown`, + :attr:`Application.on_cleanup`) are executed inside the same context. + + E.g. all context modifications made on application startup are visible on teardown. + +* On every request handling *aiohttp* creates a context copy. :term:`web-handler` has + all variables installed on initialization stage. But the context modification made by + a handler or middleware is invisible to another HTTP request handling call. + +An example of context vars usage:: + + from contextvars import ContextVar + + from aiohttp import web + + VAR = ContextVar('VAR', default='default') + + + async def coro(): + return VAR.get() + + + async def handler(request): + var = VAR.get() + VAR.set('handler') + ret = await coro() + return web.Response(text='\n'.join([var, + ret])) + + + async def on_startup(app): + print('on_startup', VAR.get()) + VAR.set('on_startup') + + + async def on_cleanup(app): + print('on_cleanup', VAR.get()) + VAR.set('on_cleanup') + + + async def init(): + print('init', VAR.get()) + VAR.set('init') + app = web.Application() + app.router.add_get('/', handler) + + app.on_startup.append(on_startup) + app.on_cleanup.append(on_cleanup) + return app + + + web.run_app(init()) + print('done', VAR.get()) + +.. versionadded:: 3.5 + + +.. _aiohttp-web-middlewares: + +Middlewares +----------- + +:mod:`aiohttp.web` provides a powerful mechanism for customizing +:ref:`request handlers` via *middlewares*. + +A *middleware* is a coroutine that can modify either the request or +response. For example, here's a simple *middleware* which appends +``' wink'`` to the response:: + + from aiohttp.web import middleware + + @middleware + async def middleware(request, handler): + resp = await handler(request) + resp.text = resp.text + ' wink' + return resp + +.. note:: + + The example won't work with streamed responses or websockets + +Every *middleware* should accept two parameters, a :class:`request +` instance and a *handler*, and return the response or raise +an exception. If the exception is not an instance of +:exc:`HTTPException` it is converted to ``500`` +:exc:`HTTPInternalServerError` after processing the +middlewares chain. + +.. warning:: + + Second argument should be named *handler* exactly. + +When creating an :class:`Application`, these *middlewares* are passed to +the keyword-only ``middlewares`` parameter:: + + app = web.Application(middlewares=[middleware_1, + middleware_2]) + +Internally, a single :ref:`request handler ` is constructed +by applying the middleware chain to the original handler in reverse order, +and is called by the :class:`RequestHandler` as a regular *handler*. + +Since *middlewares* are themselves coroutines, they may perform extra +``await`` calls when creating a new handler, e.g. call database etc. + +*Middlewares* usually call the handler, but they may choose to ignore it, +e.g. displaying *403 Forbidden page* or raising :exc:`HTTPForbidden` exception +if the user does not have permissions to access the underlying resource. +They may also render errors raised by the handler, perform some pre- or +post-processing like handling *CORS* and so on. + +The following code demonstrates middlewares execution order:: + + from aiohttp import web + + async def test(request): + print('Handler function called') + return web.Response(text="Hello") + + @web.middleware + async def middleware1(request, handler): + print('Middleware 1 called') + response = await handler(request) + print('Middleware 1 finished') + return response + + @web.middleware + async def middleware2(request, handler): + print('Middleware 2 called') + response = await handler(request) + print('Middleware 2 finished') + return response + + + app = web.Application(middlewares=[middleware1, middleware2]) + app.router.add_get('/', test) + web.run_app(app) + +Produced output:: + + Middleware 1 called + Middleware 2 called + Handler function called + Middleware 2 finished + Middleware 1 finished + +Example +^^^^^^^ + +A common use of middlewares is to implement custom error pages. The following +example will render 404 errors using a JSON response, as might be appropriate +a JSON REST service:: + + from aiohttp import web + + @web.middleware + async def error_middleware(request, handler): + try: + response = await handler(request) + if response.status != 404: + return response + message = response.message + except web.HTTPException as ex: + if ex.status != 404: + raise + message = ex.reason + return web.json_response({'error': message}) + + app = web.Application(middlewares=[error_middleware]) + + +Middleware Factory +^^^^^^^^^^^^^^^^^^ + +A *middleware factory* is a function that creates a middleware with passed arguments. For example, here's a trivial *middleware factory*:: + + def middleware_factory(text): + @middleware + async def sample_middleware(request, handler): + resp = await handler(request) + resp.text = resp.text + text + return resp + return sample_middleware + +Remember that contrary to regular middlewares you need the result of a middleware factory not the function itself. So when passing a middleware factory to an app you actually need to call it:: + + app = web.Application(middlewares=[middleware_factory(' wink')]) + +.. _aiohttp-web-signals: + +Signals +------- + +Although :ref:`middlewares ` can customize +:ref:`request handlers` before or after a :class:`Response` +has been prepared, they can't customize a :class:`Response` **while** it's +being prepared. For this :mod:`aiohttp.web` provides *signals*. + +For example, a middleware can only change HTTP headers for *unprepared* +responses (see :meth:`StreamResponse.prepare`), but sometimes we +need a hook for changing HTTP headers for streamed responses and WebSockets. +This can be accomplished by subscribing to the +:attr:`Application.on_response_prepare` signal, which is called after default +headers have been computed and directly before headers are sent:: + + async def on_prepare(request, response): + response.headers['My-Header'] = 'value' + + app.on_response_prepare.append(on_prepare) + + +Additionally, the :attr:`Application.on_startup` and +:attr:`Application.on_cleanup` signals can be subscribed to for +application component setup and tear down accordingly. + +The following example will properly initialize and dispose an aiopg connection +engine:: + + from aiopg.sa import create_engine + + async def create_aiopg(app): + app['pg_engine'] = await create_engine( + user='postgre', + database='postgre', + host='localhost', + port=5432, + password='' + ) + + async def dispose_aiopg(app): + app['pg_engine'].close() + await app['pg_engine'].wait_closed() + + app.on_startup.append(create_aiopg) + app.on_cleanup.append(dispose_aiopg) + + +Signal handlers should not return a value but may modify incoming mutable +parameters. + +Signal handlers will be run sequentially, in order they were +added. All handlers must be asynchronous since *aiohttp* 3.0. + +.. _aiohttp-web-cleanup-ctx: + +Cleanup Context +--------------- + +Bare :attr:`Application.on_startup` / :attr:`Application.on_cleanup` +pair still has a pitfall: signals handlers are independent on each other. + +E.g. we have ``[create_pg, create_redis]`` in *startup* signal and +``[dispose_pg, dispose_redis]`` in *cleanup*. + +If, for example, ``create_pg(app)`` call fails ``create_redis(app)`` +is not called. But on application cleanup both ``dispose_pg(app)`` and +``dispose_redis(app)`` are still called: *cleanup signal* has no +knowledge about startup/cleanup pairs and their execution state. + + +The solution is :attr:`Application.cleanup_ctx` usage:: + + async def pg_engine(app): + app['pg_engine'] = await create_engine( + user='postgre', + database='postgre', + host='localhost', + port=5432, + password='' + ) + yield + app['pg_engine'].close() + await app['pg_engine'].wait_closed() + + app.cleanup_ctx.append(pg_engine) + +The attribute is a list of *asynchronous generators*, a code *before* +``yield`` is an initialization stage (called on *startup*), a code +*after* ``yield`` is executed on *cleanup*. The generator must have only +one ``yield``. + +*aiohttp* guarantees that *cleanup code* is called if and only if +*startup code* was successfully finished. + +Asynchronous generators are supported by Python 3.6+, on Python 3.5 +please use `async_generator `_ +library. + +.. versionadded:: 3.1 + +.. _aiohttp-web-nested-applications: + +Nested applications +------------------- + +Sub applications are designed for solving the problem of the big +monolithic code base. +Let's assume we have a project with own business logic and tools like +administration panel and debug toolbar. + +Administration panel is a separate application by its own nature but all +toolbar URLs are served by prefix like ``/admin``. + +Thus we'll create a totally separate application named ``admin`` and +connect it to main app with prefix by +:meth:`Application.add_subapp`:: + + admin = web.Application() + # setup admin routes, signals and middlewares + + app.add_subapp('/admin/', admin) + +Middlewares and signals from ``app`` and ``admin`` are chained. + +It means that if URL is ``'/admin/something'`` middlewares from +``app`` are applied first and ``admin.middlewares`` are the next in +the call chain. + +The same is going for +:attr:`Application.on_response_prepare` signal -- the +signal is delivered to both top level ``app`` and ``admin`` if +processing URL is routed to ``admin`` sub-application. + +Common signals like :attr:`Application.on_startup`, +:attr:`Application.on_shutdown` and +:attr:`Application.on_cleanup` are delivered to all +registered sub-applications. The passed parameter is sub-application +instance, not top-level application. + + +Third level sub-applications can be nested into second level ones -- +there are no limitation for nesting level. + +Url reversing for sub-applications should generate urls with proper prefix. + +But for getting URL sub-application's router should be used:: + + admin = web.Application() + admin.add_routes([web.get('/resource', handler, name='name')]) + + app.add_subapp('/admin/', admin) + + url = admin.router['name'].url_for() + +The generated ``url`` from example will have a value +``URL('/admin/resource')``. + +If main application should do URL reversing for sub-application it could +use the following explicit technique:: + + admin = web.Application() + admin.add_routes([web.get('/resource', handler, name='name')]) + + app.add_subapp('/admin/', admin) + app['admin'] = admin + + async def handler(request): # main application's handler + admin = request.app['admin'] + url = admin.router['name'].url_for() + +.. _aiohttp-web-expect-header: + +*Expect* Header +--------------- + +:mod:`aiohttp.web` supports *Expect* header. By default it sends +``HTTP/1.1 100 Continue`` line to client, or raises +:exc:`HTTPExpectationFailed` if header value is not equal to +"100-continue". It is possible to specify custom *Expect* header +handler on per route basis. This handler gets called if *Expect* +header exist in request after receiving all headers and before +processing application's :ref:`aiohttp-web-middlewares` and +route handler. Handler can return *None*, in that case the request +processing continues as usual. If handler returns an instance of class +:class:`StreamResponse`, *request handler* uses it as response. Also +handler can raise a subclass of :exc:`HTTPException`. In this case all +further processing will not happen and client will receive appropriate +http response. + +.. note:: + A server that does not understand or is unable to comply with any of the + expectation values in the Expect field of a request MUST respond with + appropriate error status. The server MUST respond with a 417 + (Expectation Failed) status if any of the expectations cannot be met or, + if there are other problems with the request, some other 4xx status. + + http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.20 + +If all checks pass, the custom handler *must* write a *HTTP/1.1 100 Continue* +status code before returning. + +The following example shows how to setup a custom handler for the *Expect* +header:: + + async def check_auth(request): + if request.version != aiohttp.HttpVersion11: + return + + if request.headers.get('EXPECT') != '100-continue': + raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect) + + if request.headers.get('AUTHORIZATION') is None: + raise HTTPForbidden() + + request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") + + async def hello(request): + return web.Response(body=b"Hello, world") + + app = web.Application() + app.add_routes([web.add_get('/', hello, expect_handler=check_auth)]) + +.. _aiohttp-web-custom-resource: + +Custom resource implementation +------------------------------ + +To register custom resource use :meth:`UrlDispatcher.register_resource`. +Resource instance must implement `AbstractResource` interface. + +.. _aiohttp-web-app-runners: + +Application runners +------------------- + +:func:`run_app` provides a simple *blocking* API for running an +:class:`Application`. + +For starting the application *asynchronously* or serving on multiple +HOST/PORT :class:`AppRunner` exists. + +The simple startup code for serving HTTP site on ``'localhost'``, port +``8080`` looks like:: + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 8080) + await site.start() + + while True: + await asyncio.sleep(3600) # sleep forever + +To stop serving call :meth:`AppRunner.cleanup`:: + + await runner.cleanup() + +.. versionadded:: 3.0 + +.. _aiohttp-web-graceful-shutdown: + +Graceful shutdown +------------------ + +Stopping *aiohttp web server* by just closing all connections is not +always satisfactory. + +The problem is: if application supports :term:`websocket`\s or *data +streaming* it most likely has open connections at server +shutdown time. + +The *library* has no knowledge how to close them gracefully but +developer can help by registering :attr:`Application.on_shutdown` +signal handler and call the signal on *web server* closing. + +Developer should keep a list of opened connections +(:class:`Application` is a good candidate). + +The following :term:`websocket` snippet shows an example for websocket +handler:: + + from aiohttp import web + import weakref + + app = web.Application() + app['websockets'] = weakref.WeakSet() + + async def websocket_handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + request.app['websockets'].add(ws) + try: + async for msg in ws: + ... + finally: + request.app['websockets'].discard(ws) + + return ws + +Signal handler may look like:: + + from aiohttp import WSCloseCode + + async def on_shutdown(app): + for ws in set(app['websockets']): + await ws.close(code=WSCloseCode.GOING_AWAY, + message='Server shutdown') + + app.on_shutdown.append(on_shutdown) + +Both :func:`run_app` and :meth:`AppRunner.cleanup` call shutdown +signal handlers. + +.. _aiohttp-web-background-tasks: + +Background tasks +----------------- + +Sometimes there's a need to perform some asynchronous operations just +after application start-up. + +Even more, in some sophisticated systems there could be a need to run some +background tasks in the event loop along with the application's request +handler. Such as listening to message queue or other network message/event +sources (e.g. ZeroMQ, Redis Pub/Sub, AMQP, etc.) to react to received messages +within the application. + +For example the background task could listen to ZeroMQ on +:data:`zmq.SUB` socket, process and forward retrieved messages to +clients connected via WebSocket that are stored somewhere in the +application (e.g. in the :obj:`application['websockets']` list). + +To run such short and long running background tasks aiohttp provides an +ability to register :attr:`Application.on_startup` signal handler(s) that +will run along with the application's request handler. + +For example there's a need to run one quick task and two long running +tasks that will live till the application is alive. The appropriate +background tasks could be registered as an :attr:`Application.on_startup` +signal handlers as shown in the example below:: + + + async def listen_to_redis(app): + try: + sub = await aioredis.create_redis(('localhost', 6379)) + ch, *_ = await sub.subscribe('news') + async for msg in ch.iter(encoding='utf-8'): + # Forward message to all connected websockets: + for ws in app['websockets']: + ws.send_str('{}: {}'.format(ch.name, msg)) + except asyncio.CancelledError: + pass + finally: + await sub.unsubscribe(ch.name) + await sub.quit() + + + async def start_background_tasks(app): + app['redis_listener'] = asyncio.create_task(listen_to_redis(app)) + + + async def cleanup_background_tasks(app): + app['redis_listener'].cancel() + await app['redis_listener'] + + + app = web.Application() + app.on_startup.append(start_background_tasks) + app.on_cleanup.append(cleanup_background_tasks) + web.run_app(app) + + +The task :func:`listen_to_redis` will run forever. +To shut it down correctly :attr:`Application.on_cleanup` signal handler +may be used to send a cancellation to it. + +Handling error pages +-------------------- + +Pages like *404 Not Found* and *500 Internal Error* could be handled +by custom middleware, see :ref:`polls demo ` +for example. + +.. _aiohttp-web-forwarded-support: + +Deploying behind a Proxy +------------------------ + +As discussed in :ref:`aiohttp-deployment` the preferable way is +deploying *aiohttp* web server behind a *Reverse Proxy Server* like +:term:`nginx` for production usage. + +In this way properties like :attr:`BaseRequest.scheme` +:attr:`BaseRequest.host` and :attr:`BaseRequest.remote` are +incorrect. + +Real values should be given from proxy server, usually either +``Forwarded`` or old-fashion ``X-Forwarded-For``, +``X-Forwarded-Host``, ``X-Forwarded-Proto`` HTTP headers are used. + +*aiohttp* does not take *forwarded* headers into account by default +because it produces *security issue*: HTTP client might add these +headers too, pushing non-trusted data values. + +That's why *aiohttp server* should setup *forwarded* headers in custom +middleware in tight conjunction with *reverse proxy configuration*. + +For changing :attr:`BaseRequest.scheme` :attr:`BaseRequest.host` and +:attr:`BaseRequest.remote` the middleware might use +:meth:`BaseRequest.clone`. + +.. seealso:: + + https://github.com/aio-libs/aiohttp-remotes provides secure helpers + for modifying *scheme*, *host* and *remote* attributes according + to ``Forwarded`` and ``X-Forwarded-*`` HTTP headers. + +Swagger support +--------------- + +`aiohttp-swagger `_ is a +library that allow to add Swagger documentation and embed the +Swagger-UI into your :mod:`aiohttp.web` project. + +CORS support +------------ + +:mod:`aiohttp.web` itself does not support `Cross-Origin Resource +Sharing `_, but +there is an aiohttp plugin for it: +`aiohttp_cors `_. + + +Debug Toolbar +------------- + +`aiohttp-debugtoolbar`_ is a very useful library that provides a +debugging toolbar while you're developing an :mod:`aiohttp.web` +application. + +Install it with ``pip``: + +.. code-block:: shell + + $ pip install aiohttp_debugtoolbar + + +Just call :func:`aiohttp_debugtoolbar.setup`:: + + import aiohttp_debugtoolbar + from aiohttp_debugtoolbar import toolbar_middleware_factory + + app = web.Application() + aiohttp_debugtoolbar.setup(app) + +The toolbar is ready to use. Enjoy!!! + +.. _aiohttp-debugtoolbar: https://github.com/aio-libs/aiohttp_debugtoolbar + + +Dev Tools +--------- + +`aiohttp-devtools`_ provides a couple of tools to simplify development of +:mod:`aiohttp.web` applications. + + +Install with ``pip``: + +.. code-block:: shell + + $ pip install aiohttp-devtools + +* ``runserver`` provides a development server with auto-reload, + live-reload, static file serving and `aiohttp-debugtoolbar`_ + integration. +* ``start`` is a `cookiecutter command which does the donkey work + of creating new :mod:`aiohttp.web` Applications. + +Documentation and a complete tutorial of creating and running an app +locally are available at `aiohttp-devtools`_. + +.. _aiohttp-devtools: https://github.com/aio-libs/aiohttp-devtools diff --git a/docs/web_lowlevel.rst b/docs/web_lowlevel.rst index 14960f000fc..696c58d38e1 100644 --- a/docs/web_lowlevel.rst +++ b/docs/web_lowlevel.rst @@ -30,8 +30,7 @@ parameter and performs one of the following actions: 2. Create a :class:`StreamResponse`, send headers by :meth:`StreamResponse.prepare` call, send data chunks by - :meth:`StreamResponse.write` / :meth:`StreamResponse.drain`, - return finished response. + :meth:`StreamResponse.write` and return finished response. 3. Raise :class:`HTTPException` derived exception (see :ref:`aiohttp-web-exceptions` section). @@ -56,9 +55,13 @@ The following code demonstrates very trivial usage example:: return web.Response(text="OK") - async def main(loop): + async def main(): server = web.Server(handler) - await loop.create_server(server, "127.0.0.1", 8080) + runner = web.ServerRunner(server) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 8080) + await site.start() + print("======= Serving on http://127.0.0.1:8080/ ======") # pause here for very long time by serving HTTP requests and @@ -69,7 +72,7 @@ The following code demonstrates very trivial usage example:: loop = asyncio.get_event_loop() try: - loop.run_until_complete(main(loop)) + loop.run_until_complete(main()) except KeyboardInterrupt: pass loop.close() @@ -80,14 +83,11 @@ In the snippet we have ``handler`` which returns a regular This *handler* is processed by ``server`` (:class:`Server` which acts as *protocol factory*). Network communication is created by -``loop.create_server`` call to serve ``http://127.0.0.1:8080/``. +:ref:`runners API ` to serve +``http://127.0.0.1:8080/``. -The handler should process every request: ``GET``, ``POST``, -Web-Socket for every *path*. +The handler should process every request for every *path*, e.g. +``GET``, ``POST``, Web-Socket. The example is very basic: it always return ``200 OK`` response, real -life code should be much more complex. - - -.. disqus:: - :title: aiohttp.web low-level server +life code is much more complex usually. diff --git a/docs/web_quickstart.rst b/docs/web_quickstart.rst new file mode 100644 index 00000000000..1db1d6823e7 --- /dev/null +++ b/docs/web_quickstart.rst @@ -0,0 +1,759 @@ +.. _aiohttp-web-quickstart: + +Web Server Quickstart +===================== + +.. currentmodule:: aiohttp.web + + +Run a Simple Web Server +----------------------- + +In order to implement a web server, first create a +:ref:`request handler `. + +A request handler must be a :ref:`coroutine ` that +accepts a :class:`Request` instance as its only parameter and returns a +:class:`Response` instance:: + + from aiohttp import web + + async def hello(request): + return web.Response(text="Hello, world") + +Next, create an :class:`Application` instance and register the +request handler on a particular *HTTP method* and *path*:: + + app = web.Application() + app.add_routes([web.get('/', hello)]) + +After that, run the application by :func:`run_app` call:: + + web.run_app(app) + +That's it. Now, head over to ``http://localhost:8080/`` to see the results. + +Alternatively if you prefer *route decorators* create a *route table* +and register a :term:`web-handler`:: + + routes = web.RouteTableDef() + + @routes.get('/') + async def hello(request): + return web.Response(text="Hello, world") + + app = web.Application() + app.add_routes(routes) + web.run_app(app) + +Both ways essentially do the same work, the difference is only in your +taste: do you prefer *Django style* with famous ``urls.py`` or *Flask* +with shiny route decorators. + +*aiohttp* server documentation uses both ways in code snippets to +emphasize their equality, switching from one style to another is very +trivial. + +.. seealso:: + + :ref:`aiohttp-web-graceful-shutdown` section explains what :func:`run_app` + does and how to implement complex server initialization/finalization + from scratch. + + :ref:`aiohttp-web-app-runners` for more handling more complex cases + like *asynchronous* web application serving and multiple hosts + support. + +.. _aiohttp-web-cli: + +Command Line Interface (CLI) +---------------------------- +:mod:`aiohttp.web` implements a basic CLI for quickly serving an +:class:`Application` in *development* over TCP/IP: + +.. code-block:: shell + + $ python -m aiohttp.web -H localhost -P 8080 package.module:init_func + +``package.module:init_func`` should be an importable :term:`callable` that +accepts a list of any non-parsed command-line arguments and returns an +:class:`Application` instance after setting it up:: + + def init_func(argv): + app = web.Application() + app.router.add_get("/", index_handler) + return app + + +.. _aiohttp-web-handler: + +Handler +------- + +A request handler must be a :ref:`coroutine` that accepts a +:class:`Request` instance as its only argument and returns a +:class:`StreamResponse` derived (e.g. :class:`Response`) instance:: + + async def handler(request): + return web.Response() + +Handlers are setup to handle requests by registering them with the +:meth:`Application.add_routes` on a particular route (*HTTP method* and +*path* pair) using helpers like :func:`get` and +:func:`post`:: + + app.add_routes([web.get('/', handler), + web.post('/post', post_handler), + web.put('/put', put_handler)]) + +Or use *route decorators*:: + + routes = web.RouteTableDef() + + @routes.get('/') + async def get_handler(request): + ... + + @routes.post('/post') + async def post_handler(request): + ... + + @routes.put('/put') + async def put_handler(request): + ... + + app.add_routes(routes) + + +Wildcard *HTTP method* is also supported by :func:`route` or +:meth:`RouteTableDef.route`, allowing a handler to serve incoming +requests on a *path* having **any** *HTTP method*:: + + app.add_routes([web.route('*', '/path', all_handler)]) + +The *HTTP method* can be queried later in the request handler using the +:attr:`Request.method` property. + +By default endpoints added with ``GET`` method will accept +``HEAD`` requests and return the same response headers as they would +for a ``GET`` request. You can also deny ``HEAD`` requests on a route:: + + web.get('/', handler, allow_head=False) + +Here ``handler`` won't be called on ``HEAD`` request and the server +will respond with ``405: Method Not Allowed``. + +.. _aiohttp-web-resource-and-route: + +Resources and Routes +-------------------- + +Internally routes are served by :attr:`Application.router` +(:class:`UrlDispatcher` instance). + +The *router* is a list of *resources*. + +Resource is an entry in *route table* which corresponds to requested URL. + +Resource in turn has at least one *route*. + +Route corresponds to handling *HTTP method* by calling *web handler*. + +Thus when you add a *route* the *resouce* object is created under the hood. + +The library implementation **merges** all subsequent route additions +for the same path adding the only resource for all HTTP methods. + +Consider two examples:: + + app.add_routes([web.get('/path1', get_1), + web.post('/path1', post_1), + web.get('/path2', get_2), + web.post('/path2', post_2)] + +and:: + + app.add_routes([web.get('/path1', get_1), + web.get('/path2', get_2), + web.post('/path2', post_2), + web.post('/path1', post_1)] + +First one is *optimized*. You have got the idea. + +.. _aiohttp-web-variable-handler: + +Variable Resources +^^^^^^^^^^^^^^^^^^ + +Resource may have *variable path* also. For instance, a resource with +the path ``'/a/{name}/c'`` would match all incoming requests with +paths such as ``'/a/b/c'``, ``'/a/1/c'``, and ``'/a/etc/c'``. + +A variable *part* is specified in the form ``{identifier}``, where the +``identifier`` can be used later in a +:ref:`request handler ` to access the matched value for +that *part*. This is done by looking up the ``identifier`` in the +:attr:`Request.match_info` mapping:: + + @routes.get('/{name}') + async def variable_handler(request): + return web.Response( + text="Hello, {}".format(request.match_info['name'])) + +By default, each *part* matches the regular expression ``[^{}/]+``. + +You can also specify a custom regex in the form ``{identifier:regex}``:: + + web.get(r'/{name:\d+}', handler) + + +.. _aiohttp-web-named-routes: + +Reverse URL Constructing using Named Resources +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Routes can also be given a *name*:: + + @routes.get('/root', name='root') + async def handler(request): + ... + +Which can then be used to access and build a *URL* for that resource later (e.g. +in a :ref:`request handler `):: + + url = request.app.router['root'].url_for().with_query({"a": "b", "c": "d"}) + assert url == URL('/root?a=b&c=d') + +A more interesting example is building *URLs* for :ref:`variable +resources `:: + + app.router.add_resource(r'/{user}/info', name='user-info') + + +In this case you can also pass in the *parts* of the route:: + + url = request.app.router['user-info'].url_for(user='john_doe') + url_with_qs = url.with_query("a=b") + assert url_with_qs == '/john_doe/info?a=b' + + +Organizing Handlers in Classes +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +As discussed above, :ref:`handlers ` can be first-class +coroutines:: + + async def hello(request): + return web.Response(text="Hello, world") + + app.router.add_get('/', hello) + +But sometimes it's convenient to group logically similar handlers into a Python +*class*. + +Since :mod:`aiohttp.web` does not dictate any implementation details, +application developers can organize handlers in classes if they so wish:: + + class Handler: + + def __init__(self): + pass + + async def handle_intro(self, request): + return web.Response(text="Hello, world") + + async def handle_greeting(self, request): + name = request.match_info.get('name', "Anonymous") + txt = "Hello, {}".format(name) + return web.Response(text=txt) + + handler = Handler() + app.add_routes([web.get('/intro', handler.handle_intro), + web.get('/greet/{name}', handler.handle_greeting)]) + + +.. _aiohttp-web-class-based-views: + +Class Based Views +^^^^^^^^^^^^^^^^^ + +:mod:`aiohttp.web` has support for *class based views*. + +You can derive from :class:`View` and define methods for handling http +requests:: + + class MyView(web.View): + async def get(self): + return await get_resp(self.request) + + async def post(self): + return await post_resp(self.request) + +Handlers should be coroutines accepting *self* only and returning +response object as regular :term:`web-handler`. Request object can be +retrieved by :attr:`View.request` property. + +After implementing the view (``MyView`` from example above) should be +registered in application's router:: + + web.view('/path/to', MyView) + +or:: + + @routes.view('/path/to') + class MyView(web.View): + ... + +Example will process GET and POST requests for */path/to* but raise +*405 Method not allowed* exception for unimplemented HTTP methods. + +Resource Views +^^^^^^^^^^^^^^ + +*All* registered resources in a router can be viewed using the +:meth:`UrlDispatcher.resources` method:: + + for resource in app.router.resources(): + print(resource) + +A *subset* of the resources that were registered with a *name* can be +viewed using the :meth:`UrlDispatcher.named_resources` method:: + + for name, resource in app.router.named_resources().items(): + print(name, resource) + + +.. _aiohttp-web-alternative-routes-definition: + +Alternative ways for registering routes +--------------------------------------- + +Code examples shown above use *imperative* style for adding new +routes: they call ``app.router.add_get(...)`` etc. + +There are two alternatives: route tables and route decorators. + +Route tables look like Django way:: + + async def handle_get(request): + ... + + + async def handle_post(request): + ... + + app.router.add_routes([web.get('/get', handle_get), + web.post('/post', handle_post), + + +The snippet calls :meth:`~aiohttp.web.UrlDispather.add_routes` to +register a list of *route definitions* (:class:`aiohttp.web.RouteDef` +instances) created by :func:`aiohttp.web.get` or +:func:`aiohttp.web.post` functions. + +.. seealso:: :ref:`aiohttp-web-route-def` reference. + +Route decorators are closer to Flask approach:: + + routes = web.RouteTableDef() + + @routes.get('/get') + async def handle_get(request): + ... + + + @routes.post('/post') + async def handle_post(request): + ... + + app.router.add_routes(routes) + +It is also possible to use decorators with class-based views:: + + routes = web.RouteTableDef() + + @routes.view("/view") + class MyView(web.View): + async def get(self): + ... + + async def post(self): + ... + + app.router.add_routes(routes) + +The example creates a :class:`aiohttp.web.RouteTableDef` container first. + +The container is a list-like object with additional decorators +:meth:`aiohttp.web.RouteTableDef.get`, +:meth:`aiohttp.web.RouteTableDef.post` etc. for registering new +routes. + +After filling the container +:meth:`~aiohttp.web.UrlDispather.add_routes` is used for adding +registered *route definitions* into application's router. + +.. seealso:: :ref:`aiohttp-web-route-table-def` reference. + +All tree ways (imperative calls, route tables and decorators) are +equivalent, you could use what do you prefer or even mix them on your +own. + +.. versionadded:: 2.3 + + +JSON Response +------------- + +It is a common case to return JSON data in response, :mod:`aiohttp.web` +provides a shortcut for returning JSON -- :func:`aiohttp.web.json_response`:: + + async def handler(request): + data = {'some': 'data'} + return web.json_response(data) + +The shortcut method returns :class:`aiohttp.web.Response` instance +so you can for example set cookies before returning it from handler. + + +User Sessions +------------- + +Often you need a container for storing user data across requests. The concept +is usually called a *session*. + +:mod:`aiohttp.web` has no built-in concept of a *session*, however, there is a +third-party library, :mod:`aiohttp_session`, that adds *session* support:: + + import asyncio + import time + import base64 + from cryptography import fernet + from aiohttp import web + from aiohttp_session import setup, get_session, session_middleware + from aiohttp_session.cookie_storage import EncryptedCookieStorage + + async def handler(request): + session = await get_session(request) + last_visit = session['last_visit'] if 'last_visit' in session else None + text = 'Last visited: {}'.format(last_visit) + return web.Response(text=text) + + async def make_app(): + app = web.Application() + # secret_key must be 32 url-safe base64-encoded bytes + fernet_key = fernet.Fernet.generate_key() + secret_key = base64.urlsafe_b64decode(fernet_key) + setup(app, EncryptedCookieStorage(secret_key)) + app.add_routes([web.get('/', handler)]) + return app + + web.run_app(make_app()) + + +.. _aiohttp-web-forms: + +HTTP Forms +---------- + +HTTP Forms are supported out of the box. + +If form's method is ``"GET"`` (``
    ``) use +:attr:`Request.query` for getting form data. + +To access form data with ``"POST"`` method use +:meth:`Request.post` or :meth:`Request.multipart`. + +:meth:`Request.post` accepts both +``'application/x-www-form-urlencoded'`` and ``'multipart/form-data'`` +form's data encoding (e.g. ````). +It stores files data in temporary directory. If `client_max_size` is +specified `post` raises `ValueError` exception. +For efficiency use :meth:`Request.multipart`, It is especially effective +for uploading large files (:ref:`aiohttp-web-file-upload`). + +Values submitted by the following form: + +.. code-block:: html + + + + + + + + + +
    + +could be accessed as:: + + async def do_login(request): + data = await request.post() + login = data['login'] + password = data['password'] + + +.. _aiohttp-web-file-upload: + +File Uploads +------------ + +:mod:`aiohttp.web` has built-in support for handling files uploaded from the +browser. + +First, make sure that the HTML ``
    `` element has its *enctype* attribute +set to ``enctype="multipart/form-data"``. As an example, here is a form that +accepts an MP3 file: + +.. code-block:: html + + + + + + + +
    + +Then, in the :ref:`request handler ` you can access the +file input field as a :class:`FileField` instance. :class:`FileField` is simply +a container for the file as well as some of its metadata:: + + async def store_mp3_handler(request): + + # WARNING: don't do that if you plan to receive large files! + data = await request.post() + + mp3 = data['mp3'] + + # .filename contains the name of the file in string format. + filename = mp3.filename + + # .file contains the actual file data that needs to be stored somewhere. + mp3_file = data['mp3'].file + + content = mp3_file.read() + + return web.Response(body=content, + headers=MultiDict( + {'CONTENT-DISPOSITION': mp3_file})) + + +You might have noticed a big warning in the example above. The general issue is +that :meth:`Request.post` reads the whole payload in memory, +resulting in possible +:abbr:`OOM (Out Of Memory)` errors. To avoid this, for multipart uploads, you +should use :meth:`Request.multipart` which returns a :ref:`multipart reader +`:: + + async def store_mp3_handler(request): + + reader = await request.multipart() + + # /!\ Don't forget to validate your inputs /!\ + + # reader.next() will `yield` the fields of your form + + field = await reader.next() + assert field.name == 'name' + name = await field.read(decode=True) + + field = await reader.next() + assert field.name == 'mp3' + filename = field.filename + # You cannot rely on Content-Length if transfer is chunked. + size = 0 + with open(os.path.join('/spool/yarrr-media/mp3/', filename), 'wb') as f: + while True: + chunk = await field.read_chunk() # 8192 bytes by default. + if not chunk: + break + size += len(chunk) + f.write(chunk) + + return web.Response(text='{} sized of {} successfully stored' + ''.format(filename, size)) + +.. _aiohttp-web-websockets: + +WebSockets +---------- + +:mod:`aiohttp.web` supports *WebSockets* out-of-the-box. + +To setup a *WebSocket*, create a :class:`WebSocketResponse` in a +:ref:`request handler ` and then use it to communicate +with the peer:: + + async def websocket_handler(request): + + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == 'close': + await ws.close() + else: + await ws.send_str(msg.data + '/answer') + elif msg.type == aiohttp.WSMsgType.ERROR: + print('ws connection closed with exception %s' % + ws.exception()) + + print('websocket connection closed') + + return ws + +The handler should be registered as HTTP GET processor:: + + app.add_routes([web.get('/ws', websocket_handler)]) + +.. _aiohttp-web-redirects: + +Redirects +--------- + +To redirect user to another endpoint - raise :class:`HTTPFound` with +an absolute URL, relative URL or view name (the argument from router):: + + raise web.HTTPFound('/redirect') + +The following example shows redirect to view named 'login' in routes:: + + async def handler(request): + location = request.app.router['login'].url_for() + raise web.HTTPFound(location=location) + + router.add_get('/handler', handler) + router.add_get('/login', login_handler, name='login') + +Example with login validation:: + + @aiohttp_jinja2.template('login.html') + async def login(request): + + if request.method == 'POST': + form = await request.post() + error = validate_login(form) + if error: + return {'error': error} + else: + # login form is valid + location = request.app.router['index'].url_for() + raise web.HTTPFound(location=location) + + return {} + + app.router.add_get('/', index, name='index') + app.router.add_get('/login', login, name='login') + app.router.add_post('/login', login, name='login') + +.. _aiohttp-web-exceptions: + +Exceptions +---------- + +:mod:`aiohttp.web` defines a set of exceptions for every *HTTP status code*. + +Each exception is a subclass of :class:`~HTTPException` and relates to a single +HTTP status code:: + + async def handler(request): + raise aiohttp.web.HTTPFound('/redirect') + +.. warning:: + + Returning :class:`~HTTPException` or its subclasses is deprecated and will + be removed in subsequent aiohttp versions. + +Each exception class has a status code according to :rfc:`2068`: +codes with 100-300 are not really errors; 400s are client errors, +and 500s are server errors. + +HTTP Exception hierarchy chart:: + + Exception + HTTPException + HTTPSuccessful + * 200 - HTTPOk + * 201 - HTTPCreated + * 202 - HTTPAccepted + * 203 - HTTPNonAuthoritativeInformation + * 204 - HTTPNoContent + * 205 - HTTPResetContent + * 206 - HTTPPartialContent + HTTPRedirection + * 300 - HTTPMultipleChoices + * 301 - HTTPMovedPermanently + * 302 - HTTPFound + * 303 - HTTPSeeOther + * 304 - HTTPNotModified + * 305 - HTTPUseProxy + * 307 - HTTPTemporaryRedirect + * 308 - HTTPPermanentRedirect + HTTPError + HTTPClientError + * 400 - HTTPBadRequest + * 401 - HTTPUnauthorized + * 402 - HTTPPaymentRequired + * 403 - HTTPForbidden + * 404 - HTTPNotFound + * 405 - HTTPMethodNotAllowed + * 406 - HTTPNotAcceptable + * 407 - HTTPProxyAuthenticationRequired + * 408 - HTTPRequestTimeout + * 409 - HTTPConflict + * 410 - HTTPGone + * 411 - HTTPLengthRequired + * 412 - HTTPPreconditionFailed + * 413 - HTTPRequestEntityTooLarge + * 414 - HTTPRequestURITooLong + * 415 - HTTPUnsupportedMediaType + * 416 - HTTPRequestRangeNotSatisfiable + * 417 - HTTPExpectationFailed + * 421 - HTTPMisdirectedRequest + * 422 - HTTPUnprocessableEntity + * 424 - HTTPFailedDependency + * 426 - HTTPUpgradeRequired + * 428 - HTTPPreconditionRequired + * 429 - HTTPTooManyRequests + * 431 - HTTPRequestHeaderFieldsTooLarge + * 451 - HTTPUnavailableForLegalReasons + HTTPServerError + * 500 - HTTPInternalServerError + * 501 - HTTPNotImplemented + * 502 - HTTPBadGateway + * 503 - HTTPServiceUnavailable + * 504 - HTTPGatewayTimeout + * 505 - HTTPVersionNotSupported + * 506 - HTTPVariantAlsoNegotiates + * 507 - HTTPInsufficientStorage + * 510 - HTTPNotExtended + * 511 - HTTPNetworkAuthenticationRequired + +All HTTP exceptions have the same constructor signature:: + + HTTPNotFound(*, headers=None, reason=None, + body=None, text=None, content_type=None) + +If not directly specified, *headers* will be added to the *default +response headers*. + +Classes :class:`HTTPMultipleChoices`, :class:`HTTPMovedPermanently`, +:class:`HTTPFound`, :class:`HTTPSeeOther`, :class:`HTTPUseProxy`, +:class:`HTTPTemporaryRedirect` have the following constructor signature:: + + HTTPFound(location, *, headers=None, reason=None, + body=None, text=None, content_type=None) + +where *location* is value for *Location HTTP header*. + +:class:`HTTPMethodNotAllowed` is constructed by providing the incoming +unsupported method and list of allowed methods:: + + HTTPMethodNotAllowed(method, allowed_methods, *, + headers=None, reason=None, + body=None, text=None, content_type=None) diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 7c0a43effb0..4073eb21321 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -3,8 +3,6 @@ Server Reference ================ -.. module:: aiohttp.web - .. currentmodule:: aiohttp.web .. _aiohttp-web-request: @@ -16,11 +14,11 @@ Request and Base Request The Request object contains all the information about an incoming HTTP request. :class:`BaseRequest` is used for :ref:`Low-Level -Servers` (which have no applications, routers, signals -and middlewares) and :class:`Request` has an *application* and *match -info* attributes. +Servers` (which have no applications, routers, +signals and middlewares). :class:`Request` has an :attr:`Request.app` +and :attr:`Request.match_info` attributes. -A :class:`BaseRequest`/:class:`Request` are :obj:`dict`-like objects, +A :class:`BaseRequest` / :class:`Request` are :obj:`dict` like objects, allowing them to be used for :ref:`sharing data` among :ref:`aiohttp-web-middlewares` and :ref:`aiohttp-web-signals` handlers. @@ -67,23 +65,89 @@ and :ref:`aiohttp-web-signals` handlers. A string representing the scheme of the request. The scheme is ``'https'`` if transport for request handling is - *SSL* or ``secure_proxy_ssl_header`` is matching. + *SSL*, ``'http'`` otherwise. - ``'http'`` otherwise. + The value could be overridden by :meth:`~BaseRequest.clone`. Read-only :class:`str` property. - .. seealso:: :meth:`Application.make_handler` + .. versionchanged:: 2.3 + + *Forwarded* and *X-Forwarded-Proto* are not used anymore. + + Call ``.clone(scheme=new_scheme)`` for setting up the value + explicitly. + + .. seealso:: :ref:`aiohttp-web-forwarded-support` + + .. attribute:: secure + + Shorthand for ``request.url.scheme == 'https'`` + + Read-only :class:`bool` property. + + .. seealso:: :attr:`scheme` + + .. attribute:: forwarded + + A tuple containing all parsed Forwarded header(s). + + Makes an effort to parse Forwarded headers as specified by :rfc:`7239`: - .. deprecated:: 1.1 + - It adds one (immutable) dictionary per Forwarded ``field-value``, i.e. + per proxy. The element corresponds to the data in the Forwarded + ``field-value`` added by the first proxy encountered by the client. + Each subsequent item corresponds to those added by later proxies. + - It checks that every value has valid syntax in general as specified + in :rfc:`7239#section-4`: either a ``token`` or a ``quoted-string``. + - It un-escapes ``quoted-pairs``. + - It does NOT validate 'by' and 'for' contents as specified in + :rfc:`7239#section-6`. + - It does NOT validate ``host`` contents (Host ABNF). + - It does NOT validate ``proto`` contents for valid URI scheme names. - Use :attr:`url` (``request.url.scheme``) instead. + Returns a tuple containing one or more ``MappingProxy`` objects + + .. seealso:: :attr:`scheme` + + .. seealso:: :attr:`host` .. attribute:: host - *HOST* header of request, Read-only property. + Host name of the request, resolved in this order: + + - Overridden value by :meth:`~BaseRequest.clone` call. + - *Host* HTTP header + - :func:`socket.gtfqdn` + + Read-only :class:`str` property. + + .. versionchanged:: 2.3 + + *Forwarded* and *X-Forwarded-Host* are not used anymore. + + Call ``.clone(host=new_host)`` for setting up the value + explicitly. + + .. seealso:: :ref:`aiohttp-web-forwarded-support` + + .. attribute:: remote + + Originating IP address of a client initiated HTTP request. + + The IP is resolved through the following headers, in this order: - Returns :class:`str` or ``None`` if HTTP request has no *HOST* header. + - Overridden value by :meth:`~BaseRequest.clone` call. + - Peer name of opened socket. + + Read-only :class:`str` property. + + Call ``.clone(remote=new_remote)`` for setting up the value + explicitly. + + .. versionadded:: 2.3 + + .. seealso:: :ref:`aiohttp-web-forwarded-support` .. attribute:: path_qs @@ -95,7 +159,7 @@ and :ref:`aiohttp-web-signals` handlers. .. attribute:: path The URL including *PATH INFO* without the host or scheme. e.g., - ``/app/blog``. The path is URL-unquoted. For raw path info see + ``/app/blog``. The path is URL-decoded. For raw path info see :attr:`raw_path`. Read-only :class:`str` property. @@ -103,10 +167,11 @@ and :ref:`aiohttp-web-signals` handlers. .. attribute:: raw_path The URL including raw *PATH INFO* without the host or scheme. - Warning, the path may be quoted and may contains non valid URL characters, e.g. + Warning, the path may be URL-encoded and may contain invalid URL + characters, e.g. ``/my%2Fpath%7Cwith%21some%25strange%24characters``. - For unquoted version please take a look on :attr:`path`. + For URL-decoded version please take a look on :attr:`path`. Read-only :class:`str` property. @@ -142,7 +207,7 @@ and :ref:`aiohttp-web-signals` handlers. .. attribute:: transport - An :ref:`transport` used to process request, + A :ref:`transport` used to process request. Read-only property. The property can be used, for example, for getting IP address of @@ -152,6 +217,14 @@ and :ref:`aiohttp-web-signals` handlers. if peername is not None: host, port = peername + .. attribute:: loop + + An event loop instance used by HTTP request handling. + + Read-only :class:`asyncio.AbstractEventLoop` property. + + .. deprecated:: 3.5 + .. attribute:: cookies A multidict of all request's cookies. @@ -165,13 +238,31 @@ and :ref:`aiohttp-web-signals` handlers. Read-only property. - .. attribute:: has_body + .. attribute:: body_exists Return ``True`` if request has *HTTP BODY*, ``False`` otherwise. Read-only :class:`bool` property. - .. versionadded:: 0.16 + .. versionadded:: 2.3 + + .. attribute:: can_read_body + + Return ``True`` if request's *HTTP BODY* can be read, ``False`` otherwise. + + Read-only :class:`bool` property. + + .. versionadded:: 2.3 + + .. attribute:: has_body + + Return ``True`` if request's *HTTP BODY* can be read, ``False`` otherwise. + + Read-only :class:`bool` property. + + .. deprecated:: 2.3 + + Use :meth:`can_read_body` instead. .. attribute:: content_type @@ -225,8 +316,6 @@ and :ref:`aiohttp-web-signals` handlers. return buffer[request.http_range] - .. versionadded:: 1.2 - .. attribute:: if_modified_since Read-only property that returns the date specified in the @@ -236,6 +325,28 @@ and :ref:`aiohttp-web-signals` handlers. *If-Modified-Since* header is absent or is not a valid HTTP date. + .. attribute:: if_unmodified_since + + Read-only property that returns the date specified in the + *If-Unmodified-Since* header. + + Returns :class:`datetime.datetime` or ``None`` if + *If-Unmodified-Since* header is absent or is not a valid + HTTP date. + + .. versionadded:: 3.1 + + .. attribute:: if_range + + Read-only property that returns the date specified in the + *If-Range* header. + + Returns :class:`datetime.datetime` or ``None`` if + *If-Range* header is absent or is not a valid + HTTP date. + + .. versionadded:: 3.1 + .. method:: clone(*, method=..., rel_url=..., headers=...) Clone itself with replacement some attributes. @@ -248,12 +359,24 @@ and :ref:`aiohttp-web-signals` handlers. :param rel_url: url to use, :class:`str` or :class:`~yarl.URL` - :param headers: :class:`~multidict.CIMultidict` or compatible + :param headers: :class:`~multidict.CIMultiDict` or compatible headers container. :return: a cloned :class:`Request` instance. - .. coroutinemethod:: read() + .. method:: get_extra_info(name, default=None) + + Reads extra information from the protocol's transport. + If no value associated with ``name`` is found, ``default`` is returned. + + :param str name: The key to look up in the transport extra information. + + :param default: Default value to be used when no value for ``name`` is + found (default is ``None``). + + .. versionadded:: 3.7 + + .. comethod:: read() Read request body, returns :class:`bytes` object with body content. @@ -262,7 +385,7 @@ and :ref:`aiohttp-web-signals` handlers. The method **does** store read data internally, subsequent :meth:`~Request.read` call will return the same value. - .. coroutinemethod:: text() + .. comethod:: text() Read request body, decode it using :attr:`charset` encoding or ``UTF-8`` if no encoding was specified in *MIME-type*. @@ -274,7 +397,7 @@ and :ref:`aiohttp-web-signals` handlers. The method **does** store read data internally, subsequent :meth:`~Request.text` call will return the same value. - .. coroutinemethod:: json(*, loads=json.loads) + .. comethod:: json(*, loads=json.loads) Read request body decoded as *json*. @@ -296,7 +419,7 @@ and :ref:`aiohttp-web-signals` handlers. :meth:`~Request.json` call will return the same value. - .. coroutinemethod:: multipart(*, reader=aiohttp.multipart.MultipartReader) + .. comethod:: multipart() Returns :class:`aiohttp.multipart.MultipartReader` which processes incoming *multipart* request. @@ -317,7 +440,11 @@ and :ref:`aiohttp-web-signals` handlers. .. seealso:: :ref:`aiohttp-multipart` - .. coroutinemethod:: post() + .. versionchanged:: 3.4 + + Dropped *reader* parameter. + + .. comethod:: post() A :ref:`coroutine ` that reads POST parameters from request body. @@ -335,7 +462,7 @@ and :ref:`aiohttp-web-signals` handlers. The method **does** store read data internally, subsequent :meth:`~Request.post` call will return the same value. - .. coroutinemethod:: release() + .. comethod:: release() Release request. @@ -350,7 +477,7 @@ and :ref:`aiohttp-web-signals` handlers. .. class:: Request - An request used for receiving request's information by *web handler*. + A request used for receiving request's information by *web handler*. Every :ref:`handler` accepts a request instance as the first positional parameter. @@ -374,6 +501,16 @@ and :ref:`aiohttp-web-signals` handlers. An :class:`Application` instance used to call :ref:`request handler `, Read-only property. + .. attribute:: config_dict + + A :class:`aiohttp.ChainMapProxy` instance for mapping all properties + from the current application returned by :attr:`app` property + and all its parents. + + .. seealso:: :ref:`aiohttp-web-data-sharing-app-config` + + .. versionadded:: 3.2 + .. note:: You should never create the :class:`Request` instance manually @@ -383,6 +520,7 @@ and :ref:`aiohttp-web-signals` handlers. + .. _aiohttp-web-response: @@ -409,9 +547,19 @@ The common case for sending an answer from :ref:`web-handler` is returning a :class:`Response` instance:: - def handler(request): - return Response("All right!") + async def handler(request): + return Response(text="All right!") + +Response classes are :obj:`dict` like objects, +allowing them to be used for :ref:`sharing +data` among :ref:`aiohttp-web-middlewares` +and :ref:`aiohttp-web-signals` handlers:: + + resp['key'] = value + +.. versionadded:: 3.0 + Dict-like interface support. StreamResponse ^^^^^^^^^^^^^^ @@ -446,8 +594,6 @@ StreamResponse Read-only :class:`bool` property, ``True`` if :meth:`prepare` has been called, ``False`` otherwise. - .. versionadded:: 0.18 - .. attribute:: task A task that serves HTTP request handling. @@ -455,8 +601,6 @@ StreamResponse May be useful for graceful shutdown of long-running requests (streaming, long polling or web-socket). - .. versionadded:: 1.2 - .. attribute:: status Read-only property for *HTTP response status code*, :class:`int`. @@ -515,7 +659,7 @@ StreamResponse .. method:: enable_chunked_encoding Enables :attr:`chunked` encoding for response. There are no ways to - disable it back. With enabled :attr:`chunked` encoding each `write()` + disable it back. With enabled :attr:`chunked` encoding each :meth:`write` operation encoded in separate chunk. .. warning:: chunked encoding can be enabled for ``HTTP/1.1`` only. @@ -527,7 +671,7 @@ StreamResponse .. attribute:: headers - :class:`~aiohttp.CIMultiiDct` instance + :class:`~multidict.CIMultiDict` instance for *outgoing* *HTTP headers*. .. attribute:: cookies @@ -545,7 +689,8 @@ StreamResponse .. method:: set_cookie(name, value, *, path='/', expires=None, \ domain=None, max_age=None, \ - secure=None, httponly=None, version=None) + secure=None, httponly=None, version=None, \ + samesite=None) Convenient way for setting :attr:`cookies`, allows to specify some additional properties like *max_age* in a single call. @@ -590,6 +735,14 @@ StreamResponse specification the cookie conforms. (Optional, *version=1* by default) + :param str samesite: Asserts that a cookie must not be sent with + cross-origin requests, providing some protection + against cross-site request forgery attacks. + Generally the value should be one of: ``None``, + ``Lax`` or ``Strict``. (optional) + + .. versionadded:: 3.7 + .. warning:: In HTTP version 1.1, ``expires`` was deprecated and replaced with @@ -606,11 +759,6 @@ StreamResponse :param str path: optional cookie path, ``'/'`` by default - .. versionchanged:: 1.0 - - Fixed cookie expiration support for - Internet Explorer (version less than 11). - .. attribute:: content_length *Content-Length* for outgoing response. @@ -634,38 +782,7 @@ StreamResponse as an :class:`int` or a :class:`float` object, and the value ``None`` to unset the header. - .. attribute:: tcp_cork - - :const:`~socket.TCP_CORK` (linux) or :const:`~socket.TCP_NOPUSH` - (FreeBSD and MacOSX) is applied to underlying transport if the - property is ``True``. - - Use :meth:`set_tcp_cork` to assign new value to the property. - - Default value is ``False``. - - .. method:: set_tcp_cork(value) - - Set :attr:`tcp_cork` property to *value*. - - Clear :attr:`tcp_nodelay` if *value* is ``True``. - - .. attribute:: tcp_nodelay - - :const:`~socket.TCP_NODELAY` is applied to underlying transport - if the property is ``True``. - - Use :meth:`set_tcp_nodelay` to assign new value to the property. - - Default value is ``True``. - - .. method:: set_tcp_nodelay(value) - - Set :attr:`tcp_nodelay` property to *value*. - - Clear :attr:`tcp_cork` if *value* is ``True``. - - .. coroutinemethod:: prepare(request) + .. comethod:: prepare(request) :param aiohttp.web.Request request: HTTP request object, that the response answers. @@ -674,15 +791,16 @@ StreamResponse calling this method. The coroutine calls :attr:`~aiohttp.web.Application.on_response_prepare` - signal handlers. + signal handlers after default headers have been computed and directly + before headers are sent. - .. versionadded:: 0.18 + .. comethod:: write(data) - .. method:: write(data) + Send byte-ish data as the part of *response BODY*:: - Send byte-ish data as the part of *response BODY*. + await resp.write(data) - :meth:`prepare` must be called before. + :meth:`prepare` must be invoked before the call. Raises :exc:`TypeError` if data is not :class:`bytes`, :class:`bytearray` or :class:`memoryview` instance. @@ -691,23 +809,7 @@ StreamResponse Raises :exc:`RuntimeError` if :meth:`write_eof` has been called. - .. coroutinemethod:: drain() - - A :ref:`coroutine` to let the write buffer of the - underlying transport a chance to be flushed. - - The intended use is to write:: - - resp.write(data) - await resp.drain() - - Yielding from :meth:`drain` gives the opportunity for the loop - to schedule the write operation and flush the buffer. It should - especially be used when a possibly large amount of data is - written to the transport, and the coroutine does not yield-from - between calls to :meth:`write`. - - .. coroutinemethod:: write_eof() + .. comethod:: write_eof() A :ref:`coroutine` *may* be called as a mark of the *HTTP response* processing finish. @@ -722,8 +824,9 @@ StreamResponse Response ^^^^^^^^ -.. class:: Response(*, status=200, headers=None, content_type=None, \ - charset=None, body=None, text=None) +.. class:: Response(*, body=None, status=200, reason=None, text=None, \ + headers=None, content_type=None, charset=None, \ + zlib_executor_size=sentinel, zlib_executor=None) The most usable response class, inherited from :class:`StreamResponse`. @@ -748,6 +851,15 @@ Response :param str charset: response's charset. ``'utf-8'`` if *text* is passed also, ``None`` otherwise. + :param int zlib_executor_size: length in bytes which will trigger zlib compression + of body to happen in an executor + + .. versionadded:: 3.5 + + :param int zlib_executor: executor to use for zlib compression + + .. versionadded:: 3.5 + .. attribute:: body @@ -757,6 +869,11 @@ Response Setting :attr:`body` also recalculates :attr:`~StreamResponse.content_length` value. + Assigning :class:`str` to :attr:`body` will make the :attr:`body` + type of :class:`aiohttp.payload.StringPayload`, which tries to encode + the given data based on *Content-Type* HTTP header, while defaulting + to ``UTF-8``. + Resetting :attr:`body` (assigning ``None``) sets :attr:`~StreamResponse.content_length` to ``None`` too, dropping *Content-Length* HTTP header. @@ -778,8 +895,9 @@ Response WebSocketResponse ^^^^^^^^^^^^^^^^^ -.. class:: WebSocketResponse(*, timeout=10.0, receive_timeout=None, autoclose=True, \ - autoping=True, heartbeat=None, protocols=()) +.. class:: WebSocketResponse(*, timeout=10.0, receive_timeout=None, \ + autoclose=True, autoping=True, heartbeat=None, \ + protocols=(), compress=True, max_msg_size=4194304) Class for handling server-side websockets, inherited from :class:`StreamResponse`. @@ -789,11 +907,10 @@ WebSocketResponse communicate with websocket client by :meth:`send_str`, :meth:`receive` and others. - .. versionadded:: 1.3.0 - To enable back-pressure from slow websocket clients treat methods - `ping()`, `pong()`, `send_str()`, `send_bytes()`, `send_json()` as coroutines. - By default write buffer size is set to 64k. + :meth:`ping()`, :meth:`pong()`, :meth:`send_str()`, + :meth:`send_bytes()`, :meth:`send_json()` as coroutines. By + default write buffer size is set to 64k. :param bool autoping: Automatically send :const:`~aiohttp.WSMsgType.PONG` on @@ -801,33 +918,40 @@ WebSocketResponse message from client, and handle :const:`~aiohttp.WSMsgType.PONG` responses from client. - Note that server doesn't send + Note that server does not send :const:`~aiohttp.WSMsgType.PING` requests, you need to do this explicitly using :meth:`ping` method. - .. versionadded:: 1.3.0 + :param float heartbeat: Send `ping` message every `heartbeat` + seconds and wait `pong` response, close + connection if `pong` response is not + received. The timer is reset on any data reception. + + :param float receive_timeout: Timeout value for `receive` + operations. Default value is None + (no timeout for receive operation) - :param float heartbeat: Send `ping` message every `heartbeat` seconds - and wait `pong` response, close connection if `pong` response - is not received. + :param bool compress: Enable per-message deflate extension support. + False for disabled, default value is True. - :param float receive_timeout: Timeout value for `receive` operations. - Default value is None (no timeout for receive operation) + :param int max_msg_size: maximum size of read websocket message, 4 + MB by default. To disable the size limit use ``0``. - .. versionadded:: 0.19 + .. versionadded:: 3.3 - The class supports ``async for`` statement for iterating over - incoming messages:: - ws = web.WebSocketResponse() - await ws.prepare(request) + The class supports ``async for`` statement for iterating over + incoming messages:: - async for msg in ws: - print(msg.data) + ws = web.WebSocketResponse() + await ws.prepare(request) + async for msg in ws: + print(msg.data) - .. coroutinemethod:: prepare(request) + + .. comethod:: prepare(request) Starts websocket. After the call you can use websocket methods. @@ -837,8 +961,6 @@ WebSocketResponse :raises HTTPException: if websocket handshake has failed. - .. versionadded:: 0.18 - .. method:: can_prepare(request) Performs checks for *request* data to figure out if websocket @@ -873,7 +995,7 @@ WebSocketResponse Read-only property, close code from peer. It is set to ``None`` on opened connection. - .. attribute:: protocol + .. attribute:: ws_protocol Websocket *subprotocol* chosen after :meth:`start` call. @@ -884,7 +1006,7 @@ WebSocketResponse Returns last occurred exception or None. - .. method:: ping(message=b'') + .. comethod:: ping(message=b'') Send :const:`~aiohttp.WSMsgType.PING` to peer. @@ -894,7 +1016,11 @@ WebSocketResponse :raise RuntimeError: if connections is not started or closing. - .. method:: pong(message=b'') + .. versionchanged:: 3.0 + + The method is converted into :term:`coroutine` + + .. comethod:: pong(message=b'') Send *unsolicited* :const:`~aiohttp.WSMsgType.PONG` to peer. @@ -904,33 +1030,59 @@ WebSocketResponse :raise RuntimeError: if connections is not started or closing. - .. coroutinemethod:: send_str(data) + .. versionchanged:: 3.0 + + The method is converted into :term:`coroutine` + + .. comethod:: send_str(data, compress=None) Send *data* to peer as :const:`~aiohttp.WSMsgType.TEXT` message. :param str data: data to send. + :param int compress: sets specific level of compression for + single message, + ``None`` for not overriding per-socket setting. + :raise RuntimeError: if connection is not started or closing :raise TypeError: if data is not :class:`str` - .. coroutinemethod:: send_bytes(data) + .. versionchanged:: 3.0 + + The method is converted into :term:`coroutine`, + *compress* parameter added. + + .. comethod:: send_bytes(data, compress=None) Send *data* to peer as :const:`~aiohttp.WSMsgType.BINARY` message. :param data: data to send. + :param int compress: sets specific level of compression for + single message, + ``None`` for not overriding per-socket setting. + :raise RuntimeError: if connection is not started or closing :raise TypeError: if data is not :class:`bytes`, :class:`bytearray` or :class:`memoryview`. - .. coroutinemethod:: send_json(data, *, dumps=json.loads) + .. versionchanged:: 3.0 + + The method is converted into :term:`coroutine`, + *compress* parameter added. + + .. comethod:: send_json(data, compress=None, *, dumps=json.dumps) Send *data* to peer as JSON string. :param data: data to send. + :param int compress: sets specific level of compression for + single message, + ``None`` for not overriding per-socket setting. + :param callable dumps: any :term:`callable` that accepts an object and returns a JSON string (:func:`json.dumps` by default). @@ -941,22 +1093,27 @@ WebSocketResponse :raise TypeError: if value returned by ``dumps`` param is not :class:`str` - .. coroutinemethod:: close(*, code=1000, message=b'') + .. versionchanged:: 3.0 + + The method is converted into :term:`coroutine`, + *compress* parameter added. + + .. comethod:: close(*, code=1000, message=b'') A :ref:`coroutine` that initiates closing handshake by sending :const:`~aiohttp.WSMsgType.CLOSE` message. - It is save to call `close()` from different task. + It is safe to call `close()` from different task. :param int code: closing code - :param message: optional payload of *pong* message, + :param message: optional payload of *close* message, :class:`str` (converted to *UTF-8* encoded bytes) or :class:`bytes`. :raise RuntimeError: if connection is not started - .. coroutinemethod:: receive(timeout=None) + .. comethod:: receive(timeout=None) A :ref:`coroutine` that waits upcoming *data* message from peer and returns it. @@ -974,13 +1131,14 @@ WebSocketResponse Can only be called by the request handling task. :param timeout: timeout for `receive` operation. - timeout value overrides response`s receive_timeout attribute. + + timeout value overrides response`s receive_timeout attribute. :return: :class:`~aiohttp.WSMessage` :raise RuntimeError: if connection is not started - .. coroutinemethod:: receive_str(*, timeout=None) + .. comethod:: receive_str(*, timeout=None) A :ref:`coroutine` that calls :meth:`receive` but also asserts the message type is :const:`~aiohttp.WSMsgType.TEXT`. @@ -990,13 +1148,14 @@ WebSocketResponse Can only be called by the request handling task. :param timeout: timeout for `receive` operation. - timeout value overrides response`s receive_timeout attribute. + + timeout value overrides response`s receive_timeout attribute. :return str: peer's message content. :raise TypeError: if message is :const:`~aiohttp.WSMsgType.BINARY`. - .. coroutinemethod:: receive_bytes(*, timeout=None) + .. comethod:: receive_bytes(*, timeout=None) A :ref:`coroutine` that calls :meth:`receive` but also asserts the message type is @@ -1007,13 +1166,14 @@ WebSocketResponse Can only be called by the request handling task. :param timeout: timeout for `receive` operation. - timeout value overrides response`s receive_timeout attribute. + + timeout value overrides response`s receive_timeout attribute. :return bytes: peer's message content. :raise TypeError: if message is :const:`~aiohttp.WSMsgType.TEXT`. - .. coroutinemethod:: receive_json(*, loads=json.loads, timeout=None) + .. comethod:: receive_json(*, loads=json.loads, timeout=None) A :ref:`coroutine` that calls :meth:`receive_str` and loads the JSON string to a Python dict. @@ -1027,19 +1187,18 @@ WebSocketResponse with parsed JSON (:func:`json.loads` by default). - :param timeout: timeout for `receive` operation. - timeout value overrides response`s receive_timeout attribute. + :param timeout: timeout for `receive` operation. + + timeout value overrides response`s receive_timeout attribute. :return dict: loaded JSON content :raise TypeError: if message is :const:`~aiohttp.WSMsgType.BINARY`. :raise ValueError: if message is not valid JSON. - .. versionadded:: 0.22 - .. seealso:: :ref:`WebSockets handling` - + WebSocketReady ^^^^^^^^^^^^^^ @@ -1068,7 +1227,7 @@ WebSocketReady json_response -------------- +^^^^^^^^^^^^^ .. function:: json_response([data], *, text=None, body=None, \ status=200, reason=None, headers=None, \ @@ -1079,6 +1238,40 @@ Return :class:`Response` with predefined ``'application/json'`` content type and *data* encoded by ``dumps`` parameter (:func:`json.dumps` by default). +HTTP Exceptions +^^^^^^^^^^^^^^^ +Errors can also be returned by raising a HTTP exception instance from within +the handler. + +.. class:: HTTPException(*, headers=None, reason=None, text=None, content_type=None) + + Low-level HTTP failure. + + :param headers: headers for the response + :type headers: dict or multidict.CIMultiDict + + :param str reason: reason included in the response + + :param str text: response's body + + :param str content_type: response's content type. This is passed through + to the :class:`Response` initializer. + + Sub-classes of ``HTTPException`` exist for the standard HTTP response codes + as described in :ref:`aiohttp-web-exceptions` and the expected usage is to + simply raise the appropriate exception type to respond with a specific HTTP + response code. + + Since ``HTTPException`` is a sub-class of :class:`Response`, it contains the + methods and properties that allow you to directly manipulate details of the + response. + + .. attribute:: status_code + + HTTP status code for this exception class. This attribute is usually + defined at the class level. ``self.status_code`` is passed to the + :class:`Response` initializer. + .. _aiohttp-web-app-and-router: @@ -1092,10 +1285,8 @@ Application Application is a synonym for web-server. To get fully working example, you have to make *application*, register -supported urls in *router* and create a *server socket* with -:class:`~aiohttp.web.Server` as a *protocol -factory*. *Server* could be constructed with -:meth:`Application.make_handler`. +supported urls in *router* and pass it to :func:`aiohttp.web.run_app` +or :class:`aiohttp.web.AppRunner`. *Application* contains a *router* instance and a list of callbacks that will be called during application finishing. @@ -1115,25 +1306,48 @@ properties for later access from a :ref:`handler` via the Although :class:`Application` is a :obj:`dict`-like object, it can't be duplicated like one using :meth:`Application.copy`. -.. class:: Application(*, router=None, logger=, \ - middlewares=(), debug=False, **kwargs) +.. class:: Application(*, logger=, router=None,middlewares=(), \ + handler_args=None, client_max_size=1024**2, \ + loop=None, debug=...) The class inherits :class:`dict`. + :param logger: :class:`logging.Logger` instance for storing application logs. + + By default the value is ``logging.getLogger("aiohttp.web")`` + :param router: :class:`aiohttp.abc.AbstractRouter` instance, the system creates :class:`UrlDispatcher` by default if *router* is ``None``. - :param logger: :class:`logging.Logger` instance for storing application logs. + .. deprecated:: 3.3 - By default the value is ``logging.getLogger("aiohttp.web")`` + The custom routers support is deprecated, the parameter will + be removed in 4.0. :param middlewares: :class:`list` of middleware factories, see :ref:`aiohttp-web-middlewares` for details. + :param handler_args: dict-like object that overrides keyword arguments of + :meth:`Application.make_handler` + + :param client_max_size: client's maximum size in a request, in + bytes. If a POST request exceeds this + value, it raises an + `HTTPRequestEntityTooLarge` exception. + + :param loop: event loop + + .. deprecated:: 2.0 + + The parameter is deprecated. Loop is get set during freeze + stage. + :param debug: Switches debug mode. - :param loop: loop parameter is deprecated. loop is get set during freeze stage. + .. deprecated:: 3.5 + + Use asyncio :ref:`asyncio-debug-mode` instead. .. attribute:: router @@ -1147,17 +1361,23 @@ duplicated like one using :meth:`Application.copy`. :ref:`event loop` used for processing HTTP requests. + .. deprecated:: 3.5 .. attribute:: debug Boolean value indicating whether the debug mode is turned on or off. + .. deprecated:: 3.5 + + Use asyncio :ref:`asyncio-debug-mode` instead. + .. attribute:: on_response_prepare - A :class:`~aiohttp.signals.Signal` that is fired at the beginning + A :class:`~aiohttp.Signal` that is fired near the end of :meth:`StreamResponse.prepare` with parameters *request* and *response*. It can be used, for example, to add custom headers to each - response before sending. + response, or to modify the default headers computed by the application, + directly before sending the headers to the client. Signal handlers should have the following signature:: @@ -1166,7 +1386,7 @@ duplicated like one using :meth:`Application.copy`. .. attribute:: on_startup - A :class:`~aiohttp.signals.Signal` that is fired on application start-up. + A :class:`~aiohttp.Signal` that is fired on application start-up. Subscribers may use the signal to run background tasks in the event loop along with the application's request handler just after the @@ -1177,11 +1397,11 @@ duplicated like one using :meth:`Application.copy`. async def on_startup(app): pass - .. seealso:: :ref:`aiohttp-web-background-tasks`. + .. seealso:: :ref:`aiohttp-web-signals`. .. attribute:: on_shutdown - A :class:`~aiohttp.signals.Signal` that is fired on application shutdown. + A :class:`~aiohttp.Signal` that is fired on application shutdown. Subscribers may use the signal for gracefully closing long running connections, e.g. websockets and data streaming. @@ -1201,7 +1421,7 @@ duplicated like one using :meth:`Application.copy`. .. attribute:: on_cleanup - A :class:`~aiohttp.signals.Signal` that is fired on application cleanup. + A :class:`~aiohttp.Signal` that is fired on application cleanup. Subscribers may use the signal for gracefully closing connections to database server etc. @@ -1211,70 +1431,125 @@ duplicated like one using :meth:`Application.copy`. async def on_cleanup(app): pass - .. seealso:: :ref:`aiohttp-web-graceful-shutdown` and :attr:`on_shutdown`. + .. seealso:: :ref:`aiohttp-web-signals` and :attr:`on_shutdown`. - .. method:: make_handler(loop=None, **kwargs) + .. attribute:: cleanup_ctx + + A list of *context generators* for *startup*/*cleanup* handling. + + Signal handlers should have the following signature:: + + async def context(app): + # do startup stuff + yield + # do cleanup + + .. versionadded:: 3.1 + + .. seealso:: :ref:`aiohttp-web-cleanup-ctx`. + + .. method:: add_subapp(prefix, subapp) + + Register nested sub-application under given path *prefix*. + + In resolving process if request's path starts with *prefix* then + further resolving is passed to *subapp*. + + :param str prefix: path's prefix for the resource. + + :param Application subapp: nested application attached under *prefix*. + + :returns: a :class:`PrefixedSubAppResource` instance. + + .. method:: add_domain(domain, subapp) + + Register nested sub-application that serves + the domain name or domain name mask. + + In resolving process if request.headers['host'] + matches the pattern *domain* then + further resolving is passed to *subapp*. + + :param str domain: domain or mask of domain for the resource. + + :param Application subapp: nested application. + + :returns: a :class:`MatchedSubAppResource` instance. + + .. method:: add_routes(routes_table) + + Register route definitions from *routes_table*. + + The table is a :class:`list` of :class:`RouteDef` items or + :class:`RouteTableDef`. + + :returns: :class:`list` of registered :class:`AbstractRoute` instances. - Creates HTTP protocol factory for handling requests. + The method is a shortcut for + ``app.router.add_routes(routes_table)``, see also + :meth:`UrlDispatcher.add_routes`. - :param loop: :ref:`event loop` used - for processing HTTP requests. + .. versionadded:: 3.1 - If param is ``None`` :func:`asyncio.get_event_loop` - used for getting default event loop. + .. versionchanged:: 3.7 - :param tuple secure_proxy_ssl_header: Secure proxy SSL header. Can - be used to detect request scheme, - e.g. ``secure_proxy_ssl_header=('X-Forwarded-Proto', 'https')``. + Return value updated from ``None`` to :class:`list` of + :class:`AbstractRoute` instances. - Default: ``None``. - :param bool tcp_keepalive: Enable TCP Keep-Alive. Default: ``True``. - :param int keepalive_timeout: Number of seconds before closing Keep-Alive - connection. Default: ``75`` seconds (NGINX's default value). - :param slow_request_timeout: Slow request timeout. Default: ``0``. - :param logger: Custom logger object. Default: - :data:`aiohttp.log.server_logger`. - :param access_log: Custom logging object. Default: - :data:`aiohttp.log.access_logger`. - :param str access_log_format: Access log format string. Default: - :attr:`helpers.AccessLogger.LOG_FORMAT`. - :param bool debug: Switches debug mode. Default: ``False``. + .. method:: make_handler(loop=None, **kwargs) + + Creates HTTP protocol factory for handling requests. + + :param loop: :ref:`event loop` used + for processing HTTP requests. + + If param is ``None`` :func:`asyncio.get_event_loop` + used for getting default event loop. - .. deprecated:: 1.0 + .. deprecated:: 2.0 - The usage of ``debug`` parameter in :meth:`Application.make_handler` - is deprecated in favor of :attr:`Application.debug`. - The :class:`Application`'s debug mode setting should be used - as a single point to setup a debug mode. + :param bool tcp_keepalive: Enable TCP Keep-Alive. Default: ``True``. + :param int keepalive_timeout: Number of seconds before closing Keep-Alive + connection. Default: ``75`` seconds (NGINX's default value). + :param logger: Custom logger object. Default: + :data:`aiohttp.log.server_logger`. + :param access_log: Custom logging object. Default: + :data:`aiohttp.log.access_logger`. + :param access_log_class: Class for `access_logger`. Default: + :data:`aiohttp.helpers.AccessLogger`. + Must to be a subclass of :class:`aiohttp.abc.AbstractAccessLogger`. + :param str access_log_format: Access log format string. Default: + :attr:`helpers.AccessLogger.LOG_FORMAT`. + :param int max_line_size: Optional maximum header line size. Default: + ``8190``. + :param int max_headers: Optional maximum header size. Default: ``32768``. + :param int max_field_size: Optional maximum header field size. Default: + ``8190``. - :param int max_line_size: Optional maximum header line size. Default: - ``8190``. - :param int max_headers: Optional maximum header size. Default: ``32768``. - :param int max_field_size: Optional maximum header field size. Default: - ``8190``. + :param float lingering_time: Maximum time during which the server + reads and ignores additional data coming from the client when + lingering close is on. Use ``0`` to disable lingering on + server channel closing. - :param float lingering_time: maximum time during which the server - reads and ignore additional data coming from the client when - lingering close is on. Use ``0`` for disabling lingering on - server channel closing. + You should pass result of the method as *protocol_factory* to + :meth:`~asyncio.AbstractEventLoop.create_server`, e.g.:: - :param float lingering_timeout: maximum waiting time for more - client data to arrive when lingering close is in effect + loop = asyncio.get_event_loop() - You should pass result of the method as *protocol_factory* to - :meth:`~asyncio.AbstractEventLoop.create_server`, e.g.:: + app = Application() - loop = asyncio.get_event_loop() + # setup route table + # app.router.add_route(...) - app = Application(loop=loop) + await loop.create_server(app.make_handler(), + '0.0.0.0', 8080) - # setup route table - # app.router.add_route(...) + .. deprecated:: 3.2 - await loop.create_server(app.make_handler(), - '0.0.0.0', 8080) + The method is deprecated and will be removed in future + aiohttp versions. Please use :ref:`aiohttp-web-app-runners` instead. - .. coroutinemethod:: startup() + .. comethod:: startup() A :ref:`coroutine` that will be called along with the application's request handler. @@ -1282,7 +1557,7 @@ duplicated like one using :meth:`Application.copy`. The purpose of the method is calling :attr:`on_startup` signal handlers. - .. coroutinemethod:: shutdown() + .. comethod:: shutdown() A :ref:`coroutine` that should be called on server stopping but before :meth:`cleanup()`. @@ -1290,7 +1565,7 @@ duplicated like one using :meth:`Application.copy`. The purpose of the method is calling :attr:`on_shutdown` signal handlers. - .. coroutinemethod:: cleanup() + .. comethod:: cleanup() A :ref:`coroutine` that should be called on server stopping but after :meth:`shutdown`. @@ -1321,12 +1596,12 @@ Server A protocol factory compatible with :meth:`~asyncio.AbstreactEventLoop.create_server`. - .. class:: Server +.. class:: Server The class is responsible for creating HTTP protocol objects that can handle HTTP connections. - .. attribute:: Server.connections + .. attribute:: connections List of all currently opened connections. @@ -1334,26 +1609,11 @@ A protocol factory compatible with Amount of processed requests. - .. versionadded:: 1.0 - - .. coroutinemethod:: Server.shutdown(timeout) + .. comethod:: Server.shutdown(timeout) A :ref:`coroutine` that should be called to close all opened connections. - .. coroutinemethod:: Server.finish_connections(timeout) - - .. deprecated:: 1.2 - - A deprecated alias for :meth:`shutdown`. - - .. versionchanged:: 1.2 - - ``Server`` was called ``RequestHandlerFactory`` before ``aiohttp==1.2``. - - The rename has no deprecation period but it's safe: no user - should instantiate the class by hands. - Router ^^^^^^ @@ -1432,48 +1692,75 @@ Router is any object that implements :class:`AbstractRouter` interface. :returns: new :class:`PlainRoute` or :class:`DynamicRoute` instance. - .. method:: add_get(path, *args, **kwargs) + .. method:: add_routes(routes_table) + + Register route definitions from *routes_table*. + + The table is a :class:`list` of :class:`RouteDef` items or + :class:`RouteTableDef`. + + :returns: :class:`list` of registered :class:`AbstractRoute` instances. + + .. versionadded:: 2.3 + + .. versionchanged:: 3.7 + + Return value updated from ``None`` to :class:`list` of + :class:`AbstractRoute` instances. + + .. method:: add_get(path, handler, *, name=None, allow_head=True, **kwargs) Shortcut for adding a GET handler. Calls the :meth:`add_route` with \ ``method`` equals to ``'GET'``. - .. versionadded:: 1.0 + If *allow_head* is ``True`` (default) the route for method HEAD + is added with the same handler as for GET. - .. method:: add_post(path, *args, **kwargs) + If *name* is provided the name for HEAD route is suffixed with + ``'-head'``. For example ``router.add_get(path, handler, + name='route')`` call adds two routes: first for GET with name + ``'route'`` and second for HEAD with name ``'route-head'``. + + .. method:: add_post(path, handler, **kwargs) Shortcut for adding a POST handler. Calls the :meth:`add_route` with \ ``method`` equals to ``'POST'``. - .. versionadded:: 1.0 + .. method:: add_head(path, handler, **kwargs) + + Shortcut for adding a HEAD handler. Calls the :meth:`add_route` with \ + ``method`` equals to ``'HEAD'``. - .. method:: add_put(path, *args, **kwargs) + .. method:: add_put(path, handler, **kwargs) Shortcut for adding a PUT handler. Calls the :meth:`add_route` with \ ``method`` equals to ``'PUT'``. - .. versionadded:: 1.0 - - .. method:: add_patch(path, *args, **kwargs) + .. method:: add_patch(path, handler, **kwargs) Shortcut for adding a PATCH handler. Calls the :meth:`add_route` with \ ``method`` equals to ``'PATCH'``. - .. versionadded:: 1.0 - - .. method:: add_delete(path, *args, **kwargs) + .. method:: add_delete(path, handler, **kwargs) Shortcut for adding a DELETE handler. Calls the :meth:`add_route` with \ ``method`` equals to ``'DELETE'``. - .. versionadded:: 1.0 + .. method:: add_view(path, handler, **kwargs) + + Shortcut for adding a class-based view handler. Calls the \ + :meth:`add_route` with ``method`` equals to ``'*'``. + + .. versionadded:: 3.0 .. method:: add_static(prefix, path, *, name=None, expect_handler=None, \ chunk_size=256*1024, \ response_factory=StreamResponse, \ show_index=False, \ - follow_symlinks=False) + follow_symlinks=False, \ + append_version=False) Adds a router and a handler for returning static files. @@ -1495,17 +1782,6 @@ Router is any object that implements :class:`AbstractRouter` interface. static content should be processed by web servers like *nginx* or *apache*. - .. versionchanged:: 0.18.0 - Transfer files using the ``sendfile`` system call on supported - platforms. - - .. versionchanged:: 0.19.0 - Disable ``sendfile`` by setting environment variable - ``AIOHTTP_NOSENDFILE=1`` - - .. versionchanged:: 1.2.0 - Send gzip version if file path + ``.gz`` exists. - :param str prefix: URL path prefix for handled static files :param path: path to the folder in file system that contains @@ -1522,15 +1798,6 @@ Router is any object that implements :class:`AbstractRouter` interface. say, 1Mb may increase file downloading speed but consumes more memory. - .. versionadded:: 0.16 - - :param callable response_factory: factory to use to generate a new - response, defaults to - :class:`StreamResponse` and should - expose a compatible API. - - .. versionadded:: 0.17 - :param bool show_index: flag for allowing to show indexes of a directory, by default it's not allowed and HTTP/403 will be returned on directory access. @@ -1539,24 +1806,16 @@ Router is any object that implements :class:`AbstractRouter` interface. a directory, by default it's not allowed and HTTP/404 will be returned on access. - :returns: new :class:`StaticRoute` instance. - - .. method:: add_subapp(prefix, subapp) - - Register nested sub-application under given path *prefix*. - - In resolving process if request's path starts with *prefix* then - further resolving is passed to *subapp*. - - :param str prefix: path's prefix for the resource. + :param bool append_version: flag for adding file version (hash) + to the url query string, this value will + be used as default when you call to + :meth:`StaticRoute.url` and + :meth:`StaticRoute.url_for` methods. - :param Application subapp: nested application attached under *prefix*. - - :returns: a :class:`PrefixedSubAppResource` instance. - .. versionadded:: 1.1 + :returns: new :class:`StaticRoute` instance. - .. coroutinemethod:: resolve(request) + .. comethod:: resolve(request) A :ref:`coroutine` that returns :class:`AbstractMatchInfo` for *request*. @@ -1597,14 +1856,10 @@ Router is any object that implements :class:`AbstractRouter` interface. route in app.router.resources() - .. versionadded:: 0.21.1 - .. method:: routes() The method returns a *view* for *all* registered routes. - .. versionadded:: 0.18 - .. method:: named_resources() Returns a :obj:`dict`-like :class:`types.MappingProxyType` *view* over @@ -1624,26 +1879,6 @@ Router is any object that implements :class:`AbstractRouter` interface. app.router.named_resources()["name"] - .. versionadded:: 0.21 - - .. method:: named_routes() - - An alias for :meth:`named_resources` starting from aiohttp 0.21. - - .. versionadded:: 0.19 - - .. versionchanged:: 0.21 - - The method is an alias for :meth:`named_resources`, so it - iterates over resources instead of routes. - - .. deprecated:: 0.21 - - Please use named **resources** instead of named **routes**. - - Several routes which belongs to the same resource shares the - resource name. - .. _aiohttp-web-resource: @@ -1709,14 +1944,18 @@ Resource classes hierarchy:: Read-only *name* of resource or ``None``. - .. coroutinemethod:: resolve(method, path) + .. attribute:: canonical + + Read-only *canonical path* associate with the resource. For example + ``/path/to`` or ``/path/{to}`` + + .. versionadded:: 3.3 + + .. comethod:: resolve(request) Resolve resource by finding appropriate :term:`web-handler` for ``(method, path)`` combination. - :param str method: requested HTTP method. - :param str path: *path* part of request. - :return: (*match_info*, *allowed_methods*) pair. *allowed_methods* is a :class:`set` or HTTP methods accepted by @@ -1741,21 +1980,6 @@ Resource classes hierarchy:: :return: :class:`~yarl.URL` -- resulting URL instance. - .. versionadded:: 1.1 - - .. method:: url(**kwargs) - - Construct an URL for route with additional params. - - **kwargs** depends on a list accepted by inherited resource - class parameters. - - :return: :class:`str` -- resulting URL string. - - .. deprecated:: 1.1 - - Use :meth:`url_for` instead. - .. class:: Resource @@ -1790,13 +2014,17 @@ Resource classes hierarchy:: The class corresponds to resources with plain-text matching, ``'/path/to'`` for example. + .. attribute:: canonical + + Read-only *canonical path* associate with the resource. Returns the path + used to create the PlainResource. For example ``/path/to`` + + .. versionadded:: 3.3 .. method:: url_for() Returns a :class:`~yarl.URL` for the resource. - .. versionadded:: 1.1 - .. class:: DynamicResource @@ -1806,6 +2034,13 @@ Resource classes hierarchy:: :ref:`variable ` matching, e.g. ``'/path/{to}/{param}'`` etc. + .. attribute:: canonical + + Read-only *canonical path* associate with the resource. Returns the + formatter obtained from the path used to create the DynamicResource. + For example, from a path ``/get/{num:^\d+}``, it returns ``/get/{num}`` + + .. versionadded:: 3.3 .. method:: url_for(**params) @@ -1817,8 +2052,6 @@ Resource classes hierarchy:: be called as ``resource.url_for(to='val1', param='val2')`` - .. versionadded:: 1.1 - .. class:: StaticResource A resource, inherited from :class:`Resource`. @@ -1826,7 +2059,14 @@ Resource classes hierarchy:: The class corresponds to resources for :ref:`static file serving `. - .. method:: url_for(filename) + .. attribute:: canonical + + Read-only *canonical path* associate with the resource. Returns the prefix + used to create the StaticResource. For example ``/prefix`` + + .. versionadded:: 3.3 + + .. method:: url_for(filename, append_version=None) Returns a :class:`~yarl.URL` for file path under resource prefix. @@ -1837,14 +2077,29 @@ Resource classes hierarchy:: E.g. an URL for ``'/prefix/dir/file.txt'`` should be generated as ``resource.url_for(filename='dir/file.txt')`` - .. versionadded:: 1.1 + :param bool append_version: -- a flag for adding file version + (hash) to the url query string for + cache boosting + + By default has value from a constructor (``False`` by default) + When set to ``True`` - ``v=FILE_HASH`` query string param will be added + When set to ``False`` has no impact + + if file not found has no impact + .. class:: PrefixedSubAppResource A resource for serving nested applications. The class instance is returned by :class:`~aiohttp.web.Application.add_subapp` call. - .. versionadded:: 1.1 + .. attribute:: canonical + + Read-only *canonical path* associate with the resource. Returns the + prefix used to create the PrefixedSubAppResource. + For example ``/prefix`` + + .. versionadded:: 3.3 .. method:: url_for(**kwargs) @@ -1898,7 +2153,7 @@ and *405 Method Not Allowed*. Actually it's a shortcut for ``route.resource.url_for(...)``. - .. coroutinemethod:: handle_expect_header(request) + .. comethod:: handle_expect_header(request) ``100-continue`` handler. @@ -1921,6 +2176,291 @@ and *405 Method Not Allowed*. HTTP status reason +.. _aiohttp-web-route-def: + + +RouteDef and StaticDef +^^^^^^^^^^^^^^^^^^^^^^ + +Route definition, a description for not registered yet route. + +Could be used for filing route table by providing a list of route +definitions (Django style). + +The definition is created by functions like :func:`get` or +:func:`post`, list of definitions could be added to router by +:meth:`UrlDispatcher.add_routes` call:: + + from aiohttp import web + + async def handle_get(request): + ... + + + async def handle_post(request): + ... + + app.router.add_routes([web.get('/get', handle_get), + web.post('/post', handle_post), + +.. class:: AbstractRouteDef + + A base class for route definitions. + + Inherited from :class:`abc.ABC`. + + .. versionadded:: 3.1 + + .. method:: register(router) + + Register itself into :class:`UrlDispatcher`. + + Abstract method, should be overridden by subclasses. + + :returns: :class:`list` of registered :class:`AbstractRoute` objects. + + .. versionchanged:: 3.7 + + Return value updated from ``None`` to :class:`list` of + :class:`AbstractRoute` instances. + + +.. class:: RouteDef + + A definition of not registered yet route. + + Implements :class:`AbstractRouteDef`. + + .. versionadded:: 2.3 + + .. versionchanged:: 3.1 + + The class implements :class:`AbstractRouteDef` interface. + + .. attribute:: method + + HTTP method (``GET``, ``POST`` etc.) (:class:`str`). + + .. attribute:: path + + Path to resource, e.g. ``/path/to``. Could contain ``{}`` + brackets for :ref:`variable resources + ` (:class:`str`). + + .. attribute:: handler + + An async function to handle HTTP request. + + .. attribute:: kwargs + + A :class:`dict` of additional arguments. + + +.. class:: StaticDef + + A definition of static file resource. + + Implements :class:`AbstractRouteDef`. + + .. versionadded:: 3.1 + + .. attribute:: prefix + + A prefix used for static file handling, e.g. ``/static``. + + .. attribute:: path + + File system directory to serve, :class:`str` or + :class:`pathlib.Path` + (e.g. ``'/home/web-service/path/to/static'``. + + .. attribute:: kwargs + + A :class:`dict` of additional arguments, see + :meth:`UrlDispatcher.add_static` for a list of supported + options. + + +.. function:: get(path, handler, *, name=None, allow_head=True, \ + expect_handler=None) + + Return :class:`RouteDef` for processing ``GET`` requests. See + :meth:`UrlDispatcher.add_get` for information about parameters. + + .. versionadded:: 2.3 + +.. function:: post(path, handler, *, name=None, expect_handler=None) + + Return :class:`RouteDef` for processing ``POST`` requests. See + :meth:`UrlDispatcher.add_post` for information about parameters. + + .. versionadded:: 2.3 + +.. function:: head(path, handler, *, name=None, expect_handler=None) + + Return :class:`RouteDef` for processing ``HEAD`` requests. See + :meth:`UrlDispatcher.add_head` for information about parameters. + + .. versionadded:: 2.3 + +.. function:: put(path, handler, *, name=None, expect_handler=None) + + Return :class:`RouteDef` for processing ``PUT`` requests. See + :meth:`UrlDispatcher.add_put` for information about parameters. + + .. versionadded:: 2.3 + +.. function:: patch(path, handler, *, name=None, expect_handler=None) + + Return :class:`RouteDef` for processing ``PATCH`` requests. See + :meth:`UrlDispatcher.add_patch` for information about parameters. + + .. versionadded:: 2.3 + +.. function:: delete(path, handler, *, name=None, expect_handler=None) + + Return :class:`RouteDef` for processing ``DELETE`` requests. See + :meth:`UrlDispatcher.add_delete` for information about parameters. + + .. versionadded:: 2.3 + +.. function:: view(path, handler, *, name=None, expect_handler=None) + + Return :class:`RouteDef` for processing ``ANY`` requests. See + :meth:`UrlDispatcher.add_view` for information about parameters. + + .. versionadded:: 3.0 + +.. function:: static(prefix, path, *, name=None, expect_handler=None, \ + chunk_size=256*1024, \ + show_index=False, follow_symlinks=False, \ + append_version=False) + + Return :class:`StaticDef` for processing static files. + + See :meth:`UrlDispatcher.add_static` for information + about supported parameters. + + .. versionadded:: 3.1 + +.. function:: route(method, path, handler, *, name=None, expect_handler=None) + + Return :class:`RouteDef` for processing requests that decided by + ``method``. See :meth:`UrlDispatcher.add_route` for information + about parameters. + + .. versionadded:: 2.3 + + +.. _aiohttp-web-route-table-def: + +RouteTableDef +^^^^^^^^^^^^^ + +A routes table definition used for describing routes by decorators +(Flask style):: + + from aiohttp import web + + routes = web.RouteTableDef() + + @routes.get('/get') + async def handle_get(request): + ... + + + @routes.post('/post') + async def handle_post(request): + ... + + app.router.add_routes(routes) + + + @routes.view("/view") + class MyView(web.View): + async def get(self): + ... + + async def post(self): + ... + +.. class:: RouteTableDef() + + A sequence of :class:`RouteDef` instances (implements + :class:`abc.collections.Sequence` protocol). + + In addition to all standard :class:`list` methods the class + provides also methods like ``get()`` and ``post()`` for adding new + route definition. + + .. versionadded:: 2.3 + + .. decoratormethod:: get(path, *, allow_head=True, \ + name=None, expect_handler=None) + + Add a new :class:`RouteDef` item for registering ``GET`` web-handler. + + See :meth:`UrlDispatcher.add_get` for information about parameters. + + .. decoratormethod:: post(path, *, name=None, expect_handler=None) + + Add a new :class:`RouteDef` item for registering ``POST`` web-handler. + + See :meth:`UrlDispatcher.add_post` for information about parameters. + + .. decoratormethod:: head(path, *, name=None, expect_handler=None) + + Add a new :class:`RouteDef` item for registering ``HEAD`` web-handler. + + See :meth:`UrlDispatcher.add_head` for information about parameters. + + .. decoratormethod:: put(path, *, name=None, expect_handler=None) + + Add a new :class:`RouteDef` item for registering ``PUT`` web-handler. + + See :meth:`UrlDispatcher.add_put` for information about parameters. + + .. decoratormethod:: patch(path, *, name=None, expect_handler=None) + + Add a new :class:`RouteDef` item for registering ``PATCH`` web-handler. + + See :meth:`UrlDispatcher.add_patch` for information about parameters. + + .. decoratormethod:: delete(path, *, name=None, expect_handler=None) + + Add a new :class:`RouteDef` item for registering ``DELETE`` web-handler. + + See :meth:`UrlDispatcher.add_delete` for information about parameters. + + .. decoratormethod:: view(path, *, name=None, expect_handler=None) + + Add a new :class:`RouteDef` item for registering ``ANY`` methods + against a class-based view. + + See :meth:`UrlDispatcher.add_view` for information about parameters. + + .. versionadded:: 3.0 + + .. method:: static(prefix, path, *, name=None, expect_handler=None, \ + chunk_size=256*1024, \ + show_index=False, follow_symlinks=False, \ + append_version=False) + + + Add a new :class:`StaticDef` item for registering static files processor. + + See :meth:`UrlDispatcher.add_static` for information about + supported parameters. + + .. versionadded:: 3.1 + + .. decoratormethod:: route(method, path, *, name=None, expect_handler=None) + + Add a new :class:`RouteDef` item for registering a web-handler + for arbitrary HTTP method. + + See :meth:`UrlDispatcher.add_route` for information about parameters. + MatchInfo ^^^^^^^^^ @@ -1973,10 +2513,10 @@ View resp = await post_response(self.request) return resp - app.router.add_route('*', '/view', MyView) + app.router.add_view('/view', MyView) The view raises *405 Method Not allowed* - (:class:`HTTPMethodNowAllowed`) if requested web verb is not + (:class:`HTTPMethodNotAllowed`) if requested web verb is not supported. :param request: instance of :class:`Request` that has initiated a view @@ -1995,6 +2535,267 @@ View .. seealso:: :ref:`aiohttp-web-class-based-views` +.. _aiohttp-web-app-runners-reference: + +Running Applications +-------------------- + +To start web application there is ``AppRunner`` and site classes. + +Runner is a storage for running application, sites are for running +application on specific TCP or Unix socket, e.g.:: + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 8080) + await site.start() + # wait for finish signal + await runner.cleanup() + + +.. versionadded:: 3.0 + + :class:`AppRunner` / :class:`ServerRunner` and :class:`TCPSite` / + :class:`UnixSite` / :class:`SockSite` are added in aiohttp 3.0 + + +.. class:: BaseRunner + + A base class for runners. Use :class:`AppRunner` for serving + :class:`Application`, :class:`ServerRunner` for low-level + :class:`Server`. + + .. attribute:: server + + Low-level web :class:`Server` for handling HTTP requests, + read-only attribute. + + .. attribute:: addresses + + A :class:`list` of served sockets addresses. + + See :meth:`socket.getsockname` for items type. + + .. versionadded:: 3.3 + + .. attribute:: sites + + A read-only :class:`set` of served sites (:class:`TCPSite` / + :class:`UnixSite` / :class:`NamedPipeSite` / :class:`SockSite` instances). + + .. comethod:: setup() + + Initialize the server. Should be called before adding sites. + + .. comethod:: cleanup() + + Stop handling all registered sites and cleanup used resources. + + +.. class:: AppRunner(app, *, handle_signals=False, **kwargs) + + A runner for :class:`Application`. Used with conjunction with sites + to serve on specific port. + + Inherited from :class:`BaseRunner`. + + :param Application app: web application instance to serve. + + :param bool handle_signals: add signal handlers for + :data:`signal.SIGINT` and + :data:`signal.SIGTERM` (``False`` by + default). + + :param kwargs: named parameters to pass into + web protocol. + + Supported *kwargs*: + + :param bool tcp_keepalive: Enable TCP Keep-Alive. Default: ``True``. + :param int keepalive_timeout: Number of seconds before closing Keep-Alive + connection. Default: ``75`` seconds (NGINX's default value). + :param logger: Custom logger object. Default: + :data:`aiohttp.log.server_logger`. + :param access_log: Custom logging object. Default: + :data:`aiohttp.log.access_logger`. + :param access_log_class: Class for `access_logger`. Default: + :data:`aiohttp.helpers.AccessLogger`. + Must to be a subclass of :class:`aiohttp.abc.AbstractAccessLogger`. + :param str access_log_format: Access log format string. Default: + :attr:`helpers.AccessLogger.LOG_FORMAT`. + :param int max_line_size: Optional maximum header line size. Default: + ``8190``. + :param int max_headers: Optional maximum header size. Default: ``32768``. + :param int max_field_size: Optional maximum header field size. Default: + ``8190``. + + :param float lingering_time: Maximum time during which the server + reads and ignores additional data coming from the client when + lingering close is on. Use ``0`` to disable lingering on + server channel closing. + :param int read_bufsize: Size of the read buffer (:attr:`BaseRequest.content`). + ``None`` by default, + it means that the session global value is used. + + .. versionadded:: 3.7 + + + .. attribute:: app + + Read-only attribute for accessing to :class:`Application` served + instance. + + .. comethod:: setup() + + Initialize application. Should be called before adding sites. + + The method calls :attr:`Application.on_startup` registered signals. + + .. comethod:: cleanup() + + Stop handling all registered sites and cleanup used resources. + + :attr:`Application.on_shutdown` and + :attr:`Application.on_cleanup` signals are called internally. + + +.. class:: ServerRunner(web_server, *, handle_signals=False, **kwargs) + + A runner for low-level :class:`Server`. Used with conjunction with sites + to serve on specific port. + + Inherited from :class:`BaseRunner`. + + :param Server web_server: low-level web server instance to serve. + + :param bool handle_signals: add signal handlers for + :data:`signal.SIGINT` and + :data:`signal.SIGTERM` (``False`` by + default). + + :param kwargs: named parameters to pass into + web protocol. + + .. seealso:: + + :ref:`aiohttp-web-lowlevel` demonstrates low-level server usage + +.. class:: BaseSite + + An abstract class for handled sites. + + .. attribute:: name + + An identifier for site, read-only :class:`str` property. Could + be a handled URL or UNIX socket path. + + .. comethod:: start() + + Start handling a site. + + .. comethod:: stop() + + Stop handling a site. + + +.. class:: TCPSite(runner, host=None, port=None, *, \ + shutdown_timeout=60.0, ssl_context=None, \ + backlog=128, reuse_address=None, \ + reuse_port=None) + + Serve a runner on TCP socket. + + :param runner: a runner to serve. + + :param str host: HOST to listen on, all interfaces if ``None`` (default). + + :param int port: PORT to listed on, ``8080`` if ``None`` (default). + + :param float shutdown_timeout: a timeout for closing opened + connections on :meth:`BaseSite.stop` + call. + + :param ssl_context: a :class:`ssl.SSLContext` instance for serving + SSL/TLS secure server, ``None`` for plain HTTP + server (default). + + :param int backlog: a number of unaccepted connections that the + system will allow before refusing new + connections, see :meth:`socket.listen` for details. + + ``128`` by default. + + :param bool reuse_address: tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its + natural timeout to expire. If not specified + will automatically be set to True on UNIX. + + :param bool reuse_port: tells the kernel to allow this endpoint to be + bound to the same port as other existing + endpoints are bound to, so long as they all set + this flag when being created. This option is not + supported on Windows. + +.. class:: UnixSite(runner, path, *, \ + shutdown_timeout=60.0, ssl_context=None, \ + backlog=128) + + Serve a runner on UNIX socket. + + :param runner: a runner to serve. + + :param str path: PATH to UNIX socket to listen. + + :param float shutdown_timeout: a timeout for closing opened + connections on :meth:`BaseSite.stop` + call. + + :param ssl_context: a :class:`ssl.SSLContext` instance for serving + SSL/TLS secure server, ``None`` for plain HTTP + server (default). + + :param int backlog: a number of unaccepted connections that the + system will allow before refusing new + connections, see :meth:`socket.listen` for details. + + ``128`` by default. + +.. class:: NamedPipeSite(runner, path, *, shutdown_timeout=60.0) + + Serve a runner on Named Pipe in Windows. + + :param runner: a runner to serve. + + :param str path: PATH of named pipe to listen. + + :param float shutdown_timeout: a timeout for closing opened + connections on :meth:`BaseSite.stop` + call. + +.. class:: SockSite(runner, sock, *, \ + shutdown_timeout=60.0, ssl_context=None, \ + backlog=128) + + Serve a runner on UNIX socket. + + :param runner: a runner to serve. + + :param sock: :class:`socket.socket` to listen. + + :param float shutdown_timeout: a timeout for closing opened + connections on :meth:`BaseSite.stop` + call. + + :param ssl_context: a :class:`ssl.SSLContext` instance for serving + SSL/TLS secure server, ``None`` for plain HTTP + server (default). + + :param int backlog: a number of unaccepted connections that the + system will allow before refusing new + connections, see :meth:`socket.listen` for details. + + ``128`` by default. + Utilities --------- @@ -2023,10 +2824,14 @@ Utilities .. function:: run_app(app, *, host=None, port=None, path=None, \ - loop=None, shutdown_timeout=60.0, \ + sock=None, shutdown_timeout=60.0, \ ssl_context=None, print=print, backlog=128, \ - access_log_format=None, \ - access_log=aiohttp.log.access_logger) + access_log_class=aiohttp.helpers.AccessLogger, \ + access_log_format=aiohttp.helpers.AccessLogger.LOG_FORMAT, \ + access_log=aiohttp.log.access_logger, \ + handle_signals=True, \ + reuse_address=None, \ + reuse_port=None) A utility function for running an application, serving it until keyboard interrupt and performing a @@ -2036,8 +2841,6 @@ Utilities Perhaps production config will use more sophisticated runner but it good enough at least at very beginning stage. - The function uses *app.loop* as event loop to run. - The server will listen on any host or Unix domain socket path you supply. If no hosts or paths are supplied, or only a port is supplied, a TCP server listening on 0.0.0.0 (all hosts) will be launched. @@ -2047,7 +2850,8 @@ Utilities handled on the same event loop. See :doc:`deployment` for ways of distributing work for increased performance. - :param app: :class:`Application` instance to run + :param app: :class:`Application` instance to run or a *coroutine* + that returns an application. :param str host: TCP/IP host or a sequence of hosts for HTTP server. Default is ``'0.0.0.0'`` if *port* has been specified @@ -2079,12 +2883,17 @@ Utilities ``None`` for HTTP connection. :param print: a callable compatible with :func:`print`. May be used - to override STDOUT output or suppress it. + to override STDOUT output or suppress it. Passing `None` + disables output. :param int backlog: the number of unaccepted connections that the system will allow before refusing new connections (``128`` by default). + :param access_log_class: class for `access_logger`. Default: + :data:`aiohttp.helpers.AccessLogger`. + Must to be a subclass of :class:`aiohttp.abc.AbstractAccessLogger`. + :param access_log: :class:`logging.Logger` instance used for saving access logs. Use ``None`` for disabling logs for sake of speedup. @@ -2093,6 +2902,29 @@ Utilities :ref:`aiohttp-logging-access-log-format-spec` for details. + :param bool handle_signals: override signal TERM handling to gracefully + exit the application. + + :param bool reuse_address: tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its + natural timeout to expire. If not specified + will automatically be set to True on UNIX. + + :param bool reuse_port: tells the kernel to allow this endpoint to be + bound to the same port as other existing + endpoints are bound to, so long as they all set + this flag when being created. This option is not + supported on Windows. + + .. versionadded:: 3.0 + + Support *access_log_class* parameter. + + Support *reuse_address*, *reuse_port* parameter. + + .. versionadded:: 3.1 + + Accept a coroutine as *app* parameter. Constants --------- @@ -2109,7 +2941,7 @@ Constants *GZIP compression* -.. attribute:: identity + .. attribute:: identity *no compression* @@ -2120,26 +2952,42 @@ Middlewares Normalize path middleware ^^^^^^^^^^^^^^^^^^^^^^^^^ -.. function:: normalize_path_middleware(*, append_slash=True, merge_slashes=True) +.. function:: normalize_path_middleware(*, \ + append_slash=True, \ + remove_slash=False, \ + merge_slashes=True, \ + redirect_class=HTTPPermanentRedirect) + + Middleware factory which produces a middleware that normalizes + the path of a request. By normalizing it means: + + - Add or remove a trailing slash to the path. + - Double slashes are replaced by one. + + The middleware returns as soon as it finds a path that resolves + correctly. The order if both merge and append/remove are enabled is: + + 1. *merge_slashes* + 2. *append_slash* or *remove_slash* + 3. both *merge_slashes* and *append_slash* or *remove_slash* - Middleware that normalizes the path of a request. By normalizing - it means: + If the path resolves with at least one of those conditions, it will + redirect to the new path. - - Add a trailing slash to the path. - - Double slashes are replaced by one. + Only one of *append_slash* and *remove_slash* can be enabled. If both are + ``True`` the factory will raise an ``AssertionError`` - The middleware returns as soon as it finds a path that resolves - correctly. The order if all enabled is 1) merge_slashes, 2) append_slash - and 3) both merge_slashes and append_slash. If the path resolves with - at least one of those conditions, it will redirect to the new path. + If *append_slash* is ``True`` the middleware will append a slash when + needed. If a resource is defined with trailing slash and the request + comes without it, it will append it automatically. - If append_slash is True append slash when needed. If a resource is - defined with trailing slash and the request comes without it, it will - append it automatically. + If *remove_slash* is ``True``, *append_slash* must be ``False``. When enabled + the middleware will remove trailing slashes and redirect if the resource is + defined. - If merge_slashes is True, merge multiple consecutive slashes in the - path into one. + If *merge_slashes* is ``True``, merge multiple consecutive slashes in the + path into one. + .. versionadded:: 3.4 -.. disqus:: - :title: aiohttp server reference + Support for *remove_slash* diff --git a/docs/api.rst b/docs/websocket_utilities.rst similarity index 82% rename from docs/api.rst rename to docs/websocket_utilities.rst index 5aa054cd059..fca08e1ba13 100644 --- a/docs/api.rst +++ b/docs/websocket_utilities.rst @@ -1,16 +1,6 @@ -.. _aiohttp-api: - -Helpers API -=========== - -All public names from submodules ``client``, ``multipart``, -``protocol`` and ``utils`` are exported into -``aiohttp`` namespace. - WebSocket utilities -------------------- +=================== -.. module:: aiohttp .. currentmodule:: aiohttp .. class:: WSCloseCode @@ -66,7 +56,7 @@ WebSocket utilities An endpoint (client) is terminating the connection because it has expected the server to negotiate one or - more extension, but the server didn't return them in the response + more extension, but the server did not return them in the response message of the WebSocket handshake. The list of extensions that are needed should appear in the /reason/ part of the Close frame. Note that this status code is not used by the server, because it @@ -163,42 +153,4 @@ WebSocket utilities Returns parsed JSON data. - .. versionadded:: 0.22 - :param loads: optional JSON decoder function. - - .. attribute:: tp - - Deprecated alias for :attr:`type`. - - .. deprecated:: 1.0 - - -aiohttp.helpers module ----------------------- - -.. automodule:: aiohttp.helpers - :members: - :undoc-members: - :exclude-members: BasicAuth - :show-inheritance: - -aiohttp.multipart module ------------------------- - -.. automodule:: aiohttp.multipart - :members: - :undoc-members: - :show-inheritance: - -aiohttp.signals module ----------------------- - -.. automodule:: aiohttp.signals - :members: - :undoc-members: - :show-inheritance: - - -.. disqus:: - :title: aiohttp helpers api diff --git a/docs/whats_new_1_1.rst b/docs/whats_new_1_1.rst index 0e3d257cf9f..db71e10e8b1 100644 --- a/docs/whats_new_1_1.rst +++ b/docs/whats_new_1_1.rst @@ -34,7 +34,7 @@ Reverse URL processing for *router* has been changed. The main API is :class:`aiohttp.web.Request.url_for(name, **kwargs)` which returns a :class:`yarl.URL` instance for named resource. It -doesn't support *query args* but adding *args* is trivial: +does not support *query args* but adding *args* is trivial: ``request.url_for('named_resource', param='a').with_query(arg='val')``. The method returns a *relative* URL, absolute URL may be constructed by diff --git a/docs/whats_new_3_0.rst b/docs/whats_new_3_0.rst new file mode 100644 index 00000000000..7c4b5844de3 --- /dev/null +++ b/docs/whats_new_3_0.rst @@ -0,0 +1,82 @@ +.. _aiohttp_whats_new_3_0: + +========================= +What's new in aiohttp 3.0 +========================= + +async/await everywhere +====================== + +The main change is dropping ``yield from`` support and using +``async``/``await`` everywhere. Farewell, Python 3.4. + +The minimal supported Python version is **3.5.3** now. + +Why not *3.5.0*? Because *3.5.3* has a crucial change: +:func:`asyncio.get_event_loop()` returns the running loop instead of +*default*, which may be different, e.g.:: + + loop = asyncio.new_event_loop() + loop.run_until_complete(f()) + +Note, :func:`asyncio.set_event_loop` was not called and default loop +is not equal to actually executed one. + +Application Runners +=================== + +People constantly asked about ability to run aiohttp servers together +with other asyncio code, but :func:`aiohttp.web.run_app` is blocking +synchronous call. + +aiohttp had support for starting the application without ``run_app`` but the API +was very low-level and cumbersome. + +Now application runners solve the task in a few lines of code, see +:ref:`aiohttp-web-app-runners` for details. + +Client Tracing +============== + +Other long awaited feature is tracing client request life cycle to +figure out when and why client request spends a time waiting for +connection establishment, getting server response headers etc. + +Now it is possible by registering special signal handlers on every +request processing stage. :ref:`aiohttp-client-tracing` provides more +info about the feature. + +HTTPS support +============= + +Unfortunately asyncio has a bug with checking SSL certificates for +non-ASCII site DNS names, e.g. `https://историк.рф `_ or +`https://雜草工作室.香港 `_. + +The bug has been fixed in upcoming Python 3.7 only (the change +requires breaking backward compatibility in :mod:`ssl` API). + +aiohttp installs a fix for older Python versions (3.5 and 3.6). + + +Dropped obsolete API +==================== + +A switch to new major version is a great chance for dropping already +deprecated features. + +The release dropped a lot, see :ref:`aiohttp_changes` for details. + +All removals was already marked as deprecated or related to very low +level implementation details. + +If user code did not raise :exc:`DeprecationWarning` it is compatible +with aiohttp 3.0 most likely. + + +Summary +======= + +Enjoy aiohttp 3.0 release! + +The full change log is here: :ref:`aiohttp_changes`. diff --git a/examples/background_tasks.py b/examples/background_tasks.py index c2ae90d0950..2a1ec12afae 100755 --- a/examples/background_tasks.py +++ b/examples/background_tasks.py @@ -3,64 +3,64 @@ import asyncio import aioredis -from aiohttp.web import Application, WebSocketResponse, run_app + +from aiohttp import web async def websocket_handler(request): - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(request) - request.app['websockets'].append(ws) + request.app["websockets"].append(ws) try: async for msg in ws: print(msg) await asyncio.sleep(1) finally: - request.app['websockets'].remove(ws) + request.app["websockets"].remove(ws) return ws async def on_shutdown(app): - for ws in app['websockets']: - await ws.close(code=999, message='Server shutdown') + for ws in app["websockets"]: + await ws.close(code=999, message="Server shutdown") async def listen_to_redis(app): try: - sub = await aioredis.create_redis(('localhost', 6379), loop=app.loop) - ch, *_ = await sub.subscribe('news') - async for msg in ch.iter(encoding='utf-8'): + sub = await aioredis.create_redis(("localhost", 6379), loop=app.loop) + ch, *_ = await sub.subscribe("news") + async for msg in ch.iter(encoding="utf-8"): # Forward message to all connected websockets: - for ws in app['websockets']: - ws.send_str('{}: {}'.format(ch.name, msg)) - print("message in {}: {}".format(ch.name, msg)) + for ws in app["websockets"]: + await ws.send_str(f"{ch.name}: {msg}") + print(f"message in {ch.name}: {msg}") except asyncio.CancelledError: pass finally: - print('Cancel Redis listener: close connection...') + print("Cancel Redis listener: close connection...") await sub.unsubscribe(ch.name) await sub.quit() - print('Redis connection closed.') + print("Redis connection closed.") async def start_background_tasks(app): - app['redis_listener'] = app.loop.create_task(listen_to_redis(app)) + app["redis_listener"] = app.loop.create_task(listen_to_redis(app)) async def cleanup_background_tasks(app): - print('cleanup background tasks...') - app['redis_listener'].cancel() - await app['redis_listener'] + print("cleanup background tasks...") + app["redis_listener"].cancel() + await app["redis_listener"] -async def init(loop): - app = Application() - app['websockets'] = [] - app.router.add_get('/news', websocket_handler) +def init(): + app = web.Application() + app["websockets"] = [] + app.router.add_get("/news", websocket_handler) app.on_startup.append(start_background_tasks) app.on_cleanup.append(cleanup_background_tasks) app.on_shutdown.append(on_shutdown) return app -loop = asyncio.get_event_loop() -app = loop.run_until_complete(init(loop)) -run_app(app) + +web.run_app(init()) diff --git a/examples/basic_srv.py b/examples/basic_srv.py deleted file mode 100755 index d4f20da1ac7..00000000000 --- a/examples/basic_srv.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python3 -"""Basic HTTP server with minimal setup""" - -import asyncio -from urllib.parse import parse_qsl, urlparse - -import aiohttp -import aiohttp.server -from aiohttp import MultiDict - - -class HttpRequestHandler(aiohttp.server.ServerHttpProtocol): - - @asyncio.coroutine - def handle_request(self, message, payload): - response = aiohttp.Response( - self.writer, 200, http_version=message.version) - get_params = MultiDict(parse_qsl(urlparse(message.path).query)) - if message.method == 'POST': - post_params = yield from payload.read() - else: - post_params = None - content = "

    It Works!

    " - if get_params: - content += "

    Get params

    " + str(get_params) + "

    " - if post_params: - content += "

    Post params

    " + str(post_params) + "

    " - bcontent = content.encode('utf-8') - response.add_header('Content-Type', 'text/html; charset=UTF-8') - response.add_header('Content-Length', str(len(bcontent))) - response.send_headers() - response.write(bcontent) - yield from response.write_eof() - - -if __name__ == '__main__': - loop = asyncio.get_event_loop() - f = loop.create_server( - lambda: HttpRequestHandler(debug=True, keep_alive=75), - '0.0.0.0', 8080) - srv = loop.run_until_complete(f) - print('serving on', srv.sockets[0].getsockname()) - try: - loop.run_forever() - except KeyboardInterrupt: - pass diff --git a/examples/cli_app.py b/examples/cli_app.py index 1d3ea088615..9fbd3b76049 100755 --- a/examples/cli_app.py +++ b/examples/cli_app.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 """ Example of serving an Application using the `aiohttp.web` CLI. @@ -14,13 +15,13 @@ from argparse import ArgumentParser -from aiohttp.web import Application, Response +from aiohttp import web def display_message(req): args = req.app["args"] text = "\n".join([args.message] * args.repeat) - return Response(text=text) + return web.Response(text=text) def init(argv): @@ -29,27 +30,22 @@ def init(argv): ) # Positional argument - arg_parser.add_argument( - "message", - help="message to print" - ) + arg_parser.add_argument("message", help="message to print") # Optional argument arg_parser.add_argument( - "--repeat", - help="number of times to repeat message", type=int, default="1" + "--repeat", help="number of times to repeat message", type=int, default="1" ) # Avoid conflict with -h from `aiohttp.web` CLI parser arg_parser.add_argument( - "--app-help", - help="show this message and exit", action="help" + "--app-help", help="show this message and exit", action="help" ) args = arg_parser.parse_args(argv) - app = Application() + app = web.Application() app["args"] = args - app.router.add_get('/', display_message) + app.router.add_get("/", display_message) return app diff --git a/examples/client_auth.py b/examples/client_auth.py index 4e0b7341601..6513de20e5c 100755 --- a/examples/client_auth.py +++ b/examples/client_auth.py @@ -1,12 +1,12 @@ +#!/usr/bin/env python3 import asyncio import aiohttp async def fetch(session): - print('Query http://httpbin.org/basic-auth/andrew/password') - async with session.get( - 'http://httpbin.org/basic-auth/andrew/password') as resp: + print("Query http://httpbin.org/basic-auth/andrew/password") + async with session.get("http://httpbin.org/basic-auth/andrew/password") as resp: print(resp.status) body = await resp.text() print(body) @@ -14,8 +14,8 @@ async def fetch(session): async def go(loop): async with aiohttp.ClientSession( - auth=aiohttp.BasicAuth('andrew', 'password'), - loop=loop) as session: + auth=aiohttp.BasicAuth("andrew", "password"), loop=loop + ) as session: await fetch(session) diff --git a/examples/client_json.py b/examples/client_json.py index db6d9982b01..e54edeaddb6 100755 --- a/examples/client_json.py +++ b/examples/client_json.py @@ -1,12 +1,12 @@ +#!/usr/bin/env python3 import asyncio import aiohttp async def fetch(session): - print('Query http://httpbin.org/get') - async with session.get( - 'http://httpbin.org/get') as resp: + print("Query http://httpbin.org/get") + async with session.get("http://httpbin.org/get") as resp: print(resp.status) data = await resp.json() print(data) diff --git a/examples/client_ws.py b/examples/client_ws.py index d054dec3482..ec48eccc9ad 100755 --- a/examples/client_ws.py +++ b/examples/client_ws.py @@ -7,73 +7,67 @@ import aiohttp -try: - import selectors -except ImportError: - from asyncio import selectors - -def start_client(loop, url): - name = input('Please enter your name: ') - - # send request - ws = yield from aiohttp.ws_connect(url, autoclose=False, autoping=False) +async def start_client(loop, url): + name = input("Please enter your name: ") # input reader def stdin_callback(): - line = sys.stdin.buffer.readline().decode('utf-8') + line = sys.stdin.buffer.readline().decode("utf-8") if not line: loop.stop() else: - ws.send_str(name + ': ' + line) + ws.send_str(name + ": " + line) + loop.add_reader(sys.stdin.fileno(), stdin_callback) - @asyncio.coroutine - def dispatch(): + async def dispatch(): while True: - msg = yield from ws.receive() + msg = await ws.receive() if msg.type == aiohttp.WSMsgType.TEXT: - print('Text: ', msg.data.strip()) + print("Text: ", msg.data.strip()) elif msg.type == aiohttp.WSMsgType.BINARY: - print('Binary: ', msg.data) + print("Binary: ", msg.data) elif msg.type == aiohttp.WSMsgType.PING: ws.pong() elif msg.type == aiohttp.WSMsgType.PONG: - print('Pong received') + print("Pong received") else: if msg.type == aiohttp.WSMsgType.CLOSE: - yield from ws.close() + await ws.close() elif msg.type == aiohttp.WSMsgType.ERROR: - print('Error during receive %s' % ws.exception()) + print("Error during receive %s" % ws.exception()) elif msg.type == aiohttp.WSMsgType.CLOSED: pass break - yield from dispatch() + # send request + async with aiohttp.ws_connect(url, autoclose=False, autoping=False) as ws: + await dispatch() ARGS = argparse.ArgumentParser( - description="websocket console client for wssrv.py example.") + description="websocket console client for wssrv.py example." +) ARGS.add_argument( - '--host', action="store", dest='host', - default='127.0.0.1', help='Host name') + "--host", action="store", dest="host", default="127.0.0.1", help="Host name" +) ARGS.add_argument( - '--port', action="store", dest='port', - default=8080, type=int, help='Port number') + "--port", action="store", dest="port", default=8080, type=int, help="Port number" +) -if __name__ == '__main__': +if __name__ == "__main__": args = ARGS.parse_args() - if ':' in args.host: - args.host, port = args.host.split(':', 1) + if ":" in args.host: + args.host, port = args.host.split(":", 1) args.port = int(port) - url = 'http://{}:{}'.format(args.host, args.port) + url = f"http://{args.host}:{args.port}" - loop = asyncio.SelectorEventLoop(selectors.SelectSelector()) - asyncio.set_event_loop(loop) + loop = asyncio.get_event_loop() loop.add_signal_handler(signal.SIGINT, loop.stop) - asyncio.Task(start_client(loop, url)) + loop.create_task(start_client(loop, url)) loop.run_forever() diff --git a/examples/curl.py b/examples/curl.py index 9f3d51c364e..a39639af34e 100755 --- a/examples/curl.py +++ b/examples/curl.py @@ -6,28 +6,28 @@ import aiohttp -def curl(url): - session = aiohttp.ClientSession() - response = yield from session.request('GET', url) - print(repr(response)) +async def curl(url): + async with aiohttp.ClientSession() as session: + async with session.request("GET", url) as response: + print(repr(response)) + chunk = await response.content.read() + print("Downloaded: %s" % len(chunk)) - chunk = yield from response.content.read() - print('Downloaded: %s' % len(chunk)) - response.close() - yield from session.close() - - -if __name__ == '__main__': +if __name__ == "__main__": ARGS = argparse.ArgumentParser(description="GET url example") - ARGS.add_argument('url', nargs=1, metavar='URL', - help="URL to download") - ARGS.add_argument('--iocp', default=False, action="store_true", - help="Use ProactorEventLoop on Windows") + ARGS.add_argument("url", nargs=1, metavar="URL", help="URL to download") + ARGS.add_argument( + "--iocp", + default=False, + action="store_true", + help="Use ProactorEventLoop on Windows", + ) options = ARGS.parse_args() if options.iocp: from asyncio import events, windows_events + el = windows_events.ProactorEventLoop() events.set_event_loop(el) diff --git a/examples/fake_server.py b/examples/fake_server.py index bde9c9646a5..007d96ba027 100755 --- a/examples/fake_server.py +++ b/examples/fake_server.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 import asyncio import pathlib import socket @@ -9,50 +10,8 @@ from aiohttp.test_utils import unused_port -def http_method(method, path): - def wrapper(func): - func.__method__ = method - func.__path__ = path - return func - return wrapper - - -def head(path): - return http_method('HEAD', path) - - -def get(path): - return http_method('GET', path) - - -def delete(path): - return http_method('DELETE', path) - - -def options(path): - return http_method('OPTIONS', path) - - -def patch(path): - return http_method('PATCH', path) - - -def post(path): - return http_method('POST', path) - - -def put(path): - return http_method('PUT', path) - - -def trace(path): - return http_method('TRACE', path) - - class FakeResolver: - _LOCAL_HOST = {0: '127.0.0.1', - socket.AF_INET: '127.0.0.1', - socket.AF_INET6: '::1'} + _LOCAL_HOST = {0: "127.0.0.1", socket.AF_INET: "127.0.0.1", socket.AF_INET6: "::1"} def __init__(self, fakes, *, loop): """fakes -- dns -> port dict""" @@ -62,83 +21,72 @@ def __init__(self, fakes, *, loop): async def resolve(self, host, port=0, family=socket.AF_INET): fake_port = self._fakes.get(host) if fake_port is not None: - return [{'hostname': host, - 'host': self._LOCAL_HOST[family], 'port': fake_port, - 'family': family, 'proto': 0, - 'flags': socket.AI_NUMERICHOST}] + return [ + { + "hostname": host, + "host": self._LOCAL_HOST[family], + "port": fake_port, + "family": family, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + ] else: return await self._resolver.resolve(host, port, family) class FakeFacebook: - def __init__(self, *, loop): self.loop = loop self.app = web.Application(loop=loop) - for name in dir(self.__class__): - func = getattr(self.__class__, name) - if hasattr(func, '__method__'): - self.app.router.add_route(func.__method__, - func.__path__, - getattr(self, name)) - self.handler = None - self.server = None + self.app.router.add_routes( + [ + web.get("/v2.7/me", self.on_me), + web.get("/v2.7/me/friends", self.on_my_friends), + ] + ) + self.runner = None here = pathlib.Path(__file__) - ssl_cert = here.parent / 'server.crt' - ssl_key = here.parent / 'server.key' + ssl_cert = here.parent / "server.crt" + ssl_key = here.parent / "server.key" self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) self.ssl_context.load_cert_chain(str(ssl_cert), str(ssl_key)) async def start(self): port = unused_port() - self.handler = self.app.make_handler() - self.server = await self.loop.create_server(self.handler, - '127.0.0.1', port, - ssl=self.ssl_context) - return {'graph.facebook.com': port} + self.runner = web.AppRunner(self.app) + await self.runner.setup() + site = web.TCPSite(self.runner, "127.0.0.1", port, ssl_context=self.ssl_context) + await site.start() + return {"graph.facebook.com": port} async def stop(self): - self.server.close() - await self.server.wait_closed() - await self.app.shutdown() - await self.handler.shutdown() - await self.app.cleanup() + await self.runner.cleanup() - @get('/v2.7/me') async def on_me(self, request): - return web.json_response({ - "name": "John Doe", - "id": "12345678901234567" - }) + return web.json_response({"name": "John Doe", "id": "12345678901234567"}) - @get('/v2.7/me/friends') async def on_my_friends(self, request): - return web.json_response({ - "data": [ - { - "name": "Bill Doe", - "id": "233242342342" - }, - { - "name": "Mary Doe", - "id": "2342342343222" - }, - { - "name": "Alex Smith", - "id": "234234234344" - }, - ], - "paging": { - "cursors": { - "before": "QVFIUjRtc2c5NEl0ajN", - "after": "QVFIUlpFQWM0TmVuaDRad0dt", + return web.json_response( + { + "data": [ + {"name": "Bill Doe", "id": "233242342342"}, + {"name": "Mary Doe", "id": "2342342343222"}, + {"name": "Alex Smith", "id": "234234234344"}, + ], + "paging": { + "cursors": { + "before": "QVFIUjRtc2c5NEl0ajN", + "after": "QVFIUlpFQWM0TmVuaDRad0dt", + }, + "next": ( + "https://graph.facebook.com/v2.7/12345678901234567/" + "friends?access_token=EAACEdEose0cB" + ), }, - "next": ("https://graph.facebook.com/v2.7/12345678901234567/" - "friends?access_token=EAACEdEose0cB") - }, - "summary": { - "total_count": 3 - }}) + "summary": {"total_count": 3}, + } + ) async def main(loop): @@ -147,17 +95,17 @@ async def main(loop): fake_facebook = FakeFacebook(loop=loop) info = await fake_facebook.start() resolver = FakeResolver(info, loop=loop) - connector = aiohttp.TCPConnector(loop=loop, resolver=resolver, - verify_ssl=False) + connector = aiohttp.TCPConnector(loop=loop, resolver=resolver, verify_ssl=False) - async with aiohttp.ClientSession(connector=connector, - loop=loop) as session: - async with session.get('https://graph.facebook.com/v2.7/me', - params={'access_token': token}) as resp: + async with aiohttp.ClientSession(connector=connector, loop=loop) as session: + async with session.get( + "https://graph.facebook.com/v2.7/me", params={"access_token": token} + ) as resp: print(await resp.json()) - async with session.get('https://graph.facebook.com/v2.7/me/friends', - params={'access_token': token}) as resp: + async with session.get( + "https://graph.facebook.com/v2.7/me/friends", params={"access_token": token} + ) as resp: print(await resp.json()) await fake_facebook.stop() diff --git a/examples/legacy/crawl.py b/examples/legacy/crawl.py index 90600025a4a..c8029b48545 100755 --- a/examples/legacy/crawl.py +++ b/examples/legacy/crawl.py @@ -11,7 +11,6 @@ class Crawler: - def __init__(self, rooturl, loop, maxtasks=100): self.rooturl = rooturl self.loop = loop @@ -24,49 +23,46 @@ def __init__(self, rooturl, loop, maxtasks=100): # connector stores cookies between requests and uses connection pool self.session = aiohttp.ClientSession(loop=loop) - @asyncio.coroutine - def run(self): - t = asyncio.ensure_future(self.addurls([(self.rooturl, '')]), - loop=self.loop) - yield from asyncio.sleep(1, loop=self.loop) + async def run(self): + t = asyncio.ensure_future(self.addurls([(self.rooturl, "")]), loop=self.loop) + await asyncio.sleep(1, loop=self.loop) while self.busy: - yield from asyncio.sleep(1, loop=self.loop) + await asyncio.sleep(1, loop=self.loop) - yield from t - yield from self.session.close() + await t + await self.session.close() self.loop.stop() - @asyncio.coroutine - def addurls(self, urls): + async def addurls(self, urls): for url, parenturl in urls: url = urllib.parse.urljoin(parenturl, url) url, frag = urllib.parse.urldefrag(url) - if (url.startswith(self.rooturl) and - url not in self.busy and - url not in self.done and - url not in self.todo): + if ( + url.startswith(self.rooturl) + and url not in self.busy + and url not in self.done + and url not in self.todo + ): self.todo.add(url) - yield from self.sem.acquire() + await self.sem.acquire() task = asyncio.ensure_future(self.process(url), loop=self.loop) task.add_done_callback(lambda t: self.sem.release()) task.add_done_callback(self.tasks.remove) self.tasks.add(task) - @asyncio.coroutine - def process(self, url): - print('processing:', url) + async def process(self, url): + print("processing:", url) self.todo.remove(url) self.busy.add(url) try: - resp = yield from self.session.get(url) + resp = await self.session.get(url) except Exception as exc: - print('...', url, 'has error', repr(str(exc))) + print("...", url, "has error", repr(str(exc))) self.done[url] = False else: - if (resp.status == 200 and - ('text/html' in resp.headers.get('content-type'))): - data = (yield from resp.read()).decode('utf-8', 'replace') + if resp.status == 200 and ("text/html" in resp.headers.get("content-type")): + data = (await resp.read()).decode("utf-8", "replace") urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', data) asyncio.Task(self.addurls([(u, url) for u in urls])) @@ -74,8 +70,13 @@ def process(self, url): self.done[url] = True self.busy.remove(url) - print(len(self.done), 'completed tasks,', len(self.tasks), - 'still pending, todo', len(self.todo)) + print( + len(self.done), + "completed tasks,", + len(self.tasks), + "still pending, todo", + len(self.todo), + ) def main(): @@ -89,17 +90,18 @@ def main(): except RuntimeError: pass loop.run_forever() - print('todo:', len(c.todo)) - print('busy:', len(c.busy)) - print('done:', len(c.done), '; ok:', sum(c.done.values())) - print('tasks:', len(c.tasks)) + print("todo:", len(c.todo)) + print("busy:", len(c.busy)) + print("done:", len(c.done), "; ok:", sum(c.done.values())) + print("tasks:", len(c.tasks)) -if __name__ == '__main__': - if '--iocp' in sys.argv: +if __name__ == "__main__": + if "--iocp" in sys.argv: from asyncio import events, windows_events - sys.argv.remove('--iocp') - logging.info('using iocp') + + sys.argv.remove("--iocp") + logging.info("using iocp") el = windows_events.ProactorEventLoop() events.set_event_loop(el) diff --git a/examples/legacy/srv.py b/examples/legacy/srv.py index 0941273db6d..628b6f332f1 100755 --- a/examples/legacy/srv.py +++ b/examples/legacy/srv.py @@ -17,21 +17,22 @@ class HttpRequestHandler(aiohttp.server.ServerHttpProtocol): - - @asyncio.coroutine - def handle_request(self, message, payload): - print('method = {!r}; path = {!r}; version = {!r}'.format( - message.method, message.path, message.version)) + async def handle_request(self, message, payload): + print( + "method = {!r}; path = {!r}; version = {!r}".format( + message.method, message.path, message.version + ) + ) path = message.path - if (not (path.isprintable() and path.startswith('/')) or '/.' in path): - print('bad path', repr(path)) + if not (path.isprintable() and path.startswith("/")) or "/." in path: + print("bad path", repr(path)) path = None else: - path = '.' + path + path = "." + path if not os.path.exists(path): - print('no file', repr(path)) + print("no file", repr(path)) path = None else: isdir = os.path.isdir(path) @@ -42,106 +43,115 @@ def handle_request(self, message, payload): for hdr, val in message.headers.items(): print(hdr, val) - if isdir and not path.endswith('/'): - path = path + '/' + if isdir and not path.endswith("/"): + path = path + "/" raise aiohttp.HttpProcessingError( - code=302, headers=(('URI', path), ('Location', path))) + code=302, headers=(("URI", path), ("Location", path)) + ) - response = aiohttp.Response( - self.writer, 200, http_version=message.version) - response.add_header('Transfer-Encoding', 'chunked') + response = aiohttp.Response(self.writer, 200, http_version=message.version) + response.add_header("Transfer-Encoding", "chunked") # content encoding - accept_encoding = message.headers.get('accept-encoding', '').lower() - if 'deflate' in accept_encoding: - response.add_header('Content-Encoding', 'deflate') - response.add_compression_filter('deflate') - elif 'gzip' in accept_encoding: - response.add_header('Content-Encoding', 'gzip') - response.add_compression_filter('gzip') + accept_encoding = message.headers.get("accept-encoding", "").lower() + if "deflate" in accept_encoding: + response.add_header("Content-Encoding", "deflate") + response.add_compression_filter("deflate") + elif "gzip" in accept_encoding: + response.add_header("Content-Encoding", "gzip") + response.add_compression_filter("gzip") response.add_chunking_filter(1025) if isdir: - response.add_header('Content-type', 'text/html') + response.add_header("Content-type", "text/html") response.send_headers() - response.write(b'
      \r\n') + response.write(b"
        \r\n") for name in sorted(os.listdir(path)): - if name.isprintable() and not name.startswith('.'): + if name.isprintable() and not name.startswith("."): try: - bname = name.encode('ascii') + bname = name.encode("ascii") except UnicodeError: pass else: if os.path.isdir(os.path.join(path, name)): - response.write(b'
      • ' + bname + b'/
      • \r\n') + response.write( + b'
      • ' + + bname + + b"/
      • \r\n" + ) else: - response.write(b'
      • ' + bname + b'
      • \r\n') - response.write(b'
      ') + response.write( + b'
    • ' + + bname + + b"
    • \r\n" + ) + response.write(b"
    ") else: - response.add_header('Content-type', 'text/plain') + response.add_header("Content-type", "text/plain") response.send_headers() try: - with open(path, 'rb') as fp: + with open(path, "rb") as fp: chunk = fp.read(8192) while chunk: response.write(chunk) chunk = fp.read(8192) except OSError: - response.write(b'Cannot open') + response.write(b"Cannot open") - yield from response.write_eof() + await response.write_eof() if response.keep_alive(): self.keep_alive(True) ARGS = argparse.ArgumentParser(description="Run simple HTTP server.") ARGS.add_argument( - '--host', action="store", dest='host', - default='127.0.0.1', help='Host name') + "--host", action="store", dest="host", default="127.0.0.1", help="Host name" +) ARGS.add_argument( - '--port', action="store", dest='port', - default=8080, type=int, help='Port number') + "--port", action="store", dest="port", default=8080, type=int, help="Port number" +) # make iocp and ssl mutually exclusive because ProactorEventLoop is # incompatible with SSL group = ARGS.add_mutually_exclusive_group() group.add_argument( - '--iocp', action="store_true", dest='iocp', help='Windows IOCP event loop') -group.add_argument( - '--ssl', action="store_true", dest='ssl', help='Run ssl mode.') -ARGS.add_argument( - '--sslcert', action="store", dest='certfile', help='SSL cert file.') -ARGS.add_argument( - '--sslkey', action="store", dest='keyfile', help='SSL key file.') + "--iocp", action="store_true", dest="iocp", help="Windows IOCP event loop" +) +group.add_argument("--ssl", action="store_true", dest="ssl", help="Run ssl mode.") +ARGS.add_argument("--sslcert", action="store", dest="certfile", help="SSL cert file.") +ARGS.add_argument("--sslkey", action="store", dest="keyfile", help="SSL key file.") def main(): args = ARGS.parse_args() - if ':' in args.host: - args.host, port = args.host.split(':', 1) + if ":" in args.host: + args.host, port = args.host.split(":", 1) args.port = int(port) if args.iocp: from asyncio import windows_events - sys.argv.remove('--iocp') - logging.info('using iocp') + + sys.argv.remove("--iocp") + logging.info("using iocp") el = windows_events.ProactorEventLoop() asyncio.set_event_loop(el) if args.ssl: - here = os.path.join(os.path.dirname(__file__), 'tests') + here = os.path.join(os.path.dirname(__file__), "tests") if args.certfile: - certfile = args.certfile or os.path.join(here, 'sample.crt') - keyfile = args.keyfile or os.path.join(here, 'sample.key') + certfile = args.certfile or os.path.join(here, "sample.crt") + keyfile = args.keyfile or os.path.join(here, "sample.key") else: - certfile = os.path.join(here, 'sample.crt') - keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, "sample.crt") + keyfile = os.path.join(here, "sample.key") sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.load_cert_chain(certfile, keyfile) @@ -151,16 +161,18 @@ def main(): loop = asyncio.get_event_loop() f = loop.create_server( lambda: HttpRequestHandler(debug=True, keep_alive=75), - args.host, args.port, - ssl=sslcontext) + args.host, + args.port, + ssl=sslcontext, + ) svr = loop.run_until_complete(f) socks = svr.sockets - print('serving on', socks[0].getsockname()) + print("serving on", socks[0].getsockname()) try: loop.run_forever() except KeyboardInterrupt: pass -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/legacy/tcp_protocol_parser.py b/examples/legacy/tcp_protocol_parser.py index a9e2df096dd..ca49db7d8f9 100755 --- a/examples/legacy/tcp_protocol_parser.py +++ b/examples/legacy/tcp_protocol_parser.py @@ -12,12 +12,12 @@ signal = None -MSG_TEXT = b'text:' -MSG_PING = b'ping:' -MSG_PONG = b'pong:' -MSG_STOP = b'stop:' +MSG_TEXT = b"text:" +MSG_PING = b"ping:" +MSG_PONG = b"pong:" +MSG_STOP = b"stop:" -Message = collections.namedtuple('Message', ('tp', 'data')) +Message = collections.namedtuple("Message", ("tp", "data")) def my_protocol_parser(out, buf): @@ -34,41 +34,38 @@ def my_protocol_parser(out, buf): tp = yield from buf.read(5) if tp in (MSG_PING, MSG_PONG): # skip line - yield from buf.skipuntil(b'\r\n') + yield from buf.skipuntil(b"\r\n") out.feed_data(Message(tp, None)) elif tp == MSG_STOP: out.feed_data(Message(tp, None)) elif tp == MSG_TEXT: # read text - text = yield from buf.readuntil(b'\r\n') - out.feed_data(Message(tp, text.strip().decode('utf-8'))) + text = yield from buf.readuntil(b"\r\n") + out.feed_data(Message(tp, text.strip().decode("utf-8"))) else: - raise ValueError('Unknown protocol prefix.') + raise ValueError("Unknown protocol prefix.") class MyProtocolWriter: - def __init__(self, transport): self.transport = transport def ping(self): - self.transport.write(b'ping:\r\n') + self.transport.write(b"ping:\r\n") def pong(self): - self.transport.write(b'pong:\r\n') + self.transport.write(b"pong:\r\n") def stop(self): - self.transport.write(b'stop:\r\n') + self.transport.write(b"stop:\r\n") def send_text(self, text): - self.transport.write( - 'text:{}\r\n'.format(text.strip()).encode('utf-8')) + self.transport.write(f"text:{text.strip()}\r\n".encode("utf-8")) class EchoServer(asyncio.Protocol): - def connection_made(self, transport): - print('Connection made') + print("Connection made") self.transport = transport self.stream = aiohttp.StreamParser() asyncio.Task(self.dispatch()) @@ -80,55 +77,52 @@ def eof_received(self): self.stream.feed_eof() def connection_lost(self, exc): - print('Connection lost') + print("Connection lost") - @asyncio.coroutine - def dispatch(self): + async def dispatch(self): reader = self.stream.set_parser(my_protocol_parser) writer = MyProtocolWriter(self.transport) while True: try: - msg = yield from reader.read() + msg = await reader.read() except aiohttp.ConnectionError: # client has been disconnected break - print('Message received: {}'.format(msg)) + print(f"Message received: {msg}") if msg.type == MSG_PING: writer.pong() elif msg.type == MSG_TEXT: - writer.send_text('Re: ' + msg.data) + writer.send_text("Re: " + msg.data) elif msg.type == MSG_STOP: self.transport.close() break -@asyncio.coroutine -def start_client(loop, host, port): - transport, stream = yield from loop.create_connection( - aiohttp.StreamProtocol, host, port) +async def start_client(loop, host, port): + transport, stream = await loop.create_connection(aiohttp.StreamProtocol, host, port) reader = stream.reader.set_parser(my_protocol_parser) writer = MyProtocolWriter(transport) writer.ping() - message = 'This is the message. It will be echoed.' + message = "This is the message. It will be echoed." while True: try: - msg = yield from reader.read() + msg = await reader.read() except aiohttp.ConnectionError: - print('Server has been disconnected.') + print("Server has been disconnected.") break - print('Message received: {}'.format(msg)) + print(f"Message received: {msg}") if msg.type == MSG_PONG: writer.send_text(message) - print('data sent:', message) + print("data sent:", message) elif msg.type == MSG_TEXT: writer.stop() - print('stop sent') + print("stop sent") break transport.close() @@ -138,34 +132,34 @@ def start_server(loop, host, port): f = loop.create_server(EchoServer, host, port) srv = loop.run_until_complete(f) x = srv.sockets[0] - print('serving on', x.getsockname()) + print("serving on", x.getsockname()) loop.run_forever() ARGS = argparse.ArgumentParser(description="Protocol parser example.") ARGS.add_argument( - '--server', action="store_true", dest='server', - default=False, help='Run tcp server') + "--server", action="store_true", dest="server", default=False, help="Run tcp server" +) ARGS.add_argument( - '--client', action="store_true", dest='client', - default=False, help='Run tcp client') + "--client", action="store_true", dest="client", default=False, help="Run tcp client" +) ARGS.add_argument( - '--host', action="store", dest='host', - default='127.0.0.1', help='Host name') + "--host", action="store", dest="host", default="127.0.0.1", help="Host name" +) ARGS.add_argument( - '--port', action="store", dest='port', - default=9999, type=int, help='Port number') + "--port", action="store", dest="port", default=9999, type=int, help="Port number" +) -if __name__ == '__main__': +if __name__ == "__main__": args = ARGS.parse_args() - if ':' in args.host: - args.host, port = args.host.split(':', 1) + if ":" in args.host: + args.host, port = args.host.split(":", 1) args.port = int(port) if (not (args.server or args.client)) or (args.server and args.client): - print('Please specify --server or --client\n') + print("Please specify --server or --client\n") ARGS.print_help() else: loop = asyncio.get_event_loop() diff --git a/examples/lowlevel_srv.py b/examples/lowlevel_srv.py index 6699e08b2c7..5a003f40f42 100644 --- a/examples/lowlevel_srv.py +++ b/examples/lowlevel_srv.py @@ -1,4 +1,5 @@ import asyncio + from aiohttp import web @@ -13,7 +14,7 @@ async def main(loop): # pause here for very long time by serving HTTP requests and # waiting for keyboard interruption - await asyncio.sleep(100*3600) + await asyncio.sleep(100 * 3600) loop = asyncio.get_event_loop() diff --git a/examples/server_simple.py b/examples/server_simple.py new file mode 100644 index 00000000000..d464383d269 --- /dev/null +++ b/examples/server_simple.py @@ -0,0 +1,31 @@ +# server_simple.py +from aiohttp import web + + +async def handle(request): + name = request.match_info.get("name", "Anonymous") + text = "Hello, " + name + return web.Response(text=text) + + +async def wshandle(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + if msg.type == web.WSMsgType.text: + await ws.send_str(f"Hello, {msg.data}") + elif msg.type == web.WSMsgType.binary: + await ws.send_bytes(msg.data) + elif msg.type == web.WSMsgType.close: + break + + return ws + + +app = web.Application() +app.add_routes( + [web.get("/", handle), web.get("/echo", wshandle), web.get("/{name}", handle)] +) + +web.run_app(app) diff --git a/examples/static_files.py b/examples/static_files.py index 426242a8514..65f6bb9c764 100755 --- a/examples/static_files.py +++ b/examples/static_files.py @@ -1,8 +1,9 @@ +#!/usr/bin/env python3 import pathlib from aiohttp import web app = web.Application() -app.router.add_static('/', pathlib.Path(__file__).parent, show_index=True) +app.router.add_static("/", pathlib.Path(__file__).parent, show_index=True) web.run_app(app) diff --git a/examples/web_classview.py b/examples/web_classview.py new file mode 100755 index 00000000000..0f65f7d7f43 --- /dev/null +++ b/examples/web_classview.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +"""Example for aiohttp.web class based views +""" + + +import functools +import json + +from aiohttp import web + + +class MyView(web.View): + async def get(self): + return web.json_response( + { + "method": "get", + "args": dict(self.request.GET), + "headers": dict(self.request.headers), + }, + dumps=functools.partial(json.dumps, indent=4), + ) + + async def post(self): + data = await self.request.post() + return web.json_response( + { + "method": "post", + "args": dict(self.request.GET), + "data": dict(data), + "headers": dict(self.request.headers), + }, + dumps=functools.partial(json.dumps, indent=4), + ) + + +async def index(request): + txt = """ + + + Class based view example + + +

    Class based view example

    +
      +
    • / This page +
    • /get Returns GET data. +
    • /post Returns POST data. +
    + + + """ + return web.Response(text=txt, content_type="text/html") + + +def init(): + app = web.Application() + app.router.add_get("/", index) + app.router.add_get("/get", MyView) + app.router.add_post("/post", MyView) + return app + + +web.run_app(init()) diff --git a/examples/web_classview1.py b/examples/web_classview1.py deleted file mode 100755 index 2dd9f09380f..00000000000 --- a/examples/web_classview1.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python3 -"""Example for aiohttp.web class based views -""" - - -import asyncio -import functools -import json - -from aiohttp.web import Application, Response, View, json_response, run_app - - -class MyView(View): - - async def get(self): - return json_response({ - 'method': 'get', - 'args': dict(self.request.GET), - 'headers': dict(self.request.headers), - }, dumps=functools.partial(json.dumps, indent=4)) - - async def post(self): - data = await self.request.post() - return json_response({ - 'method': 'post', - 'args': dict(self.request.GET), - 'data': dict(data), - 'headers': dict(self.request.headers), - }, dumps=functools.partial(json.dumps, indent=4)) - - -async def index(request): - txt = """ - - - Class based view example - - -

    Class based view example

    -
      -
    • / This page -
    • /get Returns GET data. -
    • /post Returns POST data. -
    - - - """ - return Response(text=txt, content_type='text/html') - - -async def init(loop): - app = Application(loop=loop) - app.router.add_get('/', index) - app.router.add_get('/get', MyView) - app.router.add_post('/post', MyView) - return app - - -loop = asyncio.get_event_loop() -app = loop.run_until_complete(init(loop)) -run_app(app) diff --git a/examples/web_cookies.py b/examples/web_cookies.py index adb1fb92021..e7a4a595d77 100755 --- a/examples/web_cookies.py +++ b/examples/web_cookies.py @@ -2,52 +2,44 @@ """Example for aiohttp.web basic server with cookies. """ -import asyncio from pprint import pformat from aiohttp import web - -tmpl = '''\ +tmpl = """\ Login
    Logout
    {}
    -''' +""" -@asyncio.coroutine -def root(request): - resp = web.Response(content_type='text/html') +async def root(request): + resp = web.Response(content_type="text/html") resp.text = tmpl.format(pformat(request.cookies)) return resp -@asyncio.coroutine -def login(request): - resp = web.HTTPFound(location='/') - resp.set_cookie('AUTH', 'secret') +async def login(request): + resp = web.HTTPFound(location="/") + resp.set_cookie("AUTH", "secret") return resp -@asyncio.coroutine -def logout(request): - resp = web.HTTPFound(location='/') - resp.del_cookie('AUTH') +async def logout(request): + resp = web.HTTPFound(location="/") + resp.del_cookie("AUTH") return resp -@asyncio.coroutine def init(loop): app = web.Application(loop=loop) - app.router.add_get('/', root) - app.router.add_get('/login', login) - app.router.add_get('/logout', logout) + app.router.add_get("/", root) + app.router.add_get("/login", login) + app.router.add_get("/logout", logout) return app -loop = asyncio.get_event_loop() -app = loop.run_until_complete(init(loop)) -web.run_app(app) +web.run_app(init()) diff --git a/examples/web_rewrite_headers_middleware.py b/examples/web_rewrite_headers_middleware.py index 4e8c64493c0..20799a3a7c2 100755 --- a/examples/web_rewrite_headers_middleware.py +++ b/examples/web_rewrite_headers_middleware.py @@ -3,38 +3,28 @@ Example for rewriting response headers by middleware. """ -import asyncio +from aiohttp import web -from aiohttp.web import Application, HTTPException, Response, run_app +async def handler(request): + return web.Response(text="Everything is fine") -@asyncio.coroutine -def handler(request): - return Response(text="Everything is fine") +@web.middleware +async def middleware(request, handler): + try: + response = await handler(request) + except web.HTTPException as exc: + raise exc + if not response.prepared: + response.headers["SERVER"] = "Secured Server Software" + return response -@asyncio.coroutine -def middleware_factory(app, next_handler): - @asyncio.coroutine - def middleware(request): - try: - response = yield from next_handler(request) - except HTTPException as exc: - response = exc - if not response.prepared: - response.headers['SERVER'] = "Secured Server Software" - return response - - return middleware - - -def init(loop): - app = Application(loop=loop, middlewares=[middleware_factory]) - app.router.add_get('/', handler) +def init(): + app = web.Application(middlewares=[middleware]) + app.router.add_get("/", handler) return app -loop = asyncio.get_event_loop() -app = init(loop) -run_app(app) +web.run_app(init()) diff --git a/examples/web_srv.py b/examples/web_srv.py index 2ee42c28e24..b572326a3a2 100755 --- a/examples/web_srv.py +++ b/examples/web_srv.py @@ -2,58 +2,58 @@ """Example for aiohttp.web basic server """ -import asyncio import textwrap -from aiohttp.web import Application, Response, StreamResponse, run_app +from aiohttp import web async def intro(request): - txt = textwrap.dedent("""\ + txt = textwrap.dedent( + """\ Type {url}/hello/John {url}/simple or {url}/change_body in browser url bar - """).format(url='127.0.0.1:8080') - binary = txt.encode('utf8') - resp = StreamResponse() + """ + ).format(url="127.0.0.1:8080") + binary = txt.encode("utf8") + resp = web.StreamResponse() resp.content_length = len(binary) - resp.content_type = 'text/plain' + resp.content_type = "text/plain" await resp.prepare(request) - resp.write(binary) + await resp.write(binary) return resp async def simple(request): - return Response(text="Simple answer") + return web.Response(text="Simple answer") async def change_body(request): - resp = Response() + resp = web.Response() resp.body = b"Body changed" - resp.content_type = 'text/plain' + resp.content_type = "text/plain" return resp async def hello(request): - resp = StreamResponse() - name = request.match_info.get('name', 'Anonymous') - answer = ('Hello, ' + name).encode('utf8') + resp = web.StreamResponse() + name = request.match_info.get("name", "Anonymous") + answer = ("Hello, " + name).encode("utf8") resp.content_length = len(answer) - resp.content_type = 'text/plain' + resp.content_type = "text/plain" await resp.prepare(request) - resp.write(answer) + await resp.write(answer) await resp.write_eof() return resp -async def init(loop): - app = Application() - app.router.add_get('/', intro) - app.router.add_get('/simple', simple) - app.router.add_get('/change_body', change_body) - app.router.add_get('/hello/{name}', hello) - app.router.add_get('/hello', hello) +def init(): + app = web.Application() + app.router.add_get("/", intro) + app.router.add_get("/simple", simple) + app.router.add_get("/change_body", change_body) + app.router.add_get("/hello/{name}", hello) + app.router.add_get("/hello", hello) return app -loop = asyncio.get_event_loop() -app = loop.run_until_complete(init(loop)) -run_app(app) + +web.run_app(init()) diff --git a/examples/web_srv_route_deco.py b/examples/web_srv_route_deco.py new file mode 100644 index 00000000000..332990362cc --- /dev/null +++ b/examples/web_srv_route_deco.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +"""Example for aiohttp.web basic server +with decorator definition for routes +""" + +import textwrap + +from aiohttp import web + +routes = web.RouteTableDef() + + +@routes.get("/") +async def intro(request): + txt = textwrap.dedent( + """\ + Type {url}/hello/John {url}/simple or {url}/change_body + in browser url bar + """ + ).format(url="127.0.0.1:8080") + binary = txt.encode("utf8") + resp = web.StreamResponse() + resp.content_length = len(binary) + resp.content_type = "text/plain" + await resp.prepare(request) + await resp.write(binary) + return resp + + +@routes.get("/simple") +async def simple(request): + return web.Response(text="Simple answer") + + +@routes.get("/change_body") +async def change_body(request): + resp = web.Response() + resp.body = b"Body changed" + resp.content_type = "text/plain" + return resp + + +@routes.get("/hello") +async def hello(request): + resp = web.StreamResponse() + name = request.match_info.get("name", "Anonymous") + answer = ("Hello, " + name).encode("utf8") + resp.content_length = len(answer) + resp.content_type = "text/plain" + await resp.prepare(request) + await resp.write(answer) + await resp.write_eof() + return resp + + +def init(): + app = web.Application() + app.router.add_routes(routes) + return app + + +web.run_app(init()) diff --git a/examples/web_srv_route_table.py b/examples/web_srv_route_table.py new file mode 100644 index 00000000000..f53142adad4 --- /dev/null +++ b/examples/web_srv_route_table.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +"""Example for aiohttp.web basic server +with table definition for routes +""" + +import textwrap + +from aiohttp import web + + +async def intro(request): + txt = textwrap.dedent( + """\ + Type {url}/hello/John {url}/simple or {url}/change_body + in browser url bar + """ + ).format(url="127.0.0.1:8080") + binary = txt.encode("utf8") + resp = web.StreamResponse() + resp.content_length = len(binary) + resp.content_type = "text/plain" + await resp.prepare(request) + await resp.write(binary) + return resp + + +async def simple(request): + return web.Response(text="Simple answer") + + +async def change_body(request): + resp = web.Response() + resp.body = b"Body changed" + resp.content_type = "text/plain" + return resp + + +async def hello(request): + resp = web.StreamResponse() + name = request.match_info.get("name", "Anonymous") + answer = ("Hello, " + name).encode("utf8") + resp.content_length = len(answer) + resp.content_type = "text/plain" + await resp.prepare(request) + await resp.write(answer) + await resp.write_eof() + return resp + + +def init(): + app = web.Application() + app.router.add_routes( + [ + web.get("/", intro), + web.get("/simple", simple), + web.get("/change_body", change_body), + web.get("/hello/{name}", hello), + web.get("/hello", hello), + ] + ) + return app + + +web.run_app(init()) diff --git a/examples/web_ws.py b/examples/web_ws.py index bb808fef710..970f1506be3 100755 --- a/examples/web_ws.py +++ b/examples/web_ws.py @@ -2,59 +2,57 @@ """Example for aiohttp.web websocket server """ -import asyncio import os -from aiohttp.web import (Application, Response, WebSocketResponse, WSMsgType, - run_app) +from aiohttp import web -WS_FILE = os.path.join(os.path.dirname(__file__), 'websocket.html') +WS_FILE = os.path.join(os.path.dirname(__file__), "websocket.html") async def wshandler(request): - resp = WebSocketResponse() - ok, protocol = resp.can_prepare(request) - if not ok: - with open(WS_FILE, 'rb') as fp: - return Response(body=fp.read(), content_type='text/html') + resp = web.WebSocketResponse() + available = resp.can_prepare(request) + if not available: + with open(WS_FILE, "rb") as fp: + return web.Response(body=fp.read(), content_type="text/html") await resp.prepare(request) + await resp.send_str("Welcome!!!") + try: - print('Someone joined.') - for ws in request.app['sockets']: - ws.send_str('Someone joined') - request.app['sockets'].append(resp) + print("Someone joined.") + for ws in request.app["sockets"]: + await ws.send_str("Someone joined") + request.app["sockets"].append(resp) async for msg in resp: - if msg.type == WSMsgType.TEXT: - for ws in request.app['sockets']: + if msg.type == web.WSMsgType.TEXT: + for ws in request.app["sockets"]: if ws is not resp: - ws.send_str(msg.data) + await ws.send_str(msg.data) else: return resp return resp finally: - request.app['sockets'].remove(resp) - print('Someone disconnected.') - for ws in request.app['sockets']: - ws.send_str('Someone disconnected.') + request.app["sockets"].remove(resp) + print("Someone disconnected.") + for ws in request.app["sockets"]: + await ws.send_str("Someone disconnected.") async def on_shutdown(app): - for ws in app['sockets']: + for ws in app["sockets"]: await ws.close() -async def init(loop): - app = Application() - app['sockets'] = [] - app.router.add_get('/', wshandler) +def init(): + app = web.Application() + app["sockets"] = [] + app.router.add_get("/", wshandler) app.on_shutdown.append(on_shutdown) return app -loop = asyncio.get_event_loop() -app = loop.run_until_complete(init(loop)) -run_app(app) +web.run_app(init()) diff --git a/examples/websocket.html b/examples/websocket.html index 6bad7f74647..2ba9ff367d6 100644 --- a/examples/websocket.html +++ b/examples/websocket.html @@ -39,7 +39,6 @@ } } function update_ui() { - var msg = ''; if (conn == null) { $('#status').text('disconnected'); $('#connect').html('Connect'); diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000000..e666dfc174e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[tool.towncrier] +package = "aiohttp" +filename = "CHANGES.rst" +directory = "CHANGES/" +title_format = "{version} ({project_date})" +template = "CHANGES/.TEMPLATE.rst" +issue_format = "`#{issue} `_" diff --git a/requirements-ci.txt b/requirements-ci.txt deleted file mode 100644 index c0a6695645a..00000000000 --- a/requirements-ci.txt +++ /dev/null @@ -1,24 +0,0 @@ -pip==9.0.1 -flake8==3.3.0 -pyflakes==1.5.0 -coverage==4.3.4 -cchardet==1.1.3 -sphinx==1.5.3 -cython==0.25.2 -chardet==2.3.0 -isort==4.2.5 -tox==2.6.0 -multidict==2.1.4 -async-timeout==1.2.0 -sphinxcontrib-asyncio==0.2.0 -sphinxcontrib-newsfeed==0.1.4 -pytest==3.0.7 -pytest-cov==2.4.0 -pytest-mock==1.5.0 -pytest-timeout==1.2.0 -gunicorn==19.7.0 -pygments>=2.1 -#aiodns # Enable from .travis.yml as required c-ares will not build on windows -twine==1.8.1 -yarl==0.10.0 --e . diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 4cb5f08f2f4..00000000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,5 +0,0 @@ --r requirements-ci.txt -ipdb==0.10.2 -pytest-sugar==0.8.0 -ipython==5.3.0 -aiodns==1.1.1 diff --git a/requirements-wheel.txt b/requirements-wheel.txt deleted file mode 100644 index d6004703112..00000000000 --- a/requirements-wheel.txt +++ /dev/null @@ -1,3 +0,0 @@ -cython==0.25.2 -pytest==3.0.7 -twine==1.8.1 diff --git a/requirements/base.txt b/requirements/base.txt new file mode 100644 index 00000000000..ffd04d12ae9 --- /dev/null +++ b/requirements/base.txt @@ -0,0 +1,14 @@ +-r multidict.txt +# required c-ares will not build on windows and has build problems on Macos Python<3.7 +aiodns==2.0.0; sys_platform=="linux" or sys_platform=="darwin" and python_version>="3.7" +async-generator==1.10 +async-timeout==3.0.1 +attrs==20.3.0 +brotlipy==0.7.0 +cchardet==2.1.7 +chardet==4.0.0 +gunicorn==20.0.4 +idna-ssl==1.1.0; python_version<"3.7" +typing_extensions==3.7.4.3 +uvloop==0.14.0; platform_system!="Windows" and implementation_name=="cpython" and python_version<"3.9" # MagicStack/uvloop#14 +yarl==1.6.3 diff --git a/requirements/cython.txt b/requirements/cython.txt new file mode 100644 index 00000000000..e478589498f --- /dev/null +++ b/requirements/cython.txt @@ -0,0 +1,2 @@ +-r multidict.txt +cython==0.29.21 diff --git a/requirements/dev.txt b/requirements/dev.txt new file mode 100644 index 00000000000..fc7aee6945c --- /dev/null +++ b/requirements/dev.txt @@ -0,0 +1,4 @@ +-r lint.txt +-r test.txt +-r doc.txt +cherry_picker==1.3.2; python_version>="3.6" diff --git a/requirements/doc-spelling.txt b/requirements/doc-spelling.txt new file mode 100644 index 00000000000..699f7e3f49e --- /dev/null +++ b/requirements/doc-spelling.txt @@ -0,0 +1,2 @@ +-r doc.txt +sphinxcontrib-spelling==7.1.0; platform_system!="Windows" # We only use it in Travis CI diff --git a/requirements/doc.txt b/requirements/doc.txt new file mode 100644 index 00000000000..09d666a9f2c --- /dev/null +++ b/requirements/doc.txt @@ -0,0 +1,6 @@ +aiohttp-theme==0.1.6 +pygments==2.7.3 +sphinx==3.3.1 +sphinxcontrib-asyncio==0.3.0 +sphinxcontrib-blockdiag==2.0.0 +towncrier==19.2.0 diff --git a/requirements/lint.txt b/requirements/lint.txt new file mode 100644 index 00000000000..bcae22d6763 --- /dev/null +++ b/requirements/lint.txt @@ -0,0 +1,6 @@ +black==20.8b1; implementation_name=="cpython" +flake8==3.8.4 +flake8-pyi==20.10.0 +isort==5.6.4 +mypy==0.790; implementation_name=="cpython" +pre-commit==2.9.3 diff --git a/requirements/multidict.txt b/requirements/multidict.txt new file mode 100644 index 00000000000..7357d4643f0 --- /dev/null +++ b/requirements/multidict.txt @@ -0,0 +1 @@ +multidict==5.1.0 diff --git a/requirements/test.txt b/requirements/test.txt new file mode 100644 index 00000000000..3085dd5881f --- /dev/null +++ b/requirements/test.txt @@ -0,0 +1,13 @@ + +-r base.txt +coverage==5.3 +cryptography==3.2.1; platform_machine!="i686" and python_version<"3.9" # no 32-bit wheels; no python 3.9 wheels yet +freezegun==1.0.0 +mypy==0.790; implementation_name=="cpython" +mypy-extensions==0.4.3; implementation_name=="cpython" +pytest==6.1.2 +pytest-cov==2.10.1 +pytest-mock==3.3.1 +re-assert==1.1.0 +setuptools-git==1.2 +trustme==0.6.0; platform_machine!="i686" # no 32-bit wheels diff --git a/run_docker.sh b/run_docker.sh deleted file mode 100755 index 30d3e3cadff..00000000000 --- a/run_docker.sh +++ /dev/null @@ -1,25 +0,0 @@ -if [ ! -z $TRAVIS_TAG ] && [ -z $PYTHONASYNCIODEBUG ] && [ -z $AIOHTTP_NO_EXTENSIONS] ;then - echo "x86_64" - docker pull quay.io/pypa/manylinux1_x86_64 - docker run --rm -v `pwd`:/io quay.io/pypa/manylinux1_x86_64 /io/build-wheels.sh - echo "Dist folder content is:" - for f in dist/aiohttp*manylinux1_x86_64.whl - do - echo "Upload $f" - python -m twine upload $f --username andrew.svetlov --password $PYPI_PASSWD - done - echo "Cleanup" - docker run --rm -v `pwd`:/io quay.io/pypa/manylinux1_x86_64 rm -rf /io/dist - - echo "i686" - docker pull quay.io/pypa/manylinux1_i686 - docker run --rm -v `pwd`:/io quay.io/pypa/manylinux1_i686 linux32 /io/build-wheels.sh - echo "Dist folder content is:" - for f in dist/aiohttp*manylinux1_i686.whl - do - echo "Upload $f" - python -m twine upload $f --username andrew.svetlov --password $PYPI_PASSWD - done - echo "Cleanup" - docker run --rm -v `pwd`:/io quay.io/pypa/manylinux1_i686 rm -rf /io/dist -fi diff --git a/setup.cfg b/setup.cfg index 0196d2d13de..df8fbc3152f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,9 @@ +[aliases] +test = pytest + +[metadata] +license_file = LICENSE.txt + [pep8] max-line-length=79 @@ -5,17 +11,93 @@ max-line-length=79 zip_ok = false [flake8] -ignore = N801,N802,N803,E226 -max-line-length=79 - -[tool:pytest] -timeout = 4 +ignore = N801,N802,N803,E203,E226,E305,W504,E252,E301,E302,E704,W503,W504,F811 +max-line-length = 88 [isort] -known_third_party=jinja2 +line_length=88 +include_trailing_comma=True +multi_line_output=3 +force_grid_wrap=0 +combine_as_imports=True + +known_third_party=jinja2,pytest,multidict,yarl,gunicorn,freezegun,async_generator known_first_party=aiohttp,aiohttp_jinja2,aiopg [report] exclude_lines = @abc.abstractmethod @abstractmethod + +[coverage:run] +branch = True +source = aiohttp, tests +omit = site-packages + +[tool:pytest] +addopts = --cov=aiohttp -v -rxXs --durations 10 +filterwarnings = + error + ignore:module 'ssl' has no attribute 'OP_NO_COMPRESSION'. The Python interpreter is compiled against OpenSSL < 1.0.0. Ref. https.//docs.python.org/3/library/ssl.html#ssl.OP_NO_COMPRESSION:UserWarning +junit_suite_name = aiohttp_test_suite +norecursedirs = dist docs build .tox .eggs +minversion = 3.8.2 +testpaths = tests/ +junit_family=xunit2 +xfail_strict = true + +[mypy] +follow_imports = silent +strict_optional = True +warn_redundant_casts = True +warn_unused_ignores = True + +# uncomment next lines +# to enable strict mypy mode +# +check_untyped_defs = True +disallow_any_generics = True +disallow_untyped_defs = True + + +[mypy-pytest] +ignore_missing_imports = true + + +[mypy-uvloop] +ignore_missing_imports = true + + +[mypy-tokio] +ignore_missing_imports = true + + +[mypy-async_generator] +ignore_missing_imports = true + + +[mypy-aiodns] +ignore_missing_imports = true + + +[mypy-gunicorn.config] +ignore_missing_imports = true + +[mypy-gunicorn.workers] +ignore_missing_imports = true + + +[mypy-brotli] +ignore_missing_imports = true + + +[mypy-chardet] +ignore_missing_imports = true + + +[mypy-cchardet] +ignore_missing_imports = true + + +[mypy-idna_ssl] +ignore_missing_imports = true diff --git a/setup.py b/setup.py index a9cc9296261..428df5d4e95 100644 --- a/setup.py +++ b/setup.py @@ -1,30 +1,41 @@ -import codecs -import os +import pathlib import re import sys from distutils.command.build_ext import build_ext -from distutils.errors import (CCompilerError, DistutilsExecError, - DistutilsPlatformError) +from distutils.errors import CCompilerError, DistutilsExecError, DistutilsPlatformError from setuptools import Extension, setup -from setuptools.command.test import test as TestCommand -try: - from Cython.Build import cythonize - USE_CYTHON = True -except ImportError: - USE_CYTHON = False +if sys.version_info < (3, 6): + raise RuntimeError("aiohttp 3.7+ requires Python 3.6+") + +here = pathlib.Path(__file__).parent + -ext = '.pyx' if USE_CYTHON else '.c' +if (here / ".git").exists() and not (here / "vendor/http-parser/README.md").exists(): + print("Install submodules when building from git clone", file=sys.stderr) + print("Hint:", file=sys.stderr) + print(" git submodule update --init", file=sys.stderr) + sys.exit(2) -extensions = [Extension('aiohttp._websocket', ['aiohttp/_websocket' + ext]), - Extension('aiohttp._http_parser', - ['aiohttp/_http_parser' + ext, - 'vendor/http-parser/http_parser.c'],)] +# NOTE: makefile cythonizes all Cython modules -if USE_CYTHON: - extensions = cythonize(extensions) +extensions = [ + Extension("aiohttp._websocket", ["aiohttp/_websocket.c"]), + Extension( + "aiohttp._http_parser", + [ + "aiohttp/_http_parser.c", + "vendor/http-parser/http_parser.c", + "aiohttp/_find_header.c", + ], + define_macros=[("HTTP_PARSER_STRICT", 0)], + ), + Extension("aiohttp._frozenlist", ["aiohttp/_frozenlist.c"]), + Extension("aiohttp._helpers", ["aiohttp/_helpers.c"]), + Extension("aiohttp._http_writer", ["aiohttp/_http_writer.c"]), +] class BuildFailed(Exception): @@ -43,76 +54,99 @@ def run(self): def build_extension(self, ext): try: build_ext.build_extension(self, ext) - except (CCompilerError, DistutilsExecError, - DistutilsPlatformError, ValueError): + except (CCompilerError, DistutilsExecError, DistutilsPlatformError, ValueError): raise BuildFailed() -with codecs.open(os.path.join(os.path.abspath(os.path.dirname( - __file__)), 'aiohttp', '__init__.py'), 'r', 'latin1') as fp: - try: - version = re.findall(r"^__version__ = '([^']+)'\r?$", - fp.read(), re.M)[0] - except IndexError: - raise RuntimeError('Unable to determine version.') - - -install_requires = ['chardet', 'multidict>=2.1.4', - 'async_timeout>=1.2.0', 'yarl>=0.10.0,<0.11'] +txt = (here / "aiohttp" / "__init__.py").read_text("utf-8") +try: + version = re.findall(r'^__version__ = "([^"]+)"\r?$', txt, re.M)[0] +except IndexError: + raise RuntimeError("Unable to determine version.") -if sys.version_info < (3, 4, 2): - raise RuntimeError("aiohttp requires Python 3.4.2+") +install_requires = [ + "attrs>=17.3.0", + "chardet>=2.0,<4.0", + "multidict>=4.5,<7.0", + "async_timeout>=3.0,<4.0", + "yarl>=1.0,<2.0", + 'idna-ssl>=1.0; python_version<"3.7"', + "typing_extensions>=3.6.5", +] def read(f): - return open(os.path.join(os.path.dirname(__file__), f)).read().strip() - + return (here / f).read_text("utf-8").strip() -class PyTest(TestCommand): - user_options = [] - - def run(self): - import subprocess - import sys - errno = subprocess.call([sys.executable, '-m', 'pytest', 'tests']) - raise SystemExit(errno) +NEEDS_PYTEST = {"pytest", "test"}.intersection(sys.argv) +pytest_runner = ["pytest-runner"] if NEEDS_PYTEST else [] -tests_require = install_requires + ['pytest', 'gunicorn', 'pytest-timeout'] +tests_require = [ + "pytest", + "gunicorn", + "pytest-timeout", + "async-generator", + "pytest-xdist", +] args = dict( - name='aiohttp', + name="aiohttp", version=version, - description='Async http client/server framework (asyncio)', - long_description='\n\n'.join((read('README.rst'), read('CHANGES.rst'))), + description="Async http client/server framework (asyncio)", + long_description="\n\n".join((read("README.rst"), read("CHANGES.rst"))), classifiers=[ - 'License :: OSI Approved :: Apache Software License', - 'Intended Audience :: Developers', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Development Status :: 5 - Production/Stable', - 'Operating System :: POSIX', - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: Microsoft :: Windows', - 'Topic :: Internet :: WWW/HTTP'], - author='Nikolay Kim', - author_email='fafhrd91@gmail.com', - maintainer=', '.join(('Nikolay Kim ', - 'Andrew Svetlov ')), - maintainer_email='aio-libs@googlegroups.com', - url='https://github.com/aio-libs/aiohttp/', - license='Apache 2', - packages=['aiohttp'], + "License :: OSI Approved :: Apache Software License", + "Intended Audience :: Developers", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Development Status :: 5 - Production/Stable", + "Operating System :: POSIX", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Topic :: Internet :: WWW/HTTP", + "Framework :: AsyncIO", + ], + author="Nikolay Kim", + author_email="fafhrd91@gmail.com", + maintainer=", ".join( + ( + "Nikolay Kim ", + "Andrew Svetlov ", + ) + ), + maintainer_email="aio-libs@googlegroups.com", + url="https://github.com/aio-libs/aiohttp", + project_urls={ + "Chat: Gitter": "https://gitter.im/aio-libs/Lobby", + "CI: Azure Pipelines": "https://dev.azure.com/aio-libs/aiohttp/_build", + "Coverage: codecov": "https://codecov.io/github/aio-libs/aiohttp", + "Docs: RTD": "https://docs.aiohttp.org", + "GitHub: issues": "https://github.com/aio-libs/aiohttp/issues", + "GitHub: repo": "https://github.com/aio-libs/aiohttp", + }, + license="Apache 2", + packages=["aiohttp"], + python_requires=">=3.6", install_requires=install_requires, + extras_require={ + "speedups": [ + "aiodns", + "brotlipy", + "cchardet", + ], + }, tests_require=tests_require, + setup_requires=pytest_runner, include_package_data=True, ext_modules=extensions, - cmdclass=dict(build_ext=ve_build_ext, - test=PyTest)) + cmdclass=dict(build_ext=ve_build_ext), +) try: setup(**args) @@ -120,6 +154,6 @@ def run(self): print("************************************************************") print("Cannot compile C accelerator module, use pure python version") print("************************************************************") - del args['ext_modules'] - del args['cmdclass'] + del args["ext_modules"] + del args["cmdclass"] setup(**args) diff --git a/tests/aiohttp.jpg b/tests/aiohttp.jpg index 2dfbbd41541..e8c51c341db 100644 Binary files a/tests/aiohttp.jpg and b/tests/aiohttp.jpg differ diff --git a/tests/aiohttp.png b/tests/aiohttp.png index 1a3c9498119..db272a4f8ea 100644 Binary files a/tests/aiohttp.png and b/tests/aiohttp.png differ diff --git a/tests/autobahn/client.py b/tests/autobahn/client.py index 6bfc6bdb2c2..513a4ee39fc 100644 --- a/tests/autobahn/client.py +++ b/tests/autobahn/client.py @@ -5,46 +5,47 @@ import aiohttp -def client(loop, url, name): - ws = yield from aiohttp.ws_connect(url + '/getCaseCount') - num_tests = int((yield from ws.receive()).data) - print('running %d cases' % num_tests) - yield from ws.close() +async def client(loop, url, name): + ws = await aiohttp.ws_connect(url + "/getCaseCount") + num_tests = int((await ws.receive()).data) + print("running %d cases" % num_tests) + await ws.close() for i in range(1, num_tests + 1): - print('running test case:', i) - text_url = url + '/runCase?case=%d&agent=%s' % (i, name) - ws = yield from aiohttp.ws_connect(text_url) + print("running test case:", i) + text_url = url + "/runCase?case=%d&agent=%s" % (i, name) + ws = await aiohttp.ws_connect(text_url) while True: - msg = yield from ws.receive() + msg = await ws.receive() if msg.type == aiohttp.WSMsgType.text: - ws.send_str(msg.data) + await ws.send_str(msg.data) elif msg.type == aiohttp.WSMsgType.binary: - ws.send_bytes(msg.data) + await ws.send_bytes(msg.data) elif msg.type == aiohttp.WSMsgType.close: - yield from ws.close() + await ws.close() break else: break - url = url + '/updateReports?agent=%s' % name - ws = yield from aiohttp.ws_connect(url) - yield from ws.close() + url = url + "/updateReports?agent=%s" % name + ws = await aiohttp.ws_connect(url) + await ws.close() -def run(loop, url, name): +async def run(loop, url, name): try: - yield from client(loop, url, name) - except: + await client(loop, url, name) + except Exception: import traceback + traceback.print_exc() -if __name__ == '__main__': +if __name__ == "__main__": loop = asyncio.get_event_loop() try: - loop.run_until_complete(run(loop, 'http://localhost:9001', 'aiohttp')) + loop.run_until_complete(run(loop, "http://localhost:9001", "aiohttp")) except KeyboardInterrupt: pass finally: diff --git a/tests/autobahn/server.py b/tests/autobahn/server.py index 447e152d147..3d39d6c9d53 100644 --- a/tests/autobahn/server.py +++ b/tests/autobahn/server.py @@ -6,24 +6,23 @@ from aiohttp import web -@asyncio.coroutine -def wshandler(request): +async def wshandler(request): ws = web.WebSocketResponse(autoclose=False) - ok, protocol = ws.can_start(request) - if not ok: + is_ws = ws.can_prepare(request) + if not is_ws: return web.HTTPBadRequest() - yield from ws.prepare(request) + await ws.prepare(request) while True: - msg = yield from ws.receive() + msg = await ws.receive() if msg.type == web.WSMsgType.text: - ws.send_str(msg.data) + await ws.send_str(msg.data) elif msg.type == web.WSMsgType.binary: - ws.send_bytes(msg.data) + await ws.send_bytes(msg.data) elif msg.type == web.WSMsgType.close: - yield from ws.close() + await ws.close() break else: break @@ -31,30 +30,27 @@ def wshandler(request): return ws -@asyncio.coroutine -def main(loop): +async def main(loop): app = web.Application() - app.router.add_route('GET', '/', wshandler) + app.router.add_route("GET", "/", wshandler) - handler = app.make_handler() - srv = yield from loop.create_server(handler, '127.0.0.1', 9001) + handler = app._make_handler() + srv = await loop.create_server(handler, "127.0.0.1", 9001) print("Server started at http://127.0.0.1:9001") return app, srv, handler -@asyncio.coroutine -def finish(app, srv, handler): +async def finish(app, srv, handler): srv.close() - yield from handler.finish_connections() - yield from srv.wait_closed() + await handler.shutdown() + await srv.wait_closed() -if __name__ == '__main__': - loop = asyncio.get_event_loop() - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s %(levelname)s %(message)s') - +if __name__ == "__main__": loop = asyncio.get_event_loop() + logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" + ) app, srv, handler = loop.run_until_complete(main(loop)) try: loop.run_forever() diff --git a/tests/conftest.py b/tests/conftest.py index 2cfc3bdd0ca..09cbf6c9ed7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,84 +1,109 @@ -import collections -import logging +import asyncio +import hashlib +import pathlib +import shutil +import ssl import sys +import tempfile +import uuid import pytest -pytest_plugins = 'aiohttp.pytest_plugin' +from aiohttp.test_utils import loop_context +try: + import trustme -_LoggingWatcher = collections.namedtuple("_LoggingWatcher", - ["records", "output"]) + TRUSTME = True +except ImportError: + TRUSTME = False +pytest_plugins = ["aiohttp.pytest_plugin", "pytester"] -class _CapturingHandler(logging.Handler): - """ - A logging handler capturing all (raw and formatted) logging output. + +@pytest.fixture +def shorttmpdir(): + """Provides a temporary directory with a shorter file system path than the + tmpdir fixture. """ + tmpdir = pathlib.Path(tempfile.mkdtemp()) + yield tmpdir + # str(tmpdir) is required, Python 3.5 doesn't have __fspath__ + # concept + shutil.rmtree(str(tmpdir), ignore_errors=True) - def __init__(self): - logging.Handler.__init__(self) - self.watcher = _LoggingWatcher([], []) - def flush(self): - pass +@pytest.fixture +def tls_certificate_authority(): + if not TRUSTME: + pytest.xfail("trustme fails on 32bit Linux") + return trustme.CA() - def emit(self, record): - self.watcher.records.append(record) - msg = self.format(record) - self.watcher.output.append(msg) +@pytest.fixture +def tls_certificate(tls_certificate_authority): + return tls_certificate_authority.issue_server_cert( + "localhost", + "127.0.0.1", + "::1", + ) -class _AssertLogsContext: - """A context manager used to implement TestCase.assertLogs().""" - LOGGING_FORMAT = "%(levelname)s:%(name)s:%(message)s" +@pytest.fixture +def ssl_ctx(tls_certificate): + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + tls_certificate.configure_cert(ssl_ctx) + return ssl_ctx - def __init__(self, logger_name=None, level=None): - self.logger_name = logger_name - if level: - self.level = logging._nameToLevel.get(level, level) - else: - self.level = logging.INFO - self.msg = None - def __enter__(self): - if isinstance(self.logger_name, logging.Logger): - logger = self.logger = self.logger_name +@pytest.fixture +def client_ssl_ctx(tls_certificate_authority): + ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) + tls_certificate_authority.configure_trust(ssl_ctx) + return ssl_ctx + + +@pytest.fixture +def tls_ca_certificate_pem_path(tls_certificate_authority): + with tls_certificate_authority.cert_pem.tempfile() as ca_cert_pem: + yield ca_cert_pem + + +@pytest.fixture +def tls_certificate_pem_path(tls_certificate): + with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem: + yield cert_pem + + +@pytest.fixture +def tls_certificate_pem_bytes(tls_certificate): + return tls_certificate.cert_chain_pems[0].bytes() + + +@pytest.fixture +def tls_certificate_fingerprint_sha256(tls_certificate_pem_bytes): + tls_cert_der = ssl.PEM_cert_to_DER_cert(tls_certificate_pem_bytes.decode()) + return hashlib.sha256(tls_cert_der).digest() + + +@pytest.fixture +def pipe_name(): + name = fr"\\.\pipe\{uuid.uuid4().hex}" + return name + + +@pytest.fixture +def selector_loop(): + if sys.version_info < (3, 7): + policy = asyncio.get_event_loop_policy() + policy._loop_factory = asyncio.SelectorEventLoop # type: ignore + else: + if sys.version_info >= (3, 8): + policy = asyncio.WindowsSelectorEventLoopPolicy() # type: ignore else: - logger = self.logger = logging.getLogger(self.logger_name) - formatter = logging.Formatter(self.LOGGING_FORMAT) - handler = _CapturingHandler() - handler.setFormatter(formatter) - self.watcher = handler.watcher - self.old_handlers = logger.handlers[:] - self.old_level = logger.level - self.old_propagate = logger.propagate - logger.handlers = [handler] - logger.setLevel(self.level) - logger.propagate = False - return handler.watcher - - def __exit__(self, exc_type, exc_value, tb): - self.logger.handlers = self.old_handlers - self.logger.propagate = self.old_propagate - self.logger.setLevel(self.old_level) - if exc_type is not None: - # let unexpected exceptions pass through - return False - if len(self.watcher.records) == 0: - __tracebackhide__ = True - assert 0, ("no logs of level {} or higher triggered on {}" - .format(logging.getLevelName(self.level), - self.logger.name)) - - -@pytest.yield_fixture -def log(): - yield _AssertLogsContext - - -def pytest_ignore_collect(path, config): - if 'test_py35' in str(path): - if sys.version_info < (3, 5, 0): - return True + policy = asyncio.DefaultEventLoopPolicy() + asyncio.set_event_loop_policy(policy) + + with loop_context(policy.new_event_loop) as _loop: + asyncio.set_event_loop(_loop) + yield _loop diff --git a/tests/data.zero_bytes b/tests/data.zero_bytes new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/sample.crt b/tests/sample.crt deleted file mode 100644 index 6a1e3f3c2e7..00000000000 --- a/tests/sample.crt +++ /dev/null @@ -1,14 +0,0 @@ ------BEGIN CERTIFICATE----- -MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV -UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi -MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3 -MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp -Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g -U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn -t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7 -gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg -Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB -BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB -MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I -I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB ------END CERTIFICATE----- diff --git a/tests/sample.crt.der b/tests/sample.crt.der deleted file mode 100644 index ce22b75b9e0..00000000000 Binary files a/tests/sample.crt.der and /dev/null differ diff --git a/tests/sample.key b/tests/sample.key deleted file mode 100644 index edfea8dcab3..00000000000 --- a/tests/sample.key +++ /dev/null @@ -1,15 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx -/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi -qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB -AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y -bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1 -iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb -DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP -lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL -21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF -ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0 -zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u -GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq -V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof ------END RSA PRIVATE KEY----- diff --git a/tests/test_base_protocol.py b/tests/test_base_protocol.py new file mode 100644 index 00000000000..531e55e6be4 --- /dev/null +++ b/tests/test_base_protocol.py @@ -0,0 +1,182 @@ +import asyncio +from contextlib import suppress +from unittest import mock + +import pytest + +from aiohttp.base_protocol import BaseProtocol + + +async def test_loop() -> None: + loop = asyncio.get_event_loop() + asyncio.set_event_loop(None) + pr = BaseProtocol(loop) + assert pr._loop is loop + + +async def test_pause_writing() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop) + assert not pr._paused + pr.pause_writing() + assert pr._paused + + +async def test_resume_writing_no_waiters() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + pr.pause_writing() + assert pr._paused + pr.resume_writing() + assert not pr._paused + + +async def test_connection_made() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + tr = mock.Mock() + assert pr.transport is None + pr.connection_made(tr) + assert pr.transport is not None + + +async def test_connection_lost_not_paused() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + tr = mock.Mock() + pr.connection_made(tr) + assert not pr._connection_lost + pr.connection_lost(None) + assert pr.transport is None + assert pr._connection_lost + + +async def test_connection_lost_paused_without_waiter() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + tr = mock.Mock() + pr.connection_made(tr) + assert not pr._connection_lost + pr.pause_writing() + pr.connection_lost(None) + assert pr.transport is None + assert pr._connection_lost + + +async def test_drain_lost() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + tr = mock.Mock() + pr.connection_made(tr) + pr.connection_lost(None) + with pytest.raises(ConnectionResetError): + await pr._drain_helper() + + +async def test_drain_not_paused() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + tr = mock.Mock() + pr.connection_made(tr) + assert pr._drain_waiter is None + await pr._drain_helper() + assert pr._drain_waiter is None + + +async def test_resume_drain_waited() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + tr = mock.Mock() + pr.connection_made(tr) + pr.pause_writing() + + t = loop.create_task(pr._drain_helper()) + await asyncio.sleep(0) + + assert pr._drain_waiter is not None + pr.resume_writing() + assert (await t) is None + assert pr._drain_waiter is None + + +async def test_lost_drain_waited_ok() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + tr = mock.Mock() + pr.connection_made(tr) + pr.pause_writing() + + t = loop.create_task(pr._drain_helper()) + await asyncio.sleep(0) + + assert pr._drain_waiter is not None + pr.connection_lost(None) + assert (await t) is None + assert pr._drain_waiter is None + + +async def test_lost_drain_waited_exception() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + tr = mock.Mock() + pr.connection_made(tr) + pr.pause_writing() + + t = loop.create_task(pr._drain_helper()) + await asyncio.sleep(0) + + assert pr._drain_waiter is not None + exc = RuntimeError() + pr.connection_lost(exc) + with pytest.raises(RuntimeError) as cm: + await t + assert cm.value is exc + assert pr._drain_waiter is None + + +async def test_lost_drain_cancelled() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + tr = mock.Mock() + pr.connection_made(tr) + pr.pause_writing() + + fut = loop.create_future() + + async def wait(): + fut.set_result(None) + await pr._drain_helper() + + t = loop.create_task(wait()) + await fut + t.cancel() + + assert pr._drain_waiter is not None + pr.connection_lost(None) + with suppress(asyncio.CancelledError): + await t + assert pr._drain_waiter is None + + +async def test_resume_drain_cancelled() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + tr = mock.Mock() + pr.connection_made(tr) + pr.pause_writing() + + fut = loop.create_future() + + async def wait(): + fut.set_result(None) + await pr._drain_helper() + + t = loop.create_task(wait()) + await fut + t.cancel() + + assert pr._drain_waiter is not None + pr.resume_writing() + with suppress(asyncio.CancelledError): + await t + assert pr._drain_waiter is None diff --git a/tests/test_classbasedview.py b/tests/test_classbasedview.py index a7f16493d12..0bee6db976b 100644 --- a/tests/test_classbasedview.py +++ b/tests/test_classbasedview.py @@ -1,4 +1,3 @@ -import asyncio from unittest import mock import pytest @@ -7,56 +6,50 @@ from aiohttp.web_urldispatcher import View -def test_ctor(): +def test_ctor() -> None: request = mock.Mock() view = View(request) assert view.request is request -@asyncio.coroutine -def test_render_ok(): - resp = web.Response(text='OK') +async def test_render_ok() -> None: + resp = web.Response(text="OK") class MyView(View): - @asyncio.coroutine - def get(self): + async def get(self): return resp request = mock.Mock() - request._method = 'GET' - resp2 = yield from MyView(request) + request.method = "GET" + resp2 = await MyView(request) assert resp is resp2 -@asyncio.coroutine -def test_render_unknown_method(): - +async def test_render_unknown_method() -> None: class MyView(View): - @asyncio.coroutine - def get(self): - return web.Response(text='OK') + async def get(self): + return web.Response(text="OK") + options = get request = mock.Mock() - request.method = 'UNKNOWN' + request.method = "UNKNOWN" with pytest.raises(web.HTTPMethodNotAllowed) as ctx: - yield from MyView(request) - assert ctx.value.headers['allow'] == 'GET,OPTIONS' + await MyView(request) + assert ctx.value.headers["allow"] == "GET,OPTIONS" assert ctx.value.status == 405 -@asyncio.coroutine -def test_render_unsupported_method(): - +async def test_render_unsupported_method() -> None: class MyView(View): - @asyncio.coroutine - def get(self): - return web.Response(text='OK') + async def get(self): + return web.Response(text="OK") + options = delete = get request = mock.Mock() - request.method = 'POST' + request.method = "POST" with pytest.raises(web.HTTPMethodNotAllowed) as ctx: - yield from MyView(request) - assert ctx.value.headers['allow'] == 'DELETE,GET,OPTIONS' + await MyView(request) + assert ctx.value.headers["allow"] == "DELETE,GET,OPTIONS" assert ctx.value.status == 405 diff --git a/tests/test_client_connection.py b/tests/test_client_connection.py index 86a9bbc5730..5a0739b6b0c 100644 --- a/tests/test_client_connection.py +++ b/tests/test_client_connection.py @@ -11,11 +11,6 @@ def key(): return object() -@pytest.fixture -def request(): - return mock.Mock() - - @pytest.fixture def loop(): return mock.Mock() @@ -31,15 +26,15 @@ def protocol(): return mock.Mock(should_close=False) -def test_ctor(connector, key, protocol, loop): +def test_ctor(connector, key, protocol, loop) -> None: conn = Connection(connector, key, protocol, loop) - assert conn.loop is loop + with pytest.warns(DeprecationWarning): + assert conn.loop is loop assert conn.protocol is protocol - assert conn.writer is protocol.writer conn.close() -def test_callbacks_on_close(connector, key, protocol, loop): +def test_callbacks_on_close(connector, key, protocol, loop) -> None: conn = Connection(connector, key, protocol, loop) notified = False @@ -52,7 +47,7 @@ def cb(): assert notified -def test_callbacks_on_release(connector, key, protocol, loop): +def test_callbacks_on_release(connector, key, protocol, loop) -> None: conn = Connection(connector, key, protocol, loop) notified = False @@ -65,20 +60,7 @@ def cb(): assert notified -def test_callbacks_on_detach(connector, key, protocol, loop): - conn = Connection(connector, key, protocol, loop) - notified = False - - def cb(): - nonlocal notified - notified = True - - conn.add_callback(cb) - conn.detach() - assert notified - - -def test_callbacks_exception(connector, key, protocol, loop): +def test_callbacks_exception(connector, key, protocol, loop) -> None: conn = Connection(connector, key, protocol, loop) notified = False @@ -95,7 +77,7 @@ def cb2(): assert notified -def test_del(connector, key, protocol, loop): +def test_del(connector, key, protocol, loop) -> None: loop.is_closed.return_value = False conn = Connection(connector, key, protocol, loop) exc_handler = mock.Mock() @@ -106,14 +88,16 @@ def test_del(connector, key, protocol, loop): gc.collect() connector._release.assert_called_with(key, protocol, should_close=True) - msg = {'client_connection': mock.ANY, # conn was deleted - 'message': 'Unclosed connection'} + msg = { + "client_connection": mock.ANY, # conn was deleted + "message": "Unclosed connection", + } if loop.get_debug(): - msg['source_traceback'] = mock.ANY + msg["source_traceback"] = mock.ANY loop.call_exception_handler.assert_called_with(msg) -def test_close(connector, key, protocol, loop): +def test_close(connector, key, protocol, loop) -> None: conn = Connection(connector, key, protocol, loop) assert not conn.closed conn.close() @@ -122,7 +106,7 @@ def test_close(connector, key, protocol, loop): assert conn.closed -def test_release(connector, key, protocol, loop): +def test_release(connector, key, protocol, loop) -> None: conn = Connection(connector, key, protocol, loop) assert not conn.closed conn.release() @@ -132,7 +116,7 @@ def test_release(connector, key, protocol, loop): assert conn.closed -def test_release_proto_should_close(connector, key, protocol, loop): +def test_release_proto_should_close(connector, key, protocol, loop) -> None: protocol.should_close = True conn = Connection(connector, key, protocol, loop) assert not conn.closed @@ -143,7 +127,7 @@ def test_release_proto_should_close(connector, key, protocol, loop): assert conn.closed -def test_release_released(connector, key, protocol, loop): +def test_release_released(connector, key, protocol, loop) -> None: conn = Connection(connector, key, protocol, loop) conn.release() connector._release.reset_mock() @@ -151,22 +135,3 @@ def test_release_released(connector, key, protocol, loop): assert not protocol.transport.close.called assert conn._protocol is None assert not connector._release.called - - -def test_detach(connector, key, protocol, loop): - conn = Connection(connector, key, protocol, loop) - assert not conn.closed - conn.detach() - assert conn._protocol is None - assert connector._release_acquired.called - assert not connector._release.called - assert conn.closed - - -def test_detach_closed(connector, key, protocol, loop): - conn = Connection(connector, key, protocol, loop) - conn.release() - conn.detach() - - assert not connector._release_acquired.called - assert conn._protocol is None diff --git a/tests/test_client_exceptions.py b/tests/test_client_exceptions.py index e1068f4b725..4268825897c 100644 --- a/tests/test_client_exceptions.py +++ b/tests/test_client_exceptions.py @@ -1,10 +1,326 @@ -"""Tests for http_exceptions.py""" +# Tests for client_exceptions.py -from aiohttp import client +import errno +import pickle +import sys +from unittest import mock +import pytest -def test_fingerprint_mismatch(): - err = client.ServerFingerprintMismatch('exp', 'got', 'host', 8888) - expected = ('') - assert expected == repr(err) +from aiohttp import client, client_reqrep + + +class TestClientResponseError: + request_info = client.RequestInfo( + url="http://example.com", + method="GET", + headers={}, + real_url="http://example.com", + ) + + def test_default_status(self) -> None: + err = client.ClientResponseError(history=(), request_info=self.request_info) + assert err.status == 0 + + def test_status(self) -> None: + err = client.ClientResponseError( + status=400, history=(), request_info=self.request_info + ) + assert err.status == 400 + + def test_pickle(self) -> None: + err = client.ClientResponseError(request_info=self.request_info, history=()) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.request_info == self.request_info + assert err2.history == () + assert err2.status == 0 + assert err2.message == "" + assert err2.headers is None + + err = client.ClientResponseError( + request_info=self.request_info, + history=(), + status=400, + message="Something wrong", + headers={}, + ) + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.request_info == self.request_info + assert err2.history == () + assert err2.status == 400 + assert err2.message == "Something wrong" + assert err2.headers == {} + assert err2.foo == "bar" + + def test_repr(self) -> None: + err = client.ClientResponseError(request_info=self.request_info, history=()) + assert repr(err) == (f"ClientResponseError({self.request_info!r}, ())") + + err = client.ClientResponseError( + request_info=self.request_info, + history=(), + status=400, + message="Something wrong", + headers={}, + ) + assert repr(err) == ( + "ClientResponseError(%r, (), status=400, " + "message='Something wrong', headers={})" % (self.request_info,) + ) + + def test_str(self) -> None: + err = client.ClientResponseError( + request_info=self.request_info, + history=(), + status=400, + message="Something wrong", + headers={}, + ) + assert str(err) == ( + "400, message='Something wrong', " "url='http://example.com'" + ) + + +def test_response_status() -> None: + request_info = mock.Mock(real_url="http://example.com") + err = client.ClientResponseError( + status=400, history=None, request_info=request_info + ) + assert err.status == 400 + + +def test_response_deprecated_code_property() -> None: + request_info = mock.Mock(real_url="http://example.com") + with pytest.warns(DeprecationWarning): + err = client.ClientResponseError( + code=400, history=None, request_info=request_info + ) + with pytest.warns(DeprecationWarning): + assert err.code == err.status + with pytest.warns(DeprecationWarning): + err.code = "404" + with pytest.warns(DeprecationWarning): + assert err.code == err.status + + +def test_response_both_code_and_status() -> None: + with pytest.raises(ValueError): + client.ClientResponseError( + code=400, status=400, history=None, request_info=None + ) + + +class TestClientConnectorError: + connection_key = client_reqrep.ConnectionKey( + host="example.com", + port=8080, + is_ssl=False, + ssl=None, + proxy=None, + proxy_auth=None, + proxy_headers_hash=None, + ) + + def test_ctor(self) -> None: + err = client.ClientConnectorError( + connection_key=self.connection_key, + os_error=OSError(errno.ENOENT, "No such file"), + ) + assert err.errno == errno.ENOENT + assert err.strerror == "No such file" + assert err.os_error.errno == errno.ENOENT + assert err.os_error.strerror == "No such file" + assert err.host == "example.com" + assert err.port == 8080 + assert err.ssl is None + + def test_pickle(self) -> None: + err = client.ClientConnectorError( + connection_key=self.connection_key, + os_error=OSError(errno.ENOENT, "No such file"), + ) + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.errno == errno.ENOENT + assert err2.strerror == "No such file" + assert err2.os_error.errno == errno.ENOENT + assert err2.os_error.strerror == "No such file" + assert err2.host == "example.com" + assert err2.port == 8080 + assert err2.ssl is None + assert err2.foo == "bar" + + def test_repr(self) -> None: + os_error = OSError(errno.ENOENT, "No such file") + err = client.ClientConnectorError( + connection_key=self.connection_key, os_error=os_error + ) + assert repr(err) == ( + f"ClientConnectorError({self.connection_key!r}, {os_error!r})" + ) + + def test_str(self) -> None: + err = client.ClientConnectorError( + connection_key=self.connection_key, + os_error=OSError(errno.ENOENT, "No such file"), + ) + assert str(err) == ( + "Cannot connect to host example.com:8080 ssl:" "default [No such file]" + ) + + +class TestClientConnectorCertificateError: + connection_key = client_reqrep.ConnectionKey( + host="example.com", + port=8080, + is_ssl=False, + ssl=None, + proxy=None, + proxy_auth=None, + proxy_headers_hash=None, + ) + + def test_ctor(self) -> None: + certificate_error = Exception("Bad certificate") + err = client.ClientConnectorCertificateError( + connection_key=self.connection_key, certificate_error=certificate_error + ) + assert err.certificate_error == certificate_error + assert err.host == "example.com" + assert err.port == 8080 + assert err.ssl is False + + def test_pickle(self) -> None: + certificate_error = Exception("Bad certificate") + err = client.ClientConnectorCertificateError( + connection_key=self.connection_key, certificate_error=certificate_error + ) + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.certificate_error.args == ("Bad certificate",) + assert err2.host == "example.com" + assert err2.port == 8080 + assert err2.ssl is False + assert err2.foo == "bar" + + def test_repr(self) -> None: + certificate_error = Exception("Bad certificate") + err = client.ClientConnectorCertificateError( + connection_key=self.connection_key, certificate_error=certificate_error + ) + assert repr(err) == ( + "ClientConnectorCertificateError(%r, %r)" + % (self.connection_key, certificate_error) + ) + + def test_str(self) -> None: + certificate_error = Exception("Bad certificate") + err = client.ClientConnectorCertificateError( + connection_key=self.connection_key, certificate_error=certificate_error + ) + assert str(err) == ( + "Cannot connect to host example.com:8080 ssl:False" + " [Exception: ('Bad certificate',)]" + ) + + +class TestServerDisconnectedError: + def test_ctor(self) -> None: + err = client.ServerDisconnectedError() + assert err.message == "Server disconnected" + + err = client.ServerDisconnectedError(message="No connection") + assert err.message == "No connection" + + def test_pickle(self) -> None: + err = client.ServerDisconnectedError(message="No connection") + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.message == "No connection" + assert err2.foo == "bar" + + def test_repr(self) -> None: + err = client.ServerDisconnectedError() + if sys.version_info < (3, 7): + assert repr(err) == ("ServerDisconnectedError" "('Server disconnected',)") + else: + assert repr(err) == ("ServerDisconnectedError" "('Server disconnected')") + + err = client.ServerDisconnectedError(message="No connection") + if sys.version_info < (3, 7): + assert repr(err) == "ServerDisconnectedError('No connection',)" + else: + assert repr(err) == "ServerDisconnectedError('No connection')" + + def test_str(self) -> None: + err = client.ServerDisconnectedError() + assert str(err) == "Server disconnected" + + err = client.ServerDisconnectedError(message="No connection") + assert str(err) == "No connection" + + +class TestServerFingerprintMismatch: + def test_ctor(self) -> None: + err = client.ServerFingerprintMismatch( + expected=b"exp", got=b"got", host="example.com", port=8080 + ) + assert err.expected == b"exp" + assert err.got == b"got" + assert err.host == "example.com" + assert err.port == 8080 + + def test_pickle(self) -> None: + err = client.ServerFingerprintMismatch( + expected=b"exp", got=b"got", host="example.com", port=8080 + ) + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.expected == b"exp" + assert err2.got == b"got" + assert err2.host == "example.com" + assert err2.port == 8080 + assert err2.foo == "bar" + + def test_repr(self) -> None: + err = client.ServerFingerprintMismatch(b"exp", b"got", "example.com", 8080) + assert repr(err) == ( + "" + ) + + +class TestInvalidURL: + def test_ctor(self) -> None: + err = client.InvalidURL(url=":wrong:url:") + assert err.url == ":wrong:url:" + + def test_pickle(self) -> None: + err = client.InvalidURL(url=":wrong:url:") + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.url == ":wrong:url:" + assert err2.foo == "bar" + + def test_repr(self) -> None: + err = client.InvalidURL(url=":wrong:url:") + assert repr(err) == "" + + def test_str(self) -> None: + err = client.InvalidURL(url=":wrong:url:") + assert str(err) == ":wrong:url:" diff --git a/tests/test_client_fingerprint.py b/tests/test_client_fingerprint.py new file mode 100644 index 00000000000..753a9f367d8 --- /dev/null +++ b/tests/test_client_fingerprint.py @@ -0,0 +1,86 @@ +import hashlib +from unittest import mock + +import pytest + +import aiohttp +from aiohttp.client_reqrep import _merge_ssl_params + +ssl = pytest.importorskip("ssl") + + +def test_fingerprint_sha256() -> None: + sha256 = hashlib.sha256(b"12345678" * 64).digest() + fp = aiohttp.Fingerprint(sha256) + assert fp.fingerprint == sha256 + + +def test_fingerprint_sha1() -> None: + sha1 = hashlib.sha1(b"12345678" * 64).digest() + with pytest.raises(ValueError): + aiohttp.Fingerprint(sha1) + + +def test_fingerprint_md5() -> None: + md5 = hashlib.md5(b"12345678" * 64).digest() + with pytest.raises(ValueError): + aiohttp.Fingerprint(md5) + + +def test_fingerprint_check_no_ssl() -> None: + sha256 = hashlib.sha256(b"12345678" * 64).digest() + fp = aiohttp.Fingerprint(sha256) + transport = mock.Mock() + transport.get_extra_info.return_value = None + assert fp.check(transport) is None + + +def test__merge_ssl_params_verify_ssl() -> None: + with pytest.warns(DeprecationWarning): + assert _merge_ssl_params(None, False, None, None) is False + + +def test__merge_ssl_params_verify_ssl_conflict() -> None: + ctx = ssl.SSLContext() + with pytest.warns(DeprecationWarning): + with pytest.raises(ValueError): + _merge_ssl_params(ctx, False, None, None) + + +def test__merge_ssl_params_ssl_context() -> None: + ctx = ssl.SSLContext() + with pytest.warns(DeprecationWarning): + assert _merge_ssl_params(None, None, ctx, None) is ctx + + +def test__merge_ssl_params_ssl_context_conflict() -> None: + ctx1 = ssl.SSLContext() + ctx2 = ssl.SSLContext() + with pytest.warns(DeprecationWarning): + with pytest.raises(ValueError): + _merge_ssl_params(ctx1, None, ctx2, None) + + +def test__merge_ssl_params_fingerprint() -> None: + digest = hashlib.sha256(b"123").digest() + with pytest.warns(DeprecationWarning): + ret = _merge_ssl_params(None, None, None, digest) + assert ret.fingerprint == digest + + +def test__merge_ssl_params_fingerprint_conflict() -> None: + fingerprint = aiohttp.Fingerprint(hashlib.sha256(b"123").digest()) + ctx = ssl.SSLContext() + with pytest.warns(DeprecationWarning): + with pytest.raises(ValueError): + _merge_ssl_params(ctx, None, None, fingerprint) + + +def test__merge_ssl_params_ssl() -> None: + ctx = ssl.SSLContext() + assert ctx is _merge_ssl_params(ctx, None, None, None) + + +def test__merge_ssl_params_invlid() -> None: + with pytest.raises(TypeError): + _merge_ssl_params(object(), None, None, None) diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index d761faa24a7..6bd8d44bb5a 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -1,20 +1,23 @@ -"""HTTP client functional tests against aiohttp.web server""" +# HTTP client functional tests against aiohttp.web server import asyncio +import datetime import http.cookies import io import json import pathlib -import ssl +import socket from unittest import mock import pytest +from async_generator import async_generator, yield_ from multidict import MultiDict import aiohttp -from aiohttp import hdrs, web -from aiohttp.client import ServerFingerprintMismatch -from aiohttp.multipart import MultipartWriter +from aiohttp import Fingerprint, ServerFingerprintMismatch, hdrs, web +from aiohttp.abc import AbstractResolver +from aiohttp.client_exceptions import TooManyRedirects +from aiohttp.test_utils import unused_port @pytest.fixture @@ -23,1650 +26,1606 @@ def here(): @pytest.fixture -def ssl_ctx(here): - ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ssl_ctx.load_cert_chain( - str(here / 'sample.crt'), - str(here / 'sample.key')) - return ssl_ctx +def fname(here): + return here / "conftest.py" -@pytest.fixture -def fname(here): - return here / 'sample.key' +async def test_keepalive_two_requests_success(aiohttp_client) -> None: + async def handler(request): + body = await request.read() + assert b"" == body + return web.Response(body=b"OK") + + app = web.Application() + app.router.add_route("GET", "/", handler) + + connector = aiohttp.TCPConnector(limit=1) + client = await aiohttp_client(app, connector=connector) + + resp1 = await client.get("/") + await resp1.read() + resp2 = await client.get("/") + await resp2.read() + + assert 1 == len(client._session.connector._conns) + +async def test_keepalive_after_head_requests_success(aiohttp_client) -> None: + async def handler(request): + body = await request.read() + assert b"" == body + return web.Response(body=b"OK") -def ceil(val): - return val + cnt_conn_reuse = 0 + async def on_reuseconn(session, ctx, params): + nonlocal cnt_conn_reuse + cnt_conn_reuse += 1 -@asyncio.coroutine -def test_keepalive_two_requests_success(loop, test_client): - @asyncio.coroutine - def handler(request): - body = yield from request.read() - assert b'' == body - return web.Response(body=b'OK') + trace_config = aiohttp.TraceConfig() + trace_config._on_connection_reuseconn.append(on_reuseconn) app = web.Application() - app.router.add_route('GET', '/', handler) + app.router.add_route("GET", "/", handler) - connector = aiohttp.TCPConnector(loop=loop, limit=1) - client = yield from test_client(app, connector=connector) + connector = aiohttp.TCPConnector(limit=1) + client = await aiohttp_client( + app, connector=connector, trace_configs=[trace_config] + ) - resp1 = yield from client.get('/') - yield from resp1.read() - resp2 = yield from client.get('/') - yield from resp2.read() + resp1 = await client.head("/") + await resp1.read() + resp2 = await client.get("/") + await resp2.read() - assert 1 == len(client._session.connector._conns) + assert 1 == cnt_conn_reuse -@asyncio.coroutine -def test_keepalive_response_released(loop, test_client): - @asyncio.coroutine - def handler(request): - body = yield from request.read() - assert b'' == body - return web.Response(body=b'OK') +async def test_keepalive_response_released(aiohttp_client) -> None: + async def handler(request): + body = await request.read() + assert b"" == body + return web.Response(body=b"OK") app = web.Application() - app.router.add_route('GET', '/', handler) + app.router.add_route("GET", "/", handler) - connector = aiohttp.TCPConnector(loop=loop, limit=1) - client = yield from test_client(app, connector=connector) + connector = aiohttp.TCPConnector(limit=1) + client = await aiohttp_client(app, connector=connector) - resp1 = yield from client.get('/') + resp1 = await client.get("/") resp1.release() - resp2 = yield from client.get('/') + resp2 = await client.get("/") resp2.release() assert 1 == len(client._session.connector._conns) -@asyncio.coroutine -def test_keepalive_server_force_close_connection(loop, test_client): - @asyncio.coroutine - def handler(request): - body = yield from request.read() - assert b'' == body - response = web.Response(body=b'OK') +async def test_keepalive_server_force_close_connection(aiohttp_client) -> None: + async def handler(request): + body = await request.read() + assert b"" == body + response = web.Response(body=b"OK") response.force_close() return response app = web.Application() - app.router.add_route('GET', '/', handler) + app.router.add_route("GET", "/", handler) - connector = aiohttp.TCPConnector(loop=loop, limit=1) - client = yield from test_client(app, connector=connector) + connector = aiohttp.TCPConnector(limit=1) + client = await aiohttp_client(app, connector=connector) - resp1 = yield from client.get('/') + resp1 = await client.get("/") resp1.close() - resp2 = yield from client.get('/') + resp2 = await client.get("/") resp2.close() assert 0 == len(client._session.connector._conns) -@asyncio.coroutine -def test_release_early(loop, test_client): - @asyncio.coroutine - def handler(request): - yield from request.read() - return web.Response(body=b'OK') +async def test_release_early(aiohttp_client) -> None: + async def handler(request): + await request.read() + return web.Response(body=b"OK") app = web.Application() - app.router.add_route('GET', '/', handler) + app.router.add_route("GET", "/", handler) - client = yield from test_client(app) - resp = yield from client.get('/') + client = await aiohttp_client(app) + resp = await client.get("/") assert resp.closed assert 1 == len(client._session.connector._conns) -@asyncio.coroutine -def test_HTTP_304(loop, test_client): - @asyncio.coroutine - def handler(request): - body = yield from request.read() - assert b'' == body +async def test_HTTP_304(aiohttp_client) -> None: + async def handler(request): + body = await request.read() + assert b"" == body return web.Response(status=304) app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert resp.status == 304 - content = yield from resp.read() - assert content == b'' + content = await resp.read() + assert content == b"" -@asyncio.coroutine -def test_HTTP_304_WITH_BODY(loop, test_client): - @asyncio.coroutine - def handler(request): - body = yield from request.read() - assert b'' == body - return web.Response(body=b'test', status=304) +async def test_HTTP_304_WITH_BODY(aiohttp_client) -> None: + async def handler(request): + body = await request.read() + assert b"" == body + return web.Response(body=b"test", status=304) app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert resp.status == 304 - content = yield from resp.read() - assert content == b'' + content = await resp.read() + assert content == b"" -@asyncio.coroutine -def test_auto_header_user_agent(loop, test_client): - @asyncio.coroutine - def handler(request): - assert 'aiohttp' in request.headers['user-agent'] +async def test_auto_header_user_agent(aiohttp_client) -> None: + async def handler(request): + assert "aiohttp" in request.headers["user-agent"] return web.Response() app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') - assert 200, resp.status + resp = await client.get("/") + assert 200 == resp.status -@asyncio.coroutine -def test_skip_auto_headers_user_agent(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_skip_auto_headers_user_agent(aiohttp_client) -> None: + async def handler(request): assert hdrs.USER_AGENT not in request.headers return web.Response() app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/', - skip_auto_headers=['user-agent']) + resp = await client.get("/", skip_auto_headers=["user-agent"]) assert 200 == resp.status -@asyncio.coroutine -def test_skip_default_auto_headers_user_agent(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_skip_default_auto_headers_user_agent(aiohttp_client) -> None: + async def handler(request): assert hdrs.USER_AGENT not in request.headers return web.Response() app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app, skip_auto_headers=['user-agent']) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app, skip_auto_headers=["user-agent"]) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status -@asyncio.coroutine -def test_skip_auto_headers_content_type(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_skip_auto_headers_content_type(aiohttp_client) -> None: + async def handler(request): assert hdrs.CONTENT_TYPE not in request.headers return web.Response() app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/', - skip_auto_headers=['content-type']) + resp = await client.get("/", skip_auto_headers=["content-type"]) assert 200 == resp.status -@asyncio.coroutine -def test_post_data_bytesio(loop, test_client): - data = b'some buffer' +async def test_post_data_bytesio(aiohttp_client) -> None: + data = b"some buffer" - @asyncio.coroutine - def handler(request): + async def handler(request): assert len(data) == request.content_length - val = yield from request.read() + val = await request.read() assert data == val return web.Response() app = web.Application() - app.router.add_route('POST', '/', handler) - client = yield from test_client(app) + app.router.add_route("POST", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data=io.BytesIO(data)) + resp = await client.post("/", data=io.BytesIO(data)) assert 200 == resp.status -@asyncio.coroutine -def test_post_data_with_bytesio_file(loop, test_client): - data = b'some buffer' +async def test_post_data_with_bytesio_file(aiohttp_client) -> None: + data = b"some buffer" - @asyncio.coroutine - def handler(request): - post_data = yield from request.post() - assert ['file'] == list(post_data.keys()) - assert data == post_data['file'].file.read() + async def handler(request): + post_data = await request.post() + assert ["file"] == list(post_data.keys()) + assert data == post_data["file"].file.read() return web.Response() app = web.Application() - app.router.add_route('POST', '/', handler) - client = yield from test_client(app) + app.router.add_route("POST", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data={'file': io.BytesIO(data)}) + resp = await client.post("/", data={"file": io.BytesIO(data)}) assert 200 == resp.status -@asyncio.coroutine -def test_post_data_stringio(loop, test_client): - data = 'some buffer' +async def test_post_data_stringio(aiohttp_client) -> None: + data = "some buffer" - @asyncio.coroutine - def handler(request): + async def handler(request): assert len(data) == request.content_length - assert request.headers['CONTENT-TYPE'] == 'text/plain; charset=utf-8' - val = yield from request.text() + assert request.headers["CONTENT-TYPE"] == "text/plain; charset=utf-8" + val = await request.text() assert data == val return web.Response() app = web.Application() - app.router.add_route('POST', '/', handler) - client = yield from test_client(app) + app.router.add_route("POST", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data=io.StringIO(data)) + resp = await client.post("/", data=io.StringIO(data)) assert 200 == resp.status -@asyncio.coroutine -def test_post_data_textio_encoding(loop, test_client): - data = 'текст' +async def test_post_data_textio_encoding(aiohttp_client) -> None: + data = "текст" - @asyncio.coroutine - def handler(request): - assert request.headers['CONTENT-TYPE'] == 'text/plain; charset=koi8-r' - val = yield from request.text() + async def handler(request): + assert request.headers["CONTENT-TYPE"] == "text/plain; charset=koi8-r" + val = await request.text() assert data == val return web.Response() app = web.Application() - app.router.add_route('POST', '/', handler) - client = yield from test_client(app) + app.router.add_route("POST", "/", handler) + client = await aiohttp_client(app) - pl = aiohttp.TextIOPayload(io.StringIO(data), encoding='koi8-r') - resp = yield from client.post('/', data=pl) + pl = aiohttp.TextIOPayload(io.StringIO(data), encoding="koi8-r") + resp = await client.post("/", data=pl) assert 200 == resp.status -@asyncio.coroutine -def test_client_ssl(loop, ssl_ctx, test_server, test_client): - connector = aiohttp.TCPConnector(verify_ssl=False, loop=loop) +async def test_ssl_client( + aiohttp_server, + ssl_ctx, + aiohttp_client, + client_ssl_ctx, +) -> None: + connector = aiohttp.TCPConnector(ssl=client_ssl_ctx) - @asyncio.coroutine - def handler(request): - return web.HTTPOk(text='Test message') + async def handler(request): + return web.Response(text="Test message") app = web.Application() - app.router.add_route('GET', '/', handler) - server = yield from test_server(app, ssl=ssl_ctx) - client = yield from test_client(server, connector=connector) + app.router.add_route("GET", "/", handler) + server = await aiohttp_server(app, ssl=ssl_ctx) + client = await aiohttp_client(server, connector=connector) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - txt = yield from resp.text() - assert txt == 'Test message' - - -@pytest.mark.parametrize('fingerprint', [ - b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=', - b's\x93\xfd:\xed\x08\x1do\xa9\xaeq9\x1a\xe3\xc5\x7f\x89\xe7l\xf9', - b'0\x9a\xc9D\x83\xdc\x91\'\x88\x91\x11\xa1d\x97\xfd\xcb~7U\x14D@L' - b'\x11\xab\x99\xa8\xae\xb7\x14\xee\x8b'], - ids=['md5', 'sha1', 'sha256']) -@asyncio.coroutine -def test_tcp_connector_fingerprint_ok(test_server, test_client, - loop, ssl_ctx, fingerprint): - @asyncio.coroutine - def handler(request): - return web.HTTPOk(text='Test message') - - # Test for deprecation warning on md5 and sha1 len digests. - if len(fingerprint) == 16 or len(fingerprint) == 20: - with pytest.warns(DeprecationWarning) as cm: - connector = aiohttp.TCPConnector(loop=loop, verify_ssl=False, - fingerprint=fingerprint) - assert 'Use sha256.' in str(cm[0].message) - else: - connector = aiohttp.TCPConnector(loop=loop, verify_ssl=False, - fingerprint=fingerprint) - app = web.Application() - app.router.add_route('GET', '/', handler) - server = yield from test_server(app, ssl=ssl_ctx) - client = yield from test_client(server, connector=connector) - - resp = yield from client.get('/') + txt = await resp.text() + assert txt == "Test message" + + +async def test_tcp_connector_fingerprint_ok( + aiohttp_server, + aiohttp_client, + ssl_ctx, + tls_certificate_fingerprint_sha256, +): + tls_fingerprint = Fingerprint(tls_certificate_fingerprint_sha256) + + async def handler(request): + return web.Response(text="Test message") + + connector = aiohttp.TCPConnector(ssl=tls_fingerprint) + app = web.Application() + app.router.add_route("GET", "/", handler) + server = await aiohttp_server(app, ssl=ssl_ctx) + client = await aiohttp_client(server, connector=connector) + + resp = await client.get("/") assert resp.status == 200 resp.close() -@pytest.mark.parametrize('fingerprint', [ - b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=', - b's\x93\xfd:\xed\x08\x1do\xa9\xaeq9\x1a\xe3\xc5\x7f\x89\xe7l\xf9', - b'0\x9a\xc9D\x83\xdc\x91\'\x88\x91\x11\xa1d\x97\xfd\xcb~7U\x14D@L' - b'\x11\xab\x99\xa8\xae\xb7\x14\xee\x8b'], - ids=['md5', 'sha1', 'sha256']) -@asyncio.coroutine -def test_tcp_connector_fingerprint_fail(test_server, test_client, - loop, ssl_ctx, fingerprint): - @asyncio.coroutine - def handler(request): - return web.HTTPOk(text='Test message') +async def test_tcp_connector_fingerprint_fail( + aiohttp_server, + aiohttp_client, + ssl_ctx, + tls_certificate_fingerprint_sha256, +): + async def handler(request): + return web.Response(text="Test message") - bad_fingerprint = b'\x00' * len(fingerprint) + bad_fingerprint = b"\x00" * len(tls_certificate_fingerprint_sha256) - connector = aiohttp.TCPConnector(loop=loop, verify_ssl=False, - fingerprint=bad_fingerprint) + connector = aiohttp.TCPConnector(ssl=Fingerprint(bad_fingerprint)) app = web.Application() - app.router.add_route('GET', '/', handler) - server = yield from test_server(app, ssl=ssl_ctx) - client = yield from test_client(server, connector=connector) + app.router.add_route("GET", "/", handler) + server = await aiohttp_server(app, ssl=ssl_ctx) + client = await aiohttp_client(server, connector=connector) with pytest.raises(ServerFingerprintMismatch) as cm: - yield from client.get('/') + await client.get("/") exc = cm.value assert exc.expected == bad_fingerprint - assert exc.got == fingerprint + assert exc.got == tls_certificate_fingerprint_sha256 -@asyncio.coroutine -def test_format_task_get(test_server, loop): +async def test_format_task_get(aiohttp_server) -> None: + loop = asyncio.get_event_loop() - @asyncio.coroutine - def handler(request): - return web.Response(body=b'OK') + async def handler(request): + return web.Response(body=b"OK") app = web.Application() - app.router.add_route('GET', '/', handler) - server = yield from test_server(app) - client = aiohttp.ClientSession(loop=loop) - task = loop.create_task(client.get(server.make_url('/'))) - assert "{}".format(task)[:18] == " None: + async def handler(request): + assert "q=t est" in request.rel_url.query_string return web.Response() app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/', params='q=t+est') + resp = await client.get("/", params="q=t+est") assert 200 == resp.status -@asyncio.coroutine -def test_drop_params_on_redirect(loop, test_client): - @asyncio.coroutine - def handler_redirect(request): - return web.Response(status=301, headers={'Location': '/ok?a=redirect'}) +async def test_drop_params_on_redirect(aiohttp_client) -> None: + async def handler_redirect(request): + return web.Response(status=301, headers={"Location": "/ok?a=redirect"}) - @asyncio.coroutine - def handler_ok(request): - assert request.rel_url.query_string == 'a=redirect' + async def handler_ok(request): + assert request.rel_url.query_string == "a=redirect" return web.Response(status=200) app = web.Application() - app.router.add_route('GET', '/ok', handler_ok) - app.router.add_route('GET', '/redirect', handler_redirect) - client = yield from test_client(app) + app.router.add_route("GET", "/ok", handler_ok) + app.router.add_route("GET", "/redirect", handler_redirect) + client = await aiohttp_client(app) - resp = yield from client.get('/redirect', params={'a': 'initial'}) + resp = await client.get("/redirect", params={"a": "initial"}) assert resp.status == 200 -@asyncio.coroutine -def test_drop_fragment_on_redirect(loop, test_client): - @asyncio.coroutine - def handler_redirect(request): - return web.Response(status=301, headers={'Location': '/ok#fragment'}) +async def test_drop_fragment_on_redirect(aiohttp_client) -> None: + async def handler_redirect(request): + return web.Response(status=301, headers={"Location": "/ok#fragment"}) - @asyncio.coroutine - def handler_ok(request): + async def handler_ok(request): return web.Response(status=200) app = web.Application() - app.router.add_route('GET', '/ok', handler_ok) - app.router.add_route('GET', '/redirect', handler_redirect) - client = yield from test_client(app) + app.router.add_route("GET", "/ok", handler_ok) + app.router.add_route("GET", "/redirect", handler_redirect) + client = await aiohttp_client(app) - resp = yield from client.get('/redirect') + resp = await client.get("/redirect") assert resp.status == 200 - assert resp.url.path == '/ok' + assert resp.url.path == "/ok" -@asyncio.coroutine -def test_drop_fragment(loop, test_client): - @asyncio.coroutine - def handler_ok(request): +async def test_drop_fragment(aiohttp_client) -> None: + async def handler_ok(request): return web.Response(status=200) app = web.Application() - app.router.add_route('GET', '/ok', handler_ok) - client = yield from test_client(app) + app.router.add_route("GET", "/ok", handler_ok) + client = await aiohttp_client(app) - resp = yield from client.get('/ok#fragment') + resp = await client.get("/ok#fragment") assert resp.status == 200 - assert resp.url.path == '/ok' + assert resp.url.path == "/ok" -@asyncio.coroutine -def test_history(loop, test_client): - @asyncio.coroutine - def handler_redirect(request): - return web.Response(status=301, headers={'Location': '/ok'}) +async def test_history(aiohttp_client) -> None: + async def handler_redirect(request): + return web.Response(status=301, headers={"Location": "/ok"}) - @asyncio.coroutine - def handler_ok(request): + async def handler_ok(request): return web.Response(status=200) app = web.Application() - app.router.add_route('GET', '/ok', handler_ok) - app.router.add_route('GET', '/redirect', handler_redirect) - client = yield from test_client(app) + app.router.add_route("GET", "/ok", handler_ok) + app.router.add_route("GET", "/redirect", handler_redirect) + client = await aiohttp_client(app) - resp = yield from client.get('/ok') + resp = await client.get("/ok") assert len(resp.history) == 0 assert resp.status == 200 - resp_redirect = yield from client.get('/redirect') + resp_redirect = await client.get("/redirect") assert len(resp_redirect.history) == 1 assert resp_redirect.history[0].status == 301 assert resp_redirect.status == 200 -@asyncio.coroutine -def test_keepalive_closed_by_server(loop, test_client): - @asyncio.coroutine - def handler(request): - body = yield from request.read() - assert b'' == body - resp = web.Response(body=b'OK') +async def test_keepalive_closed_by_server(aiohttp_client) -> None: + async def handler(request): + body = await request.read() + assert b"" == body + resp = web.Response(body=b"OK") resp.force_close() return resp app = web.Application() - app.router.add_route('GET', '/', handler) + app.router.add_route("GET", "/", handler) - connector = aiohttp.TCPConnector(loop=loop, limit=1) - client = yield from test_client(app, connector=connector) + connector = aiohttp.TCPConnector(limit=1) + client = await aiohttp_client(app, connector=connector) - resp1 = yield from client.get('/') - val1 = yield from resp1.read() - assert val1 == b'OK' - resp2 = yield from client.get('/') - val2 = yield from resp2.read() - assert val2 == b'OK' + resp1 = await client.get("/") + val1 = await resp1.read() + assert val1 == b"OK" + resp2 = await client.get("/") + val2 = await resp2.read() + assert val2 == b"OK" assert 0 == len(client._session.connector._conns) -@asyncio.coroutine -def test_wait_for(loop, test_client): - @asyncio.coroutine - def handler(request): - return web.Response(body=b'OK') +async def test_wait_for(aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"OK") app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + resp = await asyncio.wait_for(client.get("/"), 10) + assert resp.status == 200 + txt = await resp.text() + assert txt == "OK" - resp = yield from asyncio.wait_for(client.get('/'), 10, loop=loop) + +async def test_raw_headers(aiohttp_client) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") assert resp.status == 200 - txt = yield from resp.text() - assert txt == 'OK' + + raw_headers = tuple((bytes(h), bytes(v)) for h, v in resp.raw_headers) + assert raw_headers == ( + (b"Content-Length", b"0"), + (b"Content-Type", b"application/octet-stream"), + (b"Date", mock.ANY), + (b"Server", mock.ANY), + ) + resp.close() -@asyncio.coroutine -def test_raw_headers(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_host_header_first(aiohttp_client) -> None: + async def handler(request): + assert list(request.headers)[0] == hdrs.HOST return web.Response() app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.get('/') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") + assert resp.status == 200 + + +async def test_empty_header_values(aiohttp_client) -> None: + async def handler(request): + resp = web.Response() + resp.headers["X-Empty"] = "" + return resp + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") assert resp.status == 200 - assert resp.raw_headers == ((b'Content-Length', b'0'), - (b'Content-Type', b'application/octet-stream'), - (b'Date', mock.ANY), - (b'Server', mock.ANY)) + raw_headers = tuple((bytes(h), bytes(v)) for h, v in resp.raw_headers) + assert raw_headers == ( + (b"X-Empty", b""), + (b"Content-Length", b"0"), + (b"Content-Type", b"application/octet-stream"), + (b"Date", mock.ANY), + (b"Server", mock.ANY), + ) resp.close() -@asyncio.coroutine -def test_204_with_gzipped_content_encoding(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_204_with_gzipped_content_encoding(aiohttp_client) -> None: + async def handler(request): resp = web.StreamResponse(status=204) resp.content_length = 0 - resp.content_type = 'application/json' + resp.content_type = "application/json" # resp.enable_compression(web.ContentCoding.gzip) - resp.headers['Content-Encoding'] = 'gzip' - yield from resp.prepare(request) + resp.headers["Content-Encoding"] = "gzip" + await resp.prepare(request) return resp app = web.Application() - app.router.add_route('DELETE', '/', handler) - client = yield from test_client(app) + app.router.add_route("DELETE", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.delete('/') + resp = await client.delete("/") assert resp.status == 204 assert resp.closed -@asyncio.coroutine -def test_timeout_on_reading_headers(loop, test_client, mocker): - mocker.patch('aiohttp.helpers.ceil').side_effect = ceil - - @asyncio.coroutine - def handler(request): +async def test_timeout_on_reading_headers(aiohttp_client, mocker) -> None: + async def handler(request): resp = web.StreamResponse() - yield from asyncio.sleep(0.1, loop=loop) - yield from resp.prepare(request) + await asyncio.sleep(0.1) + await resp.prepare(request) return resp app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) with pytest.raises(asyncio.TimeoutError): - yield from client.get('/', timeout=0.01) + await client.get("/", timeout=0.01) -@asyncio.coroutine -def test_timeout_on_conn_reading_headers(loop, test_client, mocker): +async def test_timeout_on_conn_reading_headers(aiohttp_client, mocker) -> None: # tests case where user did not set a connection timeout - mocker.patch('aiohttp.helpers.ceil').side_effect = ceil - - @asyncio.coroutine - def handler(request): + async def handler(request): resp = web.StreamResponse() - yield from asyncio.sleep(0.1, loop=loop) - yield from resp.prepare(request) + await asyncio.sleep(0.1) + await resp.prepare(request) return resp app = web.Application() - app.router.add_route('GET', '/', handler) + app.router.add_route("GET", "/", handler) - conn = aiohttp.TCPConnector(loop=loop) - client = yield from test_client(app, connector=conn) + conn = aiohttp.TCPConnector() + client = await aiohttp_client(app, connector=conn) with pytest.raises(asyncio.TimeoutError): - yield from client.get('/', timeout=0.01) + await client.get("/", timeout=0.01) -@asyncio.coroutine -def test_timeout_on_session_read_timeout(loop, test_client, mocker): - mocker.patch('aiohttp.helpers.ceil').side_effect = ceil - - @asyncio.coroutine - def handler(request): +async def test_timeout_on_session_read_timeout(aiohttp_client, mocker) -> None: + async def handler(request): resp = web.StreamResponse() - yield from asyncio.sleep(0.1, loop=loop) - yield from resp.prepare(request) + await asyncio.sleep(0.1) + await resp.prepare(request) return resp app = web.Application() - app.router.add_route('GET', '/', handler) + app.router.add_route("GET", "/", handler) - conn = aiohttp.TCPConnector(loop=loop) - client = yield from test_client(app, connector=conn, read_timeout=0.01) + conn = aiohttp.TCPConnector() + client = await aiohttp_client( + app, connector=conn, timeout=aiohttp.ClientTimeout(sock_read=0.01) + ) with pytest.raises(asyncio.TimeoutError): - yield from client.get('/', timeout=None) + await client.get("/") + + +async def test_read_timeout_between_chunks(aiohttp_client, mocker) -> None: + async def handler(request): + resp = aiohttp.web.StreamResponse() + await resp.prepare(request) + # write data 4 times, with pauses. Total time 2 seconds. + for _ in range(4): + await asyncio.sleep(0.5) + await resp.write(b"data\n") + return resp + + app = web.Application() + app.add_routes([web.get("/", handler)]) + # A timeout of 0.2 seconds should apply per read. + timeout = aiohttp.ClientTimeout(sock_read=1) + client = await aiohttp_client(app, timeout=timeout) -@asyncio.coroutine -def test_timeout_on_reading_data(loop, test_client): + res = b"" + async with await client.get("/") as resp: + res += await resp.read() - @asyncio.coroutine - def handler(request): - resp = web.StreamResponse(headers={'content-length': '100'}) - yield from resp.prepare(request) - yield from resp.drain() - yield from asyncio.sleep(0.2, loop=loop) + assert res == b"data\n" * 4 + + +async def test_read_timeout_on_reading_chunks(aiohttp_client, mocker) -> None: + async def handler(request): + resp = aiohttp.web.StreamResponse() + await resp.prepare(request) + await resp.write(b"data\n") + await asyncio.sleep(1) + await resp.write(b"data\n") return resp app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.add_routes([web.get("/", handler)]) + + # A timeout of 0.2 seconds should apply per read. + timeout = aiohttp.ClientTimeout(sock_read=0.2) + client = await aiohttp_client(app, timeout=timeout) + + async with await client.get("/") as resp: + assert (await resp.content.read(5)) == b"data\n" + with pytest.raises(asyncio.TimeoutError): + await resp.content.read() - resp = yield from client.get('/', timeout=0.05) + +async def test_timeout_on_reading_data(aiohttp_client, mocker) -> None: + loop = asyncio.get_event_loop() + + fut = loop.create_future() + + async def handler(request): + resp = web.StreamResponse(headers={"content-length": "100"}) + await resp.prepare(request) + fut.set_result(None) + await asyncio.sleep(0.2) + return resp + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/", timeout=1) + await fut with pytest.raises(asyncio.TimeoutError): - yield from resp.read() + await resp.read() -@asyncio.coroutine -def test_timeout_none(loop, test_client, mocker): - @asyncio.coroutine - def handler(request): +async def test_timeout_none(aiohttp_client, mocker) -> None: + async def handler(request): resp = web.StreamResponse() - yield from resp.prepare(request) + await resp.prepare(request) return resp app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/', timeout=None) + resp = await client.get("/", timeout=None) assert resp.status == 200 -@asyncio.coroutine -def test_readline_error_on_conn_close(loop, test_client): +async def test_readline_error_on_conn_close(aiohttp_client) -> None: + loop = asyncio.get_event_loop() - @asyncio.coroutine - def handler(request): + async def handler(request): resp_ = web.StreamResponse() - yield from resp_.prepare(request) + await resp_.prepare(request) # make sure connection is closed by client. with pytest.raises(aiohttp.ServerDisconnectedError): for _ in range(10): - resp_.write(b'data\n') - yield from resp_.drain() - yield from asyncio.sleep(0.5, loop=loop) + await resp_.write(b"data\n") + await asyncio.sleep(0.5) return resp_ app = web.Application() - app.router.add_route('GET', '/', handler) - server = yield from test_client(app) + app.router.add_route("GET", "/", handler) + server = await aiohttp_client(app) - with aiohttp.ClientSession(loop=loop) as session: + session = aiohttp.ClientSession() + try: timer_started = False - url, headers = server.make_url('/'), {'Connection': 'Keep-alive'} - resp = yield from session.get(url, headers=headers) + url, headers = server.make_url("/"), {"Connection": "Keep-alive"} + resp = await session.get(url, headers=headers) with pytest.raises(aiohttp.ClientConnectionError): while True: - data = yield from resp.content.readline() + data = await resp.content.readline() data = data.strip() if not data: break - assert data == b'data' + assert data == b"data" if not timer_started: + def do_release(): loop.create_task(resp.release()) + loop.call_later(1.0, do_release) timer_started = True + finally: + await session.close() -@asyncio.coroutine -def test_no_error_on_conn_close_if_eof(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_no_error_on_conn_close_if_eof(aiohttp_client) -> None: + async def handler(request): resp_ = web.StreamResponse() - yield from resp_.prepare(request) - resp_.write(b'data\n') - yield from resp_.drain() - yield from asyncio.sleep(0.5, loop=loop) + await resp_.prepare(request) + await resp_.write(b"data\n") + await asyncio.sleep(0.5) return resp_ app = web.Application() - app.router.add_route('GET', '/', handler) - server = yield from test_client(app) + app.router.add_route("GET", "/", handler) + server = await aiohttp_client(app) - with aiohttp.ClientSession(loop=loop) as session: - url, headers = server.make_url('/'), {'Connection': 'Keep-alive'} - resp = yield from session.get(url, headers=headers) + session = aiohttp.ClientSession() + try: + url, headers = server.make_url("/"), {"Connection": "Keep-alive"} + resp = await session.get(url, headers=headers) while True: - data = yield from resp.content.readline() + data = await resp.content.readline() data = data.strip() if not data: break - assert data == b'data' + assert data == b"data" assert resp.content.exception() is None + finally: + await session.close() -@asyncio.coroutine -def test_error_not_overwrote_on_conn_close(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_error_not_overwrote_on_conn_close(aiohttp_client) -> None: + async def handler(request): resp_ = web.StreamResponse() - yield from resp_.prepare(request) + await resp_.prepare(request) return resp_ app = web.Application() - app.router.add_route('GET', '/', handler) - server = yield from test_client(app) + app.router.add_route("GET", "/", handler) + server = await aiohttp_client(app) - with aiohttp.ClientSession(loop=loop) as session: - url, headers = server.make_url('/'), {'Connection': 'Keep-alive'} - resp = yield from session.get(url, headers=headers) + session = aiohttp.ClientSession() + try: + url, headers = server.make_url("/"), {"Connection": "Keep-alive"} + resp = await session.get(url, headers=headers) resp.content.set_exception(ValueError()) + finally: + await session.close() assert isinstance(resp.content.exception(), ValueError) -@asyncio.coroutine -def test_HTTP_200_OK_METHOD(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_HTTP_200_OK_METHOD(aiohttp_client) -> None: + async def handler(request): return web.Response(text=request.method) app = web.Application() - for meth in ('get', 'post', 'put', 'delete', 'head', 'patch', 'options'): - app.router.add_route(meth.upper(), '/', handler) + for meth in ("get", "post", "put", "delete", "head", "patch", "options"): + app.router.add_route(meth.upper(), "/", handler) - client = yield from test_client(app) - for meth in ('get', 'post', 'put', 'delete', 'head', 'patch', 'options'): - resp = yield from client.request(meth, '/') + client = await aiohttp_client(app) + for meth in ("get", "post", "put", "delete", "head", "patch", "options"): + resp = await client.request(meth, "/") assert resp.status == 200 assert len(resp.history) == 0 - content1 = yield from resp.read() - content2 = yield from resp.read() + content1 = await resp.read() + content2 = await resp.read() assert content1 == content2 - content = yield from resp.text() + content = await resp.text() - if meth == 'head': - assert b'' == content1 + if meth == "head": + assert b"" == content1 else: assert meth.upper() == content -@asyncio.coroutine -def test_HTTP_200_OK_METHOD_connector(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_HTTP_200_OK_METHOD_connector(aiohttp_client) -> None: + async def handler(request): return web.Response(text=request.method) - conn = aiohttp.TCPConnector(resolve=True, loop=loop) + conn = aiohttp.TCPConnector() conn.clear_dns_cache() app = web.Application() - for meth in ('get', 'post', 'put', 'delete', 'head'): - app.router.add_route(meth.upper(), '/', handler) - client = yield from test_client(app, connector=conn, conn_timeout=0.2) + for meth in ("get", "post", "put", "delete", "head"): + app.router.add_route(meth.upper(), "/", handler) + client = await aiohttp_client(app, connector=conn) - for meth in ('get', 'post', 'put', 'delete', 'head'): - resp = yield from client.request(meth, '/') + for meth in ("get", "post", "put", "delete", "head"): + resp = await client.request(meth, "/") - content1 = yield from resp.read() - content2 = yield from resp.read() + content1 = await resp.read() + content2 = await resp.read() assert content1 == content2 - content = yield from resp.text() + content = await resp.text() assert resp.status == 200 - if meth == 'head': - assert b'' == content1 + if meth == "head": + assert b"" == content1 else: assert meth.upper() == content -@asyncio.coroutine -def test_HTTP_302_REDIRECT_GET(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_HTTP_302_REDIRECT_GET(aiohttp_client) -> None: + async def handler(request): return web.Response(text=request.method) - @asyncio.coroutine - def redirect(request): - return web.HTTPFound(location='/') + async def redirect(request): + raise web.HTTPFound(location="/") app = web.Application() - app.router.add_get('/', handler) - app.router.add_get('/redirect', redirect) - client = yield from test_client(app) + app.router.add_get("/", handler) + app.router.add_get("/redirect", redirect) + client = await aiohttp_client(app) - resp = yield from client.get('/redirect') + resp = await client.get("/redirect") assert 200 == resp.status assert 1 == len(resp.history) resp.close() -@asyncio.coroutine -def test_HTTP_302_REDIRECT_HEAD(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_HTTP_302_REDIRECT_HEAD(aiohttp_client) -> None: + async def handler(request): return web.Response(text=request.method) - @asyncio.coroutine - def redirect(request): - return web.HTTPFound(location='/') + async def redirect(request): + raise web.HTTPFound(location="/") app = web.Application() - app.router.add_get('/', handler) - app.router.add_get('/redirect', redirect) - app.router.add_head('/', handler) - app.router.add_head('/redirect', redirect) - client = yield from test_client(app) + app.router.add_get("/", handler) + app.router.add_get("/redirect", redirect) + app.router.add_head("/", handler) + app.router.add_head("/redirect", redirect) + client = await aiohttp_client(app) - resp = yield from client.request('head', '/redirect') + resp = await client.request("head", "/redirect") assert 200 == resp.status assert 1 == len(resp.history) - assert resp.method == 'HEAD' + assert resp.method == "HEAD" resp.close() -@asyncio.coroutine -def test_HTTP_302_REDIRECT_NON_HTTP(loop, test_client): - - @asyncio.coroutine - def redirect(request): - return web.HTTPFound(location='ftp://127.0.0.1/test/') +async def test_HTTP_302_REDIRECT_NON_HTTP(aiohttp_client) -> None: + async def redirect(request): + raise web.HTTPFound(location="ftp://127.0.0.1/test/") app = web.Application() - app.router.add_get('/redirect', redirect) - client = yield from test_client(app) + app.router.add_get("/redirect", redirect) + client = await aiohttp_client(app) with pytest.raises(ValueError): - yield from client.get('/redirect') + await client.get("/redirect") + + +async def test_HTTP_302_REDIRECT_POST(aiohttp_client) -> None: + async def handler(request): + return web.Response(text=request.method) + + async def redirect(request): + raise web.HTTPFound(location="/") + + app = web.Application() + app.router.add_get("/", handler) + app.router.add_post("/redirect", redirect) + client = await aiohttp_client(app) + + resp = await client.post("/redirect") + assert 200 == resp.status + assert 1 == len(resp.history) + txt = await resp.text() + assert txt == "GET" + resp.close() -@asyncio.coroutine -def test_HTTP_302_REDIRECT_POST(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_HTTP_302_REDIRECT_POST_with_content_length_hdr(aiohttp_client) -> None: + async def handler(request): return web.Response(text=request.method) - @asyncio.coroutine - def redirect(request): - return web.HTTPFound(location='/') + async def redirect(request): + await request.read() + raise web.HTTPFound(location="/") + data = json.dumps({"some": "data"}) app = web.Application() - app.router.add_get('/', handler) - app.router.add_post('/redirect', redirect) - client = yield from test_client(app) + app.router.add_get("/", handler) + app.router.add_post("/redirect", redirect) + client = await aiohttp_client(app) - resp = yield from client.post('/redirect') + resp = await client.post( + "/redirect", data=data, headers={"Content-Length": str(len(data))} + ) assert 200 == resp.status assert 1 == len(resp.history) - txt = yield from resp.text() - assert txt == 'GET' + txt = await resp.text() + assert txt == "GET" resp.close() -@asyncio.coroutine -def test_HTTP_302_REDIRECT_POST_with_content_length_header(loop, - test_client): - @asyncio.coroutine - def handler(request): +async def test_HTTP_307_REDIRECT_POST(aiohttp_client) -> None: + async def handler(request): return web.Response(text=request.method) - @asyncio.coroutine - def redirect(request): - yield from request.read() - return web.HTTPFound(location='/') + async def redirect(request): + await request.read() + raise web.HTTPTemporaryRedirect(location="/") - data = json.dumps({'some': 'data'}) - app = web.Application(debug=True) - app.router.add_get('/', handler) - app.router.add_post('/redirect', redirect) - client = yield from test_client(app) + app = web.Application() + app.router.add_post("/", handler) + app.router.add_post("/redirect", redirect) + client = await aiohttp_client(app) - resp = yield from client.post('/redirect', data=data, - headers={'Content-Length': str(len(data))}) + resp = await client.post("/redirect", data={"some": "data"}) assert 200 == resp.status assert 1 == len(resp.history) - txt = yield from resp.text() - assert txt == 'GET' + txt = await resp.text() + assert txt == "POST" resp.close() -@asyncio.coroutine -def test_HTTP_307_REDIRECT_POST(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_HTTP_308_PERMANENT_REDIRECT_POST(aiohttp_client) -> None: + async def handler(request): return web.Response(text=request.method) - @asyncio.coroutine - def redirect(request): - yield from request.read() - return web.HTTPTemporaryRedirect(location='/') + async def redirect(request): + await request.read() + raise web.HTTPPermanentRedirect(location="/") app = web.Application() - app.router.add_post('/', handler) - app.router.add_post('/redirect', redirect) - client = yield from test_client(app) + app.router.add_post("/", handler) + app.router.add_post("/redirect", redirect) + client = await aiohttp_client(app) - resp = yield from client.post('/redirect', data={'some': 'data'}) + resp = await client.post("/redirect", data={"some": "data"}) assert 200 == resp.status assert 1 == len(resp.history) - txt = yield from resp.text() - assert txt == 'POST' + txt = await resp.text() + assert txt == "POST" resp.close() -@asyncio.coroutine -def test_HTTP_302_max_redirects(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_HTTP_302_max_redirects(aiohttp_client) -> None: + async def handler(request): return web.Response(text=request.method) - @asyncio.coroutine - def redirect(request): - count = int(request.match_info['count']) + async def redirect(request): + count = int(request.match_info["count"]) if count: - return web.HTTPFound(location='/redirect/{}'.format(count-1)) + raise web.HTTPFound(location="/redirect/{}".format(count - 1)) else: - return web.HTTPFound(location='/') + raise web.HTTPFound(location="/") app = web.Application() - app.router.add_get('/', handler) - app.router.add_get(r'/redirect/{count:\d+}', redirect) - client = yield from test_client(app) + app.router.add_get("/", handler) + app.router.add_get(r"/redirect/{count:\d+}", redirect) + client = await aiohttp_client(app) - resp = yield from client.get('/redirect/5', max_redirects=2) - assert 302 == resp.status - assert 2 == len(resp.history) - resp.close() + with pytest.raises(TooManyRedirects) as ctx: + await client.get("/redirect/5", max_redirects=2) + assert 2 == len(ctx.value.history) + assert ctx.value.request_info.url.path == "/redirect/5" + assert ctx.value.request_info.method == "GET" -@asyncio.coroutine -def test_HTTP_200_GET_WITH_PARAMS(loop, test_client): - @asyncio.coroutine - def handler(request): - return web.Response(text='&'.join( - k+'='+v for k, v in request.query.items())) +async def test_HTTP_200_GET_WITH_PARAMS(aiohttp_client) -> None: + async def handler(request): + return web.Response( + text="&".join(k + "=" + v for k, v in request.query.items()) + ) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/', params={'q': 'test'}) + resp = await client.get("/", params={"q": "test"}) assert 200 == resp.status - txt = yield from resp.text() - assert txt == 'q=test' + txt = await resp.text() + assert txt == "q=test" resp.close() -@asyncio.coroutine -def test_HTTP_200_GET_WITH_MultiDict_PARAMS(loop, test_client): - @asyncio.coroutine - def handler(request): - return web.Response(text='&'.join( - k+'='+v for k, v in request.query.items())) +async def test_HTTP_200_GET_WITH_MultiDict_PARAMS(aiohttp_client) -> None: + async def handler(request): + return web.Response( + text="&".join(k + "=" + v for k, v in request.query.items()) + ) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/', params=MultiDict([('q', 'test'), - ('q', 'test2')])) + resp = await client.get("/", params=MultiDict([("q", "test"), ("q", "test2")])) assert 200 == resp.status - txt = yield from resp.text() - assert txt == 'q=test&q=test2' + txt = await resp.text() + assert txt == "q=test&q=test2" resp.close() -@asyncio.coroutine -def test_HTTP_200_GET_WITH_MIXED_PARAMS(loop, test_client): - @asyncio.coroutine - def handler(request): - return web.Response(text='&'.join( - k+'='+v for k, v in request.query.items())) +async def test_HTTP_200_GET_WITH_MIXED_PARAMS(aiohttp_client) -> None: + async def handler(request): + return web.Response( + text="&".join(k + "=" + v for k, v in request.query.items()) + ) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/?test=true', params={'q': 'test'}) + resp = await client.get("/?test=true", params={"q": "test"}) assert 200 == resp.status - txt = yield from resp.text() - assert txt == 'test=true&q=test' + txt = await resp.text() + assert txt == "test=true&q=test" resp.close() -@asyncio.coroutine -def test_POST_DATA(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() +async def test_POST_DATA(aiohttp_client) -> None: + async def handler(request): + data = await request.post() return web.json_response(dict(data)) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data={'some': 'data'}) + resp = await client.post("/", data={"some": "data"}) assert 200 == resp.status - content = yield from resp.json() - assert content == {'some': 'data'} + content = await resp.json() + assert content == {"some": "data"} resp.close() -@asyncio.coroutine -def test_POST_DATA_with_explicit_formdata(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() +async def test_POST_DATA_with_explicit_formdata(aiohttp_client) -> None: + async def handler(request): + data = await request.post() return web.json_response(dict(data)) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) form = aiohttp.FormData() - form.add_field('name', 'text') + form.add_field("name", "text") - resp = yield from client.post('/', data=form) + resp = await client.post("/", data=form) assert 200 == resp.status - content = yield from resp.json() - assert content == {'name': 'text'} + content = await resp.json() + assert content == {"name": "text"} resp.close() -@asyncio.coroutine -def test_POST_DATA_with_charset(loop, test_client): - @asyncio.coroutine - def handler(request): - mp = yield from request.multipart() - part = yield from mp.next() - text = yield from part.text() +async def test_POST_DATA_with_charset(aiohttp_client) -> None: + async def handler(request): + mp = await request.multipart() + part = await mp.next() + text = await part.text() return web.Response(text=text) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) form = aiohttp.FormData() - form.add_field('name', 'текст', content_type='text/plain; charset=koi8-r') + form.add_field("name", "текст", content_type="text/plain; charset=koi8-r") - resp = yield from client.post('/', data=form) + resp = await client.post("/", data=form) assert 200 == resp.status - content = yield from resp.text() - assert content == 'текст' + content = await resp.text() + assert content == "текст" resp.close() -@asyncio.coroutine -def test_POST_DATA_formdats_with_charset(loop, test_client): - @asyncio.coroutine - def handler(request): - mp = yield from request.post() - assert 'name' in mp - from pprint import pprint - pprint(dict(request.headers)) - return web.Response(text=mp['name']) +async def test_POST_DATA_formdats_with_charset(aiohttp_client) -> None: + async def handler(request): + mp = await request.post() + assert "name" in mp + return web.Response(text=mp["name"]) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - form = aiohttp.FormData(charset='koi8-r') - form.add_field('name', 'текст') + form = aiohttp.FormData(charset="koi8-r") + form.add_field("name", "текст") - resp = yield from client.post('/', data=form) + resp = await client.post("/", data=form) assert 200 == resp.status - content = yield from resp.text() - assert content == 'текст' + content = await resp.text() + assert content == "текст" resp.close() -@asyncio.coroutine -def test_POST_DATA_with_charset_post(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - return web.Response(text=data['name']) +async def test_POST_DATA_with_charset_post(aiohttp_client) -> None: + async def handler(request): + data = await request.post() + return web.Response(text=data["name"]) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) form = aiohttp.FormData() - form.add_field('name', 'текст', content_type='text/plain; charset=koi8-r') + form.add_field("name", "текст", content_type="text/plain; charset=koi8-r") - resp = yield from client.post('/', data=form) + resp = await client.post("/", data=form) assert 200 == resp.status - content = yield from resp.text() - assert content == 'текст' + content = await resp.text() + assert content == "текст" resp.close() -@asyncio.coroutine -def test_POST_DATA_with_context_transfer_encoding(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert data['name'] == 'text' - return web.Response(text=data['name']) +async def test_POST_DATA_with_context_transfer_encoding(aiohttp_client) -> None: + async def handler(request): + data = await request.post() + assert data["name"] == "text" + return web.Response(text=data["name"]) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) form = aiohttp.FormData() - form.add_field('name', 'text', content_transfer_encoding='base64') + form.add_field("name", "text", content_transfer_encoding="base64") - resp = yield from client.post('/', data=form) + resp = await client.post("/", data=form) assert 200 == resp.status - content = yield from resp.text() - assert content == 'text' + content = await resp.text() + assert content == "text" resp.close() -@asyncio.coroutine -def test_POST_DATA_with_content_type_context_transfer_encoding( - loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert data['name'] == 'text' - return web.Response(body=data['name']) +async def test_POST_DATA_with_content_type_context_transfer_encoding(aiohttp_client): + async def handler(request): + data = await request.post() + assert data["name"] == "text" + return web.Response(body=data["name"]) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) form = aiohttp.FormData() - form.add_field('name', 'text', - content_type='text/plain', - content_transfer_encoding='base64') + form.add_field( + "name", "text", content_type="text/plain", content_transfer_encoding="base64" + ) - resp = yield from client.post('/', data=form) + resp = await client.post("/", data=form) assert 200 == resp.status - content = yield from resp.text() - assert content == 'text' + content = await resp.text() + assert content == "text" resp.close() -@asyncio.coroutine -def test_POST_MultiDict(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert data == MultiDict([('q', 'test1'), ('q', 'test2')]) +async def test_POST_MultiDict(aiohttp_client) -> None: + async def handler(request): + data = await request.post() + assert data == MultiDict([("q", "test1"), ("q", "test2")]) return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data=MultiDict( - [('q', 'test1'), ('q', 'test2')])) + resp = await client.post("/", data=MultiDict([("q", "test1"), ("q", "test2")])) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_DATA_DEFLATE(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() +async def test_POST_DATA_DEFLATE(aiohttp_client) -> None: + async def handler(request): + data = await request.post() return web.json_response(dict(data)) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data={'some': 'data'}, compress=True) + resp = await client.post("/", data={"some": "data"}, compress=True) assert 200 == resp.status - content = yield from resp.json() - assert content == {'some': 'data'} + content = await resp.json() + assert content == {"some": "data"} resp.close() -@asyncio.coroutine -def test_POST_FILES(loop, test_client, fname): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert data['some'].filename == fname.name - with fname.open('rb') as f: +async def test_POST_FILES(aiohttp_client, fname) -> None: + async def handler(request): + data = await request.post() + assert data["some"].filename == fname.name + with fname.open("rb") as f: content1 = f.read() - content2 = data['some'].file.read() + content2 = data["some"].file.read() assert content1 == content2 - assert data['test'].file.read() == b'data' - return web.HTTPOk() + assert data["test"].file.read() == b"data" + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open() as f: - resp = yield from client.post( - '/', data={'some': f, 'test': b'data'}, chunked=True) + with fname.open("rb") as f: + resp = await client.post("/", data={"some": f, "test": b"data"}, chunked=True) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_FILES_DEFLATE(loop, test_client, fname): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert data['some'].filename == fname.name - with fname.open('rb') as f: +async def test_POST_FILES_DEFLATE(aiohttp_client, fname) -> None: + async def handler(request): + data = await request.post() + assert data["some"].filename == fname.name + with fname.open("rb") as f: content1 = f.read() - content2 = data['some'].file.read() + content2 = data["some"].file.read() assert content1 == content2 - return web.HTTPOk() + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open() as f: - resp = yield from client.post('/', data={'some': f}, - chunked=True, - compress='deflate') + with fname.open("rb") as f: + resp = await client.post( + "/", data={"some": f}, chunked=True, compress="deflate" + ) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_FILES_STR(loop, test_client, fname): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - with fname.open() as f: - content1 = f.read() - content2 = data['some'] +async def test_POST_bytes(aiohttp_client) -> None: + body = b"0" * 12345 + + async def handler(request): + data = await request.read() + assert body == data + return web.Response() + + app = web.Application() + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + resp = await client.post("/", data=body) + assert 200 == resp.status + resp.close() + + +async def test_POST_bytes_too_large(aiohttp_client) -> None: + body = b"0" * (2 ** 20 + 1) + + async def handler(request): + data = await request.content.read() + assert body == data + return web.Response() + + app = web.Application() + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + with pytest.warns(ResourceWarning): + resp = await client.post("/", data=body) + + assert 200 == resp.status + resp.close() + + +async def test_POST_FILES_STR(aiohttp_client, fname) -> None: + async def handler(request): + data = await request.post() + with fname.open("rb") as f: + content1 = f.read().decode() + content2 = data["some"] assert content1 == content2 - return web.HTTPOk() + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open() as f: - resp = yield from client.post('/', data={'some': f.read()}) + with fname.open("rb") as f: + resp = await client.post("/", data={"some": f.read().decode()}) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_FILES_STR_SIMPLE(loop, test_client, fname): - @asyncio.coroutine - def handler(request): - data = yield from request.read() - with fname.open('rb') as f: +async def test_POST_FILES_STR_SIMPLE(aiohttp_client, fname) -> None: + async def handler(request): + data = await request.read() + with fname.open("rb") as f: content = f.read() assert content == data - return web.HTTPOk() + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open() as f: - resp = yield from client.post('/', data=f.read()) + with fname.open("rb") as f: + resp = await client.post("/", data=f.read()) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_FILES_LIST(loop, test_client, fname): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert fname.name == data['some'].filename - with fname.open('rb') as f: +async def test_POST_FILES_LIST(aiohttp_client, fname) -> None: + async def handler(request): + data = await request.post() + assert fname.name == data["some"].filename + with fname.open("rb") as f: content = f.read() - assert content == data['some'].file.read() - return web.HTTPOk() + assert content == data["some"].file.read() + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open() as f: - resp = yield from client.post('/', data=[('some', f)]) + with fname.open("rb") as f: + resp = await client.post("/", data=[("some", f)]) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_FILES_CT(loop, test_client, fname): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert fname.name == data['some'].filename - assert 'text/plain' == data['some'].content_type - with fname.open('rb') as f: +async def test_POST_FILES_CT(aiohttp_client, fname) -> None: + async def handler(request): + data = await request.post() + assert fname.name == data["some"].filename + assert "text/plain" == data["some"].content_type + with fname.open("rb") as f: content = f.read() - assert content == data['some'].file.read() - return web.HTTPOk() + assert content == data["some"].file.read() + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open() as f: + with fname.open("rb") as f: form = aiohttp.FormData() - form.add_field('some', f, content_type='text/plain') - resp = yield from client.post('/', data=form) + form.add_field("some", f, content_type="text/plain") + resp = await client.post("/", data=form) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_FILES_SINGLE(loop, test_client, fname): - - @asyncio.coroutine - def handler(request): - data = yield from request.text() - with fname.open('r') as f: - content = f.read() +async def test_POST_FILES_SINGLE(aiohttp_client, fname) -> None: + async def handler(request): + data = await request.text() + with fname.open("rb") as f: + content = f.read().decode() assert content == data - # if system cannot determine 'application/pgp-keys' MIME type - # then use 'application/octet-stream' default - assert request.content_type in ['application/pgp-keys', - 'text/plain', - 'application/octet-stream'] - assert 'content-disposition' not in request.headers + # if system cannot determine 'text/x-python' MIME type + # then use 'application/octet-stream' default + assert request.content_type in [ + "text/plain", + "application/octet-stream", + "text/x-python", + ] + assert "content-disposition" not in request.headers - return web.HTTPOk() + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open() as f: - resp = yield from client.post('/', data=f) + with fname.open("rb") as f: + resp = await client.post("/", data=f) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_FILES_SINGLE_content_disposition(loop, test_client, fname): - - @asyncio.coroutine - def handler(request): - data = yield from request.text() - with fname.open('r') as f: - content = f.read() +async def test_POST_FILES_SINGLE_content_disposition(aiohttp_client, fname) -> None: + async def handler(request): + data = await request.text() + with fname.open("rb") as f: + content = f.read().decode() assert content == data - # if system cannot determine 'application/pgp-keys' MIME type - # then use 'application/octet-stream' default - assert request.content_type in ['application/pgp-keys', - 'text/plain', - 'application/octet-stream'] - assert request.headers['content-disposition'] == ( - "inline; filename=\"sample.key\"; filename*=utf-8''sample.key") + # if system cannot determine 'application/pgp-keys' MIME type + # then use 'application/octet-stream' default + assert request.content_type in [ + "text/plain", + "application/octet-stream", + "text/x-python", + ] + assert request.headers["content-disposition"] == ( + "inline; filename=\"conftest.py\"; filename*=utf-8''conftest.py" + ) - return web.HTTPOk() + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open() as f: - resp = yield from client.post( - '/', data=aiohttp.get_payload(f, disposition='inline')) + with fname.open("rb") as f: + resp = await client.post("/", data=aiohttp.get_payload(f, disposition="inline")) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_FILES_SINGLE_BINARY(loop, test_client, fname): - @asyncio.coroutine - def handler(request): - data = yield from request.read() - with fname.open('rb') as f: +async def test_POST_FILES_SINGLE_BINARY(aiohttp_client, fname) -> None: + async def handler(request): + data = await request.read() + with fname.open("rb") as f: content = f.read() assert content == data # if system cannot determine 'application/pgp-keys' MIME type # then use 'application/octet-stream' default - assert request.content_type in ['application/pgp-keys', - 'text/plain', - 'application/octet-stream'] - return web.HTTPOk() + assert request.content_type in [ + "application/pgp-keys", + "text/plain", + "text/x-python", + "application/octet-stream", + ] + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open('rb') as f: - resp = yield from client.post('/', data=f) + with fname.open("rb") as f: + resp = await client.post("/", data=f) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_FILES_IO(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert b'data' == data['unknown'].file.read() - assert data['unknown'].content_type == 'application/octet-stream' - assert data['unknown'].filename == 'unknown' - return web.HTTPOk() +async def test_POST_FILES_IO(aiohttp_client) -> None: + async def handler(request): + data = await request.post() + assert b"data" == data["unknown"].file.read() + assert data["unknown"].content_type == "application/octet-stream" + assert data["unknown"].filename == "unknown" + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) - - data = io.BytesIO(b'data') - resp = yield from client.post('/', data=[data]) - assert 200 == resp.status - resp.close() - + app.router.add_post("/", handler) + client = await aiohttp_client(app) -@pytest.mark.xfail -@asyncio.coroutine -def test_POST_MULTIPART(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - lst = list(data.values()) - assert 3 == len(lst) - assert lst[0] == 'foo' - assert lst[1] == {'bar': 'баз'} - assert b'data' == data['unknown'].file.read() - assert data['unknown'].content_type == 'application/octet-stream' - assert data['unknown'].filename == 'unknown' - return web.HTTPOk() - - app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) - - with MultipartWriter('form-data') as writer: - writer.append('foo') - writer.append_json({'bar': 'баз'}) - writer.append_form([('тест', '4'), ('сетс', '2')]) - - resp = yield from client.post('/', data=writer) + data = io.BytesIO(b"data") + resp = await client.post("/", data=[data]) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_FILES_IO_WITH_PARAMS(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert data['test'] == 'true' - assert data['unknown'].content_type == 'application/octet-stream' - assert data['unknown'].filename == 'unknown' - assert data['unknown'].file.read() == b'data' - assert data.getall('q') == ['t1', 't2'] +async def test_POST_FILES_IO_WITH_PARAMS(aiohttp_client) -> None: + async def handler(request): + data = await request.post() + assert data["test"] == "true" + assert data["unknown"].content_type == "application/octet-stream" + assert data["unknown"].filename == "unknown" + assert data["unknown"].file.read() == b"data" + assert data.getall("q") == ["t1", "t2"] - return web.HTTPOk() + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - data = io.BytesIO(b'data') - resp = yield from client.post('/', data=(('test', 'true'), - MultiDict( - [('q', 't1'), ('q', 't2')]), - data)) + data = io.BytesIO(b"data") + resp = await client.post( + "/", data=(("test", "true"), MultiDict([("q", "t1"), ("q", "t2")]), data) + ) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_FILES_WITH_DATA(loop, test_client, fname): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert data['test'] == 'true' - assert data['some'].content_type in ['application/pgp-keys', - 'text/plain; charset=utf-8', - 'application/octet-stream'] - assert data['some'].filename == fname.name - with fname.open('rb') as f: - assert data['some'].file.read() == f.read() +async def test_POST_FILES_WITH_DATA(aiohttp_client, fname) -> None: + async def handler(request): + data = await request.post() + assert data["test"] == "true" + assert data["some"].content_type in [ + "text/x-python", + "text/plain", + "application/octet-stream", + ] + assert data["some"].filename == fname.name + with fname.open("rb") as f: + assert data["some"].file.read() == f.read() - return web.HTTPOk() + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open() as f: - resp = yield from client.post('/', data={'test': 'true', 'some': f}) + with fname.open("rb") as f: + resp = await client.post("/", data={"test": "true", "some": f}) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_STREAM_DATA(loop, test_client, fname): - @asyncio.coroutine - def handler(request): - assert request.content_type == 'application/octet-stream' - content = yield from request.read() - with fname.open('rb') as f: +async def test_POST_STREAM_DATA(aiohttp_client, fname) -> None: + async def handler(request): + assert request.content_type == "application/octet-stream" + content = await request.read() + with fname.open("rb") as f: expected = f.read() assert request.content_length == len(expected) assert content == expected - return web.HTTPOk() + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open('rb') as f: + with fname.open("rb") as f: data_size = len(f.read()) - @aiohttp.streamer - def stream(writer, fname): - with fname.open('rb') as f: - data = f.read(100) - while data: - yield from writer.write(data) + with pytest.warns(DeprecationWarning): + + @aiohttp.streamer + async def stream(writer, fname): + with fname.open("rb") as f: data = f.read(100) + while data: + await writer.write(data) + data = f.read(100) - resp = yield from client.post( - '/', data=stream(fname), headers={'Content-Length': str(data_size)}) + resp = await client.post( + "/", data=stream(fname), headers={"Content-Length": str(data_size)} + ) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_POST_STREAM_DATA_no_params(loop, test_client, fname): - @asyncio.coroutine - def handler(request): - assert request.content_type == 'application/octet-stream' - content = yield from request.read() - with fname.open('rb') as f: +async def test_POST_STREAM_DATA_no_params(aiohttp_client, fname) -> None: + async def handler(request): + assert request.content_type == "application/octet-stream" + content = await request.read() + with fname.open("rb") as f: expected = f.read() assert request.content_length == len(expected) assert content == expected - return web.HTTPOk() + return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open('rb') as f: + with fname.open("rb") as f: data_size = len(f.read()) - @aiohttp.streamer - def stream(writer): - with fname.open('rb') as f: - data = f.read(100) - while data: - yield from writer.write(data) - data = f.read(100) - - resp = yield from client.post( - '/', data=stream, headers={'Content-Length': str(data_size)}) - assert 200 == resp.status - resp.close() - - -@asyncio.coroutine -def test_POST_StreamReader(fname, loop, test_client): - @asyncio.coroutine - def handler(request): - assert request.content_type == 'application/octet-stream' - content = yield from request.read() - with fname.open('rb') as f: - expected = f.read() - assert request.content_length == len(expected) - assert content == expected - - return web.HTTPOk() - - app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) - - with fname.open('rb') as f: - data = f.read() + with pytest.warns(DeprecationWarning): - stream = aiohttp.StreamReader(loop=loop) - stream.feed_data(data) - stream.feed_eof() + @aiohttp.streamer + async def stream(writer): + with fname.open("rb") as f: + data = f.read(100) + while data: + await writer.write(data) + data = f.read(100) - resp = yield from client.post( - '/', data=stream, - headers={'Content-Length': str(len(data))}) + resp = await client.post( + "/", data=stream, headers={"Content-Length": str(data_size)} + ) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_json(loop, test_client): - @asyncio.coroutine - def handler(request): - assert request.content_type == 'application/json' - data = yield from request.json() +async def test_json(aiohttp_client) -> None: + async def handler(request): + assert request.content_type == "application/json" + data = await request.json() return web.Response(body=aiohttp.JsonPayload(data)) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', json={'some': 'data'}) + resp = await client.post("/", json={"some": "data"}) assert 200 == resp.status - content = yield from resp.json() - assert content == {'some': 'data'} + content = await resp.json() + assert content == {"some": "data"} resp.close() with pytest.raises(ValueError): - yield from client.post('/', data="some data", json={'some': 'data'}) + await client.post("/", data="some data", json={"some": "data"}) -@asyncio.coroutine -def test_json_custom(loop, test_client): - @asyncio.coroutine - def handler(request): - assert request.content_type == 'application/json' - data = yield from request.json() +async def test_json_custom(aiohttp_client) -> None: + async def handler(request): + assert request.content_type == "application/json" + data = await request.json() return web.Response(body=aiohttp.JsonPayload(data)) used = False @@ -1677,32 +1636,29 @@ def dumps(obj): return json.dumps(obj) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app, json_serialize=dumps) + app.router.add_post("/", handler) + client = await aiohttp_client(app, json_serialize=dumps) - resp = yield from client.post('/', json={'some': 'data'}) + resp = await client.post("/", json={"some": "data"}) assert 200 == resp.status assert used - content = yield from resp.json() - assert content == {'some': 'data'} + content = await resp.json() + assert content == {"some": "data"} resp.close() with pytest.raises(ValueError): - yield from client.post('/', data="some data", json={'some': 'data'}) + await client.post("/", data="some data", json={"some": "data"}) -@asyncio.coroutine -def test_expect_continue(loop, test_client): +async def test_expect_continue(aiohttp_client) -> None: expect_called = False - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert data == {'some': 'data'} - return web.HTTPOk() + async def handler(request): + data = await request.post() + assert data == {"some": "data"} + return web.Response() - @asyncio.coroutine - def expect_handler(request): + async def expect_handler(request): nonlocal expect_called expect = request.headers.get(hdrs.EXPECT) if expect.lower() == "100-continue": @@ -1710,197 +1666,318 @@ def expect_handler(request): expect_called = True app = web.Application() - app.router.add_post('/', handler, expect_handler=expect_handler) - client = yield from test_client(app) + app.router.add_post("/", handler, expect_handler=expect_handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data={'some': 'data'}, expect100=True) + resp = await client.post("/", data={"some": "data"}, expect100=True) assert 200 == resp.status resp.close() assert expect_called -@asyncio.coroutine -def test_encoding_deflate(loop, test_client): - @asyncio.coroutine - def handler(request): - resp = web.Response(text='text') +async def test_encoding_deflate(aiohttp_client) -> None: + async def handler(request): + resp = web.Response(text="text") resp.enable_chunked_encoding() resp.enable_compression(web.ContentCoding.deflate) return resp app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert 200 == resp.status + txt = await resp.text() + assert txt == "text" + resp.close() + + +async def test_encoding_deflate_nochunk(aiohttp_client) -> None: + async def handler(request): + resp = web.Response(text="text") + resp.enable_compression(web.ContentCoding.deflate) + return resp + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - txt = yield from resp.text() - assert txt == 'text' + txt = await resp.text() + assert txt == "text" resp.close() -@asyncio.coroutine -def test_encoding_gzip(loop, test_client): - @asyncio.coroutine - def handler(request): - resp = web.Response(text='text') +async def test_encoding_gzip(aiohttp_client) -> None: + async def handler(request): + resp = web.Response(text="text") resp.enable_chunked_encoding() resp.enable_compression(web.ContentCoding.gzip) return resp app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert 200 == resp.status + txt = await resp.text() + assert txt == "text" + resp.close() + + +async def test_encoding_gzip_write_by_chunks(aiohttp_client) -> None: + async def handler(request): + resp = web.StreamResponse() + resp.enable_compression(web.ContentCoding.gzip) + await resp.prepare(request) + await resp.write(b"0") + await resp.write(b"0") + return resp + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert 200 == resp.status + txt = await resp.text() + assert txt == "00" + resp.close() + + +async def test_encoding_gzip_nochunk(aiohttp_client) -> None: + async def handler(request): + resp = web.Response(text="text") + resp.enable_compression(web.ContentCoding.gzip) + return resp + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - txt = yield from resp.text() - assert txt == 'text' + txt = await resp.text() + assert txt == "text" resp.close() -@asyncio.coroutine -def test_bad_payload_compression(loop, test_client): - @asyncio.coroutine - def handler(request): - resp = web.Response(text='text') - resp.headers['Content-Encoding'] = 'gzip' +async def test_bad_payload_compression(aiohttp_client) -> None: + async def handler(request): + resp = web.Response(text="text") + resp.headers["Content-Encoding"] = "gzip" return resp app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status with pytest.raises(aiohttp.ClientPayloadError): - yield from resp.read() + await resp.read() resp.close() -@asyncio.coroutine -def test_bad_payload_chunked_encoding(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_bad_payload_chunked_encoding(aiohttp_client) -> None: + async def handler(request): resp = web.StreamResponse() resp.force_close() resp._length_check = False - resp.headers['Transfer-Encoding'] = 'chunked' - writer = yield from resp.prepare(request) - writer.write(b'9\r\n\r\n') - yield from writer.write_eof() + resp.headers["Transfer-Encoding"] = "chunked" + writer = await resp.prepare(request) + await writer.write(b"9\r\n\r\n") + await writer.write_eof() return resp app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status with pytest.raises(aiohttp.ClientPayloadError): - yield from resp.read() + await resp.read() resp.close() -@asyncio.coroutine -def test_bad_payload_content_length(loop, test_client): - @asyncio.coroutine - def handler(request): - resp = web.Response(text='text') - resp.headers['Content-Length'] = '10000' +async def test_bad_payload_content_length(aiohttp_client) -> None: + async def handler(request): + resp = web.Response(text="text") + resp.headers["Content-Length"] = "10000" resp.force_close() return resp app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status with pytest.raises(aiohttp.ClientPayloadError): - yield from resp.read() + await resp.read() resp.close() -@asyncio.coroutine -def test_chunked(loop, test_client): - @asyncio.coroutine - def handler(request): - resp = web.Response(text='text') +async def test_payload_content_length_by_chunks(aiohttp_client) -> None: + async def handler(request): + resp = web.StreamResponse(headers={"content-length": "3"}) + await resp.prepare(request) + await resp.write(b"answer") + await resp.write(b"two") + request.transport.close() + return resp + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + data = await resp.read() + assert data == b"ans" + resp.close() + + +async def test_chunked(aiohttp_client) -> None: + async def handler(request): + resp = web.Response(text="text") resp.enable_chunked_encoding() return resp app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - assert resp.headers['Transfer-Encoding'] == 'chunked' - txt = yield from resp.text() - assert txt == 'text' + assert resp.headers["Transfer-Encoding"] == "chunked" + txt = await resp.text() + assert txt == "text" resp.close() -@asyncio.coroutine -def test_shortcuts(test_client, loop): - @asyncio.coroutine - def handler(request): +async def test_shortcuts(aiohttp_client) -> None: + async def handler(request): return web.Response(text=request.method) app = web.Application() - for meth in ('get', 'post', 'put', 'delete', 'head', 'patch', 'options'): - app.router.add_route(meth.upper(), '/', handler) - client = yield from test_client(lambda loop: app) + for meth in ("get", "post", "put", "delete", "head", "patch", "options"): + app.router.add_route(meth.upper(), "/", handler) + client = await aiohttp_client(app) - for meth in ('get', 'post', 'put', 'delete', 'head', 'patch', 'options'): + for meth in ("get", "post", "put", "delete", "head", "patch", "options"): coro = getattr(client.session, meth) - resp = yield from coro(client.make_url('/')) + resp = await coro(client.make_url("/")) assert resp.status == 200 assert len(resp.history) == 0 - content1 = yield from resp.read() - content2 = yield from resp.read() + content1 = await resp.read() + content2 = await resp.read() assert content1 == content2 - content = yield from resp.text() + content = await resp.text() - if meth == 'head': - assert b'' == content1 + if meth == "head": + assert b"" == content1 else: assert meth.upper() == content -@asyncio.coroutine -def test_cookies(test_client, loop): - @asyncio.coroutine - def handler(request): - assert request.cookies.keys() == {'test1', 'test3'} - assert request.cookies['test1'] == '123' - assert request.cookies['test3'] == '456' +async def test_cookies(aiohttp_client) -> None: + async def handler(request): + assert request.cookies.keys() == {"test1", "test3"} + assert request.cookies["test1"] == "123" + assert request.cookies["test3"] == "456" return web.Response() c = http.cookies.Morsel() - c.set('test3', '456', '456') + c.set("test3", "456", "456") + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app, cookies={"test1": "123", "test2": c}) + + resp = await client.get("/") + assert 200 == resp.status + resp.close() + + +async def test_cookies_per_request(aiohttp_client) -> None: + async def handler(request): + assert request.cookies.keys() == {"test1", "test3", "test4", "test6"} + assert request.cookies["test1"] == "123" + assert request.cookies["test3"] == "456" + assert request.cookies["test4"] == "789" + assert request.cookies["test6"] == "abc" + return web.Response() + + c = http.cookies.Morsel() + c.set("test3", "456", "456") + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app, cookies={"test1": "123", "test2": c}) + + rc = http.cookies.Morsel() + rc.set("test6", "abc", "abc") + + resp = await client.get("/", cookies={"test4": "789", "test5": rc}) + assert 200 == resp.status + resp.close() + + +async def test_cookies_redirect(aiohttp_client) -> None: + async def redirect1(request): + ret = web.Response(status=301, headers={"Location": "/redirect2"}) + ret.set_cookie("c", "1") + return ret + + async def redirect2(request): + ret = web.Response(status=301, headers={"Location": "/"}) + ret.set_cookie("c", "2") + return ret + + async def handler(request): + assert request.cookies.keys() == {"c"} + assert request.cookies["c"] == "2" + return web.Response() app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client( - app, cookies={'test1': '123', 'test2': c}) + app.router.add_get("/redirect1", redirect1) + app.router.add_get("/redirect2", redirect2) + app.router.add_get("/", handler) - resp = yield from client.get('/') + client = await aiohttp_client(app) + resp = await client.get("/redirect1") assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_morsel_with_attributes(test_client, loop): +async def test_cookies_on_empty_session_jar(aiohttp_client) -> None: + async def handler(request): + assert "custom-cookie" in request.cookies + assert request.cookies["custom-cookie"] == "abc" + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app, cookies=None) + + resp = await client.get("/", cookies={"custom-cookie": "abc"}) + assert 200 == resp.status + resp.close() + + +async def test_morsel_with_attributes(aiohttp_client) -> None: # A comment from original test: # # No cookie attribute should pass here @@ -1910,193 +1987,966 @@ def test_morsel_with_attributes(test_client, loop): # Server who sent the cookie with some attributes # already knows them, no need to send this back again and again - @asyncio.coroutine - def handler(request): - assert request.cookies.keys() == {'test3'} - assert request.cookies['test3'] == '456' + async def handler(request): + assert request.cookies.keys() == {"test3"} + assert request.cookies["test3"] == "456" return web.Response() c = http.cookies.Morsel() - c.set('test3', '456', '456') - c['httponly'] = True - c['secure'] = True - c['max-age'] = 1000 + c.set("test3", "456", "456") + c["httponly"] = True + c["secure"] = True + c["max-age"] = 1000 app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app, cookies={'test2': c}) + app.router.add_get("/", handler) + client = await aiohttp_client(app, cookies={"test2": c}) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_set_cookies(test_client, loop): - @asyncio.coroutine - def handler(request): +async def test_set_cookies(aiohttp_client) -> None: + async def handler(request): ret = web.Response() - ret.set_cookie('c1', 'cookie1') - ret.set_cookie('c2', 'cookie2') - ret.headers.add('Set-Cookie', - 'ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}=' - '{925EC0B8-CB17-4BEB-8A35-1033813B0523}; ' - 'HttpOnly; Path=/') + ret.set_cookie("c1", "cookie1") + ret.set_cookie("c2", "cookie2") + ret.headers.add( + "Set-Cookie", + "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}=" + "{925EC0B8-CB17-4BEB-8A35-1033813B0523}; " + "HttpOnly; Path=/", + ) return ret app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - with mock.patch('aiohttp.client_reqrep.client_logger') as m_log: - resp = yield from client.get('/') + with mock.patch("aiohttp.client_reqrep.client_logger") as m_log: + resp = await client.get("/") assert 200 == resp.status cookie_names = {c.key for c in client.session.cookie_jar} - assert cookie_names == {'c1', 'c2'} + assert cookie_names == {"c1", "c2"} resp.close() - m_log.warning.assert_called_with('Can not load response cookies: %s', - mock.ANY) + m_log.warning.assert_called_with("Can not load response cookies: %s", mock.ANY) + + +async def test_set_cookies_expired(aiohttp_client) -> None: + async def handler(request): + ret = web.Response() + ret.set_cookie("c1", "cookie1") + ret.set_cookie("c2", "cookie2") + ret.headers.add( + "Set-Cookie", + "c3=cookie3; " "HttpOnly; Path=/" " Expires=Tue, 1 Jan 1980 12:00:00 GMT; ", + ) + return ret + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert 200 == resp.status + cookie_names = {c.key for c in client.session.cookie_jar} + assert cookie_names == {"c1", "c2"} + resp.close() + + +async def test_set_cookies_max_age(aiohttp_client) -> None: + async def handler(request): + ret = web.Response() + ret.set_cookie("c1", "cookie1") + ret.set_cookie("c2", "cookie2") + ret.headers.add("Set-Cookie", "c3=cookie3; " "HttpOnly; Path=/" " Max-Age=1; ") + return ret + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert 200 == resp.status + cookie_names = {c.key for c in client.session.cookie_jar} + assert cookie_names == {"c1", "c2", "c3"} + await asyncio.sleep(2) + cookie_names = {c.key for c in client.session.cookie_jar} + assert cookie_names == {"c1", "c2"} + resp.close() + + +async def test_set_cookies_max_age_overflow(aiohttp_client) -> None: + async def handler(request): + ret = web.Response() + ret.headers.add( + "Set-Cookie", + "overflow=overflow; " "HttpOnly; Path=/" " Max-Age=" + str(overflow) + "; ", + ) + return ret + + overflow = int( + datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp() + ) + empty = None + try: + empty = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta( + seconds=overflow + ) + except OverflowError as ex: + assert isinstance(ex, OverflowError) + assert not isinstance(empty, datetime.datetime) + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert 200 == resp.status + for cookie in client.session.cookie_jar: + if cookie.key == "overflow": + assert int(cookie["max-age"]) == int(overflow) + resp.close() -@asyncio.coroutine -def test_request_conn_error(loop): - client = aiohttp.ClientSession(loop=loop) +async def test_request_conn_error() -> None: + client = aiohttp.ClientSession() with pytest.raises(aiohttp.ClientConnectionError): - yield from client.get('http://0.0.0.0:1') - client.close() + await client.get("http://0.0.0.0:1") + await client.close() @pytest.mark.xfail -@asyncio.coroutine -def test_broken_connection(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_broken_connection(aiohttp_client) -> None: + async def handler(request): request.transport.close() - return web.Response(text='answer'*1000) + return web.Response(text="answer" * 1000) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) with pytest.raises(aiohttp.ClientResponseError): - yield from client.get('/') - - -@asyncio.coroutine -def test_broken_connection_2(loop, test_client): - @asyncio.coroutine - def handler(request): - resp = web.StreamResponse(headers={'content-length': '1000'}) - yield from resp.prepare(request) - yield from resp.drain() - resp.write(b'answer') - yield from resp.drain() + await client.get("/") + + +async def test_broken_connection_2(aiohttp_client) -> None: + async def handler(request): + resp = web.StreamResponse(headers={"content-length": "1000"}) + await resp.prepare(request) + await resp.write(b"answer") request.transport.close() return resp app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") with pytest.raises(aiohttp.ClientPayloadError): - yield from resp.read() + await resp.read() resp.close() -@asyncio.coroutine -def test_custom_headers(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_custom_headers(aiohttp_client) -> None: + async def handler(request): assert request.headers["x-api-key"] == "foo" return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', headers={ - "Content-Type": "application/json", - "x-api-key": "foo"}) + resp = await client.post( + "/", headers={"Content-Type": "application/json", "x-api-key": "foo"} + ) assert resp.status == 200 -@asyncio.coroutine -def test_redirect_to_absolute_url(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_redirect_to_absolute_url(aiohttp_client) -> None: + async def handler(request): return web.Response(text=request.method) - @asyncio.coroutine - def redirect(request): - return web.HTTPFound(location=client.make_url('/')) + async def redirect(request): + raise web.HTTPFound(location=client.make_url("/")) app = web.Application() - app.router.add_get('/', handler) - app.router.add_get('/redirect', redirect) + app.router.add_get("/", handler) + app.router.add_get("/redirect", redirect) - client = yield from test_client(app) - resp = yield from client.get('/redirect') + client = await aiohttp_client(app) + resp = await client.get("/redirect") assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_redirect_without_location_header(loop, test_client): - @asyncio.coroutine - def handler_redirect(request): - return web.Response(status=301) +async def test_redirect_without_location_header(aiohttp_client) -> None: + body = b"redirect" + + async def handler_redirect(request): + return web.Response(status=301, body=body) app = web.Application() - app.router.add_route('GET', '/redirect', handler_redirect) - client = yield from test_client(app) + app.router.add_route("GET", "/redirect", handler_redirect) + client = await aiohttp_client(app) - with pytest.raises(RuntimeError) as ctx: - yield from client.get('/redirect') - assert str(ctx.value) == ('GET http://127.0.0.1:{}/redirect returns ' - 'a redirect [301] status but response lacks ' - 'a Location or URI HTTP header' - .format(client.port)) + resp = await client.get("/redirect") + data = await resp.read() + assert data == body -@asyncio.coroutine -def test_encoding_deprecated(loop, test_client): - @asyncio.coroutine - def handler_redirect(request): +async def test_chunked_deprecated(aiohttp_client) -> None: + async def handler_redirect(request): return web.Response(status=301) app = web.Application() - app.router.add_route('GET', '/redirect', handler_redirect) - client = yield from test_client(app) + app.router.add_route("GET", "/redirect", handler_redirect) + client = await aiohttp_client(app) with pytest.warns(DeprecationWarning): - yield from client.get('/', encoding='utf-8') + await client.post("/", chunked=1024) -@asyncio.coroutine -def test_chunked_deprecated(loop, test_client): - @asyncio.coroutine - def handler_redirect(request): - return web.Response(status=301) +@pytest.mark.parametrize( + ("status", "expected_ok"), + ( + (200, True), + (201, True), + (301, True), + (400, False), + (403, False), + (500, False), + ), +) +async def test_ok_from_status(aiohttp_client, status, expected_ok) -> None: + async def handler(request): + return web.Response(status=status, body=b"") app = web.Application() - app.router.add_route('GET', '/redirect', handler_redirect) - client = yield from test_client(app) + app.router.add_route("GET", "/endpoint", handler) + client = await aiohttp_client(app, raise_for_status=False) + resp = await client.get("/endpoint") - with pytest.warns(DeprecationWarning): - yield from client.get('/', chunked=1024) + assert resp.ok is expected_ok -@asyncio.coroutine -def test_raise_for_status(loop, test_client): - @asyncio.coroutine - def handler_redirect(request): - return web.HTTPBadRequest() +async def test_raise_for_status(aiohttp_client) -> None: + async def handler_redirect(request): + raise web.HTTPBadRequest() app = web.Application() - app.router.add_route('GET', '/', handler_redirect) - client = yield from test_client(app, raise_for_status=True) + app.router.add_route("GET", "/", handler_redirect) + client = await aiohttp_client(app, raise_for_status=True) with pytest.raises(aiohttp.ClientResponseError): - yield from client.get('/') + await client.get("/") + + +async def test_raise_for_status_per_request(aiohttp_client) -> None: + async def handler_redirect(request): + raise web.HTTPBadRequest() + + app = web.Application() + app.router.add_route("GET", "/", handler_redirect) + client = await aiohttp_client(app) + + with pytest.raises(aiohttp.ClientResponseError): + await client.get("/", raise_for_status=True) + + +async def test_raise_for_status_disable_per_request(aiohttp_client) -> None: + async def handler_redirect(request): + raise web.HTTPBadRequest() + + app = web.Application() + app.router.add_route("GET", "/", handler_redirect) + client = await aiohttp_client(app, raise_for_status=True) + + resp = await client.get("/", raise_for_status=False) + assert 400 == resp.status + resp.close() + + +async def test_request_raise_for_status_default(aiohttp_server) -> None: + async def handler(request): + raise web.HTTPBadRequest() + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with aiohttp.request("GET", server.make_url("/")) as resp: + assert resp.status == 400 + + +async def test_request_raise_for_status_disabled(aiohttp_server) -> None: + async def handler(request): + raise web.HTTPBadRequest() + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + url = server.make_url("/") + + async with aiohttp.request("GET", url, raise_for_status=False) as resp: + assert resp.status == 400 + + +async def test_request_raise_for_status_enabled(aiohttp_server) -> None: + async def handler(request): + raise web.HTTPBadRequest() + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + url = server.make_url("/") + + with pytest.raises(aiohttp.ClientResponseError): + async with aiohttp.request("GET", url, raise_for_status=True): + assert False, "never executed" # pragma: no cover + + +async def test_invalid_idna() -> None: + session = aiohttp.ClientSession() + try: + with pytest.raises(aiohttp.InvalidURL): + await session.get("http://\u2061owhefopw.com") + finally: + await session.close() + + +async def test_creds_in_auth_and_url() -> None: + session = aiohttp.ClientSession() + try: + with pytest.raises(ValueError): + await session.get( + "http://user:pass@example.com", auth=aiohttp.BasicAuth("user2", "pass2") + ) + finally: + await session.close() + + +async def test_drop_auth_on_redirect_to_other_host(aiohttp_server) -> None: + async def srv1(request): + assert request.host == "host1.com" + assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz" + raise web.HTTPFound("http://host2.com/path2") + + async def srv2(request): + assert request.host == "host2.com" + assert "Authorization" not in request.headers + return web.Response() + + app = web.Application() + app.router.add_route("GET", "/path1", srv1) + app.router.add_route("GET", "/path2", srv2) + + server = await aiohttp_server(app) + + class FakeResolver(AbstractResolver): + async def resolve(self, host, port=0, family=socket.AF_INET): + return [ + { + "hostname": host, + "host": server.host, + "port": server.port, + "family": socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + ] + + async def close(self): + pass + + connector = aiohttp.TCPConnector(resolver=FakeResolver()) + async with aiohttp.ClientSession(connector=connector) as client: + resp = await client.get( + "http://host1.com/path1", auth=aiohttp.BasicAuth("user", "pass") + ) + assert resp.status == 200 + resp = await client.get( + "http://host1.com/path1", headers={"Authorization": "Basic dXNlcjpwYXNz"} + ) + assert resp.status == 200 + + +async def test_async_with_session() -> None: + with pytest.warns(None) as cm: + async with aiohttp.ClientSession() as session: + pass + assert len(cm.list) == 0 + + assert session.closed + + +async def test_session_close_awaitable() -> None: + session = aiohttp.ClientSession() + with pytest.warns(None) as cm: + await session.close() + assert len(cm.list) == 0 + + assert session.closed + + +async def test_close_run_until_complete_not_deprecated() -> None: + session = aiohttp.ClientSession() + + with pytest.warns(None) as cm: + await session.close() + + assert len(cm.list) == 0 + + +async def test_close_resp_on_error_async_with_session(aiohttp_server) -> None: + async def handler(request): + resp = web.StreamResponse(headers={"content-length": "100"}) + await resp.prepare(request) + await asyncio.sleep(0.1) + return resp + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with aiohttp.ClientSession() as session: + with pytest.raises(RuntimeError): + async with session.get(server.make_url("/")) as resp: + resp.content.set_exception(RuntimeError()) + await resp.read() + + assert len(session._connector._conns) == 0 + + +async def test_release_resp_on_normal_exit_from_cm(aiohttp_server) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with aiohttp.ClientSession() as session: + async with session.get(server.make_url("/")) as resp: + await resp.read() + + assert len(session._connector._conns) == 1 + + +async def test_non_close_detached_session_on_error_cm(aiohttp_server) -> None: + async def handler(request): + resp = web.StreamResponse(headers={"content-length": "100"}) + await resp.prepare(request) + await asyncio.sleep(0.1) + return resp + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + session = aiohttp.ClientSession() + cm = session.get(server.make_url("/")) + assert not session.closed + with pytest.raises(RuntimeError): + async with cm as resp: + resp.content.set_exception(RuntimeError()) + await resp.read() + assert not session.closed + + +async def test_close_detached_session_on_non_existing_addr() -> None: + class FakeResolver(AbstractResolver): + async def resolve(host, port=0, family=socket.AF_INET): + return {} + + async def close(self): + pass + + connector = aiohttp.TCPConnector(resolver=FakeResolver()) + + session = aiohttp.ClientSession(connector=connector) + + async with session: + cm = session.get("http://non-existing.example.com") + assert not session.closed + with pytest.raises(Exception): + await cm + + assert session.closed + + +async def test_aiohttp_request_context_manager(aiohttp_server) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with aiohttp.request("GET", server.make_url("/")) as resp: + await resp.read() + assert resp.status == 200 + + +async def test_aiohttp_request_ctx_manager_close_sess_on_error( + ssl_ctx, aiohttp_server +) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app, ssl=ssl_ctx) + + cm = aiohttp.request("GET", server.make_url("/")) + + with pytest.raises(aiohttp.ClientConnectionError): + async with cm: + pass + + assert cm._session.closed + + +async def test_aiohttp_request_ctx_manager_not_found() -> None: + + with pytest.raises(aiohttp.ClientConnectionError): + async with aiohttp.request("GET", "http://wrong-dns-name.com"): + assert False, "never executed" # pragma: no cover + + +async def test_aiohttp_request_coroutine(aiohttp_server) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + with pytest.raises(TypeError): + await aiohttp.request("GET", server.make_url("/")) + + +async def test_yield_from_in_session_request(aiohttp_client) -> None: + # a test for backward compatibility with yield from syntax + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + resp = await client.get("/") + assert resp.status == 200 + + +async def test_close_context_manager(aiohttp_client) -> None: + # a test for backward compatibility with yield from syntax + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + ctx = client.get("/") + ctx.close() + assert not ctx._coro.cr_running + + +async def test_session_auth(aiohttp_client) -> None: + async def handler(request): + return web.json_response({"headers": dict(request.headers)}) + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app, auth=aiohttp.BasicAuth("login", "pass")) + + r = await client.get("/") + assert r.status == 200 + content = await r.json() + assert content["headers"]["Authorization"] == "Basic bG9naW46cGFzcw==" + + +async def test_session_auth_override(aiohttp_client) -> None: + async def handler(request): + return web.json_response({"headers": dict(request.headers)}) + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app, auth=aiohttp.BasicAuth("login", "pass")) + + r = await client.get("/", auth=aiohttp.BasicAuth("other_login", "pass")) + assert r.status == 200 + content = await r.json() + val = content["headers"]["Authorization"] + assert val == "Basic b3RoZXJfbG9naW46cGFzcw==" + + +async def test_session_auth_header_conflict(aiohttp_client) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app, auth=aiohttp.BasicAuth("login", "pass")) + headers = {"Authorization": "Basic b3RoZXJfbG9naW46cGFzcw=="} + with pytest.raises(ValueError): + await client.get("/", headers=headers) + + +async def test_session_headers(aiohttp_client) -> None: + async def handler(request): + return web.json_response({"headers": dict(request.headers)}) + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app, headers={"X-Real-IP": "192.168.0.1"}) + + r = await client.get("/") + assert r.status == 200 + content = await r.json() + assert content["headers"]["X-Real-IP"] == "192.168.0.1" + + +async def test_session_headers_merge(aiohttp_client) -> None: + async def handler(request): + return web.json_response({"headers": dict(request.headers)}) + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client( + app, headers=[("X-Real-IP", "192.168.0.1"), ("X-Sent-By", "requests")] + ) + + r = await client.get("/", headers={"X-Sent-By": "aiohttp"}) + assert r.status == 200 + content = await r.json() + assert content["headers"]["X-Real-IP"] == "192.168.0.1" + assert content["headers"]["X-Sent-By"] == "aiohttp" + + +async def test_multidict_headers(aiohttp_client) -> None: + async def handler(request): + assert await request.read() == data + return web.Response() + + app = web.Application() + app.router.add_post("/", handler) + + client = await aiohttp_client(app) + + data = b"sample data" + + r = await client.post( + "/", data=data, headers=MultiDict({"Content-Length": str(len(data))}) + ) + assert r.status == 200 + + +async def test_request_conn_closed(aiohttp_client) -> None: + async def handler(request): + request.transport.close() + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + with pytest.raises(aiohttp.ServerDisconnectedError) as excinfo: + resp = await client.get("/") + await resp.read() + + assert str(excinfo.value) != "" + + +async def test_dont_close_explicit_connector(aiohttp_client) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + r = await client.get("/") + await r.read() + + assert 1 == len(client.session.connector._conns) + + +async def test_server_close_keepalive_connection() -> None: + loop = asyncio.get_event_loop() + + class Proto(asyncio.Protocol): + def connection_made(self, transport): + self.transp = transport + self.data = b"" + + def data_received(self, data): + self.data += data + if data.endswith(b"\r\n\r\n"): + self.transp.write( + b"HTTP/1.1 200 OK\r\n" + b"CONTENT-LENGTH: 2\r\n" + b"CONNECTION: close\r\n" + b"\r\n" + b"ok" + ) + self.transp.close() + + def connection_lost(self, exc): + self.transp = None + + server = await loop.create_server(Proto, "127.0.0.1", unused_port()) + + addr = server.sockets[0].getsockname() + + connector = aiohttp.TCPConnector(limit=1) + session = aiohttp.ClientSession(connector=connector) + + url = "http://{}:{}/".format(*addr) + for i in range(2): + r = await session.request("GET", url) + await r.read() + assert 0 == len(connector._conns) + await session.close() + connector.close() + server.close() + await server.wait_closed() + + +async def test_handle_keepalive_on_closed_connection() -> None: + loop = asyncio.get_event_loop() + + class Proto(asyncio.Protocol): + def connection_made(self, transport): + self.transp = transport + self.data = b"" + + def data_received(self, data): + self.data += data + if data.endswith(b"\r\n\r\n"): + self.transp.write( + b"HTTP/1.1 200 OK\r\n" b"CONTENT-LENGTH: 2\r\n" b"\r\n" b"ok" + ) + self.transp.close() + + def connection_lost(self, exc): + self.transp = None + + server = await loop.create_server(Proto, "127.0.0.1", unused_port()) + + addr = server.sockets[0].getsockname() + + connector = aiohttp.TCPConnector(limit=1) + session = aiohttp.ClientSession(connector=connector) + + url = "http://{}:{}/".format(*addr) + + r = await session.request("GET", url) + await r.read() + assert 1 == len(connector._conns) + + with pytest.raises(aiohttp.ClientConnectionError): + await session.request("GET", url) + assert 0 == len(connector._conns) + + await session.close() + connector.close() + server.close() + await server.wait_closed() + + +async def test_error_in_performing_request(ssl_ctx, aiohttp_client, aiohttp_server): + async def handler(request): + return web.Response() + + def exception_handler(loop, context): + # skip log messages about destroyed but pending tasks + pass + + loop = asyncio.get_event_loop() + loop.set_exception_handler(exception_handler) + + app = web.Application() + app.router.add_route("GET", "/", handler) + + server = await aiohttp_server(app, ssl=ssl_ctx) + + conn = aiohttp.TCPConnector(limit=1) + client = await aiohttp_client(server, connector=conn) + + with pytest.raises(aiohttp.ClientConnectionError): + await client.get("/") + + # second try should not hang + with pytest.raises(aiohttp.ClientConnectionError): + await client.get("/") + + +async def test_await_after_cancelling(aiohttp_client) -> None: + loop = asyncio.get_event_loop() + + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + + fut1 = loop.create_future() + fut2 = loop.create_future() + + async def fetch1(): + resp = await client.get("/") + assert resp.status == 200 + fut1.set_result(None) + with pytest.raises(asyncio.CancelledError): + await fut2 + resp.release() + + async def fetch2(): + await fut1 + resp = await client.get("/") + assert resp.status == 200 + + async def canceller(): + await fut1 + fut2.cancel() + + await asyncio.gather(fetch1(), fetch2(), canceller()) + + +async def test_async_payload_generator(aiohttp_client) -> None: + async def handler(request): + data = await request.read() + assert data == b"1234567890" * 100 + return web.Response() + + app = web.Application() + app.add_routes([web.post("/", handler)]) + + client = await aiohttp_client(app) + + @async_generator + async def gen(): + for i in range(100): + await yield_(b"1234567890") + + resp = await client.post("/", data=gen()) + assert resp.status == 200 + + +async def test_read_from_closed_response(aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"data") + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + + async with client.get("/") as resp: + assert resp.status == 200 + + with pytest.raises(aiohttp.ClientConnectionError): + await resp.read() + + +async def test_read_from_closed_response2(aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"data") + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + + async with client.get("/") as resp: + assert resp.status == 200 + await resp.read() + + with pytest.raises(aiohttp.ClientConnectionError): + await resp.read() + + +async def test_read_from_closed_content(aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"data") + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + + async with client.get("/") as resp: + assert resp.status == 200 + + with pytest.raises(aiohttp.ClientConnectionError): + await resp.content.readline() + + +async def test_read_timeout(aiohttp_client) -> None: + async def handler(request): + await asyncio.sleep(5) + return web.Response() + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + timeout = aiohttp.ClientTimeout(sock_read=0.1) + client = await aiohttp_client(app, timeout=timeout) + + with pytest.raises(aiohttp.ServerTimeoutError): + await client.get("/") + + +async def test_read_timeout_on_prepared_response(aiohttp_client) -> None: + async def handler(request): + resp = aiohttp.web.StreamResponse() + await resp.prepare(request) + await asyncio.sleep(5) + await resp.drain() + return resp + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + timeout = aiohttp.ClientTimeout(sock_read=0.1) + client = await aiohttp_client(app, timeout=timeout) + + with pytest.raises(aiohttp.ServerTimeoutError): + async with await client.get("/") as resp: + await resp.read() + + +async def test_read_bufsize_session_default(aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"1234567") + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app, read_bufsize=2) + + async with await client.get("/") as resp: + assert resp.content.get_read_buffer_limits() == (2, 4) + + +async def test_read_bufsize_explicit(aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"1234567") + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + + async with await client.get("/", read_bufsize=4) as resp: + assert resp.content.get_read_buffer_limits() == (4, 8) diff --git a/tests/test_client_functional_oldstyle.py b/tests/test_client_functional_oldstyle.py deleted file mode 100644 index bd43a639884..00000000000 --- a/tests/test_client_functional_oldstyle.py +++ /dev/null @@ -1,838 +0,0 @@ -"""HTTP client functional tests.""" - -import asyncio -import binascii -import cgi -import contextlib -import email.parser -import gc -import http.server -import io -import json -import logging -import os -import os.path -import re -import ssl -import sys -import threading -import traceback -import unittest -import urllib.parse -from unittest import mock - -from multidict import MultiDict - -import aiohttp -import aiohttp.http -from aiohttp import client, helpers, test_utils, web -from aiohttp.multipart import MultipartWriter -from aiohttp.test_utils import run_briefly, unused_port - - -@contextlib.contextmanager -def run_server(loop, *, listen_addr=('127.0.0.1', 0), - use_ssl=False, router=None): - properties = {} - transports = [] - - class HttpRequestHandler: - - def __init__(self, addr): - host, port = addr - self.host = host - self.port = port - self.address = addr - self._url = '{}://{}:{}'.format( - 'https' if use_ssl else 'http', host, port) - - def __getitem__(self, key): - return properties[key] - - def __setitem__(self, key, value): - properties[key] = value - - def url(self, *suffix): - return urllib.parse.urljoin( - self._url, '/'.join(str(s) for s in suffix)) - - @asyncio.coroutine - def handler(request): - if properties.get('close', False): - return - - for hdr, val in request.message.headers.items(): - if (hdr.upper() == 'EXPECT') and (val == '100-continue'): - request.writer.write(b'HTTP/1.0 100 Continue\r\n\r\n') - break - - rob = router(properties, request) - return (yield from rob.dispatch()) - - class TestHttpServer(web.RequestHandler): - - def connection_made(self, transport): - transports.append(transport) - super().connection_made(transport) - - if use_ssl: - here = os.path.join(os.path.dirname(__file__), '..', 'tests') - keyfile = os.path.join(here, 'sample.key') - certfile = os.path.join(here, 'sample.crt') - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.load_cert_chain(certfile, keyfile) - else: - sslcontext = None - - def run(loop, fut): - thread_loop = asyncio.new_event_loop() - asyncio.set_event_loop(thread_loop) - - host, port = listen_addr - server_coroutine = thread_loop.create_server( - lambda: TestHttpServer( - web.Server(handler, loop=thread_loop), keepalive_timeout=0.5), - host, port, ssl=sslcontext) - server = thread_loop.run_until_complete(server_coroutine) - - waiter = helpers.create_future(thread_loop) - loop.call_soon_threadsafe( - fut.set_result, (thread_loop, waiter, - server.sockets[0].getsockname())) - - try: - thread_loop.run_until_complete(waiter) - finally: - # call pending connection_made if present - run_briefly(thread_loop) - - # close opened transports - for tr in transports: - tr.close() - - run_briefly(thread_loop) # call close callbacks - - server.close() - thread_loop.stop() - thread_loop.close() - gc.collect() - - fut = helpers.create_future(loop) - server_thread = threading.Thread(target=run, args=(loop, fut)) - server_thread.start() - - thread_loop, waiter, addr = loop.run_until_complete(fut) - try: - yield HttpRequestHandler(addr) - finally: - thread_loop.call_soon_threadsafe(waiter.set_result, None) - server_thread.join() - - -class Router: - - _response_version = "1.1" - _responses = http.server.BaseHTTPRequestHandler.responses - - def __init__(self, props, request): - # headers - self._headers = http.client.HTTPMessage() - for hdr, val in request.message.headers.items(): - self._headers.add_header(hdr, val) - - self._props = props - self._request = request - self._method = request.message.method - self._uri = request.message.path - self._version = request.message.version - self._compression = request.message.compression - self._body = request.content - - url = urllib.parse.urlsplit(self._uri) - self._path = url.path - self._query = url.query - - @staticmethod - def define(rmatch): - def wrapper(fn): - f_locals = sys._getframe(1).f_locals - mapping = f_locals.setdefault('_mapping', []) - mapping.append((re.compile(rmatch), fn.__name__)) - return fn - - return wrapper - - def dispatch(self): # pragma: no cover - for route, fn in self._mapping: - match = route.match(self._path) - if match is not None: - try: - return (yield from getattr(self, fn)(match)) - except Exception: - out = io.StringIO() - traceback.print_exc(file=out) - return (yield from self._response(500, out.getvalue())) - - return () - - return (yield from self._response(self._start_response(404))) - - def _start_response(self, code): - return web.Response(status=code) - - @asyncio.coroutine - def _response(self, response, body=None, - headers=None, chunked=False, write_body=None): - r_headers = {} - for key, val in self._headers.items(): - key = '-'.join(p.capitalize() for p in key.split('-')) - r_headers[key] = val - - encoding = self._headers.get('content-encoding', '').lower() - if 'gzip' in encoding: # pragma: no cover - cmod = 'gzip' - elif 'deflate' in encoding: - cmod = 'deflate' - else: - cmod = '' - - resp = { - 'method': self._method, - 'version': '%s.%s' % self._version, - 'path': self._uri, - 'headers': r_headers, - 'origin': self._request.transport.get_extra_info('addr', ' ')[0], - 'query': self._query, - 'form': {}, - 'compression': cmod, - 'multipart-data': [] - } - if body: # pragma: no cover - resp['content'] = body - else: - resp['content'] = ( - yield from self._request.read()).decode('utf-8', 'ignore') - - ct = self._headers.get('content-type', '').lower() - - # application/x-www-form-urlencoded - if ct == 'application/x-www-form-urlencoded': - resp['form'] = urllib.parse.parse_qs(self._body.decode('latin1')) - - # multipart/form-data - elif ct.startswith('multipart/form-data'): # pragma: no cover - out = io.BytesIO() - for key, val in self._headers.items(): - out.write(bytes('{}: {}\r\n'.format(key, val), 'latin1')) - - b = yield from self._request.read() - out.write(b'\r\n') - out.write(b) - out.write(b'\r\n') - out.seek(0) - - message = email.parser.BytesParser().parse(out) - if message.is_multipart(): - for msg in message.get_payload(): - if msg.is_multipart(): - logging.warning('multipart msg is not expected') - else: - key, params = cgi.parse_header( - msg.get('content-disposition', '')) - params['data'] = msg.get_payload() - params['content-type'] = msg.get_content_type() - cte = msg.get('content-transfer-encoding') - if cte is not None: - resp['content-transfer-encoding'] = cte - resp['multipart-data'].append(params) - body = json.dumps(resp, indent=4, sort_keys=True) - - # default headers - hdrs = [('Connection', 'close'), - ('Content-Type', 'application/json')] - if chunked: - hdrs.append(('Transfer-Encoding', 'chunked')) - else: - hdrs.append(('Content-Length', str(len(body)))) - - # extra headers - if headers: - hdrs.extend(headers.items()) - - # headers - for key, val in hdrs: - response.headers[key] = val - - if chunked: - self._request.writer.enable_chunking() - - yield from response.prepare(self._request) - - # write payload - if write_body: - try: - write_body(response, body) - except: - return - else: - response.write(body.encode('utf8')) - - return response - - -class Functional(Router): - - @Router.define('/method/([A-Za-z]+)$') - def method(self, match): - return self._response(self._start_response(200)) - - @Router.define('/keepalive$') - def keepalive(self, match): - transport = self._request.transport - - transport._requests = getattr(transport, '_requests', 0) + 1 - resp = self._start_response(200) - if 'close=' in self._query: - return self._response( - resp, 'requests={}'.format(transport._requests)) - else: - return self._response( - resp, 'requests={}'.format(transport._requests), - headers={'CONNECTION': 'keep-alive'}) - - @Router.define('/cookies$') - def cookies(self, match): - cookies = helpers.SimpleCookie() - cookies['c1'] = 'cookie1' - cookies['c2'] = 'cookie2' - - resp = self._start_response(200) - for cookie in cookies.output(header='').split('\n'): - resp.headers.extend({'Set-Cookie': cookie.strip()}) - - resp.headers.extend( - {'Set-Cookie': - 'ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}=' - '{925EC0B8-CB17-4BEB-8A35-1033813B0523}; HttpOnly; Path=/'}) - - return self._response(resp) - - @Router.define('/cookies_partial$') - def cookies_partial(self, match): - cookies = helpers.SimpleCookie() - cookies['c1'] = 'other_cookie1' - - resp = self._start_response(200) - for cookie in cookies.output(header='').split('\n'): - resp.add_header('Set-Cookie', cookie.strip()) - - return self._response(resp) - - @Router.define('/broken$') - def broken(self, match): - resp = self._start_response(200) - - def write_body(resp, body): - self._transport.close() - raise ValueError() - - return self._response( - resp, - body=json.dumps({'t': (b'0' * 1024).decode('utf-8')}), - write_body=write_body) - - -class TestHttpClientFunctional(unittest.TestCase): - - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - # just in case if we have transport close callbacks - test_utils.run_briefly(self.loop) - - self.loop.close() - gc.collect() - - def test_POST_DATA_with_charset(self): - with run_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - form = aiohttp.FormData() - form.add_field('name', 'текст', - content_type='text/plain; charset=koi8-r') - - session = client.ClientSession(loop=self.loop) - r = self.loop.run_until_complete( - session.request('post', url, data=form)) - content = self.loop.run_until_complete(r.json()) - - self.assertEqual(1, len(content['multipart-data'])) - field = content['multipart-data'][0] - self.assertEqual('name', field['name']) - self.assertEqual('текст', field['data']) - self.assertEqual(r.status, 200) - r.close() - session.close() - - def test_POST_DATA_with_charset_pub_request(self): - with run_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - form = aiohttp.FormData() - form.add_field('name', 'текст', - content_type='text/plain; charset=koi8-r') - - r = self.loop.run_until_complete( - aiohttp.request('post', url, data=form, loop=self.loop)) - content = self.loop.run_until_complete(r.json()) - - self.assertEqual(1, len(content['multipart-data'])) - field = content['multipart-data'][0] - self.assertEqual('name', field['name']) - self.assertEqual('текст', field['data']) - self.assertEqual(r.status, 200) - r.close() - - def test_POST_DATA_with_content_transfer_encoding(self): - with run_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - form = aiohttp.FormData() - form.add_field('name', b'123', - content_transfer_encoding='base64') - - session = client.ClientSession(loop=self.loop) - r = self.loop.run_until_complete( - session.request('post', url, data=form)) - content = self.loop.run_until_complete(r.json()) - - self.assertEqual(1, len(content['multipart-data'])) - field = content['multipart-data'][0] - self.assertEqual('name', field['name']) - self.assertEqual(b'123', binascii.a2b_base64(field['data'])) - # self.assertEqual('base64', field['content-transfer-encoding']) - self.assertEqual(r.status, 200) - - r.close() - session.close() - - def test_POST_MULTIPART(self): - with run_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - with MultipartWriter('form-data') as writer: - writer.append('foo') - writer.append_json({'bar': 'баз'}) - writer.append_form([('тест', '4'), ('сетс', '2')]) - - session = client.ClientSession(loop=self.loop) - r = self.loop.run_until_complete( - session.request('post', url, data=writer)) - - content = self.loop.run_until_complete(r.json()) - - self.assertEqual(3, len(content['multipart-data'])) - self.assertEqual({'content-type': 'text/plain', 'data': 'foo'}, - content['multipart-data'][0]) - self.assertEqual({'content-type': 'application/json', - 'data': '{"bar": "\\u0431\\u0430\\u0437"}'}, - content['multipart-data'][1]) - self.assertEqual( - {'content-type': 'application/x-www-form-urlencoded', - 'data': '%D1%82%D0%B5%D1%81%D1%82=4&' - '%D1%81%D0%B5%D1%82%D1%81=2'}, - content['multipart-data'][2]) - self.assertEqual(r.status, 200) - r.close() - session.close() - - def test_POST_STREAM_DATA(self): - with run_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - here = os.path.dirname(__file__) - fname = os.path.join(here, 'sample.key') - - with open(fname, 'rb') as f: - data = f.read() - - fut = helpers.create_future(self.loop) - - @aiohttp.streamer - def stream(writer): - yield from fut - writer.write(data) - - self.loop.call_later(0.01, fut.set_result, True) - - session = client.ClientSession(loop=self.loop) - r = self.loop.run_until_complete( - session.request( - 'post', url, data=stream(), - headers={'Content-Length': str(len(data))})) - content = self.loop.run_until_complete(r.json()) - r.close() - session.close() - - self.assertEqual(str(len(data)), - content['headers']['Content-Length']) - self.assertEqual('application/octet-stream', - content['headers']['Content-Type']) - - def test_POST_StreamReader(self): - with run_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - here = os.path.dirname(__file__) - fname = os.path.join(here, 'sample.key') - - with open(fname, 'rb') as f: - data = f.read() - - stream = aiohttp.StreamReader(loop=self.loop) - stream.feed_data(data) - stream.feed_eof() - - session = client.ClientSession(loop=self.loop) - r = self.loop.run_until_complete( - session.request( - 'post', url, data=stream, - headers={'Content-Length': str(len(data))})) - content = self.loop.run_until_complete(r.json()) - r.close() - session.close() - - self.assertEqual(str(len(data)), - content['headers']['Content-Length']) - - def test_POST_DataQueue(self): - with run_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - here = os.path.dirname(__file__) - fname = os.path.join(here, 'sample.key') - - with open(fname, 'rb') as f: - data = f.read() - - stream = aiohttp.DataQueue(loop=self.loop) - stream.feed_data(data[:100], 100) - stream.feed_data(data[100:], len(data[100:])) - stream.feed_eof() - - session = client.ClientSession(loop=self.loop) - r = self.loop.run_until_complete( - session.request( - 'post', url, data=stream, - headers={'Content-Length': str(len(data))})) - content = self.loop.run_until_complete(r.json()) - r.close() - session.close() - - self.assertEqual(str(len(data)), - content['headers']['Content-Length']) - - def test_POST_ChunksQueue(self): - with run_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - here = os.path.dirname(__file__) - fname = os.path.join(here, 'sample.key') - - with open(fname, 'rb') as f: - data = f.read() - - stream = aiohttp.ChunksQueue(loop=self.loop) - stream.feed_data(data[:100], 100) - - d = data[100:] - stream.feed_data(d, len(d)) - stream.feed_eof() - - session = client.ClientSession(loop=self.loop) - r = self.loop.run_until_complete( - session.request( - 'post', url, data=stream, - headers={'Content-Length': str(len(data))})) - content = self.loop.run_until_complete(r.json()) - r.close() - session.close() - - self.assertEqual(str(len(data)), - content['headers']['Content-Length']) - - def test_request_conn_closed(self): - with run_server(self.loop, router=Functional) as httpd: - httpd['close'] = True - session = client.ClientSession(loop=self.loop) - with self.assertRaises(aiohttp.ServerDisconnectedError): - self.loop.run_until_complete( - session.request('get', httpd.url('method', 'get'))) - - session.close() - - def test_session_close(self): - conn = aiohttp.TCPConnector(loop=self.loop) - session = client.ClientSession(loop=self.loop, connector=conn) - - with run_server(self.loop, router=Functional) as httpd: - r = self.loop.run_until_complete( - session.request( - 'get', httpd.url('keepalive') + '?close=1')) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.json()) - self.assertEqual(content['content'], 'requests=1') - r.close() - - r = self.loop.run_until_complete( - session.request('get', httpd.url('keepalive'))) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.json()) - self.assertEqual(content['content'], 'requests=1') - r.close() - - session.close() - conn.close() - - def test_multidict_headers(self): - session = client.ClientSession(loop=self.loop) - with run_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - data = b'sample data' - - r = self.loop.run_until_complete( - session.request( - 'post', url, data=data, - headers=MultiDict( - {'Content-Length': str(len(data))}))) - content = self.loop.run_until_complete(r.json()) - r.close() - - self.assertEqual(str(len(data)), - content['headers']['Content-Length']) - - session.close() - - def test_dont_close_explicit_connector(self): - - @asyncio.coroutine - def go(url): - connector = aiohttp.TCPConnector(loop=self.loop) - session = client.ClientSession(loop=self.loop, connector=connector) - - r = yield from session.request('GET', url) - yield from r.read() - self.assertEqual(1, len(connector._conns)) - connector.close() - session.close() - - with run_server(self.loop, router=Functional) as httpd: - url = httpd.url('keepalive') - self.loop.run_until_complete(go(url)) - - def test_server_close_keepalive_connection(self): - - class Proto(asyncio.Protocol): - - def connection_made(self, transport): - self.transp = transport - self.data = b'' - - def data_received(self, data): - self.data += data - if data.endswith(b'\r\n\r\n'): - self.transp.write( - b'HTTP/1.1 200 OK\r\n' - b'CONTENT-LENGTH: 2\r\n' - b'CONNECTION: close\r\n' - b'\r\n' - b'ok') - self.transp.close() - - def connection_lost(self, exc): - self.transp = None - - @asyncio.coroutine - def go(): - server = yield from self.loop.create_server( - Proto, '127.0.0.1', unused_port()) - - addr = server.sockets[0].getsockname() - - connector = aiohttp.TCPConnector(loop=self.loop, limit=1) - session = client.ClientSession(loop=self.loop, connector=connector) - - url = 'http://{}:{}/'.format(*addr) - for i in range(2): - r = yield from session.request('GET', url) - yield from r.read() - self.assertEqual(0, len(connector._conns)) - session.close() - connector.close() - server.close() - yield from server.wait_closed() - - self.loop.run_until_complete(go()) - - def test_handle_keepalive_on_closed_connection(self): - - class Proto(asyncio.Protocol): - - def connection_made(self, transport): - self.transp = transport - self.data = b'' - - def data_received(self, data): - self.data += data - if data.endswith(b'\r\n\r\n'): - self.transp.write( - b'HTTP/1.1 200 OK\r\n' - b'CONTENT-LENGTH: 2\r\n' - b'\r\n' - b'ok') - self.transp.close() - - def connection_lost(self, exc): - self.transp = None - - @asyncio.coroutine - def go(): - server = yield from self.loop.create_server( - Proto, '127.0.0.1', unused_port()) - - addr = server.sockets[0].getsockname() - - connector = aiohttp.TCPConnector(loop=self.loop, limit=1) - session = client.ClientSession(loop=self.loop, connector=connector) - - url = 'http://{}:{}/'.format(*addr) - - r = yield from session.request('GET', url) - yield from r.read() - self.assertEqual(1, len(connector._conns)) - - with self.assertRaises(aiohttp.ServerDisconnectedError): - yield from session.request('GET', url) - self.assertEqual(0, len(connector._conns)) - - session.close() - connector.close() - server.close() - yield from server.wait_closed() - - self.loop.run_until_complete(go()) - - @mock.patch('aiohttp.client_reqrep.client_logger') - def test_session_cookies(self, m_log): - with run_server(self.loop, router=Functional) as httpd: - session = client.ClientSession(loop=self.loop) - - resp = self.loop.run_until_complete( - session.request('get', httpd.url('cookies'))) - self.assertEqual(resp.cookies['c1'].value, 'cookie1') - self.assertEqual(resp.cookies['c2'].value, 'cookie2') - resp.close() - - # Add the received cookies as shared for sending them to the test - # server, which is only accessible via IP - session.cookie_jar.update_cookies(resp.cookies) - - # Assert, that we send those cookies in next requests - r = self.loop.run_until_complete( - session.request('get', httpd.url('method', 'get'))) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.json()) - self.assertEqual( - content['headers']['Cookie'], 'c1=cookie1; c2=cookie2') - r.close() - session.close() - - def test_session_headers(self): - with run_server(self.loop, router=Functional) as httpd: - session = client.ClientSession( - loop=self.loop, headers={ - "X-Real-IP": "192.168.0.1" - }) - - r = self.loop.run_until_complete( - session.request('get', httpd.url('method', 'get'))) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.json()) - self.assertIn( - "X-Real-Ip", content['headers']) - self.assertEqual( - content['headers']["X-Real-Ip"], "192.168.0.1") - r.close() - session.close() - - def test_session_headers_merge(self): - with run_server(self.loop, router=Functional) as httpd: - session = client.ClientSession( - loop=self.loop, headers=[ - ("X-Real-IP", "192.168.0.1"), - ("X-Sent-By", "requests")]) - - r = self.loop.run_until_complete( - session.request('get', httpd.url('method', 'get'), - headers={"X-Sent-By": "aiohttp"})) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.json()) - self.assertIn( - "X-Real-Ip", content['headers']) - self.assertIn( - "X-Sent-By", content['headers']) - self.assertEqual( - content['headers']["X-Real-Ip"], "192.168.0.1") - self.assertEqual( - content['headers']["X-Sent-By"], "aiohttp") - r.close() - session.close() - - def test_session_auth(self): - with run_server(self.loop, router=Functional) as httpd: - session = client.ClientSession( - loop=self.loop, auth=helpers.BasicAuth("login", "pass")) - - r = self.loop.run_until_complete( - session.request('get', httpd.url('method', 'get'))) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.json()) - self.assertIn( - "Authorization", content['headers']) - self.assertEqual( - content['headers']["Authorization"], "Basic bG9naW46cGFzcw==") - r.close() - session.close() - - def test_session_auth_override(self): - with run_server(self.loop, router=Functional) as httpd: - session = client.ClientSession( - loop=self.loop, auth=helpers.BasicAuth("login", "pass")) - - r = self.loop.run_until_complete( - session.request('get', httpd.url('method', 'get'), - auth=helpers.BasicAuth("other_login", "pass"))) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.json()) - self.assertIn( - "Authorization", content['headers']) - self.assertEqual( - content['headers']["Authorization"], - "Basic b3RoZXJfbG9naW46cGFzcw==") - r.close() - session.close() - - def test_session_auth_header_conflict(self): - with run_server(self.loop, router=Functional) as httpd: - session = client.ClientSession( - loop=self.loop, auth=helpers.BasicAuth("login", "pass")) - - headers = {'Authorization': "Basic b3RoZXJfbG9naW46cGFzcw=="} - with self.assertRaises(ValueError): - self.loop.run_until_complete( - session.request('get', httpd.url('method', 'get'), - headers=headers)) - session.close() diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index c5a5699809e..85225c77dad 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -1,15 +1,15 @@ -import asyncio from unittest import mock from yarl import URL -from aiohttp.client_exceptions import ClientOSError, ClientResponseError +from aiohttp import http +from aiohttp.client_exceptions import ClientOSError, ServerDisconnectedError from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientResponse +from aiohttp.helpers import TimerNoop -@asyncio.coroutine -def test_oserror(loop): +async def test_oserror(loop) -> None: proto = ResponseHandler(loop=loop) transport = mock.Mock() proto.connection_made(transport) @@ -19,9 +19,10 @@ def test_oserror(loop): assert isinstance(proto.exception(), ClientOSError) -@asyncio.coroutine -def test_pause_resume_on_error(loop): +async def test_pause_resume_on_error(loop) -> None: proto = ResponseHandler(loop=loop) + transport = mock.Mock() + proto.connection_made(transport) proto.pause_reading() assert proto._reading_paused @@ -30,42 +31,106 @@ def test_pause_resume_on_error(loop): assert not proto._reading_paused -@asyncio.coroutine -def test_client_proto_bad_message(loop): +async def test_client_proto_bad_message(loop) -> None: proto = ResponseHandler(loop=loop) transport = mock.Mock() proto.connection_made(transport) - proto.set_response_params(read_until_eof=True) + proto.set_response_params() - proto.data_received(b'HTTP\r\n\r\n') + proto.data_received(b"HTTP\r\n\r\n") assert proto.should_close assert transport.close.called - assert isinstance(proto.exception(), ClientResponseError) + assert isinstance(proto.exception(), http.HttpProcessingError) -@asyncio.coroutine -def test_client_protocol_readuntil_eof(loop): +async def test_uncompleted_message(loop) -> None: + proto = ResponseHandler(loop=loop) + transport = mock.Mock() + proto.connection_made(transport) + proto.set_response_params(read_until_eof=True) + + proto.data_received( + b"HTTP/1.1 301 Moved Permanently\r\n" b"Location: http://python.org/" + ) + proto.connection_lost(None) + + exc = proto.exception() + assert isinstance(exc, ServerDisconnectedError) + assert exc.message.code == 301 + assert dict(exc.message.headers) == {"Location": "http://python.org/"} + + +async def test_client_protocol_readuntil_eof(loop) -> None: proto = ResponseHandler(loop=loop) transport = mock.Mock() proto.connection_made(transport) conn = mock.Mock() conn.protocol = proto - proto.data_received(b'HTTP/1.1 200 Ok\r\n\r\n') - - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) - yield from response.start(conn, read_until_eof=True) + proto.data_received(b"HTTP/1.1 200 Ok\r\n\r\n") + + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + request_info=mock.Mock(), + traces=[], + loop=loop, + session=mock.Mock(), + ) + proto.set_response_params(read_until_eof=True) + await response.start(conn) assert not response.content.is_eof() - proto.data_received(b'0000') - data = yield from response.content.readany() - assert data == b'0000' + proto.data_received(b"0000") + data = await response.content.readany() + assert data == b"0000" - proto.data_received(b'1111') - data = yield from response.content.readany() - assert data == b'1111' + proto.data_received(b"1111") + data = await response.content.readany() + assert data == b"1111" proto.connection_lost(None) assert response.content.is_eof() + + +async def test_empty_data(loop) -> None: + proto = ResponseHandler(loop=loop) + proto.data_received(b"") + + # do nothing + + +async def test_schedule_timeout(loop) -> None: + proto = ResponseHandler(loop=loop) + proto.set_response_params(read_timeout=1) + assert proto._read_timeout_handle is not None + + +async def test_drop_timeout(loop) -> None: + proto = ResponseHandler(loop=loop) + proto.set_response_params(read_timeout=1) + assert proto._read_timeout_handle is not None + proto._drop_timeout() + assert proto._read_timeout_handle is None + + +async def test_reschedule_timeout(loop) -> None: + proto = ResponseHandler(loop=loop) + proto.set_response_params(read_timeout=1) + assert proto._read_timeout_handle is not None + h = proto._read_timeout_handle + proto._reschedule_timeout() + assert proto._read_timeout_handle is not None + assert proto._read_timeout_handle is not h + + +async def test_eof_received(loop) -> None: + proto = ResponseHandler(loop=loop) + proto.set_response_params(read_timeout=1) + assert proto._read_timeout_handle is not None + proto.eof_received() + assert proto._read_timeout_handle is None diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 37775cec0b3..d6500593ab4 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -1,23 +1,29 @@ -# coding: utf-8 - import asyncio +import hashlib import io import os.path import urllib.parse import zlib +from http.cookies import BaseCookie, Morsel, SimpleCookie from unittest import mock import pytest -from multidict import CIMultiDict, CIMultiDictProxy, upstr +from async_generator import async_generator, yield_ +from multidict import CIMultiDict, CIMultiDictProxy, istr from yarl import URL import aiohttp -from aiohttp import BaseConnector, hdrs, helpers, payload -from aiohttp.client_reqrep import ClientRequest, ClientResponse -from aiohttp.helpers import SimpleCookie +from aiohttp import BaseConnector, hdrs, payload +from aiohttp.client_reqrep import ( + ClientRequest, + ClientResponse, + Fingerprint, + _merge_ssl_params, +) +from aiohttp.test_utils import make_mocked_coro -@pytest.yield_fixture +@pytest.fixture def make_request(loop): request = None @@ -36,994 +42,1093 @@ def buf(): return bytearray() -@pytest.yield_fixture +@pytest.fixture +def protocol(loop, transport): + protocol = mock.Mock() + protocol.transport = transport + protocol._drain_helper.return_value = loop.create_future() + protocol._drain_helper.return_value.set_result(None) + return protocol + + +@pytest.fixture def transport(buf): transport = mock.Mock() def write(chunk): buf.extend(chunk) - @asyncio.coroutine - def write_eof(): + async def write_eof(): pass transport.write.side_effect = write transport.write_eof.side_effect = write_eof + transport.is_closing.return_value = False return transport @pytest.fixture -def conn(stream): - return mock.Mock(writer=stream) +def conn(transport, protocol): + return mock.Mock(transport=transport, protocol=protocol) -@pytest.fixture -def stream(buf, transport): - stream = mock.Mock() - stream.transport = transport +def test_method1(make_request) -> None: + req = make_request("get", "http://python.org/") + assert req.method == "GET" - def acquire(writer): - writer.set_transport(transport) - stream.acquire.side_effect = acquire - stream.drain.return_value = () - return stream +def test_method2(make_request) -> None: + req = make_request("head", "http://python.org/") + assert req.method == "HEAD" -def test_method1(make_request): - req = make_request('get', 'http://python.org/') - assert req.method == 'GET' +def test_method3(make_request) -> None: + req = make_request("HEAD", "http://python.org/") + assert req.method == "HEAD" -def test_method2(make_request): - req = make_request('head', 'http://python.org/') - assert req.method == 'HEAD' +def test_version_1_0(make_request) -> None: + req = make_request("get", "http://python.org/", version="1.0") + assert req.version == (1, 0) -def test_method3(make_request): - req = make_request('HEAD', 'http://python.org/') - assert req.method == 'HEAD' +def test_version_default(make_request) -> None: + req = make_request("get", "http://python.org/") + assert req.version == (1, 1) -def test_version_1_0(make_request): - req = make_request('get', 'http://python.org/', version='1.0') - assert req.version == (1, 0) +def test_request_info(make_request) -> None: + req = make_request("get", "http://python.org/") + assert req.request_info == aiohttp.RequestInfo( + URL("http://python.org/"), "GET", req.headers + ) -def test_version_default(make_request): - req = make_request('get', 'http://python.org/') - assert req.version == (1, 1) +def test_request_info_with_fragment(make_request) -> None: + req = make_request("get", "http://python.org/#urlfragment") + assert req.request_info == aiohttp.RequestInfo( + URL("http://python.org/"), + "GET", + req.headers, + URL("http://python.org/#urlfragment"), + ) -def test_version_err(make_request): +def test_version_err(make_request) -> None: with pytest.raises(ValueError): - make_request('get', 'http://python.org/', version='1.c') + make_request("get", "http://python.org/", version="1.c") + +def test_https_proxy(make_request) -> None: + with pytest.raises(ValueError): + make_request("get", "http://python.org/", proxy=URL("https://proxy.org")) -def test_keep_alive(make_request): - req = make_request('get', 'http://python.org/', version=(0, 9)) + +def test_keep_alive(make_request) -> None: + req = make_request("get", "http://python.org/", version=(0, 9)) assert not req.keep_alive() - req = make_request('get', 'http://python.org/', version=(1, 0)) + req = make_request("get", "http://python.org/", version=(1, 0)) assert not req.keep_alive() - req = make_request('get', 'http://python.org/', - version=(1, 0), headers={'connection': 'keep-alive'}) + req = make_request( + "get", + "http://python.org/", + version=(1, 0), + headers={"connection": "keep-alive"}, + ) assert req.keep_alive() - req = make_request('get', 'http://python.org/', version=(1, 1)) + req = make_request("get", "http://python.org/", version=(1, 1)) assert req.keep_alive() - req = make_request('get', 'http://python.org/', - version=(1, 1), headers={'connection': 'close'}) + req = make_request( + "get", "http://python.org/", version=(1, 1), headers={"connection": "close"} + ) assert not req.keep_alive() -def test_host_port_default_http(make_request): - req = make_request('get', 'http://python.org/') - assert req.host == 'python.org' +def test_host_port_default_http(make_request) -> None: + req = make_request("get", "http://python.org/") + assert req.host == "python.org" assert req.port == 80 assert not req.ssl -def test_host_port_default_https(make_request): - req = make_request('get', 'https://python.org/') - assert req.host == 'python.org' +def test_host_port_default_https(make_request) -> None: + req = make_request("get", "https://python.org/") + assert req.host == "python.org" assert req.port == 443 - assert req.ssl + assert req.is_ssl() -def test_host_port_nondefault_http(make_request): - req = make_request('get', 'http://python.org:960/') - assert req.host == 'python.org' +def test_host_port_nondefault_http(make_request) -> None: + req = make_request("get", "http://python.org:960/") + assert req.host == "python.org" assert req.port == 960 - assert not req.ssl + assert not req.is_ssl() -def test_host_port_nondefault_https(make_request): - req = make_request('get', 'https://python.org:960/') - assert req.host == 'python.org' +def test_host_port_nondefault_https(make_request) -> None: + req = make_request("get", "https://python.org:960/") + assert req.host == "python.org" assert req.port == 960 - assert req.ssl + assert req.is_ssl() -def test_host_port_default_ws(make_request): - req = make_request('get', 'ws://python.org/') - assert req.host == 'python.org' +def test_host_port_default_ws(make_request) -> None: + req = make_request("get", "ws://python.org/") + assert req.host == "python.org" assert req.port == 80 - assert not req.ssl + assert not req.is_ssl() -def test_host_port_default_wss(make_request): - req = make_request('get', 'wss://python.org/') - assert req.host == 'python.org' +def test_host_port_default_wss(make_request) -> None: + req = make_request("get", "wss://python.org/") + assert req.host == "python.org" assert req.port == 443 - assert req.ssl + assert req.is_ssl() -def test_host_port_nondefault_ws(make_request): - req = make_request('get', 'ws://python.org:960/') - assert req.host == 'python.org' +def test_host_port_nondefault_ws(make_request) -> None: + req = make_request("get", "ws://python.org:960/") + assert req.host == "python.org" assert req.port == 960 - assert not req.ssl + assert not req.is_ssl() -def test_host_port_nondefault_wss(make_request): - req = make_request('get', 'wss://python.org:960/') - assert req.host == 'python.org' +def test_host_port_nondefault_wss(make_request) -> None: + req = make_request("get", "wss://python.org:960/") + assert req.host == "python.org" assert req.port == 960 - assert req.ssl + assert req.is_ssl() + + +def test_host_port_none_port(make_request) -> None: + req = make_request("get", "unix://localhost/path") + assert req.headers["Host"] == "localhost" -def test_host_port_err(make_request): +def test_host_port_err(make_request) -> None: with pytest.raises(ValueError): - make_request('get', 'http://python.org:123e/') + make_request("get", "http://python.org:123e/") -def test_hostname_err(make_request): +def test_hostname_err(make_request) -> None: with pytest.raises(ValueError): - make_request('get', 'http://:8080/') + make_request("get", "http://:8080/") -def test_host_header_host_without_port(make_request): - req = make_request('get', 'http://python.org/') - assert req.headers['HOST'] == 'python.org' +def test_host_header_host_first(make_request) -> None: + req = make_request("get", "http://python.org/") + assert list(req.headers)[0] == "Host" -def test_host_header_host_with_default_port(make_request): - req = make_request('get', 'http://python.org:80/') - assert req.headers['HOST'] == 'python.org' +def test_host_header_host_without_port(make_request) -> None: + req = make_request("get", "http://python.org/") + assert req.headers["HOST"] == "python.org" -def test_host_header_host_with_nondefault_port(make_request): - req = make_request('get', 'http://python.org:99/') - assert req.headers['HOST'] == 'python.org:99' +def test_host_header_host_with_default_port(make_request) -> None: + req = make_request("get", "http://python.org:80/") + assert req.headers["HOST"] == "python.org" -def test_host_header_host_idna_encode(make_request): - req = make_request('get', 'http://xn--9caa.com') - assert req.headers['HOST'] == 'xn--9caa.com' +def test_host_header_host_with_nondefault_port(make_request) -> None: + req = make_request("get", "http://python.org:99/") + assert req.headers["HOST"] == "python.org:99" -def test_host_header_host_unicode(make_request): - req = make_request('get', 'http://éé.com') - assert req.headers['HOST'] == 'xn--9caa.com' +def test_host_header_host_idna_encode(make_request) -> None: + req = make_request("get", "http://xn--9caa.com") + assert req.headers["HOST"] == "xn--9caa.com" -def test_host_header_explicit_host(make_request): - req = make_request('get', 'http://python.org/', - headers={'host': 'example.com'}) - assert req.headers['HOST'] == 'example.com' +def test_host_header_host_unicode(make_request) -> None: + req = make_request("get", "http://éé.com") + assert req.headers["HOST"] == "xn--9caa.com" -def test_host_header_explicit_host_with_port(make_request): - req = make_request('get', 'http://python.org/', - headers={'host': 'example.com:99'}) - assert req.headers['HOST'] == 'example.com:99' +def test_host_header_explicit_host(make_request) -> None: + req = make_request("get", "http://python.org/", headers={"host": "example.com"}) + assert req.headers["HOST"] == "example.com" -def test_default_loop(loop): - asyncio.set_event_loop(loop) - req = ClientRequest('get', URL('http://python.org/')) - assert req.loop is loop +def test_host_header_explicit_host_with_port(make_request) -> None: + req = make_request("get", "http://python.org/", headers={"host": "example.com:99"}) + assert req.headers["HOST"] == "example.com:99" -def test_default_headers_useragent(make_request): - req = make_request('get', 'http://python.org/') +def test_host_header_ipv4(make_request) -> None: + req = make_request("get", "http://127.0.0.2") + assert req.headers["HOST"] == "127.0.0.2" - assert 'SERVER' not in req.headers - assert 'USER-AGENT' in req.headers +def test_host_header_ipv6(make_request) -> None: + req = make_request("get", "http://[::2]") + assert req.headers["HOST"] == "[::2]" -def test_default_headers_useragent_custom(make_request): - req = make_request('get', 'http://python.org/', - headers={'user-agent': 'my custom agent'}) - assert 'USER-Agent' in req.headers - assert 'my custom agent' == req.headers['User-Agent'] +def test_host_header_ipv4_with_port(make_request) -> None: + req = make_request("get", "http://127.0.0.2:99") + assert req.headers["HOST"] == "127.0.0.2:99" -def test_skip_default_useragent_header(make_request): - req = make_request('get', 'http://python.org/', - skip_auto_headers=set([upstr('user-agent')])) +def test_host_header_ipv6_with_port(make_request) -> None: + req = make_request("get", "http://[::2]:99") + assert req.headers["HOST"] == "[::2]:99" - assert 'User-Agent' not in req.headers +def test_default_loop(loop) -> None: + asyncio.set_event_loop(loop) + req = ClientRequest("get", URL("http://python.org/")) + assert req.loop is loop -def test_headers(make_request): - req = make_request('get', 'http://python.org/', - headers={'Content-Type': 'text/plain'}) - assert 'CONTENT-TYPE' in req.headers - assert req.headers['CONTENT-TYPE'] == 'text/plain' - assert req.headers['ACCEPT-ENCODING'] == 'gzip, deflate' +def test_default_headers_useragent(make_request) -> None: + req = make_request("get", "http://python.org/") + assert "SERVER" not in req.headers + assert "USER-AGENT" in req.headers -def test_headers_list(make_request): - req = make_request('get', 'http://python.org/', - headers=[('Content-Type', 'text/plain')]) - assert 'CONTENT-TYPE' in req.headers - assert req.headers['CONTENT-TYPE'] == 'text/plain' +def test_default_headers_useragent_custom(make_request) -> None: + req = make_request( + "get", "http://python.org/", headers={"user-agent": "my custom agent"} + ) -def test_headers_default(make_request): - req = make_request('get', 'http://python.org/', - headers={'ACCEPT-ENCODING': 'deflate'}) - assert req.headers['ACCEPT-ENCODING'] == 'deflate' + assert "USER-Agent" in req.headers + assert "my custom agent" == req.headers["User-Agent"] -def test_invalid_url(make_request): - with pytest.raises(ValueError): - make_request('get', 'hiwpefhipowhefopw') +def test_skip_default_useragent_header(make_request) -> None: + req = make_request( + "get", "http://python.org/", skip_auto_headers={istr("user-agent")} + ) + assert "User-Agent" not in req.headers -def test_invalid_idna(make_request): - with pytest.raises(ValueError): - make_request('get', 'http://\u2061owhefopw.com') +def test_headers(make_request) -> None: + req = make_request( + "post", "http://python.org/", headers={"Content-Type": "text/plain"} + ) -def test_no_path(make_request): - req = make_request('get', 'http://python.org') - assert '/' == req.url.path + assert "CONTENT-TYPE" in req.headers + assert req.headers["CONTENT-TYPE"] == "text/plain" + assert req.headers["ACCEPT-ENCODING"] == "gzip, deflate" -def test_ipv6_default_http_port(make_request): - req = make_request('get', 'http://[2001:db8::1]/') - assert req.host == '2001:db8::1' +def test_headers_list(make_request) -> None: + req = make_request( + "post", "http://python.org/", headers=[("Content-Type", "text/plain")] + ) + assert "CONTENT-TYPE" in req.headers + assert req.headers["CONTENT-TYPE"] == "text/plain" + + +def test_headers_default(make_request) -> None: + req = make_request( + "get", "http://python.org/", headers={"ACCEPT-ENCODING": "deflate"} + ) + assert req.headers["ACCEPT-ENCODING"] == "deflate" + + +def test_invalid_url(make_request) -> None: + with pytest.raises(aiohttp.InvalidURL): + make_request("get", "hiwpefhipowhefopw") + + +def test_no_path(make_request) -> None: + req = make_request("get", "http://python.org") + assert "/" == req.url.path + + +def test_ipv6_default_http_port(make_request) -> None: + req = make_request("get", "http://[2001:db8::1]/") + assert req.host == "2001:db8::1" assert req.port == 80 assert not req.ssl -def test_ipv6_default_https_port(make_request): - req = make_request('get', 'https://[2001:db8::1]/') - assert req.host == '2001:db8::1' +def test_ipv6_default_https_port(make_request) -> None: + req = make_request("get", "https://[2001:db8::1]/") + assert req.host == "2001:db8::1" assert req.port == 443 - assert req.ssl + assert req.is_ssl() -def test_ipv6_nondefault_http_port(make_request): - req = make_request('get', 'http://[2001:db8::1]:960/') - assert req.host == '2001:db8::1' +def test_ipv6_nondefault_http_port(make_request) -> None: + req = make_request("get", "http://[2001:db8::1]:960/") + assert req.host == "2001:db8::1" assert req.port == 960 - assert not req.ssl + assert not req.is_ssl() -def test_ipv6_nondefault_https_port(make_request): - req = make_request('get', 'https://[2001:db8::1]:960/') - assert req.host == '2001:db8::1' +def test_ipv6_nondefault_https_port(make_request) -> None: + req = make_request("get", "https://[2001:db8::1]:960/") + assert req.host == "2001:db8::1" assert req.port == 960 - assert req.ssl + assert req.is_ssl() -def test_basic_auth(make_request): - req = make_request('get', 'http://python.org', - auth=aiohttp.helpers.BasicAuth('nkim', '1234')) - assert 'AUTHORIZATION' in req.headers - assert 'Basic bmtpbToxMjM0' == req.headers['AUTHORIZATION'] +def test_basic_auth(make_request) -> None: + req = make_request( + "get", "http://python.org", auth=aiohttp.BasicAuth("nkim", "1234") + ) + assert "AUTHORIZATION" in req.headers + assert "Basic bmtpbToxMjM0" == req.headers["AUTHORIZATION"] -def test_basic_auth_utf8(make_request): - req = make_request('get', 'http://python.org', - auth=aiohttp.helpers.BasicAuth('nkim', 'секрет', - 'utf-8')) - assert 'AUTHORIZATION' in req.headers - assert 'Basic bmtpbTrRgdC10LrRgNC10YI=' == req.headers['AUTHORIZATION'] +def test_basic_auth_utf8(make_request) -> None: + req = make_request( + "get", "http://python.org", auth=aiohttp.BasicAuth("nkim", "секрет", "utf-8") + ) + assert "AUTHORIZATION" in req.headers + assert "Basic bmtpbTrRgdC10LrRgNC10YI=" == req.headers["AUTHORIZATION"] -def test_basic_auth_tuple_forbidden(make_request): +def test_basic_auth_tuple_forbidden(make_request) -> None: with pytest.raises(TypeError): - make_request('get', 'http://python.org', - auth=('nkim', '1234')) + make_request("get", "http://python.org", auth=("nkim", "1234")) -def test_basic_auth_from_url(make_request): - req = make_request('get', 'http://nkim:1234@python.org') - assert 'AUTHORIZATION' in req.headers - assert 'Basic bmtpbToxMjM0' == req.headers['AUTHORIZATION'] - assert 'python.org' == req.host +def test_basic_auth_from_url(make_request) -> None: + req = make_request("get", "http://nkim:1234@python.org") + assert "AUTHORIZATION" in req.headers + assert "Basic bmtpbToxMjM0" == req.headers["AUTHORIZATION"] + assert "python.org" == req.host -def test_basic_auth_from_url_overriden(make_request): - req = make_request('get', 'http://garbage@python.org', - auth=aiohttp.BasicAuth('nkim', '1234')) - assert 'AUTHORIZATION' in req.headers - assert 'Basic bmtpbToxMjM0' == req.headers['AUTHORIZATION'] - assert 'python.org' == req.host +def test_basic_auth_from_url_overridden(make_request) -> None: + req = make_request( + "get", "http://garbage@python.org", auth=aiohttp.BasicAuth("nkim", "1234") + ) + assert "AUTHORIZATION" in req.headers + assert "Basic bmtpbToxMjM0" == req.headers["AUTHORIZATION"] + assert "python.org" == req.host -def test_path_is_not_double_encoded1(make_request): - req = make_request('get', "http://0.0.0.0/get/test case") +def test_path_is_not_double_encoded1(make_request) -> None: + req = make_request("get", "http://0.0.0.0/get/test case") assert req.url.raw_path == "/get/test%20case" -def test_path_is_not_double_encoded2(make_request): - req = make_request('get', "http://0.0.0.0/get/test%2fcase") +def test_path_is_not_double_encoded2(make_request) -> None: + req = make_request("get", "http://0.0.0.0/get/test%2fcase") assert req.url.raw_path == "/get/test%2Fcase" -def test_path_is_not_double_encoded3(make_request): - req = make_request('get', "http://0.0.0.0/get/test%20case") +def test_path_is_not_double_encoded3(make_request) -> None: + req = make_request("get", "http://0.0.0.0/get/test%20case") assert req.url.raw_path == "/get/test%20case" -def test_path_safe_chars_preserved(make_request): - req = make_request('get', "http://0.0.0.0/get/:=+/%2B/") +def test_path_safe_chars_preserved(make_request) -> None: + req = make_request("get", "http://0.0.0.0/get/:=+/%2B/") assert req.url.path == "/get/:=+/+/" -def test_params_are_added_before_fragment1(make_request): - req = make_request('GET', "http://example.com/path#fragment", - params={"a": "b"}) +def test_params_are_added_before_fragment1(make_request) -> None: + req = make_request("GET", "http://example.com/path#fragment", params={"a": "b"}) assert str(req.url) == "http://example.com/path?a=b" -def test_params_are_added_before_fragment2(make_request): - req = make_request('GET', "http://example.com/path?key=value#fragment", - params={"a": "b"}) +def test_params_are_added_before_fragment2(make_request) -> None: + req = make_request( + "GET", "http://example.com/path?key=value#fragment", params={"a": "b"} + ) assert str(req.url) == "http://example.com/path?key=value&a=b" -def test_path_not_contain_fragment1(make_request): - req = make_request('GET', "http://example.com/path#fragment") +def test_path_not_contain_fragment1(make_request) -> None: + req = make_request("GET", "http://example.com/path#fragment") assert req.url.path == "/path" -def test_path_not_contain_fragment2(make_request): - req = make_request('GET', "http://example.com/path?key=value#fragment") +def test_path_not_contain_fragment2(make_request) -> None: + req = make_request("GET", "http://example.com/path?key=value#fragment") assert str(req.url) == "http://example.com/path?key=value" -def test_cookies(make_request): - req = make_request('get', 'http://test.com/path', - cookies={'cookie1': 'val1'}) +def test_cookies(make_request) -> None: + req = make_request("get", "http://test.com/path", cookies={"cookie1": "val1"}) + + assert "COOKIE" in req.headers + assert "cookie1=val1" == req.headers["COOKIE"] + + +def test_cookies_is_quoted_with_special_characters(make_request) -> None: + req = make_request("get", "http://test.com/path", cookies={"cookie1": "val/one"}) - assert 'COOKIE' in req.headers - assert 'cookie1=val1' == req.headers['COOKIE'] + assert "COOKIE" in req.headers + assert 'cookie1="val/one"' == req.headers["COOKIE"] -def test_cookies_merge_with_headers(make_request): - req = make_request('get', 'http://test.com/path', - headers={'cookie': 'cookie1=val1'}, - cookies={'cookie2': 'val2'}) +def test_cookies_merge_with_headers(make_request) -> None: + req = make_request( + "get", + "http://test.com/path", + headers={"cookie": "cookie1=val1"}, + cookies={"cookie2": "val2"}, + ) - assert 'cookie1=val1; cookie2=val2' == req.headers['COOKIE'] + assert "cookie1=val1; cookie2=val2" == req.headers["COOKIE"] -def test_unicode_get1(make_request): - req = make_request('get', 'http://python.org', - params={'foo': 'f\xf8\xf8'}) - assert 'http://python.org/?foo=f%C3%B8%C3%B8' == str(req.url) +def test_unicode_get1(make_request) -> None: + req = make_request("get", "http://python.org", params={"foo": "f\xf8\xf8"}) + assert "http://python.org/?foo=f%C3%B8%C3%B8" == str(req.url) -def test_unicode_get2(make_request): - req = make_request('', 'http://python.org', - params={'f\xf8\xf8': 'f\xf8\xf8'}) +def test_unicode_get2(make_request) -> None: + req = make_request("", "http://python.org", params={"f\xf8\xf8": "f\xf8\xf8"}) - assert 'http://python.org/?f%C3%B8%C3%B8=f%C3%B8%C3%B8' == str(req.url) + assert "http://python.org/?f%C3%B8%C3%B8=f%C3%B8%C3%B8" == str(req.url) -def test_unicode_get3(make_request): - req = make_request('', 'http://python.org', params={'foo': 'foo'}) - assert 'http://python.org/?foo=foo' == str(req.url) +def test_unicode_get3(make_request) -> None: + req = make_request("", "http://python.org", params={"foo": "foo"}) + assert "http://python.org/?foo=foo" == str(req.url) -def test_unicode_get4(make_request): +def test_unicode_get4(make_request) -> None: def join(*suffix): - return urllib.parse.urljoin('http://python.org/', '/'.join(suffix)) + return urllib.parse.urljoin("http://python.org/", "/".join(suffix)) - req = make_request('', join('\xf8'), params={'foo': 'foo'}) - assert 'http://python.org/%C3%B8?foo=foo' == str(req.url) + req = make_request("", join("\xf8"), params={"foo": "foo"}) + assert "http://python.org/%C3%B8?foo=foo" == str(req.url) -def test_query_multivalued_param(make_request): +def test_query_multivalued_param(make_request) -> None: for meth in ClientRequest.ALL_METHODS: req = make_request( - meth, 'http://python.org', - params=(('test', 'foo'), ('test', 'baz'))) + meth, "http://python.org", params=(("test", "foo"), ("test", "baz")) + ) - assert str(req.url) == 'http://python.org/?test=foo&test=baz' + assert str(req.url) == "http://python.org/?test=foo&test=baz" -def test_query_str_param(make_request): +def test_query_str_param(make_request) -> None: for meth in ClientRequest.ALL_METHODS: - req = make_request(meth, 'http://python.org', params='test=foo') - assert str(req.url) == 'http://python.org/?test=foo' + req = make_request(meth, "http://python.org", params="test=foo") + assert str(req.url) == "http://python.org/?test=foo" -def test_query_bytes_param_raises(make_request): +def test_query_bytes_param_raises(make_request) -> None: for meth in ClientRequest.ALL_METHODS: with pytest.raises(TypeError): - make_request(meth, 'http://python.org', params=b'test=foo') + make_request(meth, "http://python.org", params=b"test=foo") -def test_query_str_param_is_not_encoded(make_request): +def test_query_str_param_is_not_encoded(make_request) -> None: for meth in ClientRequest.ALL_METHODS: - req = make_request(meth, 'http://python.org', params='test=f+oo') - assert str(req.url) == 'http://python.org/?test=f+oo' - - -def test_params_update_path_and_url(make_request): - req = make_request('get', 'http://python.org', - params=(('test', 'foo'), ('test', 'baz'))) - assert str(req.url) == 'http://python.org/?test=foo&test=baz' - - -def test_params_empty_path_and_url(make_request): - req_empty = make_request('get', 'http://python.org', params={}) - assert str(req_empty.url) == 'http://python.org' - req_none = make_request('get', 'http://python.org') - assert str(req_none.url) == 'http://python.org' - - -def test_gen_netloc_all(make_request): - req = make_request('get', - 'https://aiohttp:pwpwpw@' + - '12345678901234567890123456789' + - '012345678901234567890:8080') - assert req.headers['HOST'] == '12345678901234567890123456789' +\ - '012345678901234567890:8080' - - -def test_gen_netloc_no_port(make_request): - req = make_request('get', - 'https://aiohttp:pwpwpw@' + - '12345678901234567890123456789' + - '012345678901234567890/') - assert req.headers['HOST'] == '12345678901234567890123456789' +\ - '012345678901234567890' - - -@asyncio.coroutine -def test_connection_header(loop, conn): - req = ClientRequest('get', URL('http://python.org'), loop=loop) + req = make_request(meth, "http://python.org", params="test=f+oo") + assert str(req.url) == "http://python.org/?test=f+oo" + + +def test_params_update_path_and_url(make_request) -> None: + req = make_request( + "get", "http://python.org", params=(("test", "foo"), ("test", "baz")) + ) + assert str(req.url) == "http://python.org/?test=foo&test=baz" + + +def test_params_empty_path_and_url(make_request) -> None: + req_empty = make_request("get", "http://python.org", params={}) + assert str(req_empty.url) == "http://python.org" + req_none = make_request("get", "http://python.org") + assert str(req_none.url) == "http://python.org" + + +def test_gen_netloc_all(make_request) -> None: + req = make_request( + "get", + "https://aiohttp:pwpwpw@" + + "12345678901234567890123456789" + + "012345678901234567890:8080", + ) + assert ( + req.headers["HOST"] + == "12345678901234567890123456789" + "012345678901234567890:8080" + ) + + +def test_gen_netloc_no_port(make_request) -> None: + req = make_request( + "get", + "https://aiohttp:pwpwpw@" + + "12345678901234567890123456789" + + "012345678901234567890/", + ) + assert ( + req.headers["HOST"] == "12345678901234567890123456789" + "012345678901234567890" + ) + + +async def test_connection_header(loop, conn) -> None: + req = ClientRequest("get", URL("http://python.org"), loop=loop) req.keep_alive = mock.Mock() req.headers.clear() req.keep_alive.return_value = True req.version = (1, 1) req.headers.clear() - req.send(conn) - assert req.headers.get('CONNECTION') is None + await req.send(conn) + assert req.headers.get("CONNECTION") is None req.version = (1, 0) req.headers.clear() - req.send(conn) - assert req.headers.get('CONNECTION') == 'keep-alive' + await req.send(conn) + assert req.headers.get("CONNECTION") == "keep-alive" req.keep_alive.return_value = False req.version = (1, 1) req.headers.clear() - req.send(conn) - assert req.headers.get('CONNECTION') == 'close' + await req.send(conn) + assert req.headers.get("CONNECTION") == "close" -@asyncio.coroutine -def test_no_content_length(loop, conn): - req = ClientRequest('get', URL('http://python.org'), loop=loop) - resp = req.send(conn) - assert '0' == req.headers.get('CONTENT-LENGTH') - yield from req.close() +async def test_no_content_length(loop, conn) -> None: + req = ClientRequest("get", URL("http://python.org"), loop=loop) + resp = await req.send(conn) + assert req.headers.get("CONTENT-LENGTH") is None + await req.close() resp.close() -@asyncio.coroutine -def test_no_content_length2(loop, conn): - req = ClientRequest('head', URL('http://python.org'), loop=loop) - resp = req.send(conn) - assert '0' == req.headers.get('CONTENT-LENGTH') - yield from req.close() +async def test_no_content_length_head(loop, conn) -> None: + req = ClientRequest("head", URL("http://python.org"), loop=loop) + resp = await req.send(conn) + assert req.headers.get("CONTENT-LENGTH") is None + await req.close() resp.close() -def test_content_type_auto_header_get(loop, conn): - req = ClientRequest('get', URL('http://python.org'), loop=loop) - resp = req.send(conn) - assert 'CONTENT-TYPE' not in req.headers +async def test_content_type_auto_header_get(loop, conn) -> None: + req = ClientRequest("get", URL("http://python.org"), loop=loop) + resp = await req.send(conn) + assert "CONTENT-TYPE" not in req.headers resp.close() -def test_content_type_auto_header_form(loop, conn): - req = ClientRequest('post', URL('http://python.org'), - data={'hey': 'you'}, loop=loop) - resp = req.send(conn) - assert 'application/x-www-form-urlencoded' == \ - req.headers.get('CONTENT-TYPE') +async def test_content_type_auto_header_form(loop, conn) -> None: + req = ClientRequest( + "post", URL("http://python.org"), data={"hey": "you"}, loop=loop + ) + resp = await req.send(conn) + assert "application/x-www-form-urlencoded" == req.headers.get("CONTENT-TYPE") resp.close() -def test_content_type_auto_header_bytes(loop, conn): - req = ClientRequest('post', URL('http://python.org'), data=b'hey you', - loop=loop) - resp = req.send(conn) - assert 'application/octet-stream' == req.headers.get('CONTENT-TYPE') +async def test_content_type_auto_header_bytes(loop, conn) -> None: + req = ClientRequest("post", URL("http://python.org"), data=b"hey you", loop=loop) + resp = await req.send(conn) + assert "application/octet-stream" == req.headers.get("CONTENT-TYPE") resp.close() -def test_content_type_skip_auto_header_bytes(loop, conn): - req = ClientRequest('post', URL('http://python.org'), data=b'hey you', - skip_auto_headers={'Content-Type'}, - loop=loop) - resp = req.send(conn) - assert 'CONTENT-TYPE' not in req.headers +async def test_content_type_skip_auto_header_bytes(loop, conn) -> None: + req = ClientRequest( + "post", + URL("http://python.org"), + data=b"hey you", + skip_auto_headers={"Content-Type"}, + loop=loop, + ) + resp = await req.send(conn) + assert "CONTENT-TYPE" not in req.headers resp.close() -def test_content_type_skip_auto_header_form(loop, conn): - req = ClientRequest('post', URL('http://python.org'), - data={'hey': 'you'}, loop=loop, - skip_auto_headers={'Content-Type'}) - resp = req.send(conn) - assert 'CONTENT-TYPE' not in req.headers +async def test_content_type_skip_auto_header_form(loop, conn) -> None: + req = ClientRequest( + "post", + URL("http://python.org"), + data={"hey": "you"}, + loop=loop, + skip_auto_headers={"Content-Type"}, + ) + resp = await req.send(conn) + assert "CONTENT-TYPE" not in req.headers resp.close() -def test_content_type_auto_header_content_length_no_skip(loop, conn): - req = ClientRequest('get', URL('http://python.org'), - data=io.BytesIO(b'hey'), - skip_auto_headers={'Content-Length'}, - loop=loop) - resp = req.send(conn) - assert req.headers.get('CONTENT-LENGTH') == '3' +async def test_content_type_auto_header_content_length_no_skip(loop, conn) -> None: + req = ClientRequest( + "post", + URL("http://python.org"), + data=io.BytesIO(b"hey"), + skip_auto_headers={"Content-Length"}, + loop=loop, + ) + resp = await req.send(conn) + assert req.headers.get("CONTENT-LENGTH") == "3" resp.close() -def test_urlencoded_formdata_charset(loop, conn): +async def test_urlencoded_formdata_charset(loop, conn) -> None: req = ClientRequest( - 'post', URL('http://python.org'), - data=aiohttp.FormData({'hey': 'you'}, charset='koi8-r'), loop=loop) - req.send(conn) - assert 'application/x-www-form-urlencoded; charset=koi8-r' == \ - req.headers.get('CONTENT-TYPE') - - -@asyncio.coroutine -def test_post_data(loop, conn): + "post", + URL("http://python.org"), + data=aiohttp.FormData({"hey": "you"}, charset="koi8-r"), + loop=loop, + ) + await req.send(conn) + assert "application/x-www-form-urlencoded; charset=koi8-r" == req.headers.get( + "CONTENT-TYPE" + ) + + +async def test_post_data(loop, conn) -> None: for meth in ClientRequest.POST_METHODS: req = ClientRequest( - meth, URL('http://python.org/'), - data={'life': '42'}, loop=loop) - resp = req.send(conn) - assert '/' == req.url.path - assert b'life=42' == req.body._value - assert 'application/x-www-form-urlencoded' ==\ - req.headers['CONTENT-TYPE'] - yield from req.close() + meth, URL("http://python.org/"), data={"life": "42"}, loop=loop + ) + resp = await req.send(conn) + assert "/" == req.url.path + assert b"life=42" == req.body._value + assert "application/x-www-form-urlencoded" == req.headers["CONTENT-TYPE"] + await req.close() resp.close() -@asyncio.coroutine -def test_pass_falsy_data(loop): - with mock.patch( - 'aiohttp.client_reqrep.ClientRequest.update_body_from_data'): - req = ClientRequest( - 'post', URL('http://python.org/'), - data={}, loop=loop) - req.update_body_from_data.assert_called_once_with({}, frozenset()) - yield from req.close() +async def test_pass_falsy_data(loop) -> None: + with mock.patch("aiohttp.client_reqrep.ClientRequest.update_body_from_data"): + req = ClientRequest("post", URL("http://python.org/"), data={}, loop=loop) + req.update_body_from_data.assert_called_once_with({}) + await req.close() -@asyncio.coroutine -def test_pass_falsy_data_file(loop, tmpdir): - testfile = tmpdir.join('tmpfile').open('w+b') - testfile.write(b'data') +async def test_pass_falsy_data_file(loop, tmpdir) -> None: + testfile = tmpdir.join("tmpfile").open("w+b") + testfile.write(b"data") testfile.seek(0) skip = frozenset([hdrs.CONTENT_TYPE]) req = ClientRequest( - 'post', URL('http://python.org/'), + "post", + URL("http://python.org/"), data=testfile, skip_auto_headers=skip, - loop=loop) - assert req.headers.get('CONTENT-LENGTH', None) is not None - yield from req.close() + loop=loop, + ) + assert req.headers.get("CONTENT-LENGTH", None) is not None + await req.close() -@asyncio.coroutine -def test_get_with_data(loop): +# Elasticsearch API requires to send request body with GET-requests +async def test_get_with_data(loop) -> None: for meth in ClientRequest.GET_METHODS: req = ClientRequest( - meth, URL('http://python.org/'), data={'life': '42'}, - loop=loop) - assert '/' == req.url.path - assert b'life=42' == req.body._value - yield from req.close() + meth, URL("http://python.org/"), data={"life": "42"}, loop=loop + ) + assert "/" == req.url.path + assert b"life=42" == req.body._value + await req.close() -@asyncio.coroutine -def test_bytes_data(loop, conn): +async def test_bytes_data(loop, conn) -> None: for meth in ClientRequest.POST_METHODS: req = ClientRequest( - meth, URL('http://python.org/'), - data=b'binary data', loop=loop) - resp = req.send(conn) - assert '/' == req.url.path + meth, URL("http://python.org/"), data=b"binary data", loop=loop + ) + resp = await req.send(conn) + assert "/" == req.url.path assert isinstance(req.body, payload.BytesPayload) - assert b'binary data' == req.body._value - assert 'application/octet-stream' == req.headers['CONTENT-TYPE'] - yield from req.close() + assert b"binary data" == req.body._value + assert "application/octet-stream" == req.headers["CONTENT-TYPE"] + await req.close() resp.close() -@asyncio.coroutine -def test_content_encoding(loop, conn): - req = ClientRequest('get', URL('http://python.org/'), data='foo', - compress='deflate', loop=loop) - with mock.patch('aiohttp.client_reqrep.PayloadWriter') as m_writer: - resp = req.send(conn) - assert req.headers['TRANSFER-ENCODING'] == 'chunked' - assert req.headers['CONTENT-ENCODING'] == 'deflate' - m_writer.return_value\ - .enable_compression.assert_called_with('deflate') - yield from req.close() +async def test_content_encoding(loop, conn) -> None: + req = ClientRequest( + "post", URL("http://python.org/"), data="foo", compress="deflate", loop=loop + ) + with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: + m_writer.return_value.write_headers = make_mocked_coro() + resp = await req.send(conn) + assert req.headers["TRANSFER-ENCODING"] == "chunked" + assert req.headers["CONTENT-ENCODING"] == "deflate" + m_writer.return_value.enable_compression.assert_called_with("deflate") + await req.close() resp.close() -@asyncio.coroutine -def test_content_encoding_dont_set_headers_if_no_body(loop, conn): - req = ClientRequest('get', URL('http://python.org/'), - compress='deflate', loop=loop) - with mock.patch('aiohttp.client_reqrep.http'): - resp = req.send(conn) - assert 'TRANSFER-ENCODING' not in req.headers - assert 'CONTENT-ENCODING' not in req.headers - yield from req.close() +async def test_content_encoding_dont_set_headers_if_no_body(loop, conn) -> None: + req = ClientRequest( + "post", URL("http://python.org/"), compress="deflate", loop=loop + ) + with mock.patch("aiohttp.client_reqrep.http"): + resp = await req.send(conn) + assert "TRANSFER-ENCODING" not in req.headers + assert "CONTENT-ENCODING" not in req.headers + await req.close() resp.close() -@asyncio.coroutine -def test_content_encoding_header(loop, conn): +async def test_content_encoding_header(loop, conn) -> None: req = ClientRequest( - 'get', URL('http://python.org/'), data='foo', - headers={'Content-Encoding': 'deflate'}, loop=loop) - with mock.patch('aiohttp.client_reqrep.PayloadWriter') as m_writer: - resp = req.send(conn) + "post", + URL("http://python.org/"), + data="foo", + headers={"Content-Encoding": "deflate"}, + loop=loop, + ) + with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: + m_writer.return_value.write_headers = make_mocked_coro() + resp = await req.send(conn) assert not m_writer.return_value.enable_compression.called assert not m_writer.return_value.enable_chunking.called - yield from req.close() + await req.close() resp.close() -@asyncio.coroutine -def test_compress_and_content_encoding(loop, conn): +async def test_compress_and_content_encoding(loop, conn) -> None: with pytest.raises(ValueError): - ClientRequest('get', URL('http://python.org/'), data='foo', - headers={'content-encoding': 'deflate'}, - compress='deflate', loop=loop) + ClientRequest( + "post", + URL("http://python.org/"), + data="foo", + headers={"content-encoding": "deflate"}, + compress="deflate", + loop=loop, + ) -@asyncio.coroutine -def test_chunked(loop, conn): +async def test_chunked(loop, conn) -> None: req = ClientRequest( - 'get', URL('http://python.org/'), - headers={'TRANSFER-ENCODING': 'gzip'}, loop=loop) - resp = req.send(conn) - assert 'gzip' == req.headers['TRANSFER-ENCODING'] - yield from req.close() + "post", + URL("http://python.org/"), + headers={"TRANSFER-ENCODING": "gzip"}, + loop=loop, + ) + resp = await req.send(conn) + assert "gzip" == req.headers["TRANSFER-ENCODING"] + await req.close() resp.close() -@asyncio.coroutine -def test_chunked2(loop, conn): +async def test_chunked2(loop, conn) -> None: req = ClientRequest( - 'get', URL('http://python.org/'), - headers={'Transfer-encoding': 'chunked'}, loop=loop) - resp = req.send(conn) - assert 'chunked' == req.headers['TRANSFER-ENCODING'] - yield from req.close() + "post", + URL("http://python.org/"), + headers={"Transfer-encoding": "chunked"}, + loop=loop, + ) + resp = await req.send(conn) + assert "chunked" == req.headers["TRANSFER-ENCODING"] + await req.close() resp.close() -@asyncio.coroutine -def test_chunked_explicit(loop, conn): - req = ClientRequest( - 'get', URL('http://python.org/'), chunked=True, loop=loop) - with mock.patch('aiohttp.client_reqrep.PayloadWriter') as m_writer: - resp = req.send(conn) +async def test_chunked_explicit(loop, conn) -> None: + req = ClientRequest("post", URL("http://python.org/"), chunked=True, loop=loop) + with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: + m_writer.return_value.write_headers = make_mocked_coro() + resp = await req.send(conn) - assert 'chunked' == req.headers['TRANSFER-ENCODING'] + assert "chunked" == req.headers["TRANSFER-ENCODING"] m_writer.return_value.enable_chunking.assert_called_with() - yield from req.close() + await req.close() resp.close() -@asyncio.coroutine -def test_chunked_length(loop, conn): +async def test_chunked_length(loop, conn) -> None: with pytest.raises(ValueError): ClientRequest( - 'get', URL('http://python.org/'), - headers={'CONTENT-LENGTH': '1000'}, chunked=True, loop=loop) + "post", + URL("http://python.org/"), + headers={"CONTENT-LENGTH": "1000"}, + chunked=True, + loop=loop, + ) -@asyncio.coroutine -def test_chunked_transfer_encoding(loop, conn): +async def test_chunked_transfer_encoding(loop, conn) -> None: with pytest.raises(ValueError): ClientRequest( - 'get', URL('http://python.org/'), - headers={'TRANSFER-ENCODING': 'chunked'}, chunked=True, loop=loop) + "post", + URL("http://python.org/"), + headers={"TRANSFER-ENCODING": "chunked"}, + chunked=True, + loop=loop, + ) -@asyncio.coroutine -def test_file_upload_not_chunked(loop): +async def test_file_upload_not_chunked(loop) -> None: here = os.path.dirname(__file__) - fname = os.path.join(here, 'sample.key') - with open(fname, 'rb') as f: - req = ClientRequest( - 'post', URL('http://python.org/'), - data=f, - loop=loop) + fname = os.path.join(here, "aiohttp.png") + with open(fname, "rb") as f: + req = ClientRequest("post", URL("http://python.org/"), data=f, loop=loop) assert not req.chunked - assert req.headers['CONTENT-LENGTH'] == str(os.path.getsize(fname)) - yield from req.close() + assert req.headers["CONTENT-LENGTH"] == str(os.path.getsize(fname)) + await req.close() -@asyncio.coroutine -def test_precompressed_data_stays_intact(loop): - data = zlib.compress(b'foobar') +async def test_precompressed_data_stays_intact(loop) -> None: + data = zlib.compress(b"foobar") req = ClientRequest( - 'post', URL('http://python.org/'), + "post", + URL("http://python.org/"), data=data, - headers={'CONTENT-ENCODING': 'deflate'}, + headers={"CONTENT-ENCODING": "deflate"}, compress=False, - loop=loop) + loop=loop, + ) assert not req.compress assert not req.chunked - assert req.headers['CONTENT-ENCODING'] == 'deflate' - yield from req.close() + assert req.headers["CONTENT-ENCODING"] == "deflate" + await req.close() -@asyncio.coroutine -def test_file_upload_not_chunked_seek(loop): +async def test_file_upload_not_chunked_seek(loop) -> None: here = os.path.dirname(__file__) - fname = os.path.join(here, 'sample.key') - with open(fname, 'rb') as f: + fname = os.path.join(here, "aiohttp.png") + with open(fname, "rb") as f: f.seek(100) - req = ClientRequest( - 'post', URL('http://python.org/'), - data=f, - loop=loop) - assert req.headers['CONTENT-LENGTH'] == \ - str(os.path.getsize(fname) - 100) - yield from req.close() + req = ClientRequest("post", URL("http://python.org/"), data=f, loop=loop) + assert req.headers["CONTENT-LENGTH"] == str(os.path.getsize(fname) - 100) + await req.close() -@asyncio.coroutine -def test_file_upload_force_chunked(loop): +async def test_file_upload_force_chunked(loop) -> None: here = os.path.dirname(__file__) - fname = os.path.join(here, 'sample.key') - with open(fname, 'rb') as f: + fname = os.path.join(here, "aiohttp.png") + with open(fname, "rb") as f: req = ClientRequest( - 'post', URL('http://python.org/'), - data=f, - chunked=True, - loop=loop) + "post", URL("http://python.org/"), data=f, chunked=True, loop=loop + ) assert req.chunked - assert 'CONTENT-LENGTH' not in req.headers - yield from req.close() + assert "CONTENT-LENGTH" not in req.headers + await req.close() -def test_expect100(loop, conn): - req = ClientRequest('get', URL('http://python.org/'), - expect100=True, loop=loop) - resp = req.send(conn) - assert '100-continue' == req.headers['EXPECT'] +async def test_expect100(loop, conn) -> None: + req = ClientRequest("get", URL("http://python.org/"), expect100=True, loop=loop) + resp = await req.send(conn) + assert "100-continue" == req.headers["EXPECT"] assert req._continue is not None req.terminate() resp.close() -def test_expect_100_continue_header(loop, conn): - req = ClientRequest('get', URL('http://python.org/'), - headers={'expect': '100-continue'}, loop=loop) - resp = req.send(conn) - assert '100-continue' == req.headers['EXPECT'] +async def test_expect_100_continue_header(loop, conn) -> None: + req = ClientRequest( + "get", URL("http://python.org/"), headers={"expect": "100-continue"}, loop=loop + ) + resp = await req.send(conn) + assert "100-continue" == req.headers["EXPECT"] assert req._continue is not None req.terminate() resp.close() -@asyncio.coroutine -def test_data_stream(loop, buf, conn): - @aiohttp.streamer - def gen(writer): - writer.write(b'binary data') - writer.write(b' result') +async def test_data_stream(loop, buf, conn) -> None: + @async_generator + async def gen(): + await yield_(b"binary data") + await yield_(b" result") - req = ClientRequest( - 'POST', URL('http://python.org/'), data=gen(), loop=loop) + req = ClientRequest("POST", URL("http://python.org/"), data=gen(), loop=loop) assert req.chunked - assert req.headers['TRANSFER-ENCODING'] == 'chunked' + assert req.headers["TRANSFER-ENCODING"] == "chunked" - resp = req.send(conn) - assert isinstance(req._writer, asyncio.Future) - yield from resp.wait_for_close() + resp = await req.send(conn) + assert asyncio.isfuture(req._writer) + await resp.wait_for_close() assert req._writer is None + assert ( + buf.split(b"\r\n\r\n", 1)[1] == b"b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n" + ) + await req.close() + + +async def test_data_stream_deprecated(loop, buf, conn) -> None: + with pytest.warns(DeprecationWarning): - assert buf.split(b'\r\n\r\n', 1)[1] == \ - b'b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n' - yield from req.close() + @aiohttp.streamer + async def gen(writer): + await writer.write(b"binary data") + await writer.write(b" result") + req = ClientRequest("POST", URL("http://python.org/"), data=gen(), loop=loop) + assert req.chunked + assert req.headers["TRANSFER-ENCODING"] == "chunked" + + resp = await req.send(conn) + assert asyncio.isfuture(req._writer) + await resp.wait_for_close() + assert req._writer is None + assert ( + buf.split(b"\r\n\r\n", 1)[1] == b"b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n" + ) + await req.close() -@asyncio.coroutine -def test_data_file(loop, buf, conn): + +async def test_data_file(loop, buf, conn) -> None: req = ClientRequest( - 'POST', URL('http://python.org/'), - data=io.BufferedReader(io.BytesIO(b'*' * 2)), - loop=loop) + "POST", + URL("http://python.org/"), + data=io.BufferedReader(io.BytesIO(b"*" * 2)), + loop=loop, + ) assert req.chunked assert isinstance(req.body, payload.BufferedReaderPayload) - assert req.headers['TRANSFER-ENCODING'] == 'chunked' + assert req.headers["TRANSFER-ENCODING"] == "chunked" - resp = req.send(conn) - assert isinstance(req._writer, asyncio.Future) - yield from resp.wait_for_close() + resp = await req.send(conn) + assert asyncio.isfuture(req._writer) + await resp.wait_for_close() assert req._writer is None - assert buf.split(b'\r\n\r\n', 1)[1] == \ - b'2\r\n' + b'*' * 2 + b'\r\n0\r\n\r\n' - yield from req.close() + assert buf.split(b"\r\n\r\n", 1)[1] == b"2\r\n" + b"*" * 2 + b"\r\n0\r\n\r\n" + await req.close() -@asyncio.coroutine -def test_data_stream_exc(loop, conn): - fut = helpers.create_future(loop) +async def test_data_stream_exc(loop, conn) -> None: + fut = loop.create_future() - @aiohttp.streamer - def gen(writer): - writer.write(b'binary data') - yield from fut + @async_generator + async def gen(): + await yield_(b"binary data") + await fut - req = ClientRequest( - 'POST', URL('http://python.org/'), data=gen(), loop=loop) + req = ClientRequest("POST", URL("http://python.org/"), data=gen(), loop=loop) assert req.chunked - assert req.headers['TRANSFER-ENCODING'] == 'chunked' + assert req.headers["TRANSFER-ENCODING"] == "chunked" - @asyncio.coroutine - def exc(): - yield from asyncio.sleep(0.01, loop=loop) + async def throw_exc(): + await asyncio.sleep(0.01) fut.set_exception(ValueError) - helpers.ensure_future(exc(), loop=loop) + loop.create_task(throw_exc()) - req.send(conn) - yield from req._writer + await req.send(conn) + await req._writer # assert conn.close.called assert conn.protocol.set_exception.called - yield from req.close() + await req.close() -@asyncio.coroutine -def test_data_stream_exc_chain(loop, conn): - fut = helpers.create_future(loop) +async def test_data_stream_exc_chain(loop, conn) -> None: + fut = loop.create_future() - @aiohttp.streamer - def gen(writer): - yield from fut + @async_generator + async def gen(): + await fut - req = ClientRequest('POST', URL('http://python.org/'), - data=gen(), loop=loop) + req = ClientRequest("POST", URL("http://python.org/"), data=gen(), loop=loop) inner_exc = ValueError() - @asyncio.coroutine - def exc(): - yield from asyncio.sleep(0.01, loop=loop) + async def throw_exc(): + await asyncio.sleep(0.01) fut.set_exception(inner_exc) - helpers.ensure_future(exc(), loop=loop) + loop.create_task(throw_exc()) - req.send(conn) - yield from req._writer + await req.send(conn) + await req._writer # assert connection.close.called assert conn.protocol.set_exception.called outer_exc = conn.protocol.set_exception.call_args[0][0] assert isinstance(outer_exc, ValueError) assert inner_exc is outer_exc assert inner_exc is outer_exc - yield from req.close() + await req.close() -@asyncio.coroutine -def test_data_stream_continue(loop, buf, conn): - @aiohttp.streamer - def gen(writer): - writer.write(b'binary data') - writer.write(b' result') - yield from writer.write_eof() +async def test_data_stream_continue(loop, buf, conn) -> None: + @async_generator + async def gen(): + await yield_(b"binary data") + await yield_(b" result") req = ClientRequest( - 'POST', URL('http://python.org/'), data=gen(), - expect100=True, loop=loop) + "POST", URL("http://python.org/"), data=gen(), expect100=True, loop=loop + ) assert req.chunked - def coro(): - yield from asyncio.sleep(0.0001, loop=loop) + async def coro(): + await asyncio.sleep(0.0001) req._continue.set_result(1) - helpers.ensure_future(coro(), loop=loop) + loop.create_task(coro()) - resp = req.send(conn) - yield from req._writer - assert buf.split(b'\r\n\r\n', 1)[1] == \ - b'b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n' - yield from req.close() + resp = await req.send(conn) + await req._writer + assert ( + buf.split(b"\r\n\r\n", 1)[1] == b"b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n" + ) + await req.close() resp.close() -@asyncio.coroutine -def test_data_continue(loop, buf, conn): +async def test_data_continue(loop, buf, conn) -> None: req = ClientRequest( - 'POST', URL('http://python.org/'), data=b'data', - expect100=True, loop=loop) + "POST", URL("http://python.org/"), data=b"data", expect100=True, loop=loop + ) - def coro(): - yield from asyncio.sleep(0.0001, loop=loop) + async def coro(): + await asyncio.sleep(0.0001) req._continue.set_result(1) - helpers.ensure_future(coro(), loop=loop) + loop.create_task(coro()) - resp = req.send(conn) + resp = await req.send(conn) - yield from req._writer - assert buf.split(b'\r\n\r\n', 1)[1] == b'data' - yield from req.close() + await req._writer + assert buf.split(b"\r\n\r\n", 1)[1] == b"data" + await req.close() resp.close() -@asyncio.coroutine -def test_close(loop, buf, conn): - @aiohttp.streamer - def gen(writer): - yield from asyncio.sleep(0.00001, loop=loop) - writer.write(b'result') +async def test_close(loop, buf, conn) -> None: + @async_generator + async def gen(): + await asyncio.sleep(0.00001) + await yield_(b"result") - req = ClientRequest( - 'POST', URL('http://python.org/'), data=gen(), loop=loop) - resp = req.send(conn) - yield from req.close() - assert buf.split(b'\r\n\r\n', 1)[1] == b'6\r\nresult\r\n0\r\n\r\n' - yield from req.close() + req = ClientRequest("POST", URL("http://python.org/"), data=gen(), loop=loop) + resp = await req.send(conn) + await req.close() + assert buf.split(b"\r\n\r\n", 1)[1] == b"6\r\nresult\r\n0\r\n\r\n" + await req.close() resp.close() -@asyncio.coroutine -def test_custom_response_class(loop, conn): +async def test_custom_response_class(loop, conn) -> None: class CustomResponse(ClientResponse): def read(self, decode=False): - return 'customized!' + return "customized!" req = ClientRequest( - 'GET', URL('http://python.org/'), response_class=CustomResponse, - loop=loop) - resp = req.send(conn) - assert 'customized!' == resp.read() - yield from req.close() + "GET", URL("http://python.org/"), response_class=CustomResponse, loop=loop + ) + resp = await req.send(conn) + assert "customized!" == resp.read() + await req.close() resp.close() -@asyncio.coroutine -def test_terminate(loop, conn): - req = ClientRequest('get', URL('http://python.org'), loop=loop) - resp = req.send(conn) +async def test_oserror_on_write_bytes(loop, conn) -> None: + req = ClientRequest("POST", URL("http://python.org/"), loop=loop) + + writer = mock.Mock() + writer.write.side_effect = OSError + + await req.write_bytes(writer, conn) + + assert conn.protocol.set_exception.called + exc = conn.protocol.set_exception.call_args[0][0] + assert isinstance(exc, aiohttp.ClientOSError) + + +async def test_terminate(loop, conn) -> None: + req = ClientRequest("get", URL("http://python.org"), loop=loop) + resp = await req.send(conn) assert req._writer is not None writer = req._writer = mock.Mock() @@ -1033,11 +1138,19 @@ def test_terminate(loop, conn): resp.close() -def test_terminate_with_closed_loop(loop, conn): - req = ClientRequest('get', URL('http://python.org'), loop=loop) - resp = req.send(conn) - assert req._writer is not None - writer = req._writer = mock.Mock() +def test_terminate_with_closed_loop(loop, conn) -> None: + req = resp = writer = None + + async def go(): + nonlocal req, resp, writer + req = ClientRequest("get", URL("http://python.org")) + resp = await req.send(conn) + assert req._writer is not None + writer = req._writer = mock.Mock() + + await asyncio.sleep(0.05) + + loop.run_until_complete(go()) loop.close() req.terminate() @@ -1046,48 +1159,51 @@ def test_terminate_with_closed_loop(loop, conn): resp.close() -def test_terminate_without_writer(loop): - req = ClientRequest('get', URL('http://python.org'), loop=loop) +def test_terminate_without_writer(loop) -> None: + req = ClientRequest("get", URL("http://python.org"), loop=loop) assert req._writer is None req.terminate() assert req._writer is None -@asyncio.coroutine -def test_custom_req_rep(loop): +async def test_custom_req_rep(loop) -> None: conn = None class CustomResponse(ClientResponse): - @asyncio.coroutine - def start(self, connection, read_until_eof=False): + async def start(self, connection, read_until_eof=False): nonlocal conn conn = connection self.status = 123 - self.reason = 'Test OK' - self.headers = CIMultiDictProxy(CIMultiDict()) + self.reason = "Test OK" + self._headers = CIMultiDictProxy(CIMultiDict()) self.cookies = SimpleCookie() return called = False class CustomRequest(ClientRequest): - - def send(self, conn): - resp = self.response_class(self.method, - self.url, - writer=self._writer, - continue100=self._continue) - resp._post_init(self.loop) + async def send(self, conn): + resp = self.response_class( + self.method, + self.url, + writer=self._writer, + continue100=self._continue, + timer=self._timer, + request_info=self.request_info, + traces=self._traces, + loop=self.loop, + session=self._session, + ) self.response = resp nonlocal called called = True return resp - @asyncio.coroutine - def create_connection(req): + async def create_connection(req, traces, timeout): assert isinstance(req, CustomRequest) return mock.Mock() + connector = BaseConnector(loop=loop) connector._create_connection = create_connection @@ -1095,12 +1211,56 @@ def create_connection(req): request_class=CustomRequest, response_class=CustomResponse, connector=connector, - loop=loop) + loop=loop, + ) - resp = yield from session.request( - 'get', URL('http://example.com/path/to')) + resp = await session.request("get", URL("http://example.com/path/to")) assert isinstance(resp, CustomResponse) assert called resp.close() - session.close() + await session.close() conn.close() + + +def test_verify_ssl_false_with_ssl_context(loop, ssl_ctx) -> None: + with pytest.warns(DeprecationWarning): + with pytest.raises(ValueError): + _merge_ssl_params( + None, verify_ssl=False, ssl_context=ssl_ctx, fingerprint=None + ) + + +def test_bad_fingerprint(loop) -> None: + with pytest.raises(ValueError): + Fingerprint(b"invalid") + + +def test_insecure_fingerprint_md5(loop) -> None: + with pytest.raises(ValueError): + Fingerprint(hashlib.md5(b"foo").digest()) + + +def test_insecure_fingerprint_sha1(loop) -> None: + with pytest.raises(ValueError): + Fingerprint(hashlib.sha1(b"foo").digest()) + + +def test_loose_cookies_types(loop) -> None: + req = ClientRequest("get", URL("http://python.org"), loop=loop) + morsel = Morsel() + morsel.set(key="string", val="Another string", coded_val="really") + + accepted_types = [ + [("str", BaseCookie())], + [("str", morsel)], + [ + ("str", "str"), + ], + {"str": BaseCookie()}, + {"str": morsel}, + {"str": "str"}, + SimpleCookie(), + ] + + for loose_cookies_type in accepted_types: + req.update_cookies(cookies=loose_cookies_type) diff --git a/tests/test_client_response.py b/tests/test_client_response.py index 2a8df75918e..55aae970861 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -1,22 +1,66 @@ -# -*- coding: utf-8 -*- -"""Tests for aiohttp/client.py""" +# Tests for aiohttp/client.py -import asyncio import gc +import sys from unittest import mock import pytest +from multidict import CIMultiDict from yarl import URL import aiohttp -from aiohttp import helpers -from aiohttp.client_reqrep import ClientResponse +from aiohttp import http +from aiohttp.client_reqrep import ClientResponse, RequestInfo +from aiohttp.helpers import TimerNoop +from aiohttp.test_utils import make_mocked_coro -def test_del(): +@pytest.fixture +def session(): + return mock.Mock() + + +async def test_http_processing_error(session) -> None: loop = mock.Mock() - response = ClientResponse('get', URL('http://del-cl-resp.org')) - response._post_init(loop) + request_info = mock.Mock() + response = ClientResponse( + "get", + URL("http://del-cl-resp.org"), + request_info=request_info, + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + loop.get_debug = mock.Mock() + loop.get_debug.return_value = True + + connection = mock.Mock() + connection.protocol = aiohttp.DataQueue(loop) + connection.protocol.set_response_params = mock.Mock() + connection.protocol.set_exception(http.HttpProcessingError()) + + with pytest.raises(aiohttp.ClientResponseError) as info: + await response.start(connection) + + assert info.value.request_info is request_info + + +def test_del(session) -> None: + loop = mock.Mock() + response = ClientResponse( + "get", + URL("http://del-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) loop.get_debug = mock.Mock() loop.get_debug.return_value = True @@ -32,9 +76,18 @@ def test_del(): connection.release.assert_called_with() -def test_close(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +def test_close(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) response._closed = False response._connection = mock.Mock() response.close() @@ -43,86 +96,165 @@ def test_close(loop): response.close() -def test_wait_for_100_1(loop): +def test_wait_for_100_1(loop, session) -> None: response = ClientResponse( - 'get', URL('http://python.org'), continue100=object()) - response._post_init(loop) + "get", + URL("http://python.org"), + continue100=object(), + request_info=mock.Mock(), + writer=mock.Mock(), + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) assert response._continue is not None response.close() -def test_wait_for_100_2(loop): +def test_wait_for_100_2(loop, session) -> None: response = ClientResponse( - 'get', URL('http://python.org')) - response._post_init(loop) + "get", + URL("http://python.org"), + request_info=mock.Mock(), + continue100=None, + writer=mock.Mock(), + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) assert response._continue is None response.close() -def test_repr(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +def test_repr(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) response.status = 200 - response.reason = 'Ok' - assert ''\ - in repr(response) - - -def test_repr_non_ascii_url(): - response = ClientResponse('get', URL('http://fake-host.org/\u03bb')) - assert ""\ - in repr(response) + response.reason = "Ok" + assert "" in repr(response) -def test_repr_non_ascii_reason(): - response = ClientResponse('get', URL('http://fake-host.org/path')) - response.reason = '\u03bb' - assert ""\ - in repr(response) - - -def test_url_obj_deprecated(): - response = ClientResponse('get', URL('http://fake-host.org/')) +def test_repr_non_ascii_url() -> None: + response = ClientResponse( + "get", + URL("http://fake-host.org/\u03bb"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + assert "" in repr(response) + + +def test_repr_non_ascii_reason() -> None: + response = ClientResponse( + "get", + URL("http://fake-host.org/path"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response.reason = "\u03bb" + assert "" in repr( + response + ) + + +def test_url_obj_deprecated() -> None: + response = ClientResponse( + "get", + URL("http://fake-host.org/"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) with pytest.warns(DeprecationWarning): response.url_obj -@asyncio.coroutine -def test_read_and_release_connection(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_read_and_release_connection(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) def side_effect(*args, **kwargs): - fut = helpers.create_future(loop) - fut.set_result(b'payload') + fut = loop.create_future() + fut.set_result(b"payload") return fut + content = response.content = mock.Mock() content.read.side_effect = side_effect - res = yield from response.read() - assert res == b'payload' + res = await response.read() + assert res == b"payload" assert response._connection is None -@asyncio.coroutine -def test_read_and_release_connection_with_error(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_read_and_release_connection_with_error(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) content = response.content = mock.Mock() - content.read.return_value = helpers.create_future(loop) + content.read.return_value = loop.create_future() content.read.return_value.set_exception(ValueError) with pytest.raises(ValueError): - yield from response.read() + await response.read() assert response._closed -@asyncio.coroutine -def test_release(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) - fut = helpers.create_future(loop) - fut.set_result(b'') +async def test_release(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + fut = loop.create_future() + fut.set_result(b"") content = response.content = mock.Mock() content.readany.return_value = fut @@ -130,14 +262,26 @@ def test_release(loop): assert response._connection is None -@asyncio.coroutine -def test_release_on_del(loop): +@pytest.mark.skipif( + sys.implementation.name != "cpython", + reason="Other implementations has different GC strategies", +) +async def test_release_on_del(loop, session) -> None: connection = mock.Mock() connection.protocol.upgraded = False def run(conn): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) response._closed = False response._connection = conn @@ -146,10 +290,18 @@ def run(conn): assert connection.release.called -@asyncio.coroutine -def test_response_eof(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_response_eof(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) response._closed = False conn = response._connection = mock.Mock() conn.protocol.upgraded = False @@ -159,10 +311,18 @@ def test_response_eof(loop): assert response._connection is None -@asyncio.coroutine -def test_response_eof_upgraded(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_response_eof_upgraded(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) conn = response._connection = mock.Mock() conn.protocol.upgraded = True @@ -172,10 +332,18 @@ def test_response_eof_upgraded(loop): assert response._connection is conn -@asyncio.coroutine -def test_response_eof_after_connection_detach(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_response_eof_after_connection_detach(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) response._closed = False conn = response._connection = mock.Mock() conn.protocol = None @@ -185,256 +353,906 @@ def test_response_eof_after_connection_detach(loop): assert response._connection is None -@asyncio.coroutine -def test_text(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_text(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) def side_effect(*args, **kwargs): - fut = helpers.create_future(loop) - fut.set_result('{"тест": "пройден"}'.encode('cp1251')) + fut = loop.create_future() + fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut - response.headers = { - 'Content-Type': 'application/json;charset=cp1251'} + response._headers = {"Content-Type": "application/json;charset=cp1251"} content = response.content = mock.Mock() content.read.side_effect = side_effect - res = yield from response.text() + res = await response.text() assert res == '{"тест": "пройден"}' assert response._connection is None -@asyncio.coroutine -def test_text_bad_encoding(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_text_bad_encoding(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) def side_effect(*args, **kwargs): - fut = helpers.create_future(loop) - fut.set_result('{"тестkey": "пройденvalue"}'.encode('cp1251')) + fut = loop.create_future() + fut.set_result('{"тестkey": "пройденvalue"}'.encode("cp1251")) return fut # lie about the encoding - response.headers = { - 'Content-Type': 'application/json;charset=utf-8'} + response._headers = {"Content-Type": "application/json;charset=utf-8"} content = response.content = mock.Mock() content.read.side_effect = side_effect with pytest.raises(UnicodeDecodeError): - yield from response.text() + await response.text() # only the valid utf-8 characters will be returned - res = yield from response.text(errors='ignore') + res = await response.text(errors="ignore") assert res == '{"key": "value"}' assert response._connection is None -@asyncio.coroutine -def test_text_custom_encoding(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_text_custom_encoding(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) def side_effect(*args, **kwargs): - fut = helpers.create_future(loop) - fut.set_result('{"тест": "пройден"}'.encode('cp1251')) + fut = loop.create_future() + fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut - response.headers = { - 'Content-Type': 'application/json'} + response._headers = {"Content-Type": "application/json"} content = response.content = mock.Mock() content.read.side_effect = side_effect - response._get_encoding = mock.Mock() + response.get_encoding = mock.Mock() - res = yield from response.text(encoding='cp1251') + res = await response.text(encoding="cp1251") assert res == '{"тест": "пройден"}' assert response._connection is None - assert not response._get_encoding.called + assert not response.get_encoding.called -@asyncio.coroutine -def test_text_detect_encoding(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_text_detect_encoding(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) def side_effect(*args, **kwargs): - fut = helpers.create_future(loop) - fut.set_result('{"тест": "пройден"}'.encode('cp1251')) + fut = loop.create_future() + fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut - response.headers = {'Content-Type': 'text/plain'} + response._headers = {"Content-Type": "text/plain"} content = response.content = mock.Mock() content.read.side_effect = side_effect - yield from response.read() - res = yield from response.text() + await response.read() + res = await response.text() assert res == '{"тест": "пройден"}' assert response._connection is None -@asyncio.coroutine -def test_text_after_read(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_text_detect_encoding_if_invalid_charset(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) def side_effect(*args, **kwargs): - fut = helpers.create_future(loop) - fut.set_result('{"тест": "пройден"}'.encode('cp1251')) + fut = loop.create_future() + fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut - response.headers = { - 'Content-Type': 'application/json;charset=cp1251'} + response._headers = {"Content-Type": "text/plain;charset=invalid"} content = response.content = mock.Mock() content.read.side_effect = side_effect - res = yield from response.text() + await response.read() + res = await response.text() assert res == '{"тест": "пройден"}' assert response._connection is None + assert response.get_encoding().lower() in ("windows-1251", "maccyrillic") -@asyncio.coroutine -def test_json(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_get_encoding_body_none(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) def side_effect(*args, **kwargs): - fut = helpers.create_future(loop) - fut.set_result('{"тест": "пройден"}'.encode('cp1251')) + fut = loop.create_future() + fut.set_result('{"encoding": "test"}') return fut - response.headers = { - 'Content-Type': 'application/json;charset=cp1251'} + response._headers = {"Content-Type": "text/html"} content = response.content = mock.Mock() content.read.side_effect = side_effect - res = yield from response.json() - assert res == {'тест': 'пройден'} + with pytest.raises( + RuntimeError, + match="^Cannot guess the encoding of a not yet read body$", + ): + response.get_encoding() + assert response.closed + + +async def test_text_after_read(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + + def side_effect(*args, **kwargs): + fut = loop.create_future() + fut.set_result('{"тест": "пройден"}'.encode("cp1251")) + return fut + + response._headers = {"Content-Type": "application/json;charset=cp1251"} + content = response.content = mock.Mock() + content.read.side_effect = side_effect + + res = await response.text() + assert res == '{"тест": "пройден"}' assert response._connection is None -@asyncio.coroutine -def test_json_custom_loader(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) - response.headers = { - 'Content-Type': 'application/json;charset=cp1251'} - response._content = b'data' +async def test_json(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) - def custom(content): - return content + '-custom' + def side_effect(*args, **kwargs): + fut = loop.create_future() + fut.set_result('{"тест": "пройден"}'.encode("cp1251")) + return fut - res = yield from response.json(loads=custom) - assert res == 'data-custom' + response._headers = {"Content-Type": "application/json;charset=cp1251"} + content = response.content = mock.Mock() + content.read.side_effect = side_effect + res = await response.json() + assert res == {"тест": "пройден"} + assert response._connection is None -@asyncio.coroutine -def test_json_no_content(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) - response.headers = { - 'Content-Type': 'data/octet-stream'} - response._content = b'' - with pytest.raises(aiohttp.ClientResponseError): - yield from response.json() +async def test_json_extended_content_type(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) - res = yield from response.json(content_type=None) - assert res is None + def side_effect(*args, **kwargs): + fut = loop.create_future() + fut.set_result('{"тест": "пройден"}'.encode("cp1251")) + return fut + + response._headers = { + "Content-Type": "application/this.is-1_content+subtype+json;charset=cp1251" + } + content = response.content = mock.Mock() + content.read.side_effect = side_effect + + res = await response.json() + assert res == {"тест": "пройден"} + assert response._connection is None -@asyncio.coroutine -def test_json_override_encoding(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) +async def test_json_custom_content_type(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) def side_effect(*args, **kwargs): - fut = helpers.create_future(loop) - fut.set_result('{"тест": "пройден"}'.encode('cp1251')) + fut = loop.create_future() + fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut - response.headers = { - 'Content-Type': 'application/json;charset=utf8'} + response._headers = {"Content-Type": "custom/type;charset=cp1251"} content = response.content = mock.Mock() content.read.side_effect = side_effect - response._get_encoding = mock.Mock() - res = yield from response.json(encoding='cp1251') - assert res == {'тест': 'пройден'} + res = await response.json(content_type="custom/type") + assert res == {"тест": "пройден"} assert response._connection is None - assert not response._get_encoding.called -@pytest.mark.xfail -def test_override_flow_control(loop): - class MyResponse(ClientResponse): - flow_control_class = aiohttp.StreamReader - response = MyResponse('get', URL('http://my-cl-resp.org')) - response._post_init(loop) - response._connection = mock.Mock() - assert isinstance(response.content, aiohttp.StreamReader) - response.close() +async def test_json_custom_loader(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + response._headers = {"Content-Type": "application/json;charset=cp1251"} + response._body = b"data" + def custom(content): + return content + "-custom" -def test_get_encoding_unknown(loop): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response._post_init(loop) + res = await response.json(loads=custom) + assert res == "data-custom" - response.headers = {'Content-Type': 'application/json'} - with mock.patch('aiohttp.client_reqrep.chardet') as m_chardet: - m_chardet.detect.return_value = {'encoding': None} - assert response._get_encoding() == 'utf-8' +async def test_json_invalid_content_type(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + response._headers = {"Content-Type": "data/octet-stream"} + response._body = b"" + + with pytest.raises(aiohttp.ContentTypeError) as info: + await response.json() + + assert info.value.request_info == response.request_info + + +async def test_json_no_content(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + response._headers = {"Content-Type": "data/octet-stream"} + response._body = b"" + + res = await response.json(content_type=None) + assert res is None + + +async def test_json_override_encoding(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + + def side_effect(*args, **kwargs): + fut = loop.create_future() + fut.set_result('{"тест": "пройден"}'.encode("cp1251")) + return fut -def test_raise_for_status_2xx(): - response = ClientResponse('get', URL('http://def-cl-resp.org')) + response._headers = {"Content-Type": "application/json;charset=utf8"} + content = response.content = mock.Mock() + content.read.side_effect = side_effect + response.get_encoding = mock.Mock() + + res = await response.json(encoding="cp1251") + assert res == {"тест": "пройден"} + assert response._connection is None + assert not response.get_encoding.called + + +def test_get_encoding_unknown(loop, session) -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + + response._headers = {"Content-Type": "application/json"} + with mock.patch("aiohttp.client_reqrep.chardet") as m_chardet: + m_chardet.detect.return_value = {"encoding": None} + assert response.get_encoding() == "utf-8" + + +def test_raise_for_status_2xx() -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) response.status = 200 - response.reason = 'OK' + response.reason = "OK" response.raise_for_status() # should not raise -def test_raise_for_status_4xx(): - response = ClientResponse('get', URL('http://def-cl-resp.org')) +def test_raise_for_status_4xx() -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) response.status = 409 - response.reason = 'CONFLICT' + response.reason = "CONFLICT" with pytest.raises(aiohttp.ClientResponseError) as cm: response.raise_for_status() - assert str(cm.value.code) == '409' + assert str(cm.value.status) == "409" assert str(cm.value.message) == "CONFLICT" + assert response.closed + + +def test_raise_for_status_4xx_without_reason() -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response.status = 404 + response.reason = "" + with pytest.raises(aiohttp.ClientResponseError) as cm: + response.raise_for_status() + assert str(cm.value.status) == "404" + assert str(cm.value.message) == "" + assert response.closed + + +def test_resp_host() -> None: + response = ClientResponse( + "get", + URL("http://del-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + assert "del-cl-resp.org" == response.host + + +def test_content_type() -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response._headers = {"Content-Type": "application/json;charset=cp1251"} + + assert "application/json" == response.content_type + + +def test_content_type_no_header() -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response._headers = {} + + assert "application/octet-stream" == response.content_type + + +def test_charset() -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response._headers = {"Content-Type": "application/json;charset=cp1251"} + + assert "cp1251" == response.charset + + +def test_charset_no_header() -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response._headers = {} + + assert response.charset is None + +def test_charset_no_charset() -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response._headers = {"Content-Type": "application/json"} -def test_resp_host(): - response = ClientResponse('get', URL('http://del-cl-resp.org')) - assert 'del-cl-resp.org' == response.host + assert response.charset is None -def test_content_type(): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response.headers = {'Content-Type': 'application/json;charset=cp1251'} +def test_content_disposition_full() -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response._headers = { + "Content-Disposition": 'attachment; filename="archive.tar.gz"; foo=bar' + } + + assert "attachment" == response.content_disposition.type + assert "bar" == response.content_disposition.parameters["foo"] + assert "archive.tar.gz" == response.content_disposition.filename + with pytest.raises(TypeError): + response.content_disposition.parameters["foo"] = "baz" + + +def test_content_disposition_no_parameters() -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response._headers = {"Content-Disposition": "attachment"} + + assert "attachment" == response.content_disposition.type + assert response.content_disposition.filename is None + assert {} == response.content_disposition.parameters + + +def test_content_disposition_no_header() -> None: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response._headers = {} + + assert response.content_disposition is None + + +def test_response_request_info() -> None: + url = "http://def-cl-resp.org" + headers = {"Content-Type": "application/json;charset=cp1251"} + response = ClientResponse( + "get", + URL(url), + request_info=RequestInfo(url, "get", headers), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + assert url == response.request_info.url + assert "get" == response.request_info.method + assert headers == response.request_info.headers + + +def test_request_info_in_exception() -> None: + url = "http://def-cl-resp.org" + headers = {"Content-Type": "application/json;charset=cp1251"} + response = ClientResponse( + "get", + URL(url), + request_info=RequestInfo(url, "get", headers), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response.status = 409 + response.reason = "CONFLICT" + with pytest.raises(aiohttp.ClientResponseError) as cm: + response.raise_for_status() + assert cm.value.request_info == response.request_info - assert 'application/json' == response.content_type +def test_no_redirect_history_in_exception() -> None: + url = "http://def-cl-resp.org" + headers = {"Content-Type": "application/json;charset=cp1251"} + response = ClientResponse( + "get", + URL(url), + request_info=RequestInfo(url, "get", headers), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response.status = 409 + response.reason = "CONFLICT" + with pytest.raises(aiohttp.ClientResponseError) as cm: + response.raise_for_status() + assert () == cm.value.history -def test_content_type_no_header(): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response.headers = {} - assert 'application/octet-stream' == response.content_type +def test_redirect_history_in_exception() -> None: + hist_url = "http://def-cl-resp.org" + url = "http://def-cl-resp.org/index.htm" + hist_headers = {"Content-Type": "application/json;charset=cp1251", "Location": url} + headers = {"Content-Type": "application/json;charset=cp1251"} + response = ClientResponse( + "get", + URL(url), + request_info=RequestInfo(url, "get", headers), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + response.status = 409 + response.reason = "CONFLICT" + + hist_response = ClientResponse( + "get", + URL(hist_url), + request_info=RequestInfo(url, "get", headers), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + ) + + hist_response._headers = hist_headers + hist_response.status = 301 + hist_response.reason = "REDIRECT" + + response._history = [hist_response] + with pytest.raises(aiohttp.ClientResponseError) as cm: + response.raise_for_status() + assert [hist_response] == cm.value.history -def test_charset(): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response.headers = {'Content-Type': 'application/json;charset=cp1251'} +async def test_response_read_triggers_callback(loop, session) -> None: + trace = mock.Mock() + trace.send_response_chunk_received = make_mocked_coro() + response_method = "get" + response_url = URL("http://def-cl-resp.org") + response_body = b"This is response" - assert 'cp1251' == response.charset + response = ClientResponse( + response_method, + response_url, + request_info=mock.Mock, + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + loop=loop, + session=session, + traces=[trace], + ) + def side_effect(*args, **kwargs): + fut = loop.create_future() + fut.set_result(response_body) + return fut -def test_charset_no_header(): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response.headers = {} + response._headers = {"Content-Type": "application/json;charset=cp1251"} + content = response.content = mock.Mock() + content.read.side_effect = side_effect - assert response.charset is None + res = await response.read() + assert res == response_body + assert response._connection is None + assert trace.send_response_chunk_received.called + assert trace.send_response_chunk_received.call_args == mock.call( + response_method, response_url, response_body + ) -def test_charset_no_charset(): - response = ClientResponse('get', URL('http://def-cl-resp.org')) - response.headers = {'Content-Type': 'application/json'} - assert response.charset is None +def test_response_real_url(loop, session) -> None: + url = URL("http://def-cl-resp.org/#urlfragment") + response = ClientResponse( + "get", + url, + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + assert response.url == url.with_fragment(None) + assert response.real_url == url + + +def test_response_links_comma_separated(loop, session) -> None: + url = URL("http://def-cl-resp.org/") + response = ClientResponse( + "get", + url, + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + response._headers = CIMultiDict( + [ + ( + "Link", + ( + "; rel=next, " + "; rel=home" + ), + ) + ] + ) + assert response.links == { + "next": {"url": URL("http://example.com/page/1.html"), "rel": "next"}, + "home": {"url": URL("http://example.com/"), "rel": "home"}, + } + + +def test_response_links_multiple_headers(loop, session) -> None: + url = URL("http://def-cl-resp.org/") + response = ClientResponse( + "get", + url, + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + response._headers = CIMultiDict( + [ + ("Link", "; rel=next"), + ("Link", "; rel=home"), + ] + ) + assert response.links == { + "next": {"url": URL("http://example.com/page/1.html"), "rel": "next"}, + "home": {"url": URL("http://example.com/"), "rel": "home"}, + } + + +def test_response_links_no_rel(loop, session) -> None: + url = URL("http://def-cl-resp.org/") + response = ClientResponse( + "get", + url, + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + response._headers = CIMultiDict([("Link", "")]) + assert response.links == { + "http://example.com/": {"url": URL("http://example.com/")} + } + + +def test_response_links_quoted(loop, session) -> None: + url = URL("http://def-cl-resp.org/") + response = ClientResponse( + "get", + url, + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + response._headers = CIMultiDict( + [ + ("Link", '; rel="home-page"'), + ] + ) + assert response.links == { + "home-page": {"url": URL("http://example.com/"), "rel": "home-page"} + } + + +def test_response_links_relative(loop, session) -> None: + url = URL("http://def-cl-resp.org/") + response = ClientResponse( + "get", + url, + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + response._headers = CIMultiDict( + [ + ("Link", "; rel=rel"), + ] + ) + assert response.links == { + "rel": {"url": URL("http://def-cl-resp.org/relative/path"), "rel": "rel"} + } + + +def test_response_links_empty(loop, session) -> None: + url = URL("http://def-cl-resp.org/") + response = ClientResponse( + "get", + url, + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=session, + ) + response._headers = CIMultiDict() + assert response.links == {} diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 28846cad3a9..298dac9f274 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -1,45 +1,55 @@ import asyncio import contextlib import gc -import re -import types +import json +import sys +from http.cookies import SimpleCookie +from io import BytesIO from unittest import mock import pytest from multidict import CIMultiDict, MultiDict +from re_assert import Matches from yarl import URL import aiohttp -from aiohttp import web +from aiohttp import client, hdrs, web from aiohttp.client import ClientSession +from aiohttp.client_reqrep import ClientRequest from aiohttp.connector import BaseConnector, TCPConnector -from aiohttp.helpers import SimpleCookie +from aiohttp.helpers import DEBUG, PY_36 +from aiohttp.test_utils import make_mocked_coro @pytest.fixture def connector(loop): - conn = BaseConnector(loop=loop) + async def make_conn(): + return BaseConnector(loop=loop) + + conn = loop.run_until_complete(make_conn()) proto = mock.Mock() - conn._conns['a'] = [(proto, 123)] - return conn + conn._conns["a"] = [(proto, 123)] + yield conn + conn.close() -@pytest.yield_fixture +@pytest.fixture def create_session(loop): session = None - def maker(*args, **kwargs): + async def maker(*args, **kwargs): nonlocal session session = ClientSession(*args, loop=loop, **kwargs) return session + yield maker if session is not None: - session.close() + loop.run_until_complete(session.close()) @pytest.fixture -def session(create_session): - return create_session() +def session(create_session, loop): + return loop.run_until_complete(create_session()) @pytest.fixture @@ -52,256 +62,272 @@ def params(): compress="deflate", chunked=True, expect100=True, - read_until_eof=False) - + read_until_eof=False, + ) -@asyncio.coroutine -def test_close_deprecated(create_session): - session = create_session() - with pytest.warns(DeprecationWarning): - yield from session.close() +async def test_close_coro(create_session) -> None: + session = await create_session() + await session.close() -def test_init_headers_simple_dict(create_session): - session = create_session(headers={"h1": "header1", - "h2": "header2"}) - assert (sorted(session._default_headers.items()) == - ([("H1", "header1"), ("H2", "header2")])) +async def test_init_headers_simple_dict(create_session) -> None: + session = await create_session(headers={"h1": "header1", "h2": "header2"}) + assert sorted(session.headers.items()) == ([("h1", "header1"), ("h2", "header2")]) -def test_init_headers_list_of_tuples(create_session): - session = create_session(headers=[("h1", "header1"), - ("h2", "header2"), - ("h3", "header3")]) - assert (session._default_headers == - CIMultiDict([("h1", "header1"), - ("h2", "header2"), - ("h3", "header3")])) +async def test_init_headers_list_of_tuples(create_session) -> None: + session = await create_session( + headers=[("h1", "header1"), ("h2", "header2"), ("h3", "header3")] + ) + assert session.headers == CIMultiDict( + [("h1", "header1"), ("h2", "header2"), ("h3", "header3")] + ) -def test_init_headers_MultiDict(create_session): - session = create_session(headers=MultiDict([("h1", "header1"), - ("h2", "header2"), - ("h3", "header3")])) - assert (session._default_headers == - CIMultiDict([("H1", "header1"), - ("H2", "header2"), - ("H3", "header3")])) +async def test_init_headers_MultiDict(create_session) -> None: + session = await create_session( + headers=MultiDict([("h1", "header1"), ("h2", "header2"), ("h3", "header3")]) + ) + assert session.headers == CIMultiDict( + [("H1", "header1"), ("H2", "header2"), ("H3", "header3")] + ) -def test_init_headers_list_of_tuples_with_duplicates(create_session): - session = create_session(headers=[("h1", "header11"), - ("h2", "header21"), - ("h1", "header12")]) - assert (session._default_headers == - CIMultiDict([("H1", "header11"), - ("H2", "header21"), - ("H1", "header12")])) +async def test_init_headers_list_of_tuples_with_duplicates(create_session) -> None: + session = await create_session( + headers=[("h1", "header11"), ("h2", "header21"), ("h1", "header12")] + ) + assert session.headers == CIMultiDict( + [("H1", "header11"), ("H2", "header21"), ("H1", "header12")] + ) -def test_init_cookies_with_simple_dict(create_session): - session = create_session(cookies={"c1": "cookie1", - "c2": "cookie2"}) +async def test_init_cookies_with_simple_dict(create_session) -> None: + session = await create_session(cookies={"c1": "cookie1", "c2": "cookie2"}) cookies = session.cookie_jar.filter_cookies() - assert set(cookies) == {'c1', 'c2'} - assert cookies['c1'].value == 'cookie1' - assert cookies['c2'].value == 'cookie2' + assert set(cookies) == {"c1", "c2"} + assert cookies["c1"].value == "cookie1" + assert cookies["c2"].value == "cookie2" -def test_init_cookies_with_list_of_tuples(create_session): - session = create_session(cookies=[("c1", "cookie1"), - ("c2", "cookie2")]) +async def test_init_cookies_with_list_of_tuples(create_session) -> None: + session = await create_session(cookies=[("c1", "cookie1"), ("c2", "cookie2")]) cookies = session.cookie_jar.filter_cookies() - assert set(cookies) == {'c1', 'c2'} - assert cookies['c1'].value == 'cookie1' - assert cookies['c2'].value == 'cookie2' + assert set(cookies) == {"c1", "c2"} + assert cookies["c1"].value == "cookie1" + assert cookies["c2"].value == "cookie2" -def test_merge_headers(create_session): - # Check incoming simple dict - session = create_session(headers={"h1": "header1", - "h2": "header2"}) +async def test_merge_headers(create_session) -> None: + # Check incoming simple dict + session = await create_session(headers={"h1": "header1", "h2": "header2"}) headers = session._prepare_headers({"h1": "h1"}) assert isinstance(headers, CIMultiDict) - assert headers == CIMultiDict([("h2", "header2"), - ("h1", "h1")]) + assert headers == {"h1": "h1", "h2": "header2"} -def test_merge_headers_with_multi_dict(create_session): - session = create_session(headers={"h1": "header1", - "h2": "header2"}) +async def test_merge_headers_with_multi_dict(create_session) -> None: + session = await create_session(headers={"h1": "header1", "h2": "header2"}) headers = session._prepare_headers(MultiDict([("h1", "h1")])) assert isinstance(headers, CIMultiDict) - assert headers == CIMultiDict([("h2", "header2"), - ("h1", "h1")]) + assert headers == {"h1": "h1", "h2": "header2"} -def test_merge_headers_with_list_of_tuples(create_session): - session = create_session(headers={"h1": "header1", - "h2": "header2"}) +async def test_merge_headers_with_list_of_tuples(create_session) -> None: + session = await create_session(headers={"h1": "header1", "h2": "header2"}) headers = session._prepare_headers([("h1", "h1")]) assert isinstance(headers, CIMultiDict) - assert headers == CIMultiDict([("h2", "header2"), - ("h1", "h1")]) + assert headers == {"h1": "h1", "h2": "header2"} -def test_merge_headers_with_list_of_tuples_duplicated_names(create_session): - session = create_session(headers={"h1": "header1", - "h2": "header2"}) +async def test_merge_headers_with_list_of_tuples_duplicated_names( + create_session, +) -> None: + session = await create_session(headers={"h1": "header1", "h2": "header2"}) - headers = session._prepare_headers([("h1", "v1"), - ("h1", "v2")]) + headers = session._prepare_headers([("h1", "v1"), ("h1", "v2")]) assert isinstance(headers, CIMultiDict) - assert headers == CIMultiDict([("H2", "header2"), - ("H1", "v1"), - ("H1", "v2")]) - - -def test_http_GET(session, params): - with mock.patch("aiohttp.client.ClientSession._request") as patched: - session.get("http://test.example.com", - params={"x": 1}, - **params) + assert list(sorted(headers.items())) == [ + ("h1", "v1"), + ("h1", "v2"), + ("h2", "header2"), + ] + + +def test_http_GET(session, params) -> None: + # Python 3.8 will auto use mock.AsyncMock, it has different behavior + with mock.patch( + "aiohttp.client.ClientSession._request", new_callable=mock.MagicMock + ) as patched: + session.get("http://test.example.com", params={"x": 1}, **params) assert patched.called, "`ClientSession._request` not called" - assert list(patched.call_args) == [("GET", "http://test.example.com",), - dict( - params={"x": 1}, - allow_redirects=True, - **params)] - - -def test_http_OPTIONS(session, params): - with mock.patch("aiohttp.client.ClientSession._request") as patched: - session.options("http://opt.example.com", - params={"x": 2}, - **params) + assert list(patched.call_args) == [ + ( + "GET", + "http://test.example.com", + ), + dict(params={"x": 1}, allow_redirects=True, **params), + ] + + +def test_http_OPTIONS(session, params) -> None: + with mock.patch( + "aiohttp.client.ClientSession._request", new_callable=mock.MagicMock + ) as patched: + session.options("http://opt.example.com", params={"x": 2}, **params) assert patched.called, "`ClientSession._request` not called" - assert list(patched.call_args) == [("OPTIONS", "http://opt.example.com",), - dict( - params={"x": 2}, - allow_redirects=True, - **params)] - - -def test_http_HEAD(session, params): - with mock.patch("aiohttp.client.ClientSession._request") as patched: - session.head("http://head.example.com", - params={"x": 2}, - **params) + assert list(patched.call_args) == [ + ( + "OPTIONS", + "http://opt.example.com", + ), + dict(params={"x": 2}, allow_redirects=True, **params), + ] + + +def test_http_HEAD(session, params) -> None: + with mock.patch( + "aiohttp.client.ClientSession._request", new_callable=mock.MagicMock + ) as patched: + session.head("http://head.example.com", params={"x": 2}, **params) assert patched.called, "`ClientSession._request` not called" - assert list(patched.call_args) == [("HEAD", "http://head.example.com",), - dict( - params={"x": 2}, - allow_redirects=False, - **params)] - - -def test_http_POST(session, params): - with mock.patch("aiohttp.client.ClientSession._request") as patched: - session.post("http://post.example.com", - params={"x": 2}, - data="Some_data", - **params) + assert list(patched.call_args) == [ + ( + "HEAD", + "http://head.example.com", + ), + dict(params={"x": 2}, allow_redirects=False, **params), + ] + + +def test_http_POST(session, params) -> None: + with mock.patch( + "aiohttp.client.ClientSession._request", new_callable=mock.MagicMock + ) as patched: + session.post( + "http://post.example.com", params={"x": 2}, data="Some_data", **params + ) assert patched.called, "`ClientSession._request` not called" - assert list(patched.call_args) == [("POST", "http://post.example.com",), - dict( - params={"x": 2}, - data="Some_data", - **params)] - - -def test_http_PUT(session, params): - with mock.patch("aiohttp.client.ClientSession._request") as patched: - session.put("http://put.example.com", - params={"x": 2}, - data="Some_data", - **params) + assert list(patched.call_args) == [ + ( + "POST", + "http://post.example.com", + ), + dict(params={"x": 2}, data="Some_data", **params), + ] + + +def test_http_PUT(session, params) -> None: + with mock.patch( + "aiohttp.client.ClientSession._request", new_callable=mock.MagicMock + ) as patched: + session.put( + "http://put.example.com", params={"x": 2}, data="Some_data", **params + ) assert patched.called, "`ClientSession._request` not called" - assert list(patched.call_args) == [("PUT", "http://put.example.com",), - dict( - params={"x": 2}, - data="Some_data", - **params)] - - -def test_http_PATCH(session, params): - with mock.patch("aiohttp.client.ClientSession._request") as patched: - session.patch("http://patch.example.com", - params={"x": 2}, - data="Some_data", - **params) + assert list(patched.call_args) == [ + ( + "PUT", + "http://put.example.com", + ), + dict(params={"x": 2}, data="Some_data", **params), + ] + + +def test_http_PATCH(session, params) -> None: + with mock.patch( + "aiohttp.client.ClientSession._request", new_callable=mock.MagicMock + ) as patched: + session.patch( + "http://patch.example.com", params={"x": 2}, data="Some_data", **params + ) assert patched.called, "`ClientSession._request` not called" - assert list(patched.call_args) == [("PATCH", "http://patch.example.com",), - dict( - params={"x": 2}, - data="Some_data", - **params)] - - -def test_http_DELETE(session, params): - with mock.patch("aiohttp.client.ClientSession._request") as patched: - session.delete("http://delete.example.com", - params={"x": 2}, - **params) + assert list(patched.call_args) == [ + ( + "PATCH", + "http://patch.example.com", + ), + dict(params={"x": 2}, data="Some_data", **params), + ] + + +def test_http_DELETE(session, params) -> None: + with mock.patch( + "aiohttp.client.ClientSession._request", new_callable=mock.MagicMock + ) as patched: + session.delete("http://delete.example.com", params={"x": 2}, **params) assert patched.called, "`ClientSession._request` not called" - assert list(patched.call_args) == [("DELETE", - "http://delete.example.com",), - dict( - params={"x": 2}, - **params)] + assert list(patched.call_args) == [ + ( + "DELETE", + "http://delete.example.com", + ), + dict(params={"x": 2}, **params), + ] -def test_close(create_session, connector): - session = create_session(connector=connector) +async def test_close(create_session, connector) -> None: + session = await create_session(connector=connector) - session.close() + await session.close() assert session.connector is None assert connector.closed -def test_closed(session): +async def test_closed(session) -> None: assert not session.closed - session.close() + await session.close() assert session.closed -def test_connector(create_session, loop, mocker): +async def test_connector(create_session, loop, mocker) -> None: connector = TCPConnector(loop=loop) - mocker.spy(connector, 'close') - session = create_session(connector=connector) + mocker.spy(connector, "close") + session = await create_session(connector=connector) assert session.connector is connector - session.close() + await session.close() assert connector.close.called connector.close() -def test_create_connector(create_session, loop, mocker): - session = create_session() +async def test_create_connector(create_session, loop, mocker) -> None: + session = await create_session() connector = session.connector - mocker.spy(session.connector, 'close') + mocker.spy(session.connector, "close") - session.close() + await session.close() assert connector.close.called -def test_connector_loop(loop): +def test_connector_loop(loop) -> None: with contextlib.ExitStack() as stack: another_loop = asyncio.new_event_loop() stack.enter_context(contextlib.closing(another_loop)) - connector = TCPConnector(loop=another_loop) + + async def make_connector(): + return TCPConnector() + + connector = another_loop.run_until_complete(make_connector()) + stack.enter_context(contextlib.closing(connector)) with pytest.raises(RuntimeError) as ctx: - ClientSession(connector=connector, loop=loop) - assert re.match("Session and connector has to use same event loop", - str(ctx.value)) + + async def make_sess(): + return ClientSession(connector=connector, loop=loop) + + loop.run_until_complete(make_sess()) + assert ( + Matches("Session and connector has to use same event loop") + == str(ctx.value).strip() + ) -def test_detach(session): +def test_detach(session) -> None: conn = session.connector try: assert not conn.closed @@ -313,98 +339,146 @@ def test_detach(session): conn.close() -@asyncio.coroutine -def test_request_closed_session(session): - session.close() +async def test_request_closed_session(session) -> None: + await session.close() with pytest.raises(RuntimeError): - yield from session.request('get', '/') + await session.request("get", "/") -def test_close_flag_for_closed_connector(session): +def test_close_flag_for_closed_connector(session) -> None: conn = session.connector assert not session.closed conn.close() assert session.closed -def test_double_close(connector, create_session): - session = create_session(connector=connector) +async def test_double_close(connector, create_session) -> None: + session = await create_session(connector=connector) - session.close() + await session.close() assert session.connector is None - session.close() + await session.close() assert session.closed assert connector.closed -def test_del(connector, loop): +async def test_del(connector, loop) -> None: + loop.set_debug(False) # N.B. don't use session fixture, it stores extra reference internally session = ClientSession(connector=connector, loop=loop) - loop.set_exception_handler(lambda loop, ctx: None) + logs = [] + loop.set_exception_handler(lambda loop, ctx: logs.append(ctx)) with pytest.warns(ResourceWarning): del session gc.collect() + assert len(logs) == 1 + expected = {"client_session": mock.ANY, "message": "Unclosed client session"} + assert logs[0] == expected -def test_context_manager(connector, loop): - with ClientSession(loop=loop, connector=connector) as session: - pass - assert session.closed +async def test_del_debug(connector, loop) -> None: + loop.set_debug(True) + # N.B. don't use session fixture, it stores extra reference internally + session = ClientSession(connector=connector, loop=loop) + logs = [] + loop.set_exception_handler(lambda loop, ctx: logs.append(ctx)) + + with pytest.warns(ResourceWarning): + del session + gc.collect() + + assert len(logs) == 1 + expected = { + "client_session": mock.ANY, + "message": "Unclosed client session", + "source_traceback": mock.ANY, + } + assert logs[0] == expected -def test_borrow_connector_loop(connector, create_session, loop): +async def test_context_manager(connector, loop) -> None: + with pytest.raises(TypeError): + with ClientSession(loop=loop, connector=connector) as session: + pass + + assert session.closed + + +async def test_borrow_connector_loop(connector, create_session, loop) -> None: session = ClientSession(connector=connector, loop=None) try: assert session._loop, loop finally: - session.close() + await session.close() -@asyncio.coroutine -def test_reraise_os_error(create_session): +async def test_reraise_os_error(create_session) -> None: err = OSError(1, "permission error") req = mock.Mock() req_factory = mock.Mock(return_value=req) req.send = mock.Mock(side_effect=err) - session = create_session(request_class=req_factory) + session = await create_session(request_class=req_factory) - @asyncio.coroutine - def create_connection(req): + async def create_connection(req, traces, timeout): # return self.transport, self.protocol return mock.Mock() + session._connector._create_connection = create_connection + session._connector._release = mock.Mock() with pytest.raises(aiohttp.ClientOSError) as ctx: - yield from session.request('get', 'http://example.com') + await session.request("get", "http://example.com") e = ctx.value assert e.errno == err.errno assert e.strerror == err.strerror -@asyncio.coroutine -def test_request_ctx_manager_props(loop): - yield from asyncio.sleep(0, loop=loop) # to make it a task - with aiohttp.ClientSession(loop=loop) as client: - ctx_mgr = client.get('http://example.com') +async def test_close_conn_on_error(create_session) -> None: + class UnexpectedException(BaseException): + pass + + err = UnexpectedException("permission error") + req = mock.Mock() + req_factory = mock.Mock(return_value=req) + req.send = mock.Mock(side_effect=err) + session = await create_session(request_class=req_factory) + + connections = [] + original_connect = session._connector.connect - next(ctx_mgr) - assert isinstance(ctx_mgr.gi_frame, types.FrameType) - assert not ctx_mgr.gi_running - assert isinstance(ctx_mgr.gi_code, types.CodeType) - yield from asyncio.sleep(0.1, loop=loop) + async def connect(req, traces, timeout): + conn = await original_connect(req, traces, timeout) + connections.append(conn) + return conn + + async def create_connection(req, traces, timeout): + # return self.transport, self.protocol + conn = mock.Mock() + return conn + + session._connector.connect = connect + session._connector._create_connection = create_connection + session._connector._release = mock.Mock() + with pytest.raises(UnexpectedException): + async with session.request("get", "http://example.com") as resp: + await resp.text() -@asyncio.coroutine -def test_cookie_jar_usage(loop, test_client): + # normally called during garbage collection. triggers an exception + # if the connection wasn't already closed + for c in connections: + c.__del__() + + +async def test_cookie_jar_usage(loop, aiohttp_client) -> None: req_url = None jar = mock.Mock() jar.filter_cookies.return_value = None - @asyncio.coroutine - def handler(request): + async def handler(request): nonlocal req_url req_url = "http://%s/" % request.host @@ -413,17 +487,17 @@ def handler(request): return resp app = web.Application() - app.router.add_route('GET', '/', handler) - session = yield from test_client(app, - cookies={"request": "req_value"}, - cookie_jar=jar) + app.router.add_route("GET", "/", handler) + session = await aiohttp_client( + app, cookies={"request": "req_value"}, cookie_jar=jar + ) # Updating the cookie jar with initial user defined cookies jar.update_cookies.assert_called_with({"request": "req_value"}) jar.update_cookies.reset_mock() - resp = yield from session.get("/") - yield from resp.release() + resp = await session.get("/") + await resp.release() # Filtering the cookie jar before sending the request, # getting the request URL as only parameter @@ -437,38 +511,213 @@ def handler(request): assert resp_cookies["response"].value == "resp_value" -def test_session_default_version(loop): +async def test_session_default_version(loop) -> None: session = aiohttp.ClientSession(loop=loop) assert session.version == aiohttp.HttpVersion11 -def test_session_loop(loop): +async def test_session_loop(loop) -> None: session = aiohttp.ClientSession(loop=loop) - assert session.loop is loop - session.close() + with pytest.warns(DeprecationWarning): + assert session.loop is loop + await session.close() -def test_proxy_str(session, params): - with mock.patch("aiohttp.client.ClientSession._request") as patched: - session.get("http://test.example.com", - proxy='http://proxy.com', - **params) +def test_proxy_str(session, params) -> None: + with mock.patch( + "aiohttp.client.ClientSession._request", new_callable=mock.MagicMock + ) as patched: + session.get("http://test.example.com", proxy="http://proxy.com", **params) assert patched.called, "`ClientSession._request` not called" - assert list(patched.call_args) == [("GET", "http://test.example.com",), - dict( - allow_redirects=True, - proxy='http://proxy.com', - **params)] + assert list(patched.call_args) == [ + ( + "GET", + "http://test.example.com", + ), + dict(allow_redirects=True, proxy="http://proxy.com", **params), + ] -def test_client_session_implicit_loop_warn(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) +async def test_request_tracing(loop, aiohttp_client) -> None: + async def handler(request): + return web.json_response({"ok": True}) + + app = web.Application() + app.router.add_post("/", handler) + + trace_config_ctx = mock.Mock() + trace_request_ctx = {} + body = "This is request body" + gathered_req_body = BytesIO() + gathered_res_body = BytesIO() + on_request_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_request_redirect = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_request_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + + async def on_request_chunk_sent(session, context, params): + gathered_req_body.write(params.chunk) + + async def on_response_chunk_received(session, context, params): + gathered_res_body.write(params.chunk) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_request_start.append(on_request_start) + trace_config.on_request_end.append(on_request_end) + trace_config.on_request_chunk_sent.append(on_request_chunk_sent) + trace_config.on_response_chunk_received.append(on_response_chunk_received) + trace_config.on_request_redirect.append(on_request_redirect) + + session = await aiohttp_client(app, trace_configs=[trace_config]) + + async with session.post( + "/", data=body, trace_request_ctx=trace_request_ctx + ) as resp: + + await resp.json() + + on_request_start.assert_called_once_with( + session.session, + trace_config_ctx, + aiohttp.TraceRequestStartParams( + hdrs.METH_POST, session.make_url("/"), CIMultiDict() + ), + ) + + on_request_end.assert_called_once_with( + session.session, + trace_config_ctx, + aiohttp.TraceRequestEndParams( + hdrs.METH_POST, session.make_url("/"), CIMultiDict(), resp + ), + ) + assert not on_request_redirect.called + assert gathered_req_body.getvalue() == body.encode("utf8") + assert gathered_res_body.getvalue() == json.dumps({"ok": True}).encode("utf8") + + +async def test_request_tracing_exception() -> None: + loop = asyncio.get_event_loop() + on_request_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_request_exception = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + + trace_config = aiohttp.TraceConfig() + trace_config.on_request_end.append(on_request_end) + trace_config.on_request_exception.append(on_request_exception) + + with mock.patch("aiohttp.client.TCPConnector.connect") as connect_patched: + error = Exception() + if sys.version_info >= (3, 8, 1): + connect_patched.side_effect = error + else: + loop = asyncio.get_event_loop() + f = loop.create_future() + f.set_exception(error) + connect_patched.return_value = f + + session = aiohttp.ClientSession(loop=loop, trace_configs=[trace_config]) + + try: + await session.get("http://example.com") + except Exception: + pass + + on_request_exception.assert_called_once_with( + session, + mock.ANY, + aiohttp.TraceRequestExceptionParams( + hdrs.METH_GET, URL("http://example.com"), CIMultiDict(), error + ), + ) + assert not on_request_end.called + + +async def test_request_tracing_interpose_headers(loop, aiohttp_client) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + + class MyClientRequest(ClientRequest): + headers = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + MyClientRequest.headers = self.headers + + async def new_headers(session, trace_config_ctx, data): + data.headers["foo"] = "bar" + + trace_config = aiohttp.TraceConfig() + trace_config.on_request_start.append(new_headers) + + session = await aiohttp_client( + app, request_class=MyClientRequest, trace_configs=[trace_config] + ) + + await session.get("/") + assert MyClientRequest.headers["foo"] == "bar" + + +@pytest.mark.skipif(not PY_36, reason="Python 3.6+ required") +def test_client_session_inheritance() -> None: + with pytest.warns(DeprecationWarning): + + class A(ClientSession): + pass - with pytest.warns(ResourceWarning): - session = aiohttp.ClientSession() - assert session._loop is loop - session.close() - asyncio.set_event_loop(None) - loop.close() +@pytest.mark.skipif(not DEBUG, reason="The check is applied in DEBUG mode only") +async def test_client_session_custom_attr(loop) -> None: + session = ClientSession(loop=loop) + with pytest.warns(DeprecationWarning): + session.custom = None + + +async def test_client_session_timeout_args(loop) -> None: + session1 = ClientSession(loop=loop) + assert session1._timeout == client.DEFAULT_TIMEOUT + + with pytest.warns(DeprecationWarning): + session2 = ClientSession(loop=loop, read_timeout=20 * 60, conn_timeout=30 * 60) + assert session2._timeout == client.ClientTimeout(total=20 * 60, connect=30 * 60) + + with pytest.raises(ValueError): + ClientSession( + loop=loop, timeout=client.ClientTimeout(total=10 * 60), read_timeout=20 * 60 + ) + + with pytest.raises(ValueError): + ClientSession( + loop=loop, timeout=client.ClientTimeout(total=10 * 60), conn_timeout=30 * 60 + ) + + +async def test_client_session_timeout_default_args(loop) -> None: + session1 = ClientSession() + assert session1.timeout == client.DEFAULT_TIMEOUT + + +async def test_client_session_timeout_argument() -> None: + session = ClientSession(timeout=500) + assert session.timeout == 500 + + +async def test_requote_redirect_url_default() -> None: + session = ClientSession() + assert session.requote_redirect_url + + +async def test_requote_redirect_url_default_disable() -> None: + session = ClientSession(requote_redirect_url=False) + assert not session.requote_redirect_url + + +async def test_requote_redirect_setter() -> None: + session = ClientSession() + assert session.requote_redirect_url + with pytest.warns(DeprecationWarning): + session.requote_redirect_url = False + assert not session.requote_redirect_url diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 1f85d1bd89d..baa4469e334 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -7,9 +7,10 @@ import pytest import aiohttp -from aiohttp import client, hdrs, helpers +from aiohttp import client, hdrs from aiohttp.http import WS_KEY -from aiohttp.log import ws_logger +from aiohttp.streams import EofStream +from aiohttp.test_utils import make_mocked_coro @pytest.fixture @@ -27,483 +28,670 @@ def ws_key(key): return base64.b64encode(hashlib.sha1(key + WS_KEY).digest()).decode() -@asyncio.coroutine -def test_ws_connect(ws_key, loop, key_data): +async def test_ws_connect(ws_key, loop, key_data) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, - hdrs.SEC_WEBSOCKET_PROTOCOL: 'chat' + hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) - res = yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org', - protocols=('t1', 't2', 'chat')) + res = await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", protocols=("t1", "t2", "chat") + ) assert isinstance(res, client.ClientWebSocketResponse) - assert res.protocol == 'chat' + assert res.protocol == "chat" assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] -@asyncio.coroutine -def test_ws_connect_with_origin(key_data, loop): +async def test_ws_connect_with_origin(key_data, loop) -> None: resp = mock.Mock() resp.status = 403 - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) - origin = 'https://example.org/page.html' + origin = "https://example.org/page.html" with pytest.raises(client.WSServerHandshakeError): - yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org', origin=origin) + await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", origin=origin + ) assert hdrs.ORIGIN in m_req.call_args[1]["headers"] assert m_req.call_args[1]["headers"][hdrs.ORIGIN] == origin -@asyncio.coroutine -def test_ws_connect_custom_response(loop, ws_key, key_data): - +async def test_ws_connect_custom_response(loop, ws_key, key_data) -> None: class CustomResponse(client.ClientWebSocketResponse): def read(self, decode=False): - return 'customized!' + return "customized!" resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) - res = yield from aiohttp.ClientSession( - ws_response_class=CustomResponse, loop=loop).ws_connect( - 'http://test.org') + res = await aiohttp.ClientSession( + ws_response_class=CustomResponse, loop=loop + ).ws_connect("http://test.org") - assert res.read() == 'customized!' + assert res.read() == "customized!" -@asyncio.coroutine -def test_ws_connect_err_status(loop, ws_key, key_data): +async def test_ws_connect_err_status(loop, ws_key, key_data) -> None: resp = mock.Mock() resp.status = 500 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, - hdrs.SEC_WEBSOCKET_ACCEPT: ws_key + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError) as ctx: - yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org', - protocols=('t1', 't2', 'chat')) + await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", protocols=("t1", "t2", "chat") + ) - assert ctx.value.message == 'Invalid response status' + assert ctx.value.message == "Invalid response status" -@asyncio.coroutine -def test_ws_connect_err_upgrade(loop, ws_key, key_data): +async def test_ws_connect_err_upgrade(loop, ws_key, key_data) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: 'test', - hdrs.CONNECTION: hdrs.UPGRADE, - hdrs.SEC_WEBSOCKET_ACCEPT: ws_key + hdrs.UPGRADE: "test", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError) as ctx: - yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org', - protocols=('t1', 't2', 'chat')) + await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", protocols=("t1", "t2", "chat") + ) - assert ctx.value.message == 'Invalid upgrade header' + assert ctx.value.message == "Invalid upgrade header" -@asyncio.coroutine -def test_ws_connect_err_conn(loop, ws_key, key_data): +async def test_ws_connect_err_conn(loop, ws_key, key_data) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: 'close', - hdrs.SEC_WEBSOCKET_ACCEPT: ws_key + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "close", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError) as ctx: - yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org', - protocols=('t1', 't2', 'chat')) + await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", protocols=("t1", "t2", "chat") + ) - assert ctx.value.message == 'Invalid connection header' + assert ctx.value.message == "Invalid connection header" -@asyncio.coroutine -def test_ws_connect_err_challenge(loop, ws_key, key_data): +async def test_ws_connect_err_challenge(loop, ws_key, key_data) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, - hdrs.SEC_WEBSOCKET_ACCEPT: 'asdfasdfasdfasdfasdfasdf' + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: "asdfasdfasdfasdfasdfasdf", } - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError) as ctx: - yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org', - protocols=('t1', 't2', 'chat')) + await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", protocols=("t1", "t2", "chat") + ) - assert ctx.value.message == 'Invalid challenge response' + assert ctx.value.message == "Invalid challenge response" -@asyncio.coroutine -def test_ws_connect_common_headers(ws_key, loop, key_data): - """Emulate a headers dict being reused for a second ws_connect. +async def test_ws_connect_common_headers(ws_key, loop, key_data) -> None: + # Emulate a headers dict being reused for a second ws_connect. - In this scenario, we need to ensure that the newly generated secret key - is sent to the server, not the stale key. - """ + # In this scenario, we need to ensure that the newly generated secret key + # is sent to the server, not the stale key. headers = {} - @asyncio.coroutine - def test_connection(): - @asyncio.coroutine - def mock_get(*args, **kwargs): + async def test_connection() -> None: + async def mock_get(*args, **kwargs): resp = mock.Mock() resp.status = 101 - key = kwargs.get('headers').get(hdrs.SEC_WEBSOCKET_KEY) + key = kwargs.get("headers").get(hdrs.SEC_WEBSOCKET_KEY) accept = base64.b64encode( - hashlib.sha1(base64.b64encode(base64.b64decode(key)) + WS_KEY) - .digest()).decode() + hashlib.sha1(base64.b64encode(base64.b64decode(key)) + WS_KEY).digest() + ).decode() resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: accept, - hdrs.SEC_WEBSOCKET_PROTOCOL: 'chat' + hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } return resp - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get', - side_effect=mock_get) as m_req: + + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch( + "aiohttp.client.ClientSession.request", side_effect=mock_get + ) as m_req: m_os.urandom.return_value = key_data - res = yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org', - protocols=('t1', 't2', 'chat'), - headers=headers) + res = await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", protocols=("t1", "t2", "chat"), headers=headers + ) assert isinstance(res, client.ClientWebSocketResponse) - assert res.protocol == 'chat' + assert res.protocol == "chat" assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] - yield from test_connection() + await test_connection() # Generate a new ws key key_data = os.urandom(16) - yield from test_connection() + await test_connection() -@asyncio.coroutine -def test_close(loop, ws_key, key_data): +async def test_close(loop, ws_key, key_data) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - with mock.patch('aiohttp.client.WebSocketWriter') as WebSocketWriter: - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) - writer = WebSocketWriter.return_value = mock.Mock() + writer = mock.Mock() + WebSocketWriter.return_value = writer + writer.close = make_mocked_coro() session = aiohttp.ClientSession(loop=loop) - resp = yield from session.ws_connect( - 'http://test.org') + resp = await session.ws_connect("http://test.org") assert not resp.closed resp._reader.feed_data( - aiohttp.WSMessage(aiohttp.WSMsgType.CLOSE, b'', b''), 0) + aiohttp.WSMessage(aiohttp.WSMsgType.CLOSE, b"", b""), 0 + ) - res = yield from resp.close() - writer.close.assert_called_with(1000, b'') + res = await resp.close() + writer.close.assert_called_with(1000, b"") assert resp.closed assert res assert resp.exception() is None # idempotent - res = yield from resp.close() + res = await resp.close() assert not res assert writer.close.call_count == 1 - session.close() + await session.close() -@asyncio.coroutine -def test_close_exc(loop, ws_key, key_data): +async def test_close_eofstream(loop, ws_key, key_data) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - with mock.patch('aiohttp.client.WebSocketWriter') as WebSocketWriter: - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) - WebSocketWriter.return_value = mock.Mock() + writer = WebSocketWriter.return_value = mock.Mock() + + session = aiohttp.ClientSession(loop=loop) + resp = await session.ws_connect("http://test.org") + assert not resp.closed + + exc = EofStream() + resp._reader.set_exception(exc) + + await resp.receive() + writer.close.assert_called_with(1000, b"") + assert resp.closed + + await session.close() + + +async def test_close_exc(loop, ws_key, key_data) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + writer = mock.Mock() + WebSocketWriter.return_value = writer + writer.close = make_mocked_coro() session = aiohttp.ClientSession(loop=loop) - resp = yield from session.ws_connect('http://test.org') + resp = await session.ws_connect("http://test.org") assert not resp.closed exc = ValueError() resp._reader.set_exception(exc) - yield from resp.close() + await resp.close() assert resp.closed assert resp.exception() is exc - session.close() + await session.close() -@asyncio.coroutine -def test_close_exc2(loop, ws_key, key_data): +async def test_close_exc2(loop, ws_key, key_data) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - with mock.patch('aiohttp.client.WebSocketWriter') as WebSocketWriter: - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) writer = WebSocketWriter.return_value = mock.Mock() - resp = yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org') + resp = await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org" + ) assert not resp.closed exc = ValueError() writer.close.side_effect = exc - yield from resp.close() + await resp.close() assert resp.closed assert resp.exception() is exc resp._closed = False writer.close.side_effect = asyncio.CancelledError() with pytest.raises(asyncio.CancelledError): - yield from resp.close() + await resp.close() -@asyncio.coroutine -def test_send_data_after_close(ws_key, key_data, loop, mocker): +async def test_send_data_after_close(ws_key, key_data, loop) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) - resp = yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org') + resp = await aiohttp.ClientSession(loop=loop).ws_connect("http://test.org") resp._writer._closing = True - mocker.spy(ws_logger, 'warning') + for meth, args in ( + (resp.ping, ()), + (resp.pong, ()), + (resp.send_str, ("s",)), + (resp.send_bytes, (b"b",)), + (resp.send_json, ({},)), + ): + with pytest.raises(ConnectionResetError): + await meth(*args) - for meth, args in ((resp.ping, ()), - (resp.pong, ()), - (resp.send_str, ('s',)), - (resp.send_bytes, (b'b',)), - (resp.send_json, ({},))): - meth(*args) - assert ws_logger.warning.called - ws_logger.warning.reset_mock() - -@asyncio.coroutine -def test_send_data_type_errors(ws_key, key_data, loop): +async def test_send_data_type_errors(ws_key, key_data, loop) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - with mock.patch('aiohttp.client.WebSocketWriter') as WebSocketWriter: - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) WebSocketWriter.return_value = mock.Mock() - resp = yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org') + resp = await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org" + ) - pytest.raises(TypeError, resp.send_str, b's') - pytest.raises(TypeError, resp.send_bytes, 'b') - pytest.raises(TypeError, resp.send_json, set()) + with pytest.raises(TypeError): + await resp.send_str(b"s") + with pytest.raises(TypeError): + await resp.send_bytes("b") + with pytest.raises(TypeError): + await resp.send_json(set()) -@asyncio.coroutine -def test_reader_read_exception(ws_key, key_data, loop): +async def test_reader_read_exception(ws_key, key_data, loop) -> None: hresp = mock.Mock() hresp.status = 101 hresp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - with mock.patch('aiohttp.client.WebSocketWriter') as WebSocketWriter: - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(hresp) - WebSocketWriter.return_value = mock.Mock() + + writer = mock.Mock() + WebSocketWriter.return_value = writer + writer.close = make_mocked_coro() session = aiohttp.ClientSession(loop=loop) - resp = yield from session.ws_connect('http://test.org') + resp = await session.ws_connect("http://test.org") exc = ValueError() resp._reader.set_exception(exc) - msg = yield from resp.receive() + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.ERROR - assert msg.type is msg.tp assert resp.exception() is exc - session.close() + await session.close() -@asyncio.coroutine -def test_receive_runtime_err(loop): +async def test_receive_runtime_err(loop) -> None: resp = client.ClientWebSocketResponse( - mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock(), 10.0, - True, True, loop) + mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock(), 10.0, True, True, loop + ) resp._waiting = True with pytest.raises(RuntimeError): - yield from resp.receive() + await resp.receive() -@asyncio.coroutine -def test_ws_connect_close_resp_on_err(loop, ws_key, key_data): +async def test_ws_connect_close_resp_on_err(loop, ws_key, key_data) -> None: resp = mock.Mock() resp.status = 500 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, - hdrs.SEC_WEBSOCKET_ACCEPT: ws_key + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError): - yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org', - protocols=('t1', 't2', 'chat')) + await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", protocols=("t1", "t2", "chat") + ) resp.close.assert_called_with() -@asyncio.coroutine -def test_ws_connect_non_overlapped_protocols(ws_key, loop, key_data): +async def test_ws_connect_non_overlapped_protocols(ws_key, loop, key_data) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, - hdrs.SEC_WEBSOCKET_PROTOCOL: 'other,another' + hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another", } - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) - res = yield from aiohttp.ClientSession(loop=loop).ws_connect( - 'http://test.org', - protocols=('t1', 't2', 'chat')) + res = await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", protocols=("t1", "t2", "chat") + ) assert res.protocol is None -@asyncio.coroutine -def test_ws_connect_non_overlapped_protocols_2(ws_key, loop, key_data): +async def test_ws_connect_non_overlapped_protocols_2(ws_key, loop, key_data) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, - hdrs.SEC_WEBSOCKET_PROTOCOL: 'other,another' + hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another", } - with mock.patch('aiohttp.client.os') as m_os: - with mock.patch('aiohttp.client.ClientSession.get') as m_req: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data - m_req.return_value = helpers.create_future(loop) + m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) connector = aiohttp.TCPConnector(loop=loop, force_close=True) - res = yield from aiohttp.ClientSession( - connector=connector, loop=loop).ws_connect( - 'http://test.org', - protocols=('t1', 't2', 'chat')) + res = await aiohttp.ClientSession( + connector=connector, loop=loop + ).ws_connect("http://test.org", protocols=("t1", "t2", "chat")) assert res.protocol is None del res + + +async def test_ws_connect_deflate(loop, ws_key, key_data) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate", + } + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", compress=15 + ) + + assert res.compress == 15 + assert res.client_notakeover is False + + +async def test_ws_connect_deflate_per_message(loop, ws_key, key_data) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate", + } + with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + writer = WebSocketWriter.return_value = mock.Mock() + send = writer.send = make_mocked_coro() + + session = aiohttp.ClientSession(loop=loop) + resp = await session.ws_connect("http://test.org") + + await resp.send_str("string", compress=-1) + send.assert_called_with("string", binary=False, compress=-1) + + await resp.send_bytes(b"bytes", compress=15) + send.assert_called_with(b"bytes", binary=True, compress=15) + + await resp.send_json([{}], compress=-9) + send.assert_called_with("[{}]", binary=False, compress=-9) + + await session.close() + + +async def test_ws_connect_deflate_server_not_support(loop, ws_key, key_data) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", compress=15 + ) + + assert res.compress == 0 + assert res.client_notakeover is False + + +async def test_ws_connect_deflate_notakeover(loop, ws_key, key_data) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " + "client_no_context_takeover", + } + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", compress=15 + ) + + assert res.compress == 15 + assert res.client_notakeover is True + + +async def test_ws_connect_deflate_client_wbits(loop, ws_key, key_data) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " + "client_max_window_bits=10", + } + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", compress=15 + ) + + assert res.compress == 10 + assert res.client_notakeover is False + + +async def test_ws_connect_deflate_client_wbits_bad(loop, ws_key, key_data) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " + "client_max_window_bits=6", + } + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + with pytest.raises(client.WSServerHandshakeError): + await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", compress=15 + ) + + +async def test_ws_connect_deflate_server_ext_bad(loop, ws_key, key_data) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; bad", + } + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + with pytest.raises(client.WSServerHandshakeError): + await aiohttp.ClientSession(loop=loop).ws_connect( + "http://test.org", compress=15 + ) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index b7b3dfbd34a..e423765acb4 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1,617 +1,804 @@ import asyncio +import async_timeout import pytest import aiohttp -from aiohttp import hdrs, helpers, web +from aiohttp import hdrs, web -@pytest.fixture -def ceil(mocker): - def ceil(val): - return val - - mocker.patch('aiohttp.helpers.ceil').side_effect = ceil - - -@asyncio.coroutine -def test_send_recv_text(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_send_recv_text(aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - msg = yield from ws.receive_str() - ws.send_str(msg+'/answer') - yield from ws.close() + msg = await ws.receive_str() + await ws.send_str(msg + "/answer") + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/') - resp.send_str('ask') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + await resp.send_str("ask") - assert resp.get_extra_info('socket') is not None + assert resp.get_extra_info("socket") is not None - data = yield from resp.receive_str() - assert data == 'ask/answer' - yield from resp.close() + data = await resp.receive_str() + assert data == "ask/answer" + await resp.close() - assert resp.get_extra_info('socket') is None + assert resp.get_extra_info("socket") is None -@asyncio.coroutine -def test_send_recv_bytes_bad_type(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_send_recv_bytes_bad_type(aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - msg = yield from ws.receive_str() - ws.send_str(msg+'/answer') - yield from ws.close() + msg = await ws.receive_str() + await ws.send_str(msg + "/answer") + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/') - resp.send_str('ask') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + await resp.send_str("ask") with pytest.raises(TypeError): - yield from resp.receive_bytes() - yield from resp.close() - + await resp.receive_bytes() + await resp.close() -@asyncio.coroutine -def test_send_recv_bytes(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_send_recv_bytes(aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - msg = yield from ws.receive_bytes() - ws.send_bytes(msg+b'/answer') - yield from ws.close() + msg = await ws.receive_bytes() + await ws.send_bytes(msg + b"/answer") + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") - resp.send_bytes(b'ask') + await resp.send_bytes(b"ask") - data = yield from resp.receive_bytes() - assert data == b'ask/answer' + data = await resp.receive_bytes() + assert data == b"ask/answer" - yield from resp.close() + await resp.close() -@asyncio.coroutine -def test_send_recv_text_bad_type(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_send_recv_text_bad_type(aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - msg = yield from ws.receive_bytes() - ws.send_bytes(msg+b'/answer') - yield from ws.close() + msg = await ws.receive_bytes() + await ws.send_bytes(msg + b"/answer") + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") - resp.send_bytes(b'ask') + await resp.send_bytes(b"ask") with pytest.raises(TypeError): - yield from resp.receive_str() - - yield from resp.close() + await resp.receive_str() + await resp.close() -@asyncio.coroutine -def test_send_recv_json(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_send_recv_json(aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - data = yield from ws.receive_json() - ws.send_json({'response': data['request']}) - yield from ws.close() + data = await ws.receive_json() + await ws.send_json({"response": data["request"]}) + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/') - payload = {'request': 'test'} - resp.send_json(payload) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + payload = {"request": "test"} + await resp.send_json(payload) - data = yield from resp.receive_json() - assert data['response'] == payload['request'] - yield from resp.close() + data = await resp.receive_json() + assert data["response"] == payload["request"] + await resp.close() -@asyncio.coroutine -def test_ping_pong(loop, test_client): +async def test_ping_pong(aiohttp_client) -> None: + loop = asyncio.get_event_loop() + closed = loop.create_future() - closed = helpers.create_future(loop) - - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - msg = yield from ws.receive_bytes() - ws.ping() - ws.send_bytes(msg+b'/answer') + msg = await ws.receive_bytes() + await ws.ping() + await ws.send_bytes(msg + b"/answer") try: - yield from ws.close() + await ws.close() finally: closed.set_result(1) return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") - resp.ping() - resp.send_bytes(b'ask') + await resp.ping() + await resp.send_bytes(b"ask") - msg = yield from resp.receive() + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.BINARY - assert msg.data == b'ask/answer' + assert msg.data == b"ask/answer" - msg = yield from resp.receive() + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSE - yield from resp.close() - yield from closed - + await resp.close() + await closed -@asyncio.coroutine -def test_ping_pong_manual(loop, test_client): - closed = helpers.create_future(loop) +async def test_ping_pong_manual(aiohttp_client) -> None: + loop = asyncio.get_event_loop() + closed = loop.create_future() - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - msg = yield from ws.receive_bytes() - ws.ping() - ws.send_bytes(msg+b'/answer') + msg = await ws.receive_bytes() + await ws.ping() + await ws.send_bytes(msg + b"/answer") try: - yield from ws.close() + await ws.close() finally: closed.set_result(1) return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/', autoping=False) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/", autoping=False) - resp.ping() - resp.send_bytes(b'ask') + await resp.ping() + await resp.send_bytes(b"ask") - msg = yield from resp.receive() + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.PONG - msg = yield from resp.receive() + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.PING - resp.pong() + await resp.pong() - msg = yield from resp.receive() - assert msg.data == b'ask/answer' + msg = await resp.receive() + assert msg.data == b"ask/answer" - msg = yield from resp.receive() + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSE - yield from closed + await closed -@asyncio.coroutine -def test_close(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_close(aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - yield from ws.receive_bytes() - ws.send_str('test') + await ws.receive_bytes() + await ws.send_str("test") - yield from ws.receive() + await ws.receive() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") - resp.send_bytes(b'ask') + await resp.send_bytes(b"ask") - closed = yield from resp.close() + closed = await resp.close() assert closed assert resp.closed assert resp.close_code == 1000 - msg = yield from resp.receive() + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSED -@asyncio.coroutine -def test_concurrent_close(loop, test_client): +async def test_concurrent_close(aiohttp_client) -> None: client_ws = None - @asyncio.coroutine - def handler(request): + async def handler(request): nonlocal client_ws ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - yield from ws.receive_bytes() - ws.send_str('test') + await ws.receive_bytes() + await ws.send_str("test") - yield from client_ws.close() + await client_ws.close() - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.CLOSE return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - ws = client_ws = yield from client.ws_connect('/') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + ws = client_ws = await client.ws_connect("/") - ws.send_bytes(b'ask') + await ws.send_bytes(b"ask") - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.CLOSING - yield from asyncio.sleep(0.01, loop=loop) - msg = yield from ws.receive() + await asyncio.sleep(0.01) + msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.CLOSED -@asyncio.coroutine -def test_close_from_server(loop, test_client): - - closed = helpers.create_future(loop) +async def test_close_from_server(aiohttp_client) -> None: + loop = asyncio.get_event_loop() + closed = loop.create_future() - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) try: - yield from ws.receive_bytes() - yield from ws.close() + await ws.receive_bytes() + await ws.close() finally: closed.set_result(1) return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") - resp.send_bytes(b'ask') + await resp.send_bytes(b"ask") - msg = yield from resp.receive() + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSE assert resp.closed - msg = yield from resp.receive() + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSED - yield from closed + await closed -@asyncio.coroutine -def test_close_manual(loop, test_client): +async def test_close_manual(aiohttp_client) -> None: + loop = asyncio.get_event_loop() + closed = loop.create_future() - closed = helpers.create_future(loop) - - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - yield from ws.receive_bytes() - ws.send_str('test') + await ws.receive_bytes() + await ws.send_str("test") try: - yield from ws.close() + await ws.close() finally: closed.set_result(1) return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/', autoclose=False) - resp.send_bytes(b'ask') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/", autoclose=False) + await resp.send_bytes(b"ask") - msg = yield from resp.receive() - assert msg.data == 'test' + msg = await resp.receive() + assert msg.data == "test" - msg = yield from resp.receive() + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSE assert msg.data == 1000 - assert msg.extra == '' + assert msg.extra == "" assert not resp.closed - yield from resp.close() - yield from closed + await resp.close() + await closed assert resp.closed -@asyncio.coroutine -def test_close_timeout(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_close_timeout(aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) - yield from ws.receive_bytes() - ws.send_str('test') - yield from asyncio.sleep(1, loop=loop) + await ws.prepare(request) + await ws.receive_bytes() + await ws.send_str("test") + await asyncio.sleep(1) return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/', timeout=0.2, autoclose=False) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/", timeout=0.2, autoclose=False) - resp.send_bytes(b'ask') + await resp.send_bytes(b"ask") - msg = yield from resp.receive() - assert msg.data == 'test' + msg = await resp.receive() + assert msg.data == "test" assert msg.type == aiohttp.WSMsgType.TEXT - msg = yield from resp.close() + msg = await resp.close() assert resp.closed assert isinstance(resp.exception(), asyncio.TimeoutError) -@asyncio.coroutine -def test_close_cancel(loop, test_client): +async def test_close_cancel(aiohttp_client) -> None: + loop = asyncio.get_event_loop() - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) - yield from ws.receive_bytes() - ws.send_str('test') - yield from asyncio.sleep(10, loop=loop) + await ws.prepare(request) + await ws.receive_bytes() + await ws.send_str("test") + await asyncio.sleep(10) app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/', autoclose=False) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/", autoclose=False) - resp.send_bytes(b'ask') + await resp.send_bytes(b"ask") - text = yield from resp.receive() - assert text.data == 'test' + text = await resp.receive() + assert text.data == "test" t = loop.create_task(resp.close()) - yield from asyncio.sleep(0.1, loop=loop) + await asyncio.sleep(0.1) t.cancel() - yield from asyncio.sleep(0.1, loop=loop) + await asyncio.sleep(0.1) assert resp.closed assert resp.exception() is None -@asyncio.coroutine -def test_override_default_headers(loop, test_client): - - @asyncio.coroutine - def handler(request): - assert request.headers[hdrs.SEC_WEBSOCKET_VERSION] == '8' +async def test_override_default_headers(aiohttp_client) -> None: + async def handler(request): + assert request.headers[hdrs.SEC_WEBSOCKET_VERSION] == "8" ws = web.WebSocketResponse() - yield from ws.prepare(request) - yield from ws.send_str('answer') - yield from ws.close() + await ws.prepare(request) + await ws.send_str("answer") + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - headers = {hdrs.SEC_WEBSOCKET_VERSION: '8'} - client = yield from test_client(app) - resp = yield from client.ws_connect('/', headers=headers) - msg = yield from resp.receive() - assert msg.data == 'answer' - yield from resp.close() - - -@asyncio.coroutine -def test_additional_headers(loop, test_client): - - @asyncio.coroutine - def handler(request): - assert request.headers['x-hdr'] == 'xtra' + app.router.add_route("GET", "/", handler) + headers = {hdrs.SEC_WEBSOCKET_VERSION: "8"} + client = await aiohttp_client(app) + resp = await client.ws_connect("/", headers=headers) + msg = await resp.receive() + assert msg.data == "answer" + await resp.close() + + +async def test_additional_headers(aiohttp_client) -> None: + async def handler(request): + assert request.headers["x-hdr"] == "xtra" ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - ws.send_str('answer') - yield from ws.close() + await ws.send_str("answer") + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/', headers={'x-hdr': 'xtra'}) - msg = yield from resp.receive() - assert msg.data == 'answer' - yield from resp.close() - + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/", headers={"x-hdr": "xtra"}) + msg = await resp.receive() + assert msg.data == "answer" + await resp.close() -@asyncio.coroutine -def test_recv_protocol_error(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_recv_protocol_error(aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - yield from ws.receive_str() - ws._writer.writer.write(b'01234' * 100) - yield from ws.close() + await ws.receive_str() + ws._writer.transport.write(b"01234" * 100) + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/') - resp.send_str('ask') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + await resp.send_str("ask") - msg = yield from resp.receive() + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.ERROR assert type(msg.data) is aiohttp.WebSocketError - assert msg.data.args[0] == 'Received frame with non-zero reserved bits' + assert msg.data.code == aiohttp.WSCloseCode.PROTOCOL_ERROR + assert str(msg.data) == "Received frame with non-zero reserved bits" assert msg.extra is None - yield from resp.close() + await resp.close() -@asyncio.coroutine -def test_recv_timeout(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_recv_timeout(aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - yield from ws.receive_str() + await ws.receive_str() - yield from asyncio.sleep(0.1, loop=request.app.loop) + await asyncio.sleep(0.1) - yield from ws.close() + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/') - resp.send_str('ask') + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + await resp.send_str("ask") with pytest.raises(asyncio.TimeoutError): - with aiohttp.Timeout(0.01, loop=app.loop): - yield from resp.receive() - - yield from resp.close() + with async_timeout.timeout(0.01): + await resp.receive() + await resp.close() -@asyncio.coroutine -def test_receive_timeout(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_receive_timeout(aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) - yield from ws.receive() - yield from ws.close() + await ws.prepare(request) + await ws.receive() + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) + app.router.add_route("GET", "/", handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/', receive_timeout=0.1) + client = await aiohttp_client(app) + resp = await client.ws_connect("/", receive_timeout=0.1) with pytest.raises(asyncio.TimeoutError): - yield from resp.receive(0.05) + await resp.receive(0.05) - yield from resp.close() + await resp.close() -@asyncio.coroutine -def test_custom_receive_timeout(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_custom_receive_timeout(aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) - yield from ws.receive() - yield from ws.close() + await ws.prepare(request) + await ws.receive() + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) + app.router.add_route("GET", "/", handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/') + client = await aiohttp_client(app) + resp = await client.ws_connect("/") with pytest.raises(asyncio.TimeoutError): - yield from resp.receive(0.05) + await resp.receive(0.05) - yield from resp.close() + await resp.close() -@asyncio.coroutine -def test_heartbeat(loop, test_client, ceil): +async def test_heartbeat(aiohttp_client) -> None: ping_received = False - @asyncio.coroutine - def handler(request): + async def handler(request): nonlocal ping_received ws = web.WebSocketResponse(autoping=False) - yield from ws.prepare(request) - msg = yield from ws.receive() + await ws.prepare(request) + msg = await ws.receive() if msg.type == aiohttp.WSMsgType.ping: ping_received = True - yield from ws.close() + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - - client = yield from test_client(app) - resp = yield from client.ws_connect('/', heartbeat=0.01) + app.router.add_route("GET", "/", handler) - yield from resp.receive() - yield from resp.close() + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.01) + await asyncio.sleep(0.1) + await resp.receive() + await resp.close() assert ping_received -@asyncio.coroutine -def test_heartbeat_no_pong(loop, test_client, ceil): +async def test_heartbeat_no_pong(aiohttp_client) -> None: ping_received = False - @asyncio.coroutine - def handler(request): + async def handler(request): nonlocal ping_received ws = web.WebSocketResponse(autoping=False) - yield from ws.prepare(request) - msg = yield from ws.receive() + await ws.prepare(request) + msg = await ws.receive() if msg.type == aiohttp.WSMsgType.ping: ping_received = True - yield from ws.receive() + await ws.receive() return ws app = web.Application() - app.router.add_route('GET', '/', handler) + app.router.add_route("GET", "/", handler) - client = yield from test_client(app) - resp = yield from client.ws_connect('/', heartbeat=0.05) + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.05) - yield from resp.receive() - yield from resp.receive() + await resp.receive() + await resp.receive() assert ping_received + + +async def test_send_recv_compress(aiohttp_client) -> None: + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive_str() + await ws.send_str(msg + "/answer") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/", compress=15) + await resp.send_str("ask") + + assert resp.compress == 15 + + data = await resp.receive_str() + assert data == "ask/answer" + + await resp.close() + assert resp.get_extra_info("socket") is None + + +async def test_send_recv_compress_wbits(aiohttp_client) -> None: + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive_str() + await ws.send_str(msg + "/answer") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/", compress=9) + await resp.send_str("ask") + + # Client indicates supports wbits 15 + # Server supports wbit 15 for decode + assert resp.compress == 15 + + data = await resp.receive_str() + assert data == "ask/answer" + + await resp.close() + assert resp.get_extra_info("socket") is None + + +async def test_send_recv_compress_wbit_error(aiohttp_client) -> None: + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive_bytes() + await ws.send_bytes(msg + b"/answer") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + with pytest.raises(ValueError): + await client.ws_connect("/", compress=1) + + +async def test_ws_client_async_for(aiohttp_client) -> None: + items = ["q1", "q2", "q3"] + + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + for i in items: + await ws.send_str(i) + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + it = iter(items) + async for msg in resp: + assert msg.data == next(it) + + with pytest.raises(StopIteration): + next(it) + + assert resp.closed + + +async def test_ws_async_with(aiohttp_server) -> None: + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + msg = await ws.receive() + await ws.send_str(msg.data + "/answer") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + + server = await aiohttp_server(app) + + async with aiohttp.ClientSession() as client: + async with client.ws_connect(server.make_url("/")) as ws: + await ws.send_str("request") + msg = await ws.receive() + assert msg.data == "request/answer" + + assert ws.closed + + +async def test_ws_async_with_send(aiohttp_server) -> None: + # send_xxx methods have to return awaitable objects + + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + msg = await ws.receive() + await ws.send_str(msg.data + "/answer") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + + server = await aiohttp_server(app) + + async with aiohttp.ClientSession() as client: + async with client.ws_connect(server.make_url("/")) as ws: + await ws.send_str("request") + msg = await ws.receive() + assert msg.data == "request/answer" + + assert ws.closed + + +async def test_ws_async_with_shortcut(aiohttp_server) -> None: + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + msg = await ws.receive() + await ws.send_str(msg.data + "/answer") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + server = await aiohttp_server(app) + + async with aiohttp.ClientSession() as client: + async with client.ws_connect(server.make_url("/")) as ws: + await ws.send_str("request") + msg = await ws.receive() + assert msg.data == "request/answer" + + assert ws.closed + + +async def test_closed_async_for(aiohttp_client) -> None: + loop = asyncio.get_event_loop() + closed = loop.create_future() + + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + try: + await ws.send_bytes(b"started") + await ws.receive_bytes() + finally: + closed.set_result(1) + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + + messages = [] + async for msg in resp: + messages.append(msg) + if b"started" == msg.data: + await resp.send_bytes(b"ask") + await resp.close() + + assert 1 == len(messages) + assert messages[0].type == aiohttp.WSMsgType.BINARY + assert messages[0].data == b"started" + assert resp.closed + + await closed + + +async def test_peer_connection_lost(aiohttp_client) -> None: + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive_str() + assert msg == "ask" + await ws.send_str("answer") + request.transport.close() + await asyncio.sleep(10) + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + await resp.send_str("ask") + assert "answer" == await resp.receive_str() + + msg = await resp.receive() + assert msg.type == aiohttp.WSMsgType.CLOSED + await resp.close() + + +async def test_peer_connection_lost_iter(aiohttp_client) -> None: + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive_str() + assert msg == "ask" + await ws.send_str("answer") + request.transport.close() + await asyncio.sleep(100) + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + await resp.send_str("ask") + async for msg in resp: + assert "answer" == msg.data + + await resp.close() diff --git a/tests/test_connector.py b/tests/test_connector.py index 8a4e96074e7..09841923e16 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1,48 +1,160 @@ -"""Tests of http client with custom Connector""" +# Tests of http client with custom Connector import asyncio import gc -import os.path +import hashlib import platform -import shutil import socket import ssl -import tempfile -import unittest +import sys +import uuid +from collections import deque from unittest import mock import pytest from yarl import URL import aiohttp -from aiohttp import client, helpers, web -from aiohttp.client import ClientRequest -from aiohttp.connector import Connection -from aiohttp.test_utils import unused_port +from aiohttp import client, web +from aiohttp.client import ClientRequest, ClientTimeout +from aiohttp.client_reqrep import ConnectionKey +from aiohttp.connector import Connection, TCPConnector, _DNSCacheTable +from aiohttp.helpers import PY_37 +from aiohttp.locks import EventResultOrError +from aiohttp.test_utils import make_mocked_coro, unused_port +from aiohttp.tracing import Trace @pytest.fixture() def key(): - """Connection key""" - return ('localhost1', 80, False) + # Connection key + return ConnectionKey("localhost", 80, False, None, None, None, None) @pytest.fixture def key2(): - """Connection key""" - return ('localhost2', 80, False) + # Connection key + return ConnectionKey("localhost", 80, False, None, None, None, None) @pytest.fixture def ssl_key(): - """Connection key""" - return ('localhost', 80, True) + # Connection key + return ConnectionKey("localhost", 80, True, None, None, None, None) -def test_del(loop): - conn = aiohttp.BaseConnector(loop=loop) +@pytest.fixture +def unix_sockname(shorttmpdir): + sock_path = shorttmpdir / "socket.sock" + return str(sock_path) + + +@pytest.fixture +def unix_server(loop, unix_sockname): + runners = [] + + async def go(app): + runner = web.AppRunner(app) + runners.append(runner) + await runner.setup() + site = web.UnixSite(runner, unix_sockname) + await site.start() + + yield go + + for runner in runners: + loop.run_until_complete(runner.cleanup()) + + +@pytest.fixture +def named_pipe_server(proactor_loop, pipe_name): + runners = [] + + async def go(app): + runner = web.AppRunner(app) + runners.append(runner) + await runner.setup() + site = web.NamedPipeSite(runner, pipe_name) + await site.start() + + yield go + + for runner in runners: + proactor_loop.run_until_complete(runner.cleanup()) + + +def create_mocked_conn(conn_closing_result=None, **kwargs): + loop = asyncio.get_event_loop() + proto = mock.Mock(**kwargs) + proto.closed = loop.create_future() + proto.closed.set_result(conn_closing_result) + return proto + + +def test_connection_del(loop) -> None: + connector = mock.Mock() + key = mock.Mock() + protocol = mock.Mock() + loop.set_debug(0) + conn = Connection(connector, key, protocol, loop=loop) + exc_handler = mock.Mock() + loop.set_exception_handler(exc_handler) + + with pytest.warns(ResourceWarning): + del conn + gc.collect() + + connector._release.assert_called_with(key, protocol, should_close=True) + msg = { + "message": mock.ANY, + "client_connection": mock.ANY, + } + exc_handler.assert_called_with(loop, msg) + + +def test_connection_del_loop_debug(loop) -> None: + connector = mock.Mock() + key = mock.Mock() + protocol = mock.Mock() + loop.set_debug(1) + conn = Connection(connector, key, protocol, loop=loop) + exc_handler = mock.Mock() + loop.set_exception_handler(exc_handler) + + with pytest.warns(ResourceWarning): + del conn + gc.collect() + + msg = { + "message": mock.ANY, + "client_connection": mock.ANY, + "source_traceback": mock.ANY, + } + exc_handler.assert_called_with(loop, msg) + + +def test_connection_del_loop_closed(loop) -> None: + connector = mock.Mock() + key = mock.Mock() + protocol = mock.Mock() + loop.set_debug(1) + conn = Connection(connector, key, protocol, loop=loop) + exc_handler = mock.Mock() + loop.set_exception_handler(exc_handler) + loop.close() + + with pytest.warns(ResourceWarning): + del conn + gc.collect() + + assert not connector._release.called + assert not exc_handler.called + + +async def test_del(loop) -> None: + conn = aiohttp.BaseConnector() proto = mock.Mock(should_close=False) - conn._release('a', proto) + conn._release("a", proto) conns_impl = conn._conns exc_handler = mock.Mock() @@ -54,21 +166,22 @@ def test_del(loop): assert not conns_impl proto.close.assert_called_with() - msg = {'connector': mock.ANY, # conn was deleted - 'connections': mock.ANY, - 'message': 'Unclosed connector'} + msg = { + "connector": mock.ANY, # conn was deleted + "connections": mock.ANY, + "message": "Unclosed connector", + } if loop.get_debug(): - msg['source_traceback'] = mock.ANY + msg["source_traceback"] = mock.ANY exc_handler.assert_called_with(loop, msg) @pytest.mark.xfail -@asyncio.coroutine -def test_del_with_scheduled_cleanup(loop): +async def test_del_with_scheduled_cleanup(loop) -> None: loop.set_debug(True) conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=0.01) transp = mock.Mock() - conn._conns['a'] = [(transp, 'proto', 123)] + conn._conns["a"] = [(transp, 123)] conns_impl = conn._conns exc_handler = mock.Mock() @@ -78,22 +191,27 @@ def test_del_with_scheduled_cleanup(loop): # obviously doesn't deletion because loop has a strong # reference to connector's instance method, isn't it? del conn - yield from asyncio.sleep(0.01, loop=loop) + await asyncio.sleep(0.01) gc.collect() assert not conns_impl transp.close.assert_called_with() - msg = {'connector': mock.ANY, # conn was deleted - 'message': 'Unclosed connector'} + msg = {"connector": mock.ANY, "message": "Unclosed connector"} # conn was deleted if loop.get_debug(): - msg['source_traceback'] = mock.ANY + msg["source_traceback"] = mock.ANY exc_handler.assert_called_with(loop, msg) -def test_del_with_closed_loop(loop): - conn = aiohttp.BaseConnector(loop=loop) +@pytest.mark.skipif( + sys.implementation.name != "cpython", reason="CPython GC is required for the test" +) +def test_del_with_closed_loop(loop) -> None: + async def make_conn(): + return aiohttp.BaseConnector() + + conn = loop.run_until_complete(make_conn()) transp = mock.Mock() - conn._conns['a'] = [(transp, 'proto', 123)] + conn._conns["a"] = [(transp, 123)] conns_impl = conn._conns exc_handler = mock.Mock() @@ -109,7 +227,7 @@ def test_del_with_closed_loop(loop): assert exc_handler.called -def test_del_empty_conector(loop): +async def test_del_empty_connector(loop) -> None: conn = aiohttp.BaseConnector(loop=loop) exc_handler = mock.Mock() @@ -120,36 +238,37 @@ def test_del_empty_conector(loop): assert not exc_handler.called -@asyncio.coroutine -def test_create_conn(loop): +async def test_create_conn(loop) -> None: conn = aiohttp.BaseConnector(loop=loop) with pytest.raises(NotImplementedError): - yield from conn._create_connection(object()) + await conn._create_connection(object(), [], object()) -def test_context_manager(loop): +async def test_context_manager(loop) -> None: conn = aiohttp.BaseConnector(loop=loop) - conn.close = mock.Mock() - with conn as c: - assert conn is c + with pytest.warns(DeprecationWarning): + with conn as c: + assert conn is c - assert conn.close.called + assert conn.closed -def test_ctor_loop(): - with mock.patch('aiohttp.connector.asyncio') as m_asyncio: - session = aiohttp.BaseConnector() +async def test_async_context_manager(loop) -> None: + conn = aiohttp.BaseConnector(loop=loop) - assert session._loop is m_asyncio.get_event_loop.return_value + async with conn as c: + assert conn is c + + assert conn.closed -def test_close(loop): +async def test_close(loop) -> None: proto = mock.Mock() conn = aiohttp.BaseConnector(loop=loop) assert not conn.closed - conn._conns[('host', 8080, False)] = [(proto, object())] + conn._conns[("host", 8080, False)] = [(proto, object())] conn.close() assert not conn._conns @@ -157,7 +276,7 @@ def test_close(loop): assert conn.closed -def test_get(loop): +async def test_get(loop) -> None: conn = aiohttp.BaseConnector(loop=loop) assert conn._get(1) is None @@ -167,30 +286,65 @@ def test_get(loop): conn.close() -def test_get_expired(loop): +async def test_get_unconnected_proto(loop) -> None: + conn = aiohttp.BaseConnector() + key = ConnectionKey("localhost", 80, False, None, None, None, None) + assert conn._get(key) is None + + proto = create_mocked_conn(loop) + conn._conns[key] = [(proto, loop.time())] + assert conn._get(key) == proto + + assert conn._get(key) is None + conn._conns[key] = [(proto, loop.time())] + proto.is_connected = lambda *args: False + assert conn._get(key) is None + await conn.close() + + +async def test_get_unconnected_proto_ssl(loop) -> None: + conn = aiohttp.BaseConnector() + key = ConnectionKey("localhost", 80, True, None, None, None, None) + assert conn._get(key) is None + + proto = create_mocked_conn(loop) + conn._conns[key] = [(proto, loop.time())] + assert conn._get(key) == proto + + assert conn._get(key) is None + conn._conns[key] = [(proto, loop.time())] + proto.is_connected = lambda *args: False + assert conn._get(key) is None + await conn.close() + + +async def test_get_expired(loop) -> None: conn = aiohttp.BaseConnector(loop=loop) - assert conn._get(('localhost', 80, False)) is None + key = ConnectionKey("localhost", 80, False, None, None, None, None) + assert conn._get(key) is None proto = mock.Mock() - conn._conns[('localhost', 80, False)] = [(proto, loop.time() - 1000)] - assert conn._get(('localhost', 80, False)) is None + conn._conns[key] = [(proto, loop.time() - 1000)] + assert conn._get(key) is None assert not conn._conns conn.close() -def test_get_expired_ssl(loop): +async def test_get_expired_ssl(loop) -> None: conn = aiohttp.BaseConnector(loop=loop, enable_cleanup_closed=True) - assert conn._get(('localhost', 80, True)) is None + key = ConnectionKey("localhost", 80, True, None, None, None, None) + assert conn._get(key) is None proto = mock.Mock() - conn._conns[('localhost', 80, True)] = [(proto, loop.time() - 1000)] - assert conn._get(('localhost', 80, True)) is None + transport = proto.transport + conn._conns[key] = [(proto, loop.time() - 1000)] + assert conn._get(key) is None assert not conn._conns - assert conn._cleanup_closed_transports == [proto.close.return_value] + assert conn._cleanup_closed_transports == [transport] conn.close() -def test_release_acquired(loop, key): +async def test_release_acquired(loop, key) -> None: proto = mock.Mock() conn = aiohttp.BaseConnector(loop=loop, limit=5) conn._release_waiter = mock.Mock() @@ -209,7 +363,7 @@ def test_release_acquired(loop, key): conn.close() -def test_release_acquired_closed(loop, key): +async def test_release_acquired_closed(loop, key) -> None: proto = mock.Mock() conn = aiohttp.BaseConnector(loop=loop, limit=5) conn._release_waiter = mock.Mock() @@ -224,9 +378,7 @@ def test_release_acquired_closed(loop, key): conn.close() -def test_release(loop, key): - loop.time = mock.Mock(return_value=10) - +async def test_release(loop, key) -> None: conn = aiohttp.BaseConnector(loop=loop) conn._release_waiter = mock.Mock() @@ -237,27 +389,28 @@ def test_release(loop, key): conn._release(key, proto) assert conn._release_waiter.called - assert conn._conns[key][0] == (proto, 10) + assert conn._cleanup_handle is not None + assert conn._conns[key][0][0] == proto + assert conn._conns[key][0][1] == pytest.approx(loop.time(), abs=0.1) assert not conn._cleanup_closed_transports conn.close() -def test_release_ssl_transport(loop, ssl_key): - loop.time = mock.Mock(return_value=10) - +async def test_release_ssl_transport(loop, ssl_key) -> None: conn = aiohttp.BaseConnector(loop=loop, enable_cleanup_closed=True) conn._release_waiter = mock.Mock() proto = mock.Mock() + transport = proto.transport conn._acquired.add(proto) conn._acquired_per_host[ssl_key].add(proto) conn._release(ssl_key, proto, should_close=True) - assert conn._cleanup_closed_transports == [proto.close.return_value] + assert conn._cleanup_closed_transports == [transport] conn.close() -def test_release_already_closed(loop): +async def test_release_already_closed(loop) -> None: conn = aiohttp.BaseConnector(loop=loop) proto = mock.Mock() @@ -273,18 +426,19 @@ def test_release_already_closed(loop): assert not conn._release_acquired.called -def test_release_waiter(loop, key, key2): +async def test_release_waiter_no_limit(loop, key, key2) -> None: # limit is 0 conn = aiohttp.BaseConnector(limit=0, loop=loop) w = mock.Mock() w.done.return_value = False conn._waiters[key].append(w) conn._release_waiter() - assert len(conn._waiters) == 1 - assert not w.done.called + assert len(conn._waiters[key]) == 0 + assert w.done.called conn.close() - # release first available + +async def test_release_waiter_first_available(loop, key, key2) -> None: conn = aiohttp.BaseConnector(loop=loop) w1, w2 = mock.Mock(), mock.Mock() w1.done.return_value = False @@ -292,103 +446,518 @@ def test_release_waiter(loop, key, key2): conn._waiters[key].append(w2) conn._waiters[key2].append(w1) conn._release_waiter() - assert (w1.set_result.called and not w2.set_result.called or - not w1.set_result.called and w2.set_result.called) + assert ( + w1.set_result.called + and not w2.set_result.called + or not w1.set_result.called + and w2.set_result.called + ) conn.close() - # limited available + +async def test_release_waiter_release_first(loop, key, key2) -> None: conn = aiohttp.BaseConnector(loop=loop, limit=1) w1, w2 = mock.Mock(), mock.Mock() w1.done.return_value = False w2.done.return_value = False - conn._waiters[key] = [w1, w2] + conn._waiters[key] = deque([w1, w2]) conn._release_waiter() assert w1.set_result.called assert not w2.set_result.called conn.close() - # limited available + +async def test_release_waiter_skip_done_waiter(loop, key, key2) -> None: conn = aiohttp.BaseConnector(loop=loop, limit=1) w1, w2 = mock.Mock(), mock.Mock() w1.done.return_value = True w2.done.return_value = False - conn._waiters[key] = [w1, w2] + conn._waiters[key] = deque([w1, w2]) conn._release_waiter() assert not w1.set_result.called - assert not w2.set_result.called + assert w2.set_result.called conn.close() -def test_release_waiter_per_host(loop, key, key2): +async def test_release_waiter_per_host(loop, key, key2) -> None: # no limit conn = aiohttp.BaseConnector(loop=loop, limit=0, limit_per_host=2) w1, w2 = mock.Mock(), mock.Mock() w1.done.return_value = False w2.done.return_value = False - conn._waiters[key] = [w1] - conn._waiters[key2] = [w2] + conn._waiters[key] = deque([w1]) + conn._waiters[key2] = deque([w2]) conn._release_waiter() - assert ((w1.set_result.called and not w2.set_result.called) or - (not w1.set_result.called and w2.set_result.called)) + assert (w1.set_result.called and not w2.set_result.called) or ( + not w1.set_result.called and w2.set_result.called + ) conn.close() -def test_release_close(loop): +async def test_release_waiter_no_available(loop, key, key2) -> None: + # limit is 0 + conn = aiohttp.BaseConnector(limit=0, loop=loop) + w = mock.Mock() + w.done.return_value = False + conn._waiters[key].append(w) + conn._available_connections = mock.Mock(return_value=0) + conn._release_waiter() + assert len(conn._waiters) == 1 + assert not w.done.called + conn.close() + + +async def test_release_close(loop, key) -> None: conn = aiohttp.BaseConnector(loop=loop) proto = mock.Mock(should_close=True) - key = ('localhost', 80, False) conn._acquired.add(proto) conn._release(key, proto) assert not conn._conns assert proto.close.called -@asyncio.coroutine -def test_tcp_connector_resolve_host_use_dns_cache(loop): +async def test__drop_acquire_per_host1(loop) -> None: + conn = aiohttp.BaseConnector(loop=loop) + conn._drop_acquired_per_host(123, 456) + assert len(conn._acquired_per_host) == 0 + + +async def test__drop_acquire_per_host2(loop) -> None: + conn = aiohttp.BaseConnector(loop=loop) + conn._acquired_per_host[123].add(456) + conn._drop_acquired_per_host(123, 456) + assert len(conn._acquired_per_host) == 0 + + +async def test__drop_acquire_per_host3(loop) -> None: + conn = aiohttp.BaseConnector(loop=loop) + conn._acquired_per_host[123].add(456) + conn._acquired_per_host[123].add(789) + conn._drop_acquired_per_host(123, 456) + assert len(conn._acquired_per_host) == 1 + assert conn._acquired_per_host[123] == {789} + + +async def test_tcp_connector_certificate_error(loop) -> None: + req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) + + async def certificate_error(*args, **kwargs): + raise ssl.CertificateError + + conn = aiohttp.TCPConnector(loop=loop) + conn._loop.create_connection = certificate_error + + with pytest.raises(aiohttp.ClientConnectorCertificateError) as ctx: + await conn.connect(req, [], ClientTimeout()) + + assert isinstance(ctx.value, ssl.CertificateError) + assert isinstance(ctx.value.certificate_error, ssl.CertificateError) + assert isinstance(ctx.value, aiohttp.ClientSSLError) + + +async def test_tcp_connector_multiple_hosts_errors(loop) -> None: + conn = aiohttp.TCPConnector(loop=loop) + + ip1 = "192.168.1.1" + ip2 = "192.168.1.2" + ip3 = "192.168.1.3" + ip4 = "192.168.1.4" + ip5 = "192.168.1.5" + ips = [ip1, ip2, ip3, ip4, ip5] + ips_tried = [] + + fingerprint = hashlib.sha256(b"foo").digest() + + req = ClientRequest( + "GET", + URL("https://mocked.host"), + ssl=aiohttp.Fingerprint(fingerprint), + loop=loop, + ) + + async def _resolve_host(host, port, traces=None): + return [ + { + "hostname": host, + "host": ip, + "port": port, + "family": socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + for ip in ips + ] + + conn._resolve_host = _resolve_host + + os_error = certificate_error = ssl_error = fingerprint_error = False + connected = False + + async def create_connection(*args, **kwargs): + nonlocal os_error, certificate_error, ssl_error, fingerprint_error + nonlocal connected + + ip = args[1] + + ips_tried.append(ip) + + if ip == ip1: + os_error = True + raise OSError + + if ip == ip2: + certificate_error = True + raise ssl.CertificateError + + if ip == ip3: + ssl_error = True + raise ssl.SSLError + + if ip == ip4: + fingerprint_error = True + tr, pr = mock.Mock(), mock.Mock() + + def get_extra_info(param): + if param == "sslcontext": + return True + + if param == "ssl_object": + s = mock.Mock() + s.getpeercert.return_value = b"not foo" + return s + + if param == "peername": + return ("192.168.1.5", 12345) + + assert False, param + + tr.get_extra_info = get_extra_info + return tr, pr + + if ip == ip5: + connected = True + tr, pr = mock.Mock(), mock.Mock() + + def get_extra_info(param): + if param == "sslcontext": + return True + + if param == "ssl_object": + s = mock.Mock() + s.getpeercert.return_value = b"foo" + return s + + assert False + + tr.get_extra_info = get_extra_info + return tr, pr + + assert False + + conn._loop.create_connection = create_connection + + await conn.connect(req, [], ClientTimeout()) + assert ips == ips_tried + + assert os_error + assert certificate_error + assert ssl_error + assert fingerprint_error + assert connected + + +async def test_tcp_connector_resolve_host(loop) -> None: conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True) - res = yield from conn._resolve_host('localhost', 8080) + res = await conn._resolve_host("localhost", 8080) assert res for rec in res: - if rec['family'] == socket.AF_INET: - assert rec['host'] == '127.0.0.1' - assert rec['hostname'] == 'localhost' - assert rec['port'] == 8080 - elif rec['family'] == socket.AF_INET6: - assert rec['hostname'] == 'localhost' - assert rec['port'] == 8080 - if platform.system() == 'Darwin': - assert rec['host'] in ('::1', 'fe80::1', 'fe80::1%lo0') + if rec["family"] == socket.AF_INET: + assert rec["host"] == "127.0.0.1" + assert rec["hostname"] == "localhost" + assert rec["port"] == 8080 + elif rec["family"] == socket.AF_INET6: + assert rec["hostname"] == "localhost" + assert rec["port"] == 8080 + if platform.system() == "Darwin": + assert rec["host"] in ("::1", "fe80::1", "fe80::1%lo0") else: - assert rec['host'] == '::1' + assert rec["host"] == "::1" -@asyncio.coroutine -def test_tcp_connector_resolve_host_twice_use_dns_cache(loop): - conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True) +@pytest.fixture +def dns_response(loop): + async def coro(): + # simulates a network operation + await asyncio.sleep(0) + return ["127.0.0.1"] + + return coro + + +async def test_tcp_connector_dns_cache_not_expired(loop, dns_response) -> None: + with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10) + m_resolver().resolve.return_value = dns_response() + await conn._resolve_host("localhost", 8080) + await conn._resolve_host("localhost", 8080) + m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0) + + +async def test_tcp_connector_dns_cache_forever(loop, dns_response) -> None: + with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10) + m_resolver().resolve.return_value = dns_response() + await conn._resolve_host("localhost", 8080) + await conn._resolve_host("localhost", 8080) + m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0) + + +async def test_tcp_connector_use_dns_cache_disabled(loop, dns_response) -> None: + with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False) + m_resolver().resolve.side_effect = [dns_response(), dns_response()] + await conn._resolve_host("localhost", 8080) + await conn._resolve_host("localhost", 8080) + m_resolver().resolve.assert_has_calls( + [ + mock.call("localhost", 8080, family=0), + mock.call("localhost", 8080, family=0), + ] + ) + + +async def test_tcp_connector_dns_throttle_requests(loop, dns_response) -> None: + with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10) + m_resolver().resolve.return_value = dns_response() + loop.create_task(conn._resolve_host("localhost", 8080)) + loop.create_task(conn._resolve_host("localhost", 8080)) + await asyncio.sleep(0) + m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0) + + +async def test_tcp_connector_dns_throttle_requests_exception_spread(loop) -> None: + with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10) + e = Exception() + m_resolver().resolve.side_effect = e + r1 = loop.create_task(conn._resolve_host("localhost", 8080)) + r2 = loop.create_task(conn._resolve_host("localhost", 8080)) + await asyncio.sleep(0) + assert r1.exception() == e + assert r2.exception() == e + + +async def test_tcp_connector_dns_throttle_requests_cancelled_when_close( + loop, dns_response +): + + with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10) + m_resolver().resolve.return_value = dns_response() + loop.create_task(conn._resolve_host("localhost", 8080)) + f = loop.create_task(conn._resolve_host("localhost", 8080)) + + await asyncio.sleep(0) + conn.close() + + with pytest.raises(asyncio.CancelledError): + await f + + +@pytest.fixture +def dns_response_error(loop): + async def coro(): + # simulates a network operation + await asyncio.sleep(0) + raise socket.gaierror(-3, "Temporary failure in name resolution") + + return coro + + +async def test_tcp_connector_cancel_dns_error_captured( + loop, dns_response_error +) -> None: + + exception_handler_called = False + + def exception_handler(loop, context): + nonlocal exception_handler_called + exception_handler_called = True + + loop.set_exception_handler(mock.Mock(side_effect=exception_handler)) + + with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + req = ClientRequest( + method="GET", url=URL("http://temporary-failure:80"), loop=loop + ) + conn = aiohttp.TCPConnector( + use_dns_cache=False, + ) + m_resolver().resolve.return_value = dns_response_error() + f = loop.create_task(conn._create_direct_connection(req, [], ClientTimeout(0))) + + await asyncio.sleep(0) + f.cancel() + with pytest.raises(asyncio.CancelledError): + await f + + gc.collect() + assert exception_handler_called is False + + +async def test_tcp_connector_dns_tracing(loop, dns_response) -> None: + session = mock.Mock() + trace_config_ctx = mock.Mock() + on_dns_resolvehost_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_dns_resolvehost_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_dns_cache_hit = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_dns_cache_miss = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_dns_resolvehost_start.append(on_dns_resolvehost_start) + trace_config.on_dns_resolvehost_end.append(on_dns_resolvehost_end) + trace_config.on_dns_cache_hit.append(on_dns_cache_hit) + trace_config.on_dns_cache_miss.append(on_dns_cache_miss) + trace_config.freeze() + traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] + + with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10) + + m_resolver().resolve.return_value = dns_response() + + await conn._resolve_host("localhost", 8080, traces=traces) + on_dns_resolvehost_start.assert_called_once_with( + session, + trace_config_ctx, + aiohttp.TraceDnsResolveHostStartParams("localhost"), + ) + on_dns_resolvehost_end.assert_called_once_with( + session, trace_config_ctx, aiohttp.TraceDnsResolveHostEndParams("localhost") + ) + on_dns_cache_miss.assert_called_once_with( + session, trace_config_ctx, aiohttp.TraceDnsCacheMissParams("localhost") + ) + assert not on_dns_cache_hit.called + + await conn._resolve_host("localhost", 8080, traces=traces) + on_dns_cache_hit.assert_called_once_with( + session, trace_config_ctx, aiohttp.TraceDnsCacheHitParams("localhost") + ) + + +async def test_tcp_connector_dns_tracing_cache_disabled(loop, dns_response) -> None: + session = mock.Mock() + trace_config_ctx = mock.Mock() + on_dns_resolvehost_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_dns_resolvehost_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_dns_resolvehost_start.append(on_dns_resolvehost_start) + trace_config.on_dns_resolvehost_end.append(on_dns_resolvehost_end) + trace_config.freeze() + traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] + + with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False) + + m_resolver().resolve.side_effect = [dns_response(), dns_response()] + + await conn._resolve_host("localhost", 8080, traces=traces) + + await conn._resolve_host("localhost", 8080, traces=traces) + + on_dns_resolvehost_start.assert_has_calls( + [ + mock.call( + session, + trace_config_ctx, + aiohttp.TraceDnsResolveHostStartParams("localhost"), + ), + mock.call( + session, + trace_config_ctx, + aiohttp.TraceDnsResolveHostStartParams("localhost"), + ), + ] + ) + on_dns_resolvehost_end.assert_has_calls( + [ + mock.call( + session, + trace_config_ctx, + aiohttp.TraceDnsResolveHostEndParams("localhost"), + ), + mock.call( + session, + trace_config_ctx, + aiohttp.TraceDnsResolveHostEndParams("localhost"), + ), + ] + ) + + +async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response) -> None: + session = mock.Mock() + trace_config_ctx = mock.Mock() + on_dns_cache_hit = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_dns_cache_miss = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_dns_cache_hit.append(on_dns_cache_hit) + trace_config.on_dns_cache_miss.append(on_dns_cache_miss) + trace_config.freeze() + traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] + + with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10) + m_resolver().resolve.return_value = dns_response() + loop.create_task(conn._resolve_host("localhost", 8080, traces=traces)) + loop.create_task(conn._resolve_host("localhost", 8080, traces=traces)) + await asyncio.sleep(0) + on_dns_cache_hit.assert_called_once_with( + session, trace_config_ctx, aiohttp.TraceDnsCacheHitParams("localhost") + ) + on_dns_cache_miss.assert_called_once_with( + session, trace_config_ctx, aiohttp.TraceDnsCacheMissParams("localhost") + ) - res = yield from conn._resolve_host('localhost', 8080) - res2 = yield from conn._resolve_host('localhost', 8080) - assert res is res2 +async def test_dns_error(loop) -> None: + connector = aiohttp.TCPConnector(loop=loop) + connector._resolve_host = make_mocked_coro( + raise_exception=OSError("dont take it serious") + ) + req = ClientRequest("GET", URL("http://www.python.org"), loop=loop) -def test_get_pop_empty_conns(loop): + with pytest.raises(aiohttp.ClientConnectorError): + await connector.connect(req, [], ClientTimeout()) + + +async def test_get_pop_empty_conns(loop) -> None: # see issue #473 conn = aiohttp.BaseConnector(loop=loop) - key = ('127.0.0.1', 80, False) + key = ("127.0.0.1", 80, False) conn._conns[key] = [] proto = conn._get(key) assert proto is None assert not conn._conns -def test_release_close_do_not_add_to_pool(loop): +async def test_release_close_do_not_add_to_pool(loop, key) -> None: # see issue #473 conn = aiohttp.BaseConnector(loop=loop) - key = ('127.0.0.1', 80, False) proto = mock.Mock(should_close=True) conn._acquired.add(proto) @@ -396,11 +965,10 @@ def test_release_close_do_not_add_to_pool(loop): assert not conn._conns -def test_release_close_do_not_delete_existing_connections(loop): - key = ('127.0.0.1', 80, False) +async def test_release_close_do_not_delete_existing_connections(key) -> None: proto1 = mock.Mock() - conn = aiohttp.BaseConnector(loop=loop) + conn = aiohttp.BaseConnector() conn._conns[key] = [(proto1, 1)] proto = mock.Mock(should_close=True) @@ -411,43 +979,42 @@ def test_release_close_do_not_delete_existing_connections(loop): conn.close() -def test_release_not_started(loop): - loop.time = mock.Mock(return_value=10) +async def test_release_not_started(loop) -> None: conn = aiohttp.BaseConnector(loop=loop) proto = mock.Mock(should_close=False) key = 1 conn._acquired.add(proto) conn._release(key, proto) - assert conn._conns == {1: [(proto, 10)]} + # assert conn._conns == {1: [(proto, 10)]} + rec = conn._conns[1] + assert rec[0][0] == proto + assert rec[0][1] == pytest.approx(loop.time(), abs=0.05) assert not proto.close.called conn.close() -def test_release_not_opened(loop): +async def test_release_not_opened(loop, key) -> None: conn = aiohttp.BaseConnector(loop=loop) proto = mock.Mock() - key = ('localhost', 80, False) conn._acquired.add(proto) conn._release(key, proto) assert proto.close.called -@asyncio.coroutine -def test_connect(loop): +async def test_connect(loop, key) -> None: proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://host:80'), loop=loop) + req = ClientRequest("GET", URL("http://localhost:80"), loop=loop) conn = aiohttp.BaseConnector(loop=loop) - key = ('host', 80, False) conn._conns[key] = [(proto, loop.time())] conn._create_connection = mock.Mock() - conn._create_connection.return_value = helpers.create_future(loop) + conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) - connection = yield from conn.connect(req) + connection = await conn.connect(req, [], ClientTimeout()) assert not conn._create_connection.called assert connection._protocol is proto assert connection.transport is proto.transport @@ -455,36 +1022,76 @@ def test_connect(loop): connection.close() -@asyncio.coroutine -def test_connect_oserr(loop): +async def test_connect_tracing(loop) -> None: + session = mock.Mock() + trace_config_ctx = mock.Mock() + on_connection_create_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_connection_create_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_connection_create_start.append(on_connection_create_start) + trace_config.on_connection_create_end.append(on_connection_create_end) + trace_config.freeze() + traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] + + proto = mock.Mock() + proto.is_connected.return_value = True + + req = ClientRequest("GET", URL("http://host:80"), loop=loop) + conn = aiohttp.BaseConnector(loop=loop) conn._create_connection = mock.Mock() - conn._create_connection.return_value = helpers.create_future(loop) - err = OSError(1, 'permission error') - conn._create_connection.return_value.set_exception(err) + conn._create_connection.return_value = loop.create_future() + conn._create_connection.return_value.set_result(proto) + + conn2 = await conn.connect(req, traces, ClientTimeout()) + conn2.release() + + on_connection_create_start.assert_called_with( + session, trace_config_ctx, aiohttp.TraceConnectionCreateStartParams() + ) + on_connection_create_end.assert_called_with( + session, trace_config_ctx, aiohttp.TraceConnectionCreateEndParams() + ) - with pytest.raises(aiohttp.ClientOSError) as ctx: - req = mock.Mock() - yield from conn.connect(req) - assert 1 == ctx.value.errno - assert ctx.value.strerror.startswith('Cannot connect to') - assert ctx.value.strerror.endswith('[permission error]') + +async def test_close_during_connect(loop) -> None: + proto = mock.Mock() + proto.is_connected.return_value = True + + fut = loop.create_future() + req = ClientRequest("GET", URL("http://host:80"), loop=loop) + + conn = aiohttp.BaseConnector(loop=loop) + conn._create_connection = mock.Mock() + conn._create_connection.return_value = fut + + task = loop.create_task(conn.connect(req, None, ClientTimeout())) + await asyncio.sleep(0) + conn.close() + + fut.set_result(proto) + with pytest.raises(aiohttp.ClientConnectionError): + await task + + assert proto.close.called -def test_ctor_cleanup(): +async def test_ctor_cleanup() -> None: loop = mock.Mock() loop.time.return_value = 1.5 conn = aiohttp.BaseConnector( - loop=loop, keepalive_timeout=10, enable_cleanup_closed=True) + loop=loop, keepalive_timeout=10, enable_cleanup_closed=True + ) assert conn._cleanup_handle is None assert conn._cleanup_closed_handle is not None -def test_cleanup(): - key = ('localhost', 80, False) +async def test_cleanup(key) -> None: testset = { - key: [(mock.Mock(), 10), - (mock.Mock(), 300)], + key: [(mock.Mock(), 10), (mock.Mock(), 300)], } testset[key][0][0].is_connected.return_value = True testset[key][1][0].is_connected.return_value = False @@ -498,13 +1105,13 @@ def test_cleanup(): conn._cleanup() assert existing_handle.cancel.called assert conn._conns == {} - assert conn._cleanup_handle is not None + assert conn._cleanup_handle is None -def test_cleanup_close_ssl_transport(): +async def test_cleanup_close_ssl_transport(ssl_key) -> None: proto = mock.Mock() - key = ('localhost', 80, True) - testset = {key: [(proto, 10)]} + transport = proto.transport + testset = {ssl_key: [(proto, 10)]} loop = mock.Mock() loop.time.return_value = 300 @@ -515,10 +1122,10 @@ def test_cleanup_close_ssl_transport(): conn._cleanup() assert existing_handle.cancel.called assert conn._conns == {} - assert conn._cleanup_closed_transports == [proto.close.return_value] + assert conn._cleanup_closed_transports == [transport] -def test_cleanup2(): +async def test_cleanup2() -> None: testset = {1: [(mock.Mock(), 300)]} testset[1][0][0].is_connected.return_value = True @@ -535,10 +1142,8 @@ def test_cleanup2(): conn.close() -def test_cleanup3(): - key = ('localhost', 80, False) - testset = {key: [(mock.Mock(), 290.1), - (mock.Mock(), 305.1)]} +async def test_cleanup3(key) -> None: + testset = {key: [(mock.Mock(), 290.1), (mock.Mock(), 305.1)]} testset[key][0][0].is_connected.return_value = True loop = mock.Mock() @@ -555,8 +1160,11 @@ def test_cleanup3(): conn.close() -def test_cleanup_closed(loop, mocker): - mocker.spy(loop, 'call_at') +async def test_cleanup_closed(loop, mocker) -> None: + if not hasattr(loop, "__dict__"): + pytest.skip("can not override loop attributes") + + mocker.spy(loop, "call_at") conn = aiohttp.BaseConnector(loop=loop, enable_cleanup_closed=True) tr = mock.Mock() @@ -569,9 +1177,8 @@ def test_cleanup_closed(loop, mocker): assert cleanup_closed_handle.cancel.called -def test_cleanup_closed_disabled(loop, mocker): - conn = aiohttp.BaseConnector( - loop=loop, enable_cleanup_closed=False) +async def test_cleanup_closed_disabled(loop, mocker) -> None: + conn = aiohttp.BaseConnector(loop=loop, enable_cleanup_closed=False) tr = mock.Mock() conn._cleanup_closed_transports = [tr] @@ -580,68 +1187,125 @@ def test_cleanup_closed_disabled(loop, mocker): assert not conn._cleanup_closed_transports -def test_tcp_connector_ctor(loop): +async def test_tcp_connector_ctor(loop) -> None: conn = aiohttp.TCPConnector(loop=loop) - assert conn.verify_ssl - assert conn.fingerprint is None + assert conn._ssl is None assert conn.use_dns_cache assert conn.family == 0 - assert conn.cached_hosts == {} -def test_tcp_connector_ctor_fingerprint_valid(loop): - valid = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=' - conn = aiohttp.TCPConnector(loop=loop, fingerprint=valid) - assert conn.fingerprint == valid +async def test_tcp_connector_ctor_fingerprint_valid(loop) -> None: + valid = aiohttp.Fingerprint(hashlib.sha256(b"foo").digest()) + conn = aiohttp.TCPConnector(ssl=valid, loop=loop) + assert conn._ssl is valid -def test_tcp_connector_fingerprint_invalid(loop): - invalid = b'\x00' +async def test_insecure_fingerprint_md5(loop) -> None: with pytest.raises(ValueError): - aiohttp.TCPConnector(loop=loop, fingerprint=invalid) + aiohttp.TCPConnector( + ssl=aiohttp.Fingerprint(hashlib.md5(b"foo").digest()), loop=loop + ) -def test_tcp_connector_clear_dns_cache(loop): +async def test_insecure_fingerprint_sha1(loop) -> None: + with pytest.raises(ValueError): + aiohttp.TCPConnector( + ssl=aiohttp.Fingerprint(hashlib.sha1(b"foo").digest()), loop=loop + ) + + +async def test_tcp_connector_clear_dns_cache(loop) -> None: conn = aiohttp.TCPConnector(loop=loop) - info = object() - conn._cached_hosts[('localhost', 123)] = info - conn._cached_hosts[('localhost', 124)] = info - conn.clear_dns_cache('localhost', 123) - assert conn.cached_hosts == {('localhost', 124): info} - conn.clear_dns_cache('localhost', 123) - assert conn.cached_hosts == {('localhost', 124): info} + hosts = ["a", "b"] + conn._cached_hosts.add(("localhost", 123), hosts) + conn._cached_hosts.add(("localhost", 124), hosts) + conn.clear_dns_cache("localhost", 123) + with pytest.raises(KeyError): + conn._cached_hosts.next_addrs(("localhost", 123)) + + assert conn._cached_hosts.next_addrs(("localhost", 124)) == hosts + + # Remove removed element is OK + conn.clear_dns_cache("localhost", 123) + with pytest.raises(KeyError): + conn._cached_hosts.next_addrs(("localhost", 123)) + conn.clear_dns_cache() - assert conn.cached_hosts == {} + with pytest.raises(KeyError): + conn._cached_hosts.next_addrs(("localhost", 124)) -def test_tcp_connector_clear_dns_cache_bad_args(loop): +async def test_tcp_connector_clear_dns_cache_bad_args(loop) -> None: conn = aiohttp.TCPConnector(loop=loop) with pytest.raises(ValueError): - conn.clear_dns_cache('localhost') + conn.clear_dns_cache("localhost") -def test_ambigous_verify_ssl_and_ssl_context(loop): - with pytest.raises(ValueError): - aiohttp.TCPConnector( - verify_ssl=False, - ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23), - loop=loop) +async def test_dont_recreate_ssl_context(loop) -> None: + conn = aiohttp.TCPConnector(loop=loop) + ctx = conn._make_ssl_context(True) + assert ctx is conn._make_ssl_context(True) + + +async def test_dont_recreate_ssl_context2(loop) -> None: + conn = aiohttp.TCPConnector(loop=loop) + ctx = conn._make_ssl_context(False) + assert ctx is conn._make_ssl_context(False) -def test_dont_recreate_ssl_context(loop): +async def test___get_ssl_context1(loop) -> None: conn = aiohttp.TCPConnector(loop=loop) - ctx = conn.ssl_context - assert ctx is conn.ssl_context + req = mock.Mock() + req.is_ssl.return_value = False + assert conn._get_ssl_context(req) is None + + +async def test___get_ssl_context2(loop) -> None: + ctx = ssl.SSLContext() + conn = aiohttp.TCPConnector(loop=loop) + req = mock.Mock() + req.is_ssl.return_value = True + req.ssl = ctx + assert conn._get_ssl_context(req) is ctx + + +async def test___get_ssl_context3(loop) -> None: + ctx = ssl.SSLContext() + conn = aiohttp.TCPConnector(loop=loop, ssl=ctx) + req = mock.Mock() + req.is_ssl.return_value = True + req.ssl = None + assert conn._get_ssl_context(req) is ctx + +async def test___get_ssl_context4(loop) -> None: + ctx = ssl.SSLContext() + conn = aiohttp.TCPConnector(loop=loop, ssl=ctx) + req = mock.Mock() + req.is_ssl.return_value = True + req.ssl = False + assert conn._get_ssl_context(req) is conn._make_ssl_context(False) -def test_respect_precreated_ssl_context(loop): - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - conn = aiohttp.TCPConnector(loop=loop, ssl_context=ctx) - assert ctx is conn.ssl_context +async def test___get_ssl_context5(loop) -> None: + ctx = ssl.SSLContext() + conn = aiohttp.TCPConnector(loop=loop, ssl=ctx) + req = mock.Mock() + req.is_ssl.return_value = True + req.ssl = aiohttp.Fingerprint(hashlib.sha256(b"1").digest()) + assert conn._get_ssl_context(req) is conn._make_ssl_context(False) -def test_close_twice(loop): + +async def test___get_ssl_context6(loop) -> None: + conn = aiohttp.TCPConnector(loop=loop) + req = mock.Mock() + req.is_ssl.return_value = True + req.ssl = None + assert conn._get_ssl_context(req) is conn._make_ssl_context(True) + + +async def test_close_twice(loop) -> None: proto = mock.Mock() conn = aiohttp.BaseConnector(loop=loop) @@ -652,12 +1316,12 @@ def test_close_twice(loop): assert proto.close.called assert conn.closed - conn._conns = 'Invalid' # fill with garbage + conn._conns = "Invalid" # fill with garbage conn.close() assert conn.closed -def test_close_cancels_cleanup_handle(loop): +async def test_close_cancels_cleanup_handle(loop) -> None: conn = aiohttp.BaseConnector(loop=loop) conn._release(1, mock.Mock(should_close=False)) assert conn._cleanup_handle is not None @@ -665,7 +1329,7 @@ def test_close_cancels_cleanup_handle(loop): assert conn._cleanup_handle is None -def test_close_abort_closed_transports(loop): +async def test_close_abort_closed_transports(loop) -> None: tr = mock.Mock() conn = aiohttp.BaseConnector(loop=loop) @@ -677,37 +1341,33 @@ def test_close_abort_closed_transports(loop): assert conn.closed -def test_close_cancels_cleanup_closed_handle(loop): +async def test_close_cancels_cleanup_closed_handle(loop) -> None: conn = aiohttp.BaseConnector(loop=loop, enable_cleanup_closed=True) assert conn._cleanup_closed_handle is not None conn.close() assert conn._cleanup_closed_handle is None -def test_ctor_with_default_loop(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) +async def test_ctor_with_default_loop(loop) -> None: conn = aiohttp.BaseConnector() assert loop is conn._loop - loop.close() -@asyncio.coroutine -def test_connect_with_limit(loop, key): +async def test_connect_with_limit(loop, key) -> None: proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://localhost1:80'), - loop=loop, - response_class=mock.Mock()) + req = ClientRequest( + "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() + ) conn = aiohttp.BaseConnector(loop=loop, limit=1) conn._conns[key] = [(proto, loop.time())] conn._create_connection = mock.Mock() - conn._create_connection.return_value = helpers.create_future(loop) + conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) - connection1 = yield from conn.connect(req) + connection1 = await conn.connect(req, None, ClientTimeout()) assert connection1._protocol == proto assert 1 == len(conn._acquired) @@ -717,146 +1377,214 @@ def test_connect_with_limit(loop, key): acquired = False - @asyncio.coroutine - def f(): + async def f(): nonlocal acquired - connection2 = yield from conn.connect(req) + connection2 = await conn.connect(req, None, ClientTimeout()) acquired = True assert 1 == len(conn._acquired) assert 1 == len(conn._acquired_per_host[key]) connection2.release() - task = helpers.ensure_future(f(), loop=loop) + task = loop.create_task(f()) - yield from asyncio.sleep(0.01, loop=loop) + await asyncio.sleep(0.01) assert not acquired connection1.release() - yield from asyncio.sleep(0, loop=loop) + await asyncio.sleep(0) assert acquired - yield from task + await task + conn.close() + + +async def test_connect_queued_operation_tracing(loop, key) -> None: + session = mock.Mock() + trace_config_ctx = mock.Mock() + on_connection_queued_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_connection_queued_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_connection_queued_start.append(on_connection_queued_start) + trace_config.on_connection_queued_end.append(on_connection_queued_end) + trace_config.freeze() + traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] + + proto = mock.Mock() + proto.is_connected.return_value = True + + req = ClientRequest( + "GET", URL("http://localhost1:80"), loop=loop, response_class=mock.Mock() + ) + + conn = aiohttp.BaseConnector(loop=loop, limit=1) + conn._conns[key] = [(proto, loop.time())] + conn._create_connection = mock.Mock() + conn._create_connection.return_value = loop.create_future() + conn._create_connection.return_value.set_result(proto) + + connection1 = await conn.connect(req, traces, ClientTimeout()) + + async def f(): + connection2 = await conn.connect(req, traces, ClientTimeout()) + on_connection_queued_start.assert_called_with( + session, trace_config_ctx, aiohttp.TraceConnectionQueuedStartParams() + ) + on_connection_queued_end.assert_called_with( + session, trace_config_ctx, aiohttp.TraceConnectionQueuedEndParams() + ) + connection2.release() + + task = asyncio.ensure_future(f(), loop=loop) + await asyncio.sleep(0.01) + connection1.release() + await task conn.close() -@asyncio.coroutine -def test_connect_with_limit_and_limit_per_host(loop, key): +async def test_connect_reuseconn_tracing(loop, key) -> None: + session = mock.Mock() + trace_config_ctx = mock.Mock() + on_connection_reuseconn = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_connection_reuseconn.append(on_connection_reuseconn) + trace_config.freeze() + traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] + proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://localhost1:80'), loop=loop) + req = ClientRequest( + "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() + ) + + conn = aiohttp.BaseConnector(loop=loop, limit=1) + conn._conns[key] = [(proto, loop.time())] + conn2 = await conn.connect(req, traces, ClientTimeout()) + conn2.release() + + on_connection_reuseconn.assert_called_with( + session, trace_config_ctx, aiohttp.TraceConnectionReuseconnParams() + ) + conn.close() + + +async def test_connect_with_limit_and_limit_per_host(loop, key) -> None: + proto = mock.Mock() + proto.is_connected.return_value = True + + req = ClientRequest("GET", URL("http://localhost:80"), loop=loop) conn = aiohttp.BaseConnector(loop=loop, limit=1000, limit_per_host=1) conn._conns[key] = [(proto, loop.time())] conn._create_connection = mock.Mock() - conn._create_connection.return_value = helpers.create_future(loop) + conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) acquired = False - connection1 = yield from conn.connect(req) + connection1 = await conn.connect(req, None, ClientTimeout()) - @asyncio.coroutine - def f(): + async def f(): nonlocal acquired - connection2 = yield from conn.connect(req) + connection2 = await conn.connect(req, None, ClientTimeout()) acquired = True assert 1 == len(conn._acquired) assert 1 == len(conn._acquired_per_host[key]) connection2.release() - task = helpers.ensure_future(f(), loop=loop) + task = loop.create_task(f()) - yield from asyncio.sleep(0.01, loop=loop) + await asyncio.sleep(0.01) assert not acquired connection1.release() - yield from asyncio.sleep(0, loop=loop) + await asyncio.sleep(0) assert acquired - yield from task + await task conn.close() -@asyncio.coroutine -def test_connect_with_no_limit_and_limit_per_host(loop, key): +async def test_connect_with_no_limit_and_limit_per_host(loop, key) -> None: proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://localhost1:80'), loop=loop) + req = ClientRequest("GET", URL("http://localhost1:80"), loop=loop) conn = aiohttp.BaseConnector(loop=loop, limit=0, limit_per_host=1) conn._conns[key] = [(proto, loop.time())] conn._create_connection = mock.Mock() - conn._create_connection.return_value = helpers.create_future(loop) + conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) acquired = False - connection1 = yield from conn.connect(req) + connection1 = await conn.connect(req, None, ClientTimeout()) - @asyncio.coroutine - def f(): + async def f(): nonlocal acquired - connection2 = yield from conn.connect(req) + connection2 = await conn.connect(req, None, ClientTimeout()) acquired = True connection2.release() - task = helpers.ensure_future(f(), loop=loop) + task = loop.create_task(f()) - yield from asyncio.sleep(0.01, loop=loop) + await asyncio.sleep(0.01) assert not acquired connection1.release() - yield from asyncio.sleep(0, loop=loop) + await asyncio.sleep(0) assert acquired - yield from task + await task conn.close() -@asyncio.coroutine -def test_connect_with_no_limits(loop, key): +async def test_connect_with_no_limits(loop, key) -> None: proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://localhost1:80'), loop=loop) + req = ClientRequest("GET", URL("http://localhost:80"), loop=loop) conn = aiohttp.BaseConnector(loop=loop, limit=0, limit_per_host=0) conn._conns[key] = [(proto, loop.time())] conn._create_connection = mock.Mock() - conn._create_connection.return_value = helpers.create_future(loop) + conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) acquired = False - connection1 = yield from conn.connect(req) + connection1 = await conn.connect(req, None, ClientTimeout()) - @asyncio.coroutine - def f(): + async def f(): nonlocal acquired - connection2 = yield from conn.connect(req) + connection2 = await conn.connect(req, None, ClientTimeout()) acquired = True assert 1 == len(conn._acquired) assert 1 == len(conn._acquired_per_host[key]) connection2.release() - task = helpers.ensure_future(f(), loop=loop) + task = loop.create_task(f()) - yield from asyncio.sleep(0.01, loop=loop) + await asyncio.sleep(0.01) assert acquired connection1.release() - yield from task + await task conn.close() -@asyncio.coroutine -def test_connect_with_limit_cancelled(loop): +async def test_connect_with_limit_cancelled(loop) -> None: proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://host:80'), loop=loop) + req = ClientRequest("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector(loop=loop, limit=1) - key = ('host', 80, False) + key = ("host", 80, False) conn._conns[key] = [(proto, loop.time())] conn._create_connection = mock.Mock() - conn._create_connection.return_value = helpers.create_future(loop) + conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) - connection = yield from conn.connect(req) + connection = await conn.connect(req, None, ClientTimeout()) assert connection._protocol == proto assert connection.transport == proto.transport @@ -864,39 +1592,34 @@ def test_connect_with_limit_cancelled(loop): with pytest.raises(asyncio.TimeoutError): # limit exhausted - yield from asyncio.wait_for(conn.connect(req), 0.01, - loop=loop) + await asyncio.wait_for(conn.connect(req, None, ClientTimeout()), 0.01) connection.close() -@asyncio.coroutine -def test_connect_with_capacity_release_waiters(loop): - +async def test_connect_with_capacity_release_waiters(loop) -> None: def check_with_exc(err): conn = aiohttp.BaseConnector(limit=1, loop=loop) conn._create_connection = mock.Mock() - conn._create_connection.return_value = \ - helpers.create_future(loop) + conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_exception(err) with pytest.raises(Exception): req = mock.Mock() - yield from conn.connect(req) + yield from conn.connect(req, None, ClientTimeout()) assert not conn._waiters - check_with_exc(OSError(1, 'permission error')) + check_with_exc(OSError(1, "permission error")) check_with_exc(RuntimeError()) check_with_exc(asyncio.TimeoutError()) -@asyncio.coroutine -def test_connect_with_limit_concurrent(loop): +async def test_connect_with_limit_concurrent(loop) -> None: proto = mock.Mock() proto.should_close = False proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://host:80'), loop=loop) + req = ClientRequest("GET", URL("http://host:80"), loop=loop) max_connections = 2 num_connections = 0 @@ -906,11 +1629,10 @@ def test_connect_with_limit_concurrent(loop): # Use a real coroutine for _create_connection; a mock would mask # problems that only happen when the method yields. - @asyncio.coroutine - def create_connection(req): + async def create_connection(req, traces, timeout): nonlocal num_connections num_connections += 1 - yield from asyncio.sleep(0, loop=loop) + await asyncio.sleep(0) # Make a new transport mock each time because acquired # transports are stored in a set. Reusing the same object @@ -928,47 +1650,85 @@ def create_connection(req): # with multiple concurrent requests and stops when it hits a # predefined maximum number of requests. - max_requests = 10 + max_requests = 50 num_requests = 0 start_requests = max_connections + 1 - @asyncio.coroutine - def f(start=True): + async def f(start=True): nonlocal num_requests if num_requests == max_requests: return num_requests += 1 if not start: - connection = yield from conn.connect(req) - yield from asyncio.sleep(0, loop=loop) + connection = await conn.connect(req, None, ClientTimeout()) + await asyncio.sleep(0) connection.release() - tasks = [ - helpers.ensure_future(f(start=False), loop=loop) - for i in range(start_requests) - ] - yield from asyncio.wait(tasks, loop=loop) + await asyncio.sleep(0) + tasks = [loop.create_task(f(start=False)) for i in range(start_requests)] + await asyncio.wait(tasks) - yield from f() + await f() conn.close() assert max_connections == num_connections -@asyncio.coroutine -def test_close_with_acquired_connection(loop): +async def test_connect_waiters_cleanup(loop) -> None: proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://host:80'), loop=loop) + req = ClientRequest("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector(loop=loop, limit=1) - key = ('host', 80, False) + conn._available_connections = mock.Mock(return_value=0) + + t = loop.create_task(conn.connect(req, None, ClientTimeout())) + + await asyncio.sleep(0) + assert conn._waiters.keys() + + t.cancel() + await asyncio.sleep(0) + assert not conn._waiters.keys() + + +async def test_connect_waiters_cleanup_key_error(loop) -> None: + proto = mock.Mock() + proto.is_connected.return_value = True + + req = ClientRequest("GET", URL("http://host:80"), loop=loop) + + conn = aiohttp.BaseConnector(loop=loop, limit=1) + conn._available_connections = mock.Mock(return_value=0) + + t = loop.create_task(conn.connect(req, None, ClientTimeout())) + + await asyncio.sleep(0) + assert conn._waiters.keys() + + # we delete the entry explicitly before the + # canceled connection grabs the loop again, we + # must expect a none failure termination + conn._waiters.clear() + t.cancel() + await asyncio.sleep(0) + assert not conn._waiters.keys() == [] + + +async def test_close_with_acquired_connection(loop) -> None: + proto = mock.Mock() + proto.is_connected.return_value = True + + req = ClientRequest("GET", URL("http://host:80"), loop=loop) + + conn = aiohttp.BaseConnector(loop=loop, limit=1) + key = ("host", 80, False) conn._conns[key] = [(proto, loop.time())] conn._create_connection = mock.Mock() - conn._create_connection.return_value = helpers.create_future(loop) + conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) - connection = yield from conn.connect(req) + connection = await conn.connect(req, None, ClientTimeout()) assert 1 == len(conn._acquired) conn.close() @@ -981,44 +1741,42 @@ def test_close_with_acquired_connection(loop): assert connection.closed -def test_default_force_close(loop): +async def test_default_force_close(loop) -> None: connector = aiohttp.BaseConnector(loop=loop) assert not connector.force_close -def test_limit_property(loop): +async def test_limit_property(loop) -> None: conn = aiohttp.BaseConnector(loop=loop, limit=15) assert 15 == conn.limit conn.close() -def test_limit_by_host_property(loop): +async def test_limit_per_host_property(loop) -> None: conn = aiohttp.BaseConnector(loop=loop, limit_per_host=15) assert 15 == conn.limit_per_host conn.close() -def test_limit_property_default(loop): +async def test_limit_property_default(loop) -> None: conn = aiohttp.BaseConnector(loop=loop) assert conn.limit == 100 conn.close() -def test_limit_per_host_property_default(loop): +async def test_limit_per_host_property_default(loop) -> None: conn = aiohttp.BaseConnector(loop=loop) assert conn.limit_per_host == 0 conn.close() -def test_force_close_and_explicit_keep_alive(loop): +async def test_force_close_and_explicit_keep_alive(loop) -> None: with pytest.raises(ValueError): - aiohttp.BaseConnector(loop=loop, keepalive_timeout=30, - force_close=True) + aiohttp.BaseConnector(loop=loop, keepalive_timeout=30, force_close=True) - conn = aiohttp.BaseConnector(loop=loop, force_close=True, - keepalive_timeout=None) + conn = aiohttp.BaseConnector(loop=loop, force_close=True, keepalive_timeout=None) assert conn conn = aiohttp.BaseConnector(loop=loop, force_close=True) @@ -1026,126 +1784,479 @@ def test_force_close_and_explicit_keep_alive(loop): assert conn -@asyncio.coroutine -def test_tcp_connector(test_client, loop): - @asyncio.coroutine - def handler(request): - return web.HTTPOk() +async def test_error_on_connection(loop, key) -> None: + conn = aiohttp.BaseConnector(limit=1, loop=loop) + + req = mock.Mock() + req.connection_key = key + proto = mock.Mock() + i = 0 + + fut = loop.create_future() + exc = OSError() + + async def create_connection(req, traces, timeout): + nonlocal i + i += 1 + if i == 1: + await fut + raise exc + elif i == 2: + return proto + + conn._create_connection = create_connection + + t1 = loop.create_task(conn.connect(req, None, ClientTimeout())) + t2 = loop.create_task(conn.connect(req, None, ClientTimeout())) + await asyncio.sleep(0) + assert not t1.done() + assert not t2.done() + assert len(conn._acquired_per_host[key]) == 1 + + fut.set_result(None) + with pytest.raises(OSError): + await t1 + + ret = await t2 + assert len(conn._acquired_per_host[key]) == 1 + + assert ret._key == key + assert ret.protocol == proto + assert proto in conn._acquired + ret.release() + + +async def test_cancelled_waiter(loop) -> None: + conn = aiohttp.BaseConnector(limit=1, loop=loop) + req = mock.Mock() + req.connection_key = "key" + proto = mock.Mock() + + async def create_connection(req, traces=None): + await asyncio.sleep(1) + return proto + + conn._create_connection = create_connection + + conn._acquired.add(proto) + + conn2 = loop.create_task(conn.connect(req, None, ClientTimeout())) + await asyncio.sleep(0) + conn2.cancel() + + with pytest.raises(asyncio.CancelledError): + await conn2 + + +async def test_error_on_connection_with_cancelled_waiter(loop, key) -> None: + conn = aiohttp.BaseConnector(limit=1, loop=loop) + + req = mock.Mock() + req.connection_key = key + proto = mock.Mock() + i = 0 + + fut1 = loop.create_future() + fut2 = loop.create_future() + exc = OSError() + + async def create_connection(req, traces, timeout): + nonlocal i + i += 1 + if i == 1: + await fut1 + raise exc + if i == 2: + await fut2 + elif i == 3: + return proto + + conn._create_connection = create_connection + + t1 = loop.create_task(conn.connect(req, None, ClientTimeout())) + t2 = loop.create_task(conn.connect(req, None, ClientTimeout())) + t3 = loop.create_task(conn.connect(req, None, ClientTimeout())) + await asyncio.sleep(0) + assert not t1.done() + assert not t2.done() + assert len(conn._acquired_per_host[key]) == 1 + + fut1.set_result(None) + fut2.cancel() + with pytest.raises(OSError): + await t1 + + with pytest.raises(asyncio.CancelledError): + await t2 + + ret = await t3 + assert len(conn._acquired_per_host[key]) == 1 + + assert ret._key == key + assert ret.protocol == proto + assert proto in conn._acquired + ret.release() + + +async def test_tcp_connector(aiohttp_client, loop) -> None: + async def handler(request): + return web.Response() app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - r = yield from client.get('/') + r = await client.get("/") assert r.status == 200 -def test_default_use_dns_cache(loop): - conn = aiohttp.TCPConnector(loop=loop) +@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires unix socket") +async def test_unix_connector_not_found(loop) -> None: + connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex, loop=loop) + + req = ClientRequest("GET", URL("http://www.python.org"), loop=loop) + with pytest.raises(aiohttp.ClientConnectorError): + await connector.connect(req, None, ClientTimeout()) + + +@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires unix socket") +async def test_unix_connector_permission(loop) -> None: + loop.create_unix_connection = make_mocked_coro(raise_exception=PermissionError()) + connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex, loop=loop) + + req = ClientRequest("GET", URL("http://www.python.org"), loop=loop) + with pytest.raises(aiohttp.ClientConnectorError): + await connector.connect(req, None, ClientTimeout()) + + +@pytest.mark.skipif( + platform.system() != "Windows", reason="Proactor Event loop present only in Windows" +) +async def test_named_pipe_connector_wrong_loop(selector_loop, pipe_name) -> None: + with pytest.raises(RuntimeError): + aiohttp.NamedPipeConnector(pipe_name, loop=asyncio.get_event_loop()) + + +@pytest.mark.skipif( + platform.system() != "Windows", reason="Proactor Event loop present only in Windows" +) +async def test_named_pipe_connector_not_found(proactor_loop, pipe_name) -> None: + connector = aiohttp.NamedPipeConnector(pipe_name, loop=proactor_loop) + + req = ClientRequest("GET", URL("http://www.python.org"), loop=proactor_loop) + with pytest.raises(aiohttp.ClientConnectorError): + await connector.connect(req, None, ClientTimeout()) + + +@pytest.mark.skipif( + platform.system() != "Windows", reason="Proactor Event loop present only in Windows" +) +async def test_named_pipe_connector_permission(proactor_loop, pipe_name) -> None: + proactor_loop.create_pipe_connection = make_mocked_coro( + raise_exception=PermissionError() + ) + connector = aiohttp.NamedPipeConnector(pipe_name, loop=proactor_loop) + + req = ClientRequest("GET", URL("http://www.python.org"), loop=proactor_loop) + with pytest.raises(aiohttp.ClientConnectorError): + await connector.connect(req, None, ClientTimeout()) + + +async def test_default_use_dns_cache() -> None: + conn = aiohttp.TCPConnector() assert conn.use_dns_cache -class TestHttpClientConnector(unittest.TestCase): +async def test_resolver_not_called_with_address_is_ip(loop) -> None: + resolver = mock.MagicMock() + connector = aiohttp.TCPConnector(resolver=resolver) - def setUp(self): - self.handler = None - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) + req = ClientRequest( + "GET", + URL(f"http://127.0.0.1:{unused_port()}"), + loop=loop, + response_class=mock.Mock(), + ) - def tearDown(self): - if self.handler: - self.loop.run_until_complete(self.handler.finish_connections()) - self.loop.stop() - self.loop.run_forever() - self.loop.close() - gc.collect() + with pytest.raises(OSError): + await connector.connect(req, None, ClientTimeout()) - @asyncio.coroutine - def create_server(self, method, path, handler): - app = web.Application() - app.router.add_route(method, path, handler) - - port = unused_port() - self.handler = app.make_handler(loop=self.loop, tcp_keepalive=False) - srv = yield from self.loop.create_server( - self.handler, '127.0.0.1', port) - url = "http://127.0.0.1:{}".format(port) + path - self.addCleanup(srv.close) - return app, srv, url - - @asyncio.coroutine - def create_unix_server(self, method, path, handler): - tmpdir = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, tmpdir) - app = web.Application() - app.router.add_route(method, path, handler) - - self.handler = app.make_handler( - loop=self.loop, tcp_keepalive=False, access_log=None) - sock_path = os.path.join(tmpdir, 'socket.sock') - srv = yield from self.loop.create_unix_server( - self.handler, sock_path) - url = "http://127.0.0.1" + path - self.addCleanup(srv.close) - return app, srv, url, sock_path - - def test_tcp_connector_uses_provided_local_addr(self): - @asyncio.coroutine - def handler(request): - return web.HTTPOk() - - app, srv, url = self.loop.run_until_complete( - self.create_server('get', '/', handler) - ) + resolver.resolve.assert_not_called() - port = unused_port() - conn = aiohttp.TCPConnector(loop=self.loop, - local_addr=('127.0.0.1', port)) - session = aiohttp.ClientSession(connector=conn) +async def test_tcp_connector_raise_connector_ssl_error( + aiohttp_server, + ssl_ctx, +) -> None: + async def handler(request): + return web.Response() - r = self.loop.run_until_complete( - session.request('get', url) - ) + app = web.Application() + app.router.add_get("/", handler) - r.release() - first_conn = next(iter(conn._conns.values()))[0][0] - self.assertEqual( - first_conn.transport._sock.getsockname(), ('127.0.0.1', port)) - r.close() - session.close() - conn.close() + srv = await aiohttp_server(app, ssl=ssl_ctx) + + port = unused_port() + conn = aiohttp.TCPConnector(local_addr=("127.0.0.1", port)) + + session = aiohttp.ClientSession(connector=conn) + url = srv.make_url("/") + + if PY_37: + err = aiohttp.ClientConnectorCertificateError + else: + err = aiohttp.ClientConnectorSSLError + with pytest.raises(err) as ctx: + await session.get(url) + + if PY_37: + assert isinstance(ctx.value, aiohttp.ClientConnectorCertificateError) + assert isinstance(ctx.value.certificate_error, ssl.SSLError) + else: + assert isinstance(ctx.value, aiohttp.ClientSSLError) + assert isinstance(ctx.value.os_error, ssl.SSLError) + + await session.close() + + +async def test_tcp_connector_do_not_raise_connector_ssl_error( + aiohttp_server, + ssl_ctx, + client_ssl_ctx, +) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + + srv = await aiohttp_server(app, ssl=ssl_ctx) + port = unused_port() + conn = aiohttp.TCPConnector(local_addr=("127.0.0.1", port)) + + session = aiohttp.ClientSession(connector=conn) + url = srv.make_url("/") + + r = await session.get(url, ssl=client_ssl_ctx) - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'requires unix') - def test_unix_connector(self): - @asyncio.coroutine - def handler(request): - return web.HTTPOk() + r.release() + first_conn = next(iter(conn._conns.values()))[0][0] - app, srv, url, sock_path = self.loop.run_until_complete( - self.create_unix_server('get', '/', handler)) + try: + _sslcontext = first_conn.transport._ssl_protocol._sslcontext + except AttributeError: + _sslcontext = first_conn.transport._sslcontext - connector = aiohttp.UnixConnector(sock_path, loop=self.loop) - self.assertEqual(sock_path, connector.path) + assert _sslcontext is client_ssl_ctx + r.close() - session = client.ClientSession( - connector=connector, loop=self.loop) - r = self.loop.run_until_complete( - session.request('get', url)) - self.assertEqual(r.status, 200) - r.close() - session.close() + await session.close() + conn.close() + + +async def test_tcp_connector_uses_provided_local_addr(aiohttp_server) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + srv = await aiohttp_server(app) + + port = unused_port() + conn = aiohttp.TCPConnector(local_addr=("127.0.0.1", port)) + + session = aiohttp.ClientSession(connector=conn) + url = srv.make_url("/") + + r = await session.get(url) + r.release() + + first_conn = next(iter(conn._conns.values()))[0][0] + assert first_conn.transport.get_extra_info("sockname") == ("127.0.0.1", port) + r.close() + await session.close() + conn.close() + + +@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets") +async def test_unix_connector(unix_server, unix_sockname) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + await unix_server(app) + + url = "http://127.0.0.1/" + + connector = aiohttp.UnixConnector(unix_sockname) + assert unix_sockname == connector.path + + session = client.ClientSession(connector=connector) + r = await session.get(url) + assert r.status == 200 + r.close() + await session.close() + + +@pytest.mark.skipif( + platform.system() != "Windows", reason="Proactor Event loop present only in Windows" +) +async def test_named_pipe_connector( + proactor_loop, named_pipe_server, pipe_name +) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + await named_pipe_server(app) + + url = "http://this-does-not-matter.com" + + connector = aiohttp.NamedPipeConnector(pipe_name) + assert pipe_name == connector.path + + session = client.ClientSession(connector=connector) + r = await session.get(url) + assert r.status == 200 + r.close() + await session.close() - def test_resolver_not_called_with_address_is_ip(self): - resolver = mock.MagicMock() - connector = aiohttp.TCPConnector(resolver=resolver, loop=self.loop) - req = ClientRequest('GET', - URL('http://127.0.0.1:{}'.format(unused_port())), - loop=self.loop, - response_class=mock.Mock()) +class TestDNSCacheTable: + @pytest.fixture + def dns_cache_table(self): + return _DNSCacheTable() - with self.assertRaises(OSError): - self.loop.run_until_complete(connector.connect(req)) + def test_next_addrs_basic(self, dns_cache_table) -> None: + dns_cache_table.add("localhost", ["127.0.0.1"]) + dns_cache_table.add("foo", ["127.0.0.2"]) + + addrs = dns_cache_table.next_addrs("localhost") + assert addrs == ["127.0.0.1"] + addrs = dns_cache_table.next_addrs("foo") + assert addrs == ["127.0.0.2"] + with pytest.raises(KeyError): + dns_cache_table.next_addrs("no-such-host") + + def test_remove(self, dns_cache_table) -> None: + dns_cache_table.add("localhost", ["127.0.0.1"]) + dns_cache_table.remove("localhost") + with pytest.raises(KeyError): + dns_cache_table.next_addrs("localhost") + + def test_clear(self, dns_cache_table) -> None: + dns_cache_table.add("localhost", ["127.0.0.1"]) + dns_cache_table.clear() + with pytest.raises(KeyError): + dns_cache_table.next_addrs("localhost") + + def test_not_expired_ttl_None(self, dns_cache_table) -> None: + dns_cache_table.add("localhost", ["127.0.0.1"]) + assert not dns_cache_table.expired("localhost") + + def test_not_expired_ttl(self) -> None: + dns_cache_table = _DNSCacheTable(ttl=0.1) + dns_cache_table.add("localhost", ["127.0.0.1"]) + assert not dns_cache_table.expired("localhost") + + async def test_expired_ttl(self, loop) -> None: + dns_cache_table = _DNSCacheTable(ttl=0.01) + dns_cache_table.add("localhost", ["127.0.0.1"]) + await asyncio.sleep(0.02) + assert dns_cache_table.expired("localhost") + + def test_next_addrs(self, dns_cache_table) -> None: + dns_cache_table.add("foo", ["127.0.0.1", "127.0.0.2", "127.0.0.3"]) + + # Each calls to next_addrs return the hosts using + # a round robin strategy. + addrs = dns_cache_table.next_addrs("foo") + assert addrs == ["127.0.0.1", "127.0.0.2", "127.0.0.3"] + + addrs = dns_cache_table.next_addrs("foo") + assert addrs == ["127.0.0.2", "127.0.0.3", "127.0.0.1"] + + addrs = dns_cache_table.next_addrs("foo") + assert addrs == ["127.0.0.3", "127.0.0.1", "127.0.0.2"] + + addrs = dns_cache_table.next_addrs("foo") + assert addrs == ["127.0.0.1", "127.0.0.2", "127.0.0.3"] + + def test_next_addrs_single(self, dns_cache_table) -> None: + dns_cache_table.add("foo", ["127.0.0.1"]) + + addrs = dns_cache_table.next_addrs("foo") + assert addrs == ["127.0.0.1"] + + addrs = dns_cache_table.next_addrs("foo") + assert addrs == ["127.0.0.1"] + + +async def test_connector_cache_trace_race(): + class DummyTracer: + async def send_dns_cache_hit(self, *args, **kwargs): + connector._cached_hosts.remove(("", 0)) + + token = object() + connector = TCPConnector() + connector._cached_hosts.add(("", 0), [token]) + + traces = [DummyTracer()] + assert await connector._resolve_host("", 0, traces) == [token] + + +async def test_connector_throttle_trace_race(loop): + key = ("", 0) + token = object() + + class DummyTracer: + async def send_dns_cache_hit(self, *args, **kwargs): + event = connector._throttle_dns_events.pop(key) + event.set() + connector._cached_hosts.add(key, [token]) + + connector = TCPConnector() + connector._throttle_dns_events[key] = EventResultOrError(loop) + traces = [DummyTracer()] + assert await connector._resolve_host("", 0, traces) == [token] + + +async def test_connector_does_not_remove_needed_waiters(loop, key) -> None: + proto = create_mocked_conn(loop) + proto.is_connected.return_value = True - resolver.resolve.assert_not_called() + req = ClientRequest("GET", URL("https://localhost:80"), loop=loop) + connection_key = req.connection_key + + connector = aiohttp.BaseConnector() + connector._available_connections = mock.Mock(return_value=0) + connector._conns[key] = [(proto, loop.time())] + connector._create_connection = create_mocked_conn(loop) + connector._create_connection.return_value = loop.create_future() + connector._create_connection.return_value.set_result(proto) + + dummy_waiter = loop.create_future() + + async def await_connection_and_check_waiters(): + connection = await connector.connect(req, [], ClientTimeout()) + try: + assert connection_key in connector._waiters + assert dummy_waiter in connector._waiters[connection_key] + finally: + connection.close() + + async def allow_connection_and_add_dummy_waiter(): + # `asyncio.gather` may execute coroutines not in order. + # Skip one event loop run cycle in such a case. + if connection_key not in connector._waiters: + await asyncio.sleep(0) + connector._waiters[connection_key].popleft().set_result(None) + del connector._waiters[connection_key] + connector._waiters[connection_key].append(dummy_waiter) + + await asyncio.gather( + await_connection_and_check_waiters(), + allow_connection_and_add_dummy_waiter(), + ) diff --git a/tests/test_cookiejar.py b/tests/test_cookiejar.py index eef27ddfe8c..12bcebc01ab 100644 --- a/tests/test_cookiejar.py +++ b/tests/test_cookiejar.py @@ -1,19 +1,47 @@ import asyncio import datetime +import itertools import os import tempfile import unittest +from http.cookies import BaseCookie, Morsel, SimpleCookie from unittest import mock import pytest +from freezegun import freeze_time from yarl import URL -from aiohttp import CookieJar -from aiohttp.helpers import SimpleCookie +from aiohttp import CookieJar, DummyCookieJar @pytest.fixture def cookies_to_send(): + return SimpleCookie( + "shared-cookie=first; " + "domain-cookie=second; Domain=example.com; " + "subdomain1-cookie=third; Domain=test1.example.com; " + "subdomain2-cookie=fourth; Domain=test2.example.com; " + "dotted-domain-cookie=fifth; Domain=.example.com; " + "different-domain-cookie=sixth; Domain=different.org; " + "secure-cookie=seventh; Domain=secure.com; Secure; " + "no-path-cookie=eighth; Domain=pathtest.com; " + "path1-cookie=nineth; Domain=pathtest.com; Path=/; " + "path2-cookie=tenth; Domain=pathtest.com; Path=/one; " + "path3-cookie=eleventh; Domain=pathtest.com; Path=/one/two; " + "path4-cookie=twelfth; Domain=pathtest.com; Path=/one/two/; " + "expires-cookie=thirteenth; Domain=expirestest.com; Path=/;" + " Expires=Tue, 1 Jan 2039 12:00:00 GMT; " + "max-age-cookie=fourteenth; Domain=maxagetest.com; Path=/;" + " Max-Age=60; " + "invalid-max-age-cookie=fifteenth; Domain=invalid-values.com; " + " Max-Age=string; " + "invalid-expires-cookie=sixteenth; Domain=invalid-values.com; " + " Expires=string;" + ) + + +@pytest.fixture +def cookies_to_send_with_expired(): return SimpleCookie( "shared-cookie=first; " "domain-cookie=second; Domain=example.com; " @@ -53,27 +81,31 @@ def cookies_to_receive(): ) -def test_date_parsing(): +def test_date_parsing() -> None: parse_func = CookieJar._parse_date utc = datetime.timezone.utc assert parse_func("") is None # 70 -> 1970 - assert parse_func("Tue, 1 Jan 70 00:00:00 GMT") == \ - datetime.datetime(1970, 1, 1, tzinfo=utc) + assert parse_func("Tue, 1 Jan 70 00:00:00 GMT") == datetime.datetime( + 1970, 1, 1, tzinfo=utc + ) # 10 -> 2010 - assert parse_func("Tue, 1 Jan 10 00:00:00 GMT") == \ - datetime.datetime(2010, 1, 1, tzinfo=utc) + assert parse_func("Tue, 1 Jan 10 00:00:00 GMT") == datetime.datetime( + 2010, 1, 1, tzinfo=utc + ) # No day of week string - assert parse_func("1 Jan 1970 00:00:00 GMT") == \ - datetime.datetime(1970, 1, 1, tzinfo=utc) + assert parse_func("1 Jan 1970 00:00:00 GMT") == datetime.datetime( + 1970, 1, 1, tzinfo=utc + ) # No timezone string - assert parse_func("Tue, 1 Jan 1970 00:00:00") == \ - datetime.datetime(1970, 1, 1, tzinfo=utc) + assert parse_func("Tue, 1 Jan 1970 00:00:00") == datetime.datetime( + 1970, 1, 1, tzinfo=utc + ) # No year assert parse_func("Tue, 1 Jan 00:00:00 GMT") is None @@ -97,7 +129,7 @@ def test_date_parsing(): assert parse_func("Tue, 1 Jan 1970 77:88:99 GMT") is None -def test_domain_matching(): +def test_domain_matching() -> None: test_func = CookieJar._is_domain_match assert test_func("test.com", "test.com") @@ -110,7 +142,7 @@ def test_domain_matching(): assert not test_func("test.com", "127.0.0.1") -def test_path_matching(): +def test_path_matching() -> None: test_func = CookieJar._is_path_match assert test_func("/", "") @@ -132,7 +164,7 @@ def test_path_matching(): assert not test_func("/different-folder/", "/folder/") -def test_constructor(loop, cookies_to_send, cookies_to_receive): +async def test_constructor(loop, cookies_to_send, cookies_to_receive) -> None: jar = CookieJar(loop=loop) jar.update_cookies(cookies_to_send) jar_cookies = SimpleCookie() @@ -143,8 +175,21 @@ def test_constructor(loop, cookies_to_send, cookies_to_receive): assert jar._loop is loop -def test_save_load(loop, cookies_to_send, cookies_to_receive): - file_path = tempfile.mkdtemp() + '/aiohttp.test.cookie' +async def test_constructor_with_expired( + loop, cookies_to_send_with_expired, cookies_to_receive +) -> None: + jar = CookieJar() + jar.update_cookies(cookies_to_send_with_expired) + jar_cookies = SimpleCookie() + for cookie in jar: + dict.__setitem__(jar_cookies, cookie.key, cookie) + expected_cookies = cookies_to_send_with_expired + assert jar_cookies != expected_cookies + assert jar._loop is loop + + +async def test_save_load(loop, cookies_to_send, cookies_to_receive) -> None: + file_path = tempfile.mkdtemp() + "/aiohttp.test.cookie" # export cookie jar jar_save = CookieJar(loop=loop) @@ -162,7 +207,7 @@ def test_save_load(loop, cookies_to_send, cookies_to_receive): assert jar_test == cookies_to_receive -def test_update_cookie_with_unicode_domain(loop): +async def test_update_cookie_with_unicode_domain(loop) -> None: cookies = ( "idna-domain-first=first; Domain=xn--9caa.com; Path=/;", "idna-domain-second=second; Domain=xn--9caa.com; Path=/;", @@ -179,22 +224,16 @@ def test_update_cookie_with_unicode_domain(loop): assert jar_test == SimpleCookie(" ".join(cookies)) -def test_filter_cookie_with_unicode_domain(loop): - jar = CookieJar(loop=loop) - jar.update_cookies(SimpleCookie( - "idna-domain-first=first; Domain=xn--9caa.com; Path=/; " - )) +async def test_filter_cookie_with_unicode_domain(loop) -> None: + jar = CookieJar() + jar.update_cookies( + SimpleCookie("idna-domain-first=first; Domain=xn--9caa.com; Path=/; ") + ) assert len(jar.filter_cookies(URL("http://éé.com"))) == 1 assert len(jar.filter_cookies(URL("http://xn--9caa.com"))) == 1 -def test_ctor_ith_default_loop(loop): - asyncio.set_event_loop(loop) - jar = CookieJar() - assert jar._loop is loop - - -def test_domain_filter_ip_cookie_send(loop): +async def test_domain_filter_ip_cookie_send(loop) -> None: jar = CookieJar(loop=loop) cookies = SimpleCookie( "shared-cookie=first; " @@ -220,58 +259,68 @@ def test_domain_filter_ip_cookie_send(loop): ) jar.update_cookies(cookies) - cookies_sent = jar.filter_cookies(URL("http://1.2.3.4/")).output( - header='Cookie:') - assert cookies_sent == 'Cookie: shared-cookie=first' + cookies_sent = jar.filter_cookies(URL("http://1.2.3.4/")).output(header="Cookie:") + assert cookies_sent == "Cookie: shared-cookie=first" -def test_domain_filter_ip_cookie_receive(loop, cookies_to_receive): - jar = CookieJar(loop=loop) +async def test_domain_filter_ip_cookie_receive(cookies_to_receive) -> None: + jar = CookieJar() jar.update_cookies(cookies_to_receive, URL("http://1.2.3.4/")) assert len(jar) == 0 -def test_preserving_ip_domain_cookies(loop): - jar = CookieJar(loop=loop, unsafe=True) - jar.update_cookies(SimpleCookie( - "shared-cookie=first; " - "ip-cookie=second; Domain=127.0.0.1;" - )) - cookies_sent = jar.filter_cookies(URL("http://127.0.0.1/")).output( - header='Cookie:') - assert cookies_sent == ('Cookie: ip-cookie=second\r\n' - 'Cookie: shared-cookie=first') - - -def test_preserving_quoted_cookies(loop): - jar = CookieJar(loop=loop, unsafe=True) - jar.update_cookies(SimpleCookie( - "ip-cookie=\"second\"; Domain=127.0.0.1;" - )) - cookies_sent = jar.filter_cookies(URL("http://127.0.0.1/")).output( - header='Cookie:') - assert cookies_sent == 'Cookie: ip-cookie=\"second\"' - - -def test_ignore_domain_ending_with_dot(loop): +@pytest.mark.parametrize( + ("cookies", "expected", "quote_bool"), + [ + ( + "shared-cookie=first; ip-cookie=second; Domain=127.0.0.1;", + "Cookie: ip-cookie=second\r\nCookie: shared-cookie=first", + True, + ), + ('ip-cookie="second"; Domain=127.0.0.1;', 'Cookie: ip-cookie="second"', True), + ("custom-cookie=value/one;", 'Cookie: custom-cookie="value/one"', True), + ("custom-cookie=value1;", "Cookie: custom-cookie=value1", True), + ("custom-cookie=value/one;", "Cookie: custom-cookie=value/one", False), + ], + ids=( + "IP domain preserved", + "no shared cookie", + "quoted cookie with special char", + "quoted cookie w/o special char", + "unquoted cookie with special char", + ), +) +async def test_quotes_correctly_based_on_input( + loop, cookies, expected, quote_bool +) -> None: + jar = CookieJar(unsafe=True, quote_cookie=quote_bool) + jar.update_cookies(SimpleCookie(cookies)) + cookies_sent = jar.filter_cookies(URL("http://127.0.0.1/")).output(header="Cookie:") + assert cookies_sent == expected + + +async def test_ignore_domain_ending_with_dot(loop) -> None: jar = CookieJar(loop=loop, unsafe=True) - jar.update_cookies(SimpleCookie("cookie=val; Domain=example.com.;"), - URL("http://www.example.com")) + jar.update_cookies( + SimpleCookie("cookie=val; Domain=example.com.;"), URL("http://www.example.com") + ) cookies_sent = jar.filter_cookies(URL("http://www.example.com/")) - assert cookies_sent.output(header='Cookie:') == "Cookie: cookie=val" + assert cookies_sent.output(header="Cookie:") == "Cookie: cookie=val" cookies_sent = jar.filter_cookies(URL("http://example.com/")) - assert cookies_sent.output(header='Cookie:') == "" + assert cookies_sent.output(header="Cookie:") == "" class TestCookieJarBase(unittest.TestCase): - def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(None) # N.B. those need to be overridden in child test cases - self.jar = CookieJar(loop=self.loop) + async def make_jar(): + return CookieJar() + + self.jar = self.loop.run_until_complete(make_jar()) def tearDown(self): self.loop.close() @@ -293,7 +342,6 @@ def request_reply_with_same_url(self, url): class TestCookieJarSafe(TestCookieJarBase): - def setUp(self): super().setUp() @@ -332,86 +380,102 @@ def setUp(self): "wrong-path-cookie=nineth; Domain=pathtest.com; Path=somepath;" ) - self.jar = CookieJar(loop=self.loop) + async def make_jar(): + return CookieJar() + + self.jar = self.loop.run_until_complete(make_jar()) def timed_request(self, url, update_time, send_time): - with mock.patch.object(self.loop, 'time', return_value=update_time): + if isinstance(update_time, int): + update_time = datetime.timedelta(seconds=update_time) + elif isinstance(update_time, float): + update_time = datetime.datetime.fromtimestamp(update_time) + if isinstance(send_time, int): + send_time = datetime.timedelta(seconds=send_time) + elif isinstance(send_time, float): + send_time = datetime.datetime.fromtimestamp(send_time) + + with freeze_time(update_time): self.jar.update_cookies(self.cookies_to_send) - with mock.patch.object(self.loop, 'time', return_value=send_time): + with freeze_time(send_time): cookies_sent = self.jar.filter_cookies(URL(url)) self.jar.clear() return cookies_sent - def test_domain_filter_same_host(self): - cookies_sent, cookies_received = ( - self.request_reply_with_same_url("http://example.com/")) - - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "domain-cookie", - "dotted-domain-cookie" - }) - - self.assertEqual(set(cookies_received.keys()), { - "unconstrained-cookie", - "domain-cookie", - "dotted-domain-cookie" - }) - - def test_domain_filter_same_host_and_subdomain(self): - cookies_sent, cookies_received = ( - self.request_reply_with_same_url("http://test1.example.com/")) - - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "domain-cookie", - "subdomain1-cookie", - "dotted-domain-cookie" - }) - - self.assertEqual(set(cookies_received.keys()), { - "unconstrained-cookie", - "domain-cookie", - "subdomain1-cookie", - "dotted-domain-cookie" - }) - - def test_domain_filter_same_host_diff_subdomain(self): - cookies_sent, cookies_received = ( - self.request_reply_with_same_url("http://different.example.com/")) - - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "domain-cookie", - "dotted-domain-cookie" - }) - - self.assertEqual(set(cookies_received.keys()), { - "unconstrained-cookie", - "domain-cookie", - "dotted-domain-cookie" - }) - - def test_domain_filter_diff_host(self): - cookies_sent, cookies_received = ( - self.request_reply_with_same_url("http://different.org/")) - - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "different-domain-cookie" - }) - - self.assertEqual(set(cookies_received.keys()), { - "unconstrained-cookie", - "different-domain-cookie" - }) - - def test_domain_filter_host_only(self): - self.jar.update_cookies(self.cookies_to_receive, - URL("http://example.com/")) + def test_domain_filter_same_host(self) -> None: + cookies_sent, cookies_received = self.request_reply_with_same_url( + "http://example.com/" + ) + + self.assertEqual( + set(cookies_sent.keys()), + {"shared-cookie", "domain-cookie", "dotted-domain-cookie"}, + ) + + self.assertEqual( + set(cookies_received.keys()), + {"unconstrained-cookie", "domain-cookie", "dotted-domain-cookie"}, + ) + + def test_domain_filter_same_host_and_subdomain(self) -> None: + cookies_sent, cookies_received = self.request_reply_with_same_url( + "http://test1.example.com/" + ) + + self.assertEqual( + set(cookies_sent.keys()), + { + "shared-cookie", + "domain-cookie", + "subdomain1-cookie", + "dotted-domain-cookie", + }, + ) + + self.assertEqual( + set(cookies_received.keys()), + { + "unconstrained-cookie", + "domain-cookie", + "subdomain1-cookie", + "dotted-domain-cookie", + }, + ) + + def test_domain_filter_same_host_diff_subdomain(self) -> None: + cookies_sent, cookies_received = self.request_reply_with_same_url( + "http://different.example.com/" + ) + + self.assertEqual( + set(cookies_sent.keys()), + {"shared-cookie", "domain-cookie", "dotted-domain-cookie"}, + ) + + self.assertEqual( + set(cookies_received.keys()), + {"unconstrained-cookie", "domain-cookie", "dotted-domain-cookie"}, + ) + + def test_domain_filter_diff_host(self) -> None: + cookies_sent, cookies_received = self.request_reply_with_same_url( + "http://different.org/" + ) + + self.assertEqual( + set(cookies_sent.keys()), {"shared-cookie", "different-domain-cookie"} + ) + + self.assertEqual( + set(cookies_received.keys()), + {"unconstrained-cookie", "different-domain-cookie"}, + ) + + def test_domain_filter_host_only(self) -> None: + self.jar.update_cookies(self.cookies_to_receive, URL("http://example.com/")) cookies_sent = self.jar.filter_cookies(URL("http://example.com/")) self.assertIn("unconstrained-cookie", set(cookies_sent.keys())) @@ -419,161 +483,214 @@ def test_domain_filter_host_only(self): cookies_sent = self.jar.filter_cookies(URL("http://different.org/")) self.assertNotIn("unconstrained-cookie", set(cookies_sent.keys())) - def test_secure_filter(self): - cookies_sent, _ = ( - self.request_reply_with_same_url("http://secure.com/")) + def test_secure_filter(self) -> None: + cookies_sent, _ = self.request_reply_with_same_url("http://secure.com/") - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie" - }) + self.assertEqual(set(cookies_sent.keys()), {"shared-cookie"}) - cookies_sent, _ = ( - self.request_reply_with_same_url("https://secure.com/")) + cookies_sent, _ = self.request_reply_with_same_url("https://secure.com/") - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "secure-cookie" - }) + self.assertEqual(set(cookies_sent.keys()), {"shared-cookie", "secure-cookie"}) - def test_path_filter_root(self): - cookies_sent, _ = ( - self.request_reply_with_same_url("http://pathtest.com/")) + def test_path_filter_root(self) -> None: + cookies_sent, _ = self.request_reply_with_same_url("http://pathtest.com/") - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "no-path-cookie", - "path1-cookie" - }) + self.assertEqual( + set(cookies_sent.keys()), + {"shared-cookie", "no-path-cookie", "path1-cookie"}, + ) - def test_path_filter_folder(self): + def test_path_filter_folder(self) -> None: - cookies_sent, _ = ( - self.request_reply_with_same_url("http://pathtest.com/one/")) + cookies_sent, _ = self.request_reply_with_same_url("http://pathtest.com/one/") - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "no-path-cookie", - "path1-cookie", - "path2-cookie" - }) + self.assertEqual( + set(cookies_sent.keys()), + {"shared-cookie", "no-path-cookie", "path1-cookie", "path2-cookie"}, + ) - def test_path_filter_file(self): + def test_path_filter_file(self) -> None: cookies_sent, _ = self.request_reply_with_same_url( - "http://pathtest.com/one/two") + "http://pathtest.com/one/two" + ) - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "no-path-cookie", - "path1-cookie", - "path2-cookie", - "path3-cookie" - }) + self.assertEqual( + set(cookies_sent.keys()), + { + "shared-cookie", + "no-path-cookie", + "path1-cookie", + "path2-cookie", + "path3-cookie", + }, + ) - def test_path_filter_subfolder(self): + def test_path_filter_subfolder(self) -> None: cookies_sent, _ = self.request_reply_with_same_url( - "http://pathtest.com/one/two/") + "http://pathtest.com/one/two/" + ) - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "no-path-cookie", - "path1-cookie", - "path2-cookie", - "path3-cookie", - "path4-cookie" - }) + self.assertEqual( + set(cookies_sent.keys()), + { + "shared-cookie", + "no-path-cookie", + "path1-cookie", + "path2-cookie", + "path3-cookie", + "path4-cookie", + }, + ) - def test_path_filter_subsubfolder(self): + def test_path_filter_subsubfolder(self) -> None: cookies_sent, _ = self.request_reply_with_same_url( - "http://pathtest.com/one/two/three/") - - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "no-path-cookie", - "path1-cookie", - "path2-cookie", - "path3-cookie", - "path4-cookie" - }) - - def test_path_filter_different_folder(self): - - cookies_sent, _ = ( - self.request_reply_with_same_url("http://pathtest.com/hundred/")) - - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "no-path-cookie", - "path1-cookie" - }) - - def test_path_value(self): - _, cookies_received = ( - self.request_reply_with_same_url("http://pathtest.com/")) - - self.assertEqual(set(cookies_received.keys()), { - "unconstrained-cookie", - "no-path-cookie", - "path-cookie", - "wrong-path-cookie" - }) + "http://pathtest.com/one/two/three/" + ) + + self.assertEqual( + set(cookies_sent.keys()), + { + "shared-cookie", + "no-path-cookie", + "path1-cookie", + "path2-cookie", + "path3-cookie", + "path4-cookie", + }, + ) + + def test_path_filter_different_folder(self) -> None: + + cookies_sent, _ = self.request_reply_with_same_url( + "http://pathtest.com/hundred/" + ) + + self.assertEqual( + set(cookies_sent.keys()), + {"shared-cookie", "no-path-cookie", "path1-cookie"}, + ) + + def test_path_value(self) -> None: + _, cookies_received = self.request_reply_with_same_url("http://pathtest.com/") + + self.assertEqual( + set(cookies_received.keys()), + { + "unconstrained-cookie", + "no-path-cookie", + "path-cookie", + "wrong-path-cookie", + }, + ) self.assertEqual(cookies_received["no-path-cookie"]["path"], "/") self.assertEqual(cookies_received["path-cookie"]["path"], "/somepath") self.assertEqual(cookies_received["wrong-path-cookie"]["path"], "/") - def test_expires(self): + def test_expires(self) -> None: ts_before = datetime.datetime( - 1975, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + 1975, 1, 1, tzinfo=datetime.timezone.utc + ).timestamp() ts_after = datetime.datetime( - 2115, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + 2030, 1, 1, tzinfo=datetime.timezone.utc + ).timestamp() cookies_sent = self.timed_request( - "http://expirestest.com/", ts_before, ts_before) + "http://expirestest.com/", ts_before, ts_before + ) - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "expires-cookie" - }) + self.assertEqual(set(cookies_sent.keys()), {"shared-cookie", "expires-cookie"}) cookies_sent = self.timed_request( - "http://expirestest.com/", ts_before, ts_after) + "http://expirestest.com/", ts_before, ts_after + ) - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie" - }) + self.assertEqual(set(cookies_sent.keys()), {"shared-cookie"}) - def test_max_age(self): - cookies_sent = self.timed_request( - "http://maxagetest.com/", 1000, 1000) + def test_max_age(self) -> None: + cookies_sent = self.timed_request("http://maxagetest.com/", 1000, 1000) - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "max-age-cookie" - }) + self.assertEqual(set(cookies_sent.keys()), {"shared-cookie", "max-age-cookie"}) - cookies_sent = self.timed_request( - "http://maxagetest.com/", 1000, 2000) + cookies_sent = self.timed_request("http://maxagetest.com/", 1000, 2000) - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie" - }) + self.assertEqual(set(cookies_sent.keys()), {"shared-cookie"}) - def test_invalid_values(self): - cookies_sent, cookies_received = ( - self.request_reply_with_same_url("http://invalid-values.com/")) + def test_invalid_values(self) -> None: + cookies_sent, cookies_received = self.request_reply_with_same_url( + "http://invalid-values.com/" + ) - self.assertEqual(set(cookies_sent.keys()), { - "shared-cookie", - "invalid-max-age-cookie", - "invalid-expires-cookie" - }) + self.assertEqual( + set(cookies_sent.keys()), + {"shared-cookie", "invalid-max-age-cookie", "invalid-expires-cookie"}, + ) cookie = cookies_sent["invalid-max-age-cookie"] self.assertEqual(cookie["max-age"], "") cookie = cookies_sent["invalid-expires-cookie"] self.assertEqual(cookie["expires"], "") + + def test_cookie_not_expired_when_added_after_removal(self) -> None: + # Test case for https://github.com/aio-libs/aiohttp/issues/2084 + timestamps = [ + 533588.993, + 533588.993, + 533588.993, + 533588.993, + 533589.093, + 533589.093, + ] + + loop = mock.Mock() + loop.time.side_effect = itertools.chain( + timestamps, itertools.cycle([timestamps[-1]]) + ) + + async def make_jar(): + return CookieJar(unsafe=True) + + jar = self.loop.run_until_complete(make_jar()) + # Remove `foo` cookie. + jar.update_cookies(SimpleCookie('foo=""; Max-Age=0')) + # Set `foo` cookie to `bar`. + jar.update_cookies(SimpleCookie('foo="bar"')) + + # Assert that there is a cookie. + assert len(jar) == 1 + + +async def test_dummy_cookie_jar() -> None: + cookie = SimpleCookie("foo=bar; Domain=example.com;") + dummy_jar = DummyCookieJar() + assert len(dummy_jar) == 0 + dummy_jar.update_cookies(cookie) + assert len(dummy_jar) == 0 + with pytest.raises(StopIteration): + next(iter(dummy_jar)) + assert not dummy_jar.filter_cookies(URL("http://example.com/")) + dummy_jar.clear() + + +async def test_loose_cookies_types() -> None: + jar = CookieJar() + + accepted_types = [ + [("str", BaseCookie())], + [("str", Morsel())], + [ + ("str", "str"), + ], + {"str": BaseCookie()}, + {"str": Morsel()}, + {"str": "str"}, + SimpleCookie(), + ] + + for loose_cookies_type in accepted_types: + jar.update_cookies(cookies=loose_cookies_type) diff --git a/tests/test_flowcontrol_streams.py b/tests/test_flowcontrol_streams.py index cea5ebc2e0b..f9cce43bf4b 100644 --- a/tests/test_flowcontrol_streams.py +++ b/tests/test_flowcontrol_streams.py @@ -1,194 +1,129 @@ -import asyncio -import unittest from unittest import mock -from aiohttp import streams - - -class TestFlowControlStreamReader(unittest.TestCase): - - def setUp(self): - self.protocol = mock.Mock(_reading_paused=False) - self.transp = self.protocol.transport - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def _make_one(self, allow_pause=True, *args, **kwargs): - out = streams.FlowControlStreamReader( - self.protocol, buffer_limit=1, loop=self.loop, *args, **kwargs) - out._allow_pause = allow_pause - return out - - def test_read(self): - r = self._make_one() - r.feed_data(b'da', 2) - res = self.loop.run_until_complete(r.read(1)) - self.assertEqual(res, b'd') - self.assertFalse(r._protocol.resume_reading.called) - - def test_read_resume_paused(self): - r = self._make_one() - r.feed_data(b'test', 4) - r._protocol._reading_paused = True - - res = self.loop.run_until_complete(r.read(1)) - self.assertEqual(res, b't') - self.assertTrue(r._protocol.pause_reading.called) - - def test_readline(self): - r = self._make_one() - r.feed_data(b'data\n', 5) - res = self.loop.run_until_complete(r.readline()) - self.assertEqual(res, b'data\n') - self.assertFalse(r._protocol.resume_reading.called) - - def test_readline_resume_paused(self): - r = self._make_one() - r._protocol._reading_paused = True - r.feed_data(b'data\n', 5) - res = self.loop.run_until_complete(r.readline()) - self.assertEqual(res, b'data\n') - self.assertTrue(r._protocol.resume_reading.called) - - def test_readany(self): - r = self._make_one() - r.feed_data(b'data', 4) - res = self.loop.run_until_complete(r.readany()) - self.assertEqual(res, b'data') - self.assertFalse(r._protocol.resume_reading.called) - - def test_readany_resume_paused(self): - r = self._make_one() - r._protocol._reading_paused = True - r.feed_data(b'data', 4) - res = self.loop.run_until_complete(r.readany()) - self.assertEqual(res, b'data') - self.assertTrue(r._protocol.resume_reading.called) - - def test_readexactly(self): - r = self._make_one() - r.feed_data(b'data', 4) - res = self.loop.run_until_complete(r.readexactly(3)) - self.assertEqual(res, b'dat') - self.assertFalse(r._protocol.resume_reading.called) - - def test_readexactly_resume_paused(self): - r = self._make_one() - r._protocol._reading_paused = True - r.feed_data(b'data', 4) - res = self.loop.run_until_complete(r.readexactly(3)) - self.assertEqual(res, b'dat') - self.assertTrue(r._protocol.resume_reading.called) - - def test_feed_data(self): - r = self._make_one() - r._protocol._reading_paused = False - r.feed_data(b'datadata', 8) - self.assertTrue(r._protocol.pause_reading.called) - - def test_read_nowait(self): - r = self._make_one() - r._protocol._reading_paused = True - r.feed_data(b'data1', 5) - r.feed_data(b'data2', 5) - r.feed_data(b'data3', 5) - res = self.loop.run_until_complete(r.read(5)) - self.assertTrue(res == b'data1') - self.assertTrue(r._protocol.resume_reading.call_count == 0) - - res = r.read_nowait(5) - self.assertTrue(res == b'data2') - self.assertTrue(r._protocol.resume_reading.call_count == 0) - - res = r.read_nowait(5) - self.assertTrue(res == b'data3') - self.assertTrue(r._protocol.resume_reading.call_count == 1) - - r._protocol._reading_paused = False - res = r.read_nowait(5) - self.assertTrue(res == b'') - self.assertTrue(r._protocol.resume_reading.call_count == 1) - - -class FlowControlMixin: - - def test_feed_pause(self): - out = self._make_one() - out._protocol._reading_paused = False - out.feed_data(object(), 100) - - self.assertTrue(out._protocol.pause_reading.called) - - def test_resume_on_read(self): - out = self._make_one() - out.feed_data(object(), 100) - - out._protocol._reading_paused = True - self.loop.run_until_complete(out.read()) - self.assertTrue(out._protocol.resume_reading.called) - - -class TestFlowControlDataQueue(unittest.TestCase, FlowControlMixin): - - def setUp(self): - self.protocol = mock.Mock() - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def _make_one(self, *args, **kwargs): - out = streams.FlowControlDataQueue( - self.protocol, limit=1, loop=self.loop, *args, **kwargs) - out._allow_pause = True - return out - - -class TestFlowControlChunksQueue(unittest.TestCase, FlowControlMixin): - - def setUp(self): - self.protocol = mock.Mock() - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def _make_one(self, *args, **kwargs): - out = streams.FlowControlChunksQueue( - self.protocol, limit=1, loop=self.loop, *args, **kwargs) - out._allow_pause = True - return out - - def test_read_eof(self): - out = self._make_one() - read_task = asyncio.Task(out.read(), loop=self.loop) - - def cb(): - out.feed_eof() - self.loop.call_soon(cb) +import pytest - self.loop.run_until_complete(read_task) - self.assertTrue(out.at_eof()) - - def test_read_until_eof(self): - item = object() - - out = self._make_one() - out.feed_data(item, 1) - out.feed_eof() - - data = self.loop.run_until_complete(out.read()) - self.assertIs(data, item) +from aiohttp import streams - thing = self.loop.run_until_complete(out.read()) - self.assertEqual(thing, b'') - self.assertTrue(out.at_eof()) - def test_readany(self): - out = self._make_one() - self.assertIs(out.read.__func__, out.readany.__func__) +@pytest.fixture +def protocol(): + return mock.Mock(_reading_paused=False) + + +@pytest.fixture +def stream(loop, protocol): + out = streams.StreamReader(protocol, limit=1, loop=loop) + out._allow_pause = True + return out + + +@pytest.fixture +def buffer(loop, protocol): + out = streams.FlowControlDataQueue(protocol, limit=1, loop=loop) + out._allow_pause = True + return out + + +class TestFlowControlStreamReader: + async def test_read(self, stream) -> None: + stream.feed_data(b"da", 2) + res = await stream.read(1) + assert res == b"d" + assert not stream._protocol.resume_reading.called + + async def test_read_resume_paused(self, stream) -> None: + stream.feed_data(b"test", 4) + stream._protocol._reading_paused = True + + res = await stream.read(1) + assert res == b"t" + assert stream._protocol.pause_reading.called + + async def test_readline(self, stream) -> None: + stream.feed_data(b"d\n", 5) + res = await stream.readline() + assert res == b"d\n" + assert not stream._protocol.resume_reading.called + + async def test_readline_resume_paused(self, stream) -> None: + stream._protocol._reading_paused = True + stream.feed_data(b"d\n", 5) + res = await stream.readline() + assert res == b"d\n" + assert stream._protocol.resume_reading.called + + async def test_readany(self, stream) -> None: + stream.feed_data(b"data", 4) + res = await stream.readany() + assert res == b"data" + assert not stream._protocol.resume_reading.called + + async def test_readany_resume_paused(self, stream) -> None: + stream._protocol._reading_paused = True + stream.feed_data(b"data", 4) + res = await stream.readany() + assert res == b"data" + assert stream._protocol.resume_reading.called + + async def test_readchunk(self, stream) -> None: + stream.feed_data(b"data", 4) + res, end_of_http_chunk = await stream.readchunk() + assert res == b"data" + assert not end_of_http_chunk + assert not stream._protocol.resume_reading.called + + async def test_readchunk_resume_paused(self, stream) -> None: + stream._protocol._reading_paused = True + stream.feed_data(b"data", 4) + res, end_of_http_chunk = await stream.readchunk() + assert res == b"data" + assert not end_of_http_chunk + assert stream._protocol.resume_reading.called + + async def test_readexactly(self, stream) -> None: + stream.feed_data(b"data", 4) + res = await stream.readexactly(3) + assert res == b"dat" + assert not stream._protocol.resume_reading.called + + async def test_feed_data(self, stream) -> None: + stream._protocol._reading_paused = False + stream.feed_data(b"datadata", 8) + assert stream._protocol.pause_reading.called + + async def test_read_nowait(self, stream) -> None: + stream._protocol._reading_paused = True + stream.feed_data(b"data1", 5) + stream.feed_data(b"data2", 5) + stream.feed_data(b"data3", 5) + res = await stream.read(5) + assert res == b"data1" + assert stream._protocol.resume_reading.call_count == 0 + + res = stream.read_nowait(5) + assert res == b"data2" + assert stream._protocol.resume_reading.call_count == 0 + + res = stream.read_nowait(5) + assert res == b"data3" + assert stream._protocol.resume_reading.call_count == 1 + + stream._protocol._reading_paused = False + res = stream.read_nowait(5) + assert res == b"" + assert stream._protocol.resume_reading.call_count == 1 + + +class TestFlowControlDataQueue: + def test_feed_pause(self, buffer) -> None: + buffer._protocol._reading_paused = False + buffer.feed_data(object(), 100) + + assert buffer._protocol.pause_reading.called + + async def test_resume_on_read(self, buffer) -> None: + buffer.feed_data(object(), 100) + + buffer._protocol._reading_paused = True + await buffer.read() + assert buffer._protocol.resume_reading.called diff --git a/tests/test_formdata.py b/tests/test_formdata.py index 63235ec1241..987a262d586 100644 --- a/tests/test_formdata.py +++ b/tests/test_formdata.py @@ -1,9 +1,8 @@ -import asyncio from unittest import mock import pytest -from aiohttp.formdata import FormData +from aiohttp import ClientSession, FormData @pytest.fixture @@ -15,70 +14,86 @@ def buf(): def writer(buf): writer = mock.Mock() - def write(chunk): + async def write(chunk): buf.extend(chunk) - return () writer.write.side_effect = write return writer -def test_invalid_formdata_payload(): +def test_formdata_multipart(buf, writer) -> None: form = FormData() - form.add_field('test', object(), filename='test.txt') + assert not form.is_multipart + + form.add_field("test", b"test", filename="test.txt") + assert form.is_multipart + + +def test_invalid_formdata_payload() -> None: + form = FormData() + form.add_field("test", object(), filename="test.txt") with pytest.raises(TypeError): form() -def test_invalid_formdata_params(): +def test_invalid_formdata_params() -> None: with pytest.raises(TypeError): - FormData('asdasf') + FormData("asdasf") -def test_invalid_formdata_params2(): +def test_invalid_formdata_params2() -> None: with pytest.raises(TypeError): - FormData('as') # 2-char str is not allowed + FormData("as") # 2-char str is not allowed -def test_invalid_formdata_content_type(): +def test_invalid_formdata_content_type() -> None: form = FormData() - invalid_vals = [0, 0.1, {}, [], b'foo'] + invalid_vals = [0, 0.1, {}, [], b"foo"] for invalid_val in invalid_vals: with pytest.raises(TypeError): - form.add_field('foo', 'bar', content_type=invalid_val) + form.add_field("foo", "bar", content_type=invalid_val) -def test_invalid_formdata_filename(): +def test_invalid_formdata_filename() -> None: form = FormData() - invalid_vals = [0, 0.1, {}, [], b'foo'] + invalid_vals = [0, 0.1, {}, [], b"foo"] for invalid_val in invalid_vals: with pytest.raises(TypeError): - form.add_field('foo', 'bar', filename=invalid_val) + form.add_field("foo", "bar", filename=invalid_val) -def test_invalid_formdata_content_transfer_encoding(): +def test_invalid_formdata_content_transfer_encoding() -> None: form = FormData() - invalid_vals = [0, 0.1, {}, [], b'foo'] + invalid_vals = [0, 0.1, {}, [], b"foo"] for invalid_val in invalid_vals: with pytest.raises(TypeError): - form.add_field('foo', - 'bar', - content_transfer_encoding=invalid_val) + form.add_field("foo", "bar", content_transfer_encoding=invalid_val) -@asyncio.coroutine -def test_formdata_field_name_is_quoted(buf, writer): +async def test_formdata_field_name_is_quoted(buf, writer) -> None: form = FormData(charset="ascii") form.add_field("emails[]", "xxx@x.co", content_type="multipart/form-data") payload = form() - yield from payload.write(writer) + await payload.write(writer) assert b'name="emails%5B%5D"' in buf -@asyncio.coroutine -def test_formdata_field_name_is_not_quoted(buf, writer): +async def test_formdata_field_name_is_not_quoted(buf, writer) -> None: form = FormData(quote_fields=False, charset="ascii") form.add_field("emails[]", "xxx@x.co", content_type="multipart/form-data") payload = form() - yield from payload.write(writer) + await payload.write(writer) assert b'name="emails[]"' in buf + + +async def test_mark_formdata_as_processed() -> None: + async with ClientSession() as session: + url = "http://httpbin.org/anything" + data = FormData() + data.add_field("test", "test_value", content_type="application/json") + + await session.post(url, data=data) + assert len(data._writer._parts) == 1 + + with pytest.raises(RuntimeError): + await session.post(url, data=data) diff --git a/tests/test_frozenlist.py b/tests/test_frozenlist.py new file mode 100644 index 00000000000..68241a2c38f --- /dev/null +++ b/tests/test_frozenlist.py @@ -0,0 +1,230 @@ +from collections.abc import MutableSequence + +import pytest + +from aiohttp.frozenlist import FrozenList, PyFrozenList + + +class FrozenListMixin: + FrozenList = NotImplemented + + SKIP_METHODS = {"__abstractmethods__", "__slots__"} + + def test_subclass(self) -> None: + assert issubclass(self.FrozenList, MutableSequence) + + def test_iface(self) -> None: + for name in set(dir(MutableSequence)) - self.SKIP_METHODS: + if ( + name.startswith("_") and not name.endswith("_") + ) or name == "__class_getitem__": + continue + assert hasattr(self.FrozenList, name) + + def test_ctor_default(self) -> None: + _list = self.FrozenList([]) + assert not _list.frozen + + def test_ctor(self) -> None: + _list = self.FrozenList([1]) + assert not _list.frozen + + def test_ctor_copy_list(self) -> None: + orig = [1] + _list = self.FrozenList(orig) + del _list[0] + assert _list != orig + + def test_freeze(self) -> None: + _list = self.FrozenList() + _list.freeze() + assert _list.frozen + + def test_repr(self) -> None: + _list = self.FrozenList([1]) + assert repr(_list) == "" + _list.freeze() + assert repr(_list) == "" + + def test_getitem(self) -> None: + _list = self.FrozenList([1, 2]) + assert _list[1] == 2 + + def test_setitem(self) -> None: + _list = self.FrozenList([1, 2]) + _list[1] = 3 + assert _list[1] == 3 + + def test_delitem(self) -> None: + _list = self.FrozenList([1, 2]) + del _list[0] + assert len(_list) == 1 + assert _list[0] == 2 + + def test_len(self) -> None: + _list = self.FrozenList([1]) + assert len(_list) == 1 + + def test_iter(self) -> None: + _list = self.FrozenList([1, 2]) + assert list(iter(_list)) == [1, 2] + + def test_reversed(self) -> None: + _list = self.FrozenList([1, 2]) + assert list(reversed(_list)) == [2, 1] + + def test_eq(self) -> None: + _list = self.FrozenList([1]) + assert _list == [1] + + def test_ne(self) -> None: + _list = self.FrozenList([1]) + assert _list != [2] + + def test_le(self) -> None: + _list = self.FrozenList([1]) + assert _list <= [1] + + def test_lt(self) -> None: + _list = self.FrozenList([1]) + assert _list <= [3] + + def test_ge(self) -> None: + _list = self.FrozenList([1]) + assert _list >= [1] + + def test_gt(self) -> None: + _list = self.FrozenList([2]) + assert _list > [1] + + def test_insert(self) -> None: + _list = self.FrozenList([2]) + _list.insert(0, 1) + assert _list == [1, 2] + + def test_frozen_setitem(self) -> None: + _list = self.FrozenList([1]) + _list.freeze() + with pytest.raises(RuntimeError): + _list[0] = 2 + + def test_frozen_delitem(self) -> None: + _list = self.FrozenList([1]) + _list.freeze() + with pytest.raises(RuntimeError): + del _list[0] + + def test_frozen_insert(self) -> None: + _list = self.FrozenList([1]) + _list.freeze() + with pytest.raises(RuntimeError): + _list.insert(0, 2) + + def test_contains(self) -> None: + _list = self.FrozenList([2]) + assert 2 in _list + + def test_iadd(self) -> None: + _list = self.FrozenList([1]) + _list += [2] + assert _list == [1, 2] + + def test_iadd_frozen(self) -> None: + _list = self.FrozenList([1]) + _list.freeze() + with pytest.raises(RuntimeError): + _list += [2] + assert _list == [1] + + def test_index(self) -> None: + _list = self.FrozenList([1]) + assert _list.index(1) == 0 + + def test_remove(self) -> None: + _list = self.FrozenList([1]) + _list.remove(1) + assert len(_list) == 0 + + def test_remove_frozen(self) -> None: + _list = self.FrozenList([1]) + _list.freeze() + with pytest.raises(RuntimeError): + _list.remove(1) + assert _list == [1] + + def test_clear(self) -> None: + _list = self.FrozenList([1]) + _list.clear() + assert len(_list) == 0 + + def test_clear_frozen(self) -> None: + _list = self.FrozenList([1]) + _list.freeze() + with pytest.raises(RuntimeError): + _list.clear() + assert _list == [1] + + def test_extend(self) -> None: + _list = self.FrozenList([1]) + _list.extend([2]) + assert _list == [1, 2] + + def test_extend_frozen(self) -> None: + _list = self.FrozenList([1]) + _list.freeze() + with pytest.raises(RuntimeError): + _list.extend([2]) + assert _list == [1] + + def test_reverse(self) -> None: + _list = self.FrozenList([1, 2]) + _list.reverse() + assert _list == [2, 1] + + def test_reverse_frozen(self) -> None: + _list = self.FrozenList([1, 2]) + _list.freeze() + with pytest.raises(RuntimeError): + _list.reverse() + assert _list == [1, 2] + + def test_pop(self) -> None: + _list = self.FrozenList([1, 2]) + assert _list.pop(0) == 1 + assert _list == [2] + + def test_pop_default(self) -> None: + _list = self.FrozenList([1, 2]) + assert _list.pop() == 2 + assert _list == [1] + + def test_pop_frozen(self) -> None: + _list = self.FrozenList([1, 2]) + _list.freeze() + with pytest.raises(RuntimeError): + _list.pop() + assert _list == [1, 2] + + def test_append(self) -> None: + _list = self.FrozenList([1, 2]) + _list.append(3) + assert _list == [1, 2, 3] + + def test_append_frozen(self) -> None: + _list = self.FrozenList([1, 2]) + _list.freeze() + with pytest.raises(RuntimeError): + _list.append(3) + assert _list == [1, 2] + + def test_count(self) -> None: + _list = self.FrozenList([1, 2]) + assert _list.count(1) == 1 + + +class TestFrozenList(FrozenListMixin): + FrozenList = FrozenList + + +class TestFrozenListPy(FrozenListMixin): + FrozenList = PyFrozenList diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 0d051014295..3367c24b78a 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,252 +1,216 @@ import asyncio -import datetime +import base64 import gc +import os +import platform import sys +import tempfile +from math import isclose, modf from unittest import mock import pytest +from multidict import MultiDict +from yarl import URL from aiohttp import helpers - -def test_parse_mimetype_1(): - assert helpers.parse_mimetype('') == ('', '', '', {}) - - -def test_parse_mimetype_2(): - assert helpers.parse_mimetype('*') == ('*', '*', '', {}) - - -def test_parse_mimetype_3(): - assert (helpers.parse_mimetype('application/json') == - ('application', 'json', '', {})) - - -def test_parse_mimetype_4(): - assert ( - helpers.parse_mimetype('application/json; charset=utf-8') == - ('application', 'json', '', {'charset': 'utf-8'})) - - -def test_parse_mimetype_5(): - assert ( - helpers.parse_mimetype('''application/json; charset=utf-8;''') == - ('application', 'json', '', {'charset': 'utf-8'})) - - -def test_parse_mimetype_6(): - assert( - helpers.parse_mimetype('ApPlIcAtIoN/JSON;ChaRseT="UTF-8"') == - ('application', 'json', '', {'charset': 'UTF-8'})) - - -def test_parse_mimetype_7(): - assert ( - helpers.parse_mimetype('application/rss+xml') == - ('application', 'rss', 'xml', {})) - - -def test_parse_mimetype_8(): - assert ( - helpers.parse_mimetype('text/plain;base64') == - ('text', 'plain', '', {'base64': ''})) - - -def test_basic_auth1(): +IS_PYPY = platform.python_implementation() == "PyPy" + + +# ------------------- parse_mimetype ---------------------------------- + + +@pytest.mark.parametrize( + "mimetype, expected", + [ + ("", helpers.MimeType("", "", "", MultiDict())), + ("*", helpers.MimeType("*", "*", "", MultiDict())), + ("application/json", helpers.MimeType("application", "json", "", MultiDict())), + ( + "application/json; charset=utf-8", + helpers.MimeType( + "application", "json", "", MultiDict({"charset": "utf-8"}) + ), + ), + ( + """application/json; charset=utf-8;""", + helpers.MimeType( + "application", "json", "", MultiDict({"charset": "utf-8"}) + ), + ), + ( + 'ApPlIcAtIoN/JSON;ChaRseT="UTF-8"', + helpers.MimeType( + "application", "json", "", MultiDict({"charset": "UTF-8"}) + ), + ), + ( + "application/rss+xml", + helpers.MimeType("application", "rss", "xml", MultiDict()), + ), + ( + "text/plain;base64", + helpers.MimeType("text", "plain", "", MultiDict({"base64": ""})), + ), + ], +) +def test_parse_mimetype(mimetype, expected) -> None: + result = helpers.parse_mimetype(mimetype) + + assert isinstance(result, helpers.MimeType) + assert result == expected + + +# ------------------- guess_filename ---------------------------------- + + +def test_guess_filename_with_tempfile() -> None: + with tempfile.TemporaryFile() as fp: + assert helpers.guess_filename(fp, "no-throw") is not None + + +# ------------------- BasicAuth ----------------------------------- + + +def test_basic_auth1() -> None: # missing password here with pytest.raises(ValueError): helpers.BasicAuth(None) -def test_basic_auth2(): +def test_basic_auth2() -> None: with pytest.raises(ValueError): - helpers.BasicAuth('nkim', None) + helpers.BasicAuth("nkim", None) -def test_basic_with_auth_colon_in_login(): +def test_basic_with_auth_colon_in_login() -> None: with pytest.raises(ValueError): - helpers.BasicAuth('nkim:1', 'pwd') + helpers.BasicAuth("nkim:1", "pwd") -def test_basic_auth3(): - auth = helpers.BasicAuth('nkim') - assert auth.login == 'nkim' - assert auth.password == '' +def test_basic_auth3() -> None: + auth = helpers.BasicAuth("nkim") + assert auth.login == "nkim" + assert auth.password == "" -def test_basic_auth4(): - auth = helpers.BasicAuth('nkim', 'pwd') - assert auth.login == 'nkim' - assert auth.password == 'pwd' - assert auth.encode() == 'Basic bmtpbTpwd2Q=' +def test_basic_auth4() -> None: + auth = helpers.BasicAuth("nkim", "pwd") + assert auth.login == "nkim" + assert auth.password == "pwd" + assert auth.encode() == "Basic bmtpbTpwd2Q=" -def test_basic_auth_decode(): - auth = helpers.BasicAuth.decode('Basic bmtpbTpwd2Q=') - assert auth.login == 'nkim' - assert auth.password == 'pwd' +@pytest.mark.parametrize( + "header", + ( + "Basic bmtpbTpwd2Q=", + "basic bmtpbTpwd2Q=", + ), +) +def test_basic_auth_decode(header) -> None: + auth = helpers.BasicAuth.decode(header) + assert auth.login == "nkim" + assert auth.password == "pwd" -def test_basic_auth_invalid(): +def test_basic_auth_invalid() -> None: with pytest.raises(ValueError): - helpers.BasicAuth.decode('bmtpbTpwd2Q=') + helpers.BasicAuth.decode("bmtpbTpwd2Q=") -def test_basic_auth_decode_not_basic(): +def test_basic_auth_decode_not_basic() -> None: with pytest.raises(ValueError): - helpers.BasicAuth.decode('Complex bmtpbTpwd2Q=') + helpers.BasicAuth.decode("Complex bmtpbTpwd2Q=") -def test_basic_auth_decode_bad_base64(): +def test_basic_auth_decode_bad_base64() -> None: with pytest.raises(ValueError): - helpers.BasicAuth.decode('Basic bmtpbTpwd2Q') - - -# ------------- access logger ------------------------- - - -def test_access_logger_format(): - log_format = '%T {%{SPAM}e} "%{ETag}o" %X {X} %%P %{FOO_TEST}e %{FOO1}e' - mock_logger = mock.Mock() - access_logger = helpers.AccessLogger(mock_logger, log_format) - expected = '%s {%s} "%s" %%X {X} %%%s %s %s' - assert expected == access_logger._log_format - - -def test_access_logger_atoms(mocker): - mock_datetime = mocker.patch("aiohttp.helpers.datetime") - mock_getpid = mocker.patch("os.getpid") - utcnow = datetime.datetime(1843, 1, 1, 0, 0) - mock_datetime.datetime.utcnow.return_value = utcnow - mock_getpid.return_value = 42 - log_format = '%a %t %P %l %u %r %s %b %O %T %Tf %D' - mock_logger = mock.Mock() - access_logger = helpers.AccessLogger(mock_logger, log_format) - message = mock.Mock(headers={}, method="GET", path="/path", version=(1, 1)) - environ = {} - response = mock.Mock(headers={}, body_length=42, status=200) - transport = mock.Mock() - transport.get_extra_info.return_value = ("127.0.0.2", 1234) - access_logger.log(message, environ, response, transport, 3.1415926) - assert not mock_logger.exception.called - expected = ('127.0.0.2 [01/Jan/1843:00:00:00 +0000] <42> - - ' - 'GET /path HTTP/1.1 200 42 42 3 3.141593 3141593') - extra = { - 'bytes_sent': 42, - 'first_request_line': 'GET /path HTTP/1.1', - 'process_id': '<42>', - 'remote_address': '127.0.0.2', - 'request_time': 3, - 'request_time_frac': '3.141593', - 'request_time_micro': 3141593, - 'response_size': 42, - 'response_status': 200 - } - - mock_logger.info.assert_called_with(expected, extra=extra) - - -def test_access_logger_dicts(): - log_format = '%{User-Agent}i %{Content-Length}o %{SPAM}e %{None}i' - mock_logger = mock.Mock() - access_logger = helpers.AccessLogger(mock_logger, log_format) - message = mock.Mock(headers={"User-Agent": "Mock/1.0"}, version=(1, 1)) - environ = {"SPAM": "EGGS"} - response = mock.Mock(headers={"Content-Length": 123}) - transport = mock.Mock() - transport.get_extra_info.return_value = ("127.0.0.2", 1234) - access_logger.log(message, environ, response, transport, 0.0) - assert not mock_logger.error.called - expected = 'Mock/1.0 123 EGGS -' - extra = { - 'environ': {'SPAM': 'EGGS'}, - 'request_header': {'None': '-'}, - 'response_header': {'Content-Length': 123} - } - - mock_logger.info.assert_called_with(expected, extra=extra) - - -def test_access_logger_unix_socket(): - log_format = '|%a|' - mock_logger = mock.Mock() - access_logger = helpers.AccessLogger(mock_logger, log_format) - message = mock.Mock(headers={"User-Agent": "Mock/1.0"}, version=(1, 1)) - environ = {} - response = mock.Mock() - transport = mock.Mock() - transport.get_extra_info.return_value = "" - access_logger.log(message, environ, response, transport, 0.0) - assert not mock_logger.error.called - expected = '||' - mock_logger.info.assert_called_with(expected, extra={'remote_address': ''}) - - -def test_logger_no_message_and_environ(): - mock_logger = mock.Mock() - mock_transport = mock.Mock() - mock_transport.get_extra_info.return_value = ("127.0.0.3", 0) - access_logger = helpers.AccessLogger(mock_logger, - "%r %{FOOBAR}e %{content-type}i") - extra_dict = { - 'environ': {'FOOBAR': '-'}, - 'first_request_line': '-', - 'request_header': {'content-type': '(no headers)'} - } - - access_logger.log(None, None, None, mock_transport, 0.0) - mock_logger.info.assert_called_with("- - (no headers)", extra=extra_dict) - - -def test_logger_internal_error(): - mock_logger = mock.Mock() - mock_transport = mock.Mock() - mock_transport.get_extra_info.return_value = ("127.0.0.3", 0) - access_logger = helpers.AccessLogger(mock_logger, "%D") - access_logger.log(None, None, None, mock_transport, 'invalid') - mock_logger.exception.assert_called_with("Error in logging") - - -def test_logger_no_transport(): - mock_logger = mock.Mock() - access_logger = helpers.AccessLogger(mock_logger, "%a") - access_logger.log(None, None, None, None, 0) - mock_logger.info.assert_called_with("-", extra={'remote_address': '-'}) - - -class TestReify: - - def test_reify(self): + helpers.BasicAuth.decode("Basic bmtpbTpwd2Q") + + +@pytest.mark.parametrize("header", ("Basic ???", "Basic ")) +def test_basic_auth_decode_illegal_chars_base64(header) -> None: + with pytest.raises(ValueError, match="Invalid base64 encoding."): + helpers.BasicAuth.decode(header) + + +def test_basic_auth_decode_invalid_credentials() -> None: + with pytest.raises(ValueError, match="Invalid credentials."): + header = "Basic {}".format(base64.b64encode(b"username").decode()) + helpers.BasicAuth.decode(header) + + +@pytest.mark.parametrize( + "credentials, expected_auth", + ( + (":", helpers.BasicAuth(login="", password="", encoding="latin1")), + ( + "username:", + helpers.BasicAuth(login="username", password="", encoding="latin1"), + ), + ( + ":password", + helpers.BasicAuth(login="", password="password", encoding="latin1"), + ), + ( + "username:password", + helpers.BasicAuth(login="username", password="password", encoding="latin1"), + ), + ), +) +def test_basic_auth_decode_blank_username(credentials, expected_auth) -> None: + header = "Basic {}".format(base64.b64encode(credentials.encode()).decode()) + assert helpers.BasicAuth.decode(header) == expected_auth + + +def test_basic_auth_from_url() -> None: + url = URL("http://user:pass@example.com") + auth = helpers.BasicAuth.from_url(url) + assert auth.login == "user" + assert auth.password == "pass" + + +def test_basic_auth_from_not_url() -> None: + with pytest.raises(TypeError): + helpers.BasicAuth.from_url("http://user:pass@example.com") + + +class ReifyMixin: + + reify = NotImplemented + + def test_reify(self) -> None: class A: def __init__(self): self._cache = {} - @helpers.reify + @self.reify def prop(self): return 1 a = A() assert 1 == a.prop - def test_reify_class(self): + def test_reify_class(self) -> None: class A: def __init__(self): self._cache = {} - @helpers.reify + @self.reify def prop(self): """Docstring.""" return 1 - assert isinstance(A.prop, helpers.reify) - assert 'Docstring.' == A.prop.__doc__ + assert isinstance(A.prop, self.reify) + assert "Docstring." == A.prop.__doc__ - def test_reify_assignment(self): + def test_reify_assignment(self) -> None: class A: def __init__(self): self._cache = {} - @helpers.reify + @self.reify def prop(self): return 1 @@ -256,34 +220,20 @@ def prop(self): a.prop = 123 -@pytest.mark.skipif(sys.version_info < (3, 5), reason='old python') -def test_create_future_with_new_loop(): - # We should use the new create_future() if it's available. - mock_loop = mock.Mock() - expected = 'hello' - mock_loop.create_future.return_value = expected - assert expected == helpers.create_future(mock_loop) +class TestPyReify(ReifyMixin): + reify = helpers.reify_py -@pytest.mark.skipif(sys.version_info >= (3, 5, 2), reason='new python') -def test_create_future_with_old_loop(mocker): - MockFuture = mocker.patch('asyncio.Future') - # The old loop (without create_future()) should just have a Future object - # wrapped around it. - mock_loop = mock.Mock() - del mock_loop.create_future +if not helpers.NO_EXTENSIONS and not IS_PYPY and hasattr(helpers, "reify_c"): - expected = 'hello' - MockFuture.return_value = expected + class TestCReify(ReifyMixin): + reify = helpers.reify_c - future = helpers.create_future(mock_loop) - MockFuture.assert_called_with(loop=mock_loop) - assert expected == future # ----------------------------------- is_ip_address() ---------------------- -def test_is_ip_address(): +def test_is_ip_address() -> None: assert helpers.is_ip_address("127.0.0.1") assert helpers.is_ip_address("::1") assert helpers.is_ip_address("FE80:0000:0000:0000:0202:B3FF:FE1E:8329") @@ -301,7 +251,7 @@ def test_is_ip_address(): assert not helpers.is_ip_address("1200::AB00:1234::2552:7777:1313") -def test_is_ip_address_bytes(): +def test_is_ip_address_bytes() -> None: assert helpers.is_ip_address(b"127.0.0.1") assert helpers.is_ip_address(b"::1") assert helpers.is_ip_address(b"FE80:0000:0000:0000:0202:B3FF:FE1E:8329") @@ -319,36 +269,46 @@ def test_is_ip_address_bytes(): assert not helpers.is_ip_address(b"1200::AB00:1234::2552:7777:1313") -def test_ip_addresses(): +def test_ipv4_addresses() -> None: ip_addresses = [ - '0.0.0.0', - '127.0.0.1', - '255.255.255.255', - '0:0:0:0:0:0:0:0', - 'FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF', - '00AB:0002:3008:8CFD:00AB:0002:3008:8CFD', - '00ab:0002:3008:8cfd:00ab:0002:3008:8cfd', - 'AB:02:3008:8CFD:AB:02:3008:8CFD', - 'AB:02:3008:8CFD::02:3008:8CFD', - '::', - '1::1', + "0.0.0.0", + "127.0.0.1", + "255.255.255.255", ] for address in ip_addresses: + assert helpers.is_ipv4_address(address) + assert not helpers.is_ipv6_address(address) assert helpers.is_ip_address(address) -def test_host_addresses(): +def test_ipv6_addresses() -> None: + ip_addresses = [ + "0:0:0:0:0:0:0:0", + "FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF", + "00AB:0002:3008:8CFD:00AB:0002:3008:8CFD", + "00ab:0002:3008:8cfd:00ab:0002:3008:8cfd", + "AB:02:3008:8CFD:AB:02:3008:8CFD", + "AB:02:3008:8CFD::02:3008:8CFD", + "::", + "1::1", + ] + for address in ip_addresses: + assert not helpers.is_ipv4_address(address) + assert helpers.is_ipv6_address(address) + assert helpers.is_ip_address(address) + + +def test_host_addresses() -> None: hosts = [ - 'www.four.part.host' - 'www.python.org', - 'foo.bar', - 'localhost', + "www.four.part.host" "www.python.org", + "foo.bar", + "localhost", ] for host in hosts: assert not helpers.is_ip_address(host) -def test_is_ip_address_invalid_type(): +def test_is_ip_address_invalid_type() -> None: with pytest.raises(TypeError): helpers.is_ip_address(123) @@ -356,57 +316,10 @@ def test_is_ip_address_invalid_type(): helpers.is_ip_address(object()) -# ----------------------------------- TimeService ---------------------- - - -@pytest.fixture -def time_service(loop): - return helpers.TimeService(loop, interval=0.1) - - -class TestTimeService: - - def test_ctor(self, time_service): - assert time_service._cb is not None - assert time_service._time is not None - assert time_service._strtime is None - - def test_stop(self, time_service): - time_service.close() - assert time_service._cb is None - assert time_service._loop is None - - def test_double_stopping(self, time_service): - time_service.close() - time_service.close() - assert time_service._cb is None - assert time_service._loop is None - - def test_time(self, time_service): - t = time_service._time - assert t == time_service.time() - - def test_strtime(self, time_service): - time_service._time = 1477797232 - assert time_service.strtime() == 'Sun, 30 Oct 2016 03:13:52 GMT' - # second call should use cached value - assert time_service.strtime() == 'Sun, 30 Oct 2016 03:13:52 GMT' - - def test_recalc_time(self, time_service, mocker): - mocker.spy(time_service._loop, 'time') - - time_service._time = 123 - time_service._strtime = 'asd' - time_service._count = 1000000 - time_service._on_cb() - assert time_service._strtime is None - assert time_service._time > 1234 - assert time_service._count == 0 - - # ----------------------------------- TimeoutHandle ------------------- -def test_timeout_handle(loop): + +def test_timeout_handle(loop) -> None: handle = helpers.TimeoutHandle(loop, 10.2) cb = mock.Mock() handle.register(cb) @@ -415,7 +328,19 @@ def test_timeout_handle(loop): assert not handle._callbacks -def test_timeout_handle_cb_exc(loop): +def test_when_timeout_smaller_second(loop) -> None: + timeout = 0.1 + timer = loop.time() + timeout + + handle = helpers.TimeoutHandle(loop, timeout) + when = handle.start()._when + handle.close() + + assert isinstance(when, float) + assert isclose(when - timer, 0, abs_tol=0.001) + + +def test_timeout_handle_cb_exc(loop) -> None: handle = helpers.TimeoutHandle(loop, 10.2) cb = mock.Mock() handle.register(cb) @@ -425,40 +350,48 @@ def test_timeout_handle_cb_exc(loop): assert not handle._callbacks -# ----------------------------------- FrozenList ---------------------- +def test_timer_context_cancelled() -> None: + with mock.patch("aiohttp.helpers.asyncio") as m_asyncio: + m_asyncio.TimeoutError = asyncio.TimeoutError + loop = mock.Mock() + ctx = helpers.TimerContext(loop) + ctx.timeout() + with pytest.raises(asyncio.TimeoutError): + with ctx: + pass -class TestFrozenList: - def test_eq(self): - l = helpers.FrozenList([1]) - assert l == [1] + if helpers.PY_37: + assert m_asyncio.current_task.return_value.cancel.called + else: + assert m_asyncio.Task.current_task.return_value.cancel.called - def test_le(self): - l = helpers.FrozenList([1]) - assert l < [2] + +def test_timer_context_no_task(loop) -> None: + with pytest.raises(RuntimeError): + with helpers.TimerContext(loop): + pass # -------------------------------- CeilTimeout -------------------------- -@asyncio.coroutine -def test_weakref_handle(loop): +async def test_weakref_handle(loop) -> None: cb = mock.Mock() - helpers.weakref_handle(cb, 'test', 0.01, loop, False) - yield from asyncio.sleep(0.1, loop=loop) + helpers.weakref_handle(cb, "test", 0.01, loop) + await asyncio.sleep(0.1) assert cb.test.called -@asyncio.coroutine -def test_weakref_handle_weak(loop): +async def test_weakref_handle_weak(loop) -> None: cb = mock.Mock() - helpers.weakref_handle(cb, 'test', 0.01, loop, False) + helpers.weakref_handle(cb, "test", 0.01, loop) del cb gc.collect() - yield from asyncio.sleep(0.1, loop=loop) + await asyncio.sleep(0.1) -def test_ceil_call_later(): +def test_ceil_call_later() -> None: cb = mock.Mock() loop = mock.Mock() loop.time.return_value = 10.1 @@ -466,45 +399,235 @@ def test_ceil_call_later(): loop.call_at.assert_called_with(21.0, cb) -def test_ceil_call_later_no_timeout(): +def test_ceil_call_later_no_timeout() -> None: cb = mock.Mock() loop = mock.Mock() helpers.call_later(cb, 0, loop) assert not loop.call_at.called -@asyncio.coroutine -def test_ceil_timeout(loop): - with helpers.CeilTimeout(0, loop=loop) as timeout: +async def test_ceil_timeout(loop) -> None: + with helpers.CeilTimeout(None, loop=loop) as timeout: assert timeout._timeout is None assert timeout._cancel_handler is None +def test_ceil_timeout_no_task(loop) -> None: + with pytest.raises(RuntimeError): + with helpers.CeilTimeout(10, loop=loop): + pass + + +@pytest.mark.skipif( + sys.version_info < (3, 7), reason="TimerHandle.when() doesn't exist" +) +async def test_ceil_timeout_round(loop) -> None: + with helpers.CeilTimeout(7.5, loop=loop) as cm: + frac, integer = modf(cm._cancel_handler.when()) + assert frac == 0 + + +@pytest.mark.skipif( + sys.version_info < (3, 7), reason="TimerHandle.when() doesn't exist" +) +async def test_ceil_timeout_small(loop) -> None: + with helpers.CeilTimeout(1.1, loop=loop) as cm: + frac, integer = modf(cm._cancel_handler.when()) + # a chance for exact integer with zero fraction is negligible + assert frac != 0 + + # -------------------------------- ContentDisposition ------------------- -def test_content_disposition(): - assert (helpers.content_disposition_header('attachment', foo='bar') == - 'attachment; foo="bar"') + +def test_content_disposition() -> None: + assert ( + helpers.content_disposition_header("attachment", foo="bar") + == 'attachment; foo="bar"' + ) -def test_content_disposition_bad_type(): +def test_content_disposition_bad_type() -> None: with pytest.raises(ValueError): - helpers.content_disposition_header('foo bar') + helpers.content_disposition_header("foo bar") with pytest.raises(ValueError): - helpers.content_disposition_header('—Ç–µ—Å—Ç') + helpers.content_disposition_header("—Ç–µ—Å—Ç") with pytest.raises(ValueError): - helpers.content_disposition_header('foo\x00bar') + helpers.content_disposition_header("foo\x00bar") with pytest.raises(ValueError): - helpers.content_disposition_header('') + helpers.content_disposition_header("") -def test_set_content_disposition_bad_param(): +def test_set_content_disposition_bad_param() -> None: with pytest.raises(ValueError): - helpers.content_disposition_header('inline', **{'foo bar': 'baz'}) + helpers.content_disposition_header("inline", **{"foo bar": "baz"}) with pytest.raises(ValueError): - helpers.content_disposition_header('inline', **{'—Ç–µ—Å—Ç': 'baz'}) + helpers.content_disposition_header("inline", **{"—Ç–µ—Å—Ç": "baz"}) with pytest.raises(ValueError): - helpers.content_disposition_header('inline', **{'': 'baz'}) + helpers.content_disposition_header("inline", **{"": "baz"}) with pytest.raises(ValueError): - helpers.content_disposition_header('inline', - **{'foo\x00bar': 'baz'}) + helpers.content_disposition_header("inline", **{"foo\x00bar": "baz"}) + + +# --------------------- proxies_from_env ------------------------------ + + +def test_proxies_from_env_http(mocker) -> None: + url = URL("http://aiohttp.io/path") + mocker.patch.dict(os.environ, {"http_proxy": str(url)}) + ret = helpers.proxies_from_env() + assert ret.keys() == {"http"} + assert ret["http"].proxy == url + assert ret["http"].proxy_auth is None + + +def test_proxies_from_env_http_proxy_for_https_proto(mocker) -> None: + url = URL("http://aiohttp.io/path") + mocker.patch.dict(os.environ, {"https_proxy": str(url)}) + ret = helpers.proxies_from_env() + assert ret.keys() == {"https"} + assert ret["https"].proxy == url + assert ret["https"].proxy_auth is None + + +def test_proxies_from_env_https_proxy_skipped(mocker) -> None: + url = URL("https://aiohttp.io/path") + mocker.patch.dict(os.environ, {"https_proxy": str(url)}) + log = mocker.patch("aiohttp.log.client_logger.warning") + assert helpers.proxies_from_env() == {} + log.assert_called_with( + "HTTPS proxies %s are not supported, ignoring", URL("https://aiohttp.io/path") + ) + + +def test_proxies_from_env_http_with_auth(mocker) -> None: + url = URL("http://user:pass@aiohttp.io/path") + mocker.patch.dict(os.environ, {"http_proxy": str(url)}) + ret = helpers.proxies_from_env() + assert ret.keys() == {"http"} + assert ret["http"].proxy == url.with_user(None) + proxy_auth = ret["http"].proxy_auth + assert proxy_auth.login == "user" + assert proxy_auth.password == "pass" + assert proxy_auth.encoding == "latin1" + + +# ------------ get_running_loop --------------------------------- + + +def test_get_running_loop_not_running(loop) -> None: + with pytest.warns(DeprecationWarning): + helpers.get_running_loop() + + +async def test_get_running_loop_ok(loop) -> None: + assert helpers.get_running_loop() is loop + + +# ------------- set_result / set_exception ---------------------- + + +async def test_set_result(loop) -> None: + fut = loop.create_future() + helpers.set_result(fut, 123) + assert 123 == await fut + + +async def test_set_result_cancelled(loop) -> None: + fut = loop.create_future() + fut.cancel() + helpers.set_result(fut, 123) + + with pytest.raises(asyncio.CancelledError): + await fut + + +async def test_set_exception(loop) -> None: + fut = loop.create_future() + helpers.set_exception(fut, RuntimeError()) + with pytest.raises(RuntimeError): + await fut + + +async def test_set_exception_cancelled(loop) -> None: + fut = loop.create_future() + fut.cancel() + helpers.set_exception(fut, RuntimeError()) + + with pytest.raises(asyncio.CancelledError): + await fut + + +# ----------- ChainMapProxy -------------------------- + + +class TestChainMapProxy: + @pytest.mark.skipif(not helpers.PY_36, reason="Requires Python 3.6+") + def test_inheritance(self) -> None: + with pytest.raises(TypeError): + + class A(helpers.ChainMapProxy): + pass + + def test_getitem(self) -> None: + d1 = {"a": 2, "b": 3} + d2 = {"a": 1} + cp = helpers.ChainMapProxy([d1, d2]) + assert cp["a"] == 2 + assert cp["b"] == 3 + + def test_getitem_not_found(self) -> None: + d = {"a": 1} + cp = helpers.ChainMapProxy([d]) + with pytest.raises(KeyError): + cp["b"] + + def test_get(self) -> None: + d1 = {"a": 2, "b": 3} + d2 = {"a": 1} + cp = helpers.ChainMapProxy([d1, d2]) + assert cp.get("a") == 2 + + def test_get_default(self) -> None: + d1 = {"a": 2, "b": 3} + d2 = {"a": 1} + cp = helpers.ChainMapProxy([d1, d2]) + assert cp.get("c", 4) == 4 + + def test_get_non_default(self) -> None: + d1 = {"a": 2, "b": 3} + d2 = {"a": 1} + cp = helpers.ChainMapProxy([d1, d2]) + assert cp.get("a", 4) == 2 + + def test_len(self) -> None: + d1 = {"a": 2, "b": 3} + d2 = {"a": 1} + cp = helpers.ChainMapProxy([d1, d2]) + assert len(cp) == 2 + + def test_iter(self) -> None: + d1 = {"a": 2, "b": 3} + d2 = {"a": 1} + cp = helpers.ChainMapProxy([d1, d2]) + assert set(cp) == {"a", "b"} + + def test_contains(self) -> None: + d1 = {"a": 2, "b": 3} + d2 = {"a": 1} + cp = helpers.ChainMapProxy([d1, d2]) + assert "a" in cp + assert "b" in cp + assert "c" not in cp + + def test_bool(self) -> None: + assert helpers.ChainMapProxy([{"a": 1}]) + assert not helpers.ChainMapProxy([{}, {}]) + assert not helpers.ChainMapProxy([]) + + def test_repr(self) -> None: + d1 = {"a": 2, "b": 3} + d2 = {"a": 1} + cp = helpers.ChainMapProxy([d1, d2]) + expected = f"ChainMapProxy({d1!r}, {d2!r})" + assert expected == repr(cp) diff --git a/tests/test_http_exceptions.py b/tests/test_http_exceptions.py index bc6994ba5d6..26a5adb3bfc 100644 --- a/tests/test_http_exceptions.py +++ b/tests/test_http_exceptions.py @@ -1,20 +1,149 @@ -"""Tests for http_exceptions.py""" +# Tests for http_exceptions.py + +import pickle from aiohttp import http_exceptions -def test_bad_status_line1(): - err = http_exceptions.BadStatusLine(b'') - assert str(err) == "b''" +class TestHttpProcessingError: + def test_ctor(self) -> None: + err = http_exceptions.HttpProcessingError( + code=500, message="Internal error", headers={} + ) + assert err.code == 500 + assert err.message == "Internal error" + assert err.headers == {} + + def test_pickle(self) -> None: + err = http_exceptions.HttpProcessingError( + code=500, message="Internal error", headers={} + ) + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.code == 500 + assert err2.message == "Internal error" + assert err2.headers == {} + assert err2.foo == "bar" + + def test_str(self) -> None: + err = http_exceptions.HttpProcessingError( + code=500, message="Internal error", headers={} + ) + assert str(err) == "500, message='Internal error'" + + def test_repr(self) -> None: + err = http_exceptions.HttpProcessingError( + code=500, message="Internal error", headers={} + ) + assert repr(err) == ("") + + +class TestBadHttpMessage: + def test_ctor(self) -> None: + err = http_exceptions.BadHttpMessage("Bad HTTP message", headers={}) + assert err.code == 400 + assert err.message == "Bad HTTP message" + assert err.headers == {} + + def test_pickle(self) -> None: + err = http_exceptions.BadHttpMessage(message="Bad HTTP message", headers={}) + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.code == 400 + assert err2.message == "Bad HTTP message" + assert err2.headers == {} + assert err2.foo == "bar" + + def test_str(self) -> None: + err = http_exceptions.BadHttpMessage(message="Bad HTTP message", headers={}) + assert str(err) == "400, message='Bad HTTP message'" + + def test_repr(self) -> None: + err = http_exceptions.BadHttpMessage(message="Bad HTTP message", headers={}) + assert repr(err) == "" + + +class TestLineTooLong: + def test_ctor(self) -> None: + err = http_exceptions.LineTooLong("spam", "10", "12") + assert err.code == 400 + assert err.message == "Got more than 10 bytes (12) when reading spam." + assert err.headers is None + + def test_pickle(self) -> None: + err = http_exceptions.LineTooLong(line="spam", limit="10", actual_size="12") + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.code == 400 + assert err2.message == ("Got more than 10 bytes (12) " "when reading spam.") + assert err2.headers is None + assert err2.foo == "bar" + + def test_str(self) -> None: + err = http_exceptions.LineTooLong(line="spam", limit="10", actual_size="12") + assert str(err) == ( + "400, message='Got more than 10 bytes (12) " "when reading spam.'" + ) + + def test_repr(self) -> None: + err = http_exceptions.LineTooLong(line="spam", limit="10", actual_size="12") + assert repr(err) == ( + "" + ) + + +class TestInvalidHeader: + def test_ctor(self) -> None: + err = http_exceptions.InvalidHeader("X-Spam") + assert err.code == 400 + assert err.message == "Invalid HTTP Header: X-Spam" + assert err.headers is None + + def test_pickle(self) -> None: + err = http_exceptions.InvalidHeader(hdr="X-Spam") + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.code == 400 + assert err2.message == "Invalid HTTP Header: X-Spam" + assert err2.headers is None + assert err2.foo == "bar" + + def test_str(self) -> None: + err = http_exceptions.InvalidHeader(hdr="X-Spam") + assert str(err) == "400, message='Invalid HTTP Header: X-Spam'" + + def test_repr(self) -> None: + err = http_exceptions.InvalidHeader(hdr="X-Spam") + assert repr(err) == ( + "" + ) -def test_bad_status_line2(): - err = http_exceptions.BadStatusLine('Test') - assert str(err) == 'Test' +class TestBadStatusLine: + def test_ctor(self) -> None: + err = http_exceptions.BadStatusLine("Test") + assert err.line == "Test" + assert str(err) == "400, message=\"Bad status line 'Test'\"" + def test_ctor2(self) -> None: + err = http_exceptions.BadStatusLine(b"") + assert err.line == "b''" + assert str(err) == "400, message='Bad status line \"b\\'\\'\"'" -def test_http_error_exception(): - exc = http_exceptions.HttpProcessingError( - code=500, message='Internal error') - assert exc.code == 500 - assert exc.message == 'Internal error' + def test_pickle(self) -> None: + err = http_exceptions.BadStatusLine("Test") + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.line == "Test" + assert err2.foo == "bar" diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index cbe302a29b7..87e98eaad37 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -1,9 +1,8 @@ -"""Tests for aiohttp/protocol.py""" +# Tests for aiohttp/protocol.py import asyncio -import unittest -import zlib from unittest import mock +from urllib.parse import quote import pytest from multidict import CIMultiDict @@ -11,16 +10,27 @@ import aiohttp from aiohttp import http_exceptions, streams -from aiohttp.http_parser import (DeflateBuffer, HttpPayloadParser, - HttpRequestParserPy, HttpResponseParserPy) +from aiohttp.http_parser import ( + DeflateBuffer, + HttpPayloadParser, + HttpRequestParserPy, + HttpResponseParserPy, +) + +try: + import brotli +except ImportError: + brotli = None + REQUEST_PARSERS = [HttpRequestParserPy] RESPONSE_PARSERS = [HttpResponseParserPy] try: - from aiohttp import _http_parser - REQUEST_PARSERS.append(_http_parser.HttpRequestParserC) - RESPONSE_PARSERS.append(_http_parser.HttpResponseParserC) + from aiohttp.http_parser import HttpRequestParserC, HttpResponseParserC + + REQUEST_PARSERS.append(HttpRequestParserC) + RESPONSE_PARSERS.append(HttpResponseParserC) except ImportError: # pragma: no cover pass @@ -32,86 +42,111 @@ def protocol(): @pytest.fixture(params=REQUEST_PARSERS) def parser(loop, protocol, request): - """Parser implementations""" - return request.param(protocol, loop, 8190, 32768, 8190) + # Parser implementations + return request.param( + protocol, + loop, + 2 ** 16, + max_line_size=8190, + max_headers=32768, + max_field_size=8190, + ) @pytest.fixture(params=REQUEST_PARSERS) def request_cls(request): - """Request Parser class""" + # Request Parser class return request.param @pytest.fixture(params=RESPONSE_PARSERS) def response(loop, protocol, request): - """Parser implementations""" - return request.param(protocol, loop, 8190, 32768, 8190) + # Parser implementations + return request.param( + protocol, + loop, + 2 ** 16, + max_line_size=8190, + max_headers=32768, + max_field_size=8190, + ) @pytest.fixture(params=RESPONSE_PARSERS) def response_cls(request): - """Parser implementations""" + # Parser implementations return request.param -def test_parse_headers(parser): - text = b'''GET /test HTTP/1.1\r +@pytest.fixture +def stream(): + return mock.Mock() + + +def test_parse_headers(parser) -> None: + text = b"""GET /test HTTP/1.1\r test: line\r continue\r test2: data\r \r -''' +""" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 msg = messages[0][0] - assert list(msg.headers.items()) == [('Test', 'line continue'), - ('Test2', 'data')] - assert msg.raw_headers == ((b'test', b'line continue'), - (b'test2', b'data')) + assert list(msg.headers.items()) == [("test", "line continue"), ("test2", "data")] + assert msg.raw_headers == ((b"test", b"line continue"), (b"test2", b"data")) assert not msg.should_close assert msg.compression is None assert not msg.upgrade -def test_parse(parser): - text = b'GET /test HTTP/1.1\r\n\r\n' +def test_parse(parser) -> None: + text = b"GET /test HTTP/1.1\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 msg, _ = messages[0] assert msg.compression is None assert not msg.upgrade - assert msg.method == 'GET' - assert msg.path == '/test' + assert msg.method == "GET" + assert msg.path == "/test" assert msg.version == (1, 1) -@asyncio.coroutine -def test_parse_body(parser): - text = b'GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody' +async def test_parse_body(parser) -> None: + text = b"GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody" + messages, upgrade, tail = parser.feed_data(text) + assert len(messages) == 1 + _, payload = messages[0] + body = await payload.read(4) + assert body == b"body" + + +async def test_parse_body_with_CRLF(parser) -> None: + text = b"\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 _, payload = messages[0] - body = yield from payload.read(4) - assert body == b'body' + body = await payload.read(4) + assert body == b"body" -def test_parse_delayed(parser): - text = b'GET /test HTTP/1.1\r\n' +def test_parse_delayed(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 0 assert not upgrade - messages, upgrade, tail = parser.feed_data(b'\r\n') + messages, upgrade, tail = parser.feed_data(b"\r\n") assert len(messages) == 1 msg = messages[0][0] - assert msg.method == 'GET' + assert msg.method == "GET" -def test_headers_multi_feed(parser): - text1 = b'GET /test HTTP/1.1\r\n' - text2 = b'test: line\r' - text3 = b'\n continue\r\n\r\n' +def test_headers_multi_feed(parser) -> None: + text1 = b"GET /test HTTP/1.1\r\n" + text2 = b"test: line\r" + text3 = b"\n continue\r\n\r\n" messages, upgrade, tail = parser.feed_data(text1) assert len(messages) == 0 @@ -123,106 +158,128 @@ def test_headers_multi_feed(parser): assert len(messages) == 1 msg = messages[0][0] - assert list(msg.headers.items()) == [('Test', 'line continue')] - assert msg.raw_headers == ((b'test', b'line continue'),) + assert list(msg.headers.items()) == [("test", "line continue")] + assert msg.raw_headers == ((b"test", b"line continue"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade -def test_parse_headers_multi(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'Set-Cookie: c1=cookie1\r\n' - b'Set-Cookie: c2=cookie2\r\n\r\n') +def test_headers_split_field(parser) -> None: + text1 = b"GET /test HTTP/1.1\r\n" + text2 = b"t" + text3 = b"es" + text4 = b"t: value\r\n\r\n" + + messages, upgrade, tail = parser.feed_data(text1) + messages, upgrade, tail = parser.feed_data(text2) + messages, upgrade, tail = parser.feed_data(text3) + assert len(messages) == 0 + messages, upgrade, tail = parser.feed_data(text4) + assert len(messages) == 1 + + msg = messages[0][0] + assert list(msg.headers.items()) == [("test", "value")] + assert msg.raw_headers == ((b"test", b"value"),) + assert not msg.should_close + assert msg.compression is None + assert not msg.upgrade + + +def test_parse_headers_multi(parser) -> None: + text = ( + b"GET /test HTTP/1.1\r\n" + b"Set-Cookie: c1=cookie1\r\n" + b"Set-Cookie: c2=cookie2\r\n\r\n" + ) messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 msg = messages[0][0] - assert list(msg.headers.items()) == [('Set-Cookie', 'c1=cookie1'), - ('Set-Cookie', 'c2=cookie2')] - assert msg.raw_headers == ((b'Set-Cookie', b'c1=cookie1'), - (b'Set-Cookie', b'c2=cookie2')) + assert list(msg.headers.items()) == [ + ("Set-Cookie", "c1=cookie1"), + ("Set-Cookie", "c2=cookie2"), + ] + assert msg.raw_headers == ( + (b"Set-Cookie", b"c1=cookie1"), + (b"Set-Cookie", b"c2=cookie2"), + ) assert not msg.should_close assert msg.compression is None -def test_conn_default_1_0(parser): - text = b'GET /test HTTP/1.0\r\n\r\n' +def test_conn_default_1_0(parser) -> None: + text = b"GET /test HTTP/1.0\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close -def test_conn_default_1_1(parser): - text = b'GET /test HTTP/1.1\r\n\r\n' +def test_conn_default_1_1(parser) -> None: + text = b"GET /test HTTP/1.1\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close -def test_conn_close(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'connection: close\r\n\r\n') +def test_conn_close(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"connection: close\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close -def test_conn_close_1_0(parser): - text = (b'GET /test HTTP/1.0\r\n' - b'connection: close\r\n\r\n') +def test_conn_close_1_0(parser) -> None: + text = b"GET /test HTTP/1.0\r\n" b"connection: close\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close -def test_conn_keep_alive_1_0(parser): - text = (b'GET /test HTTP/1.0\r\n' - b'connection: keep-alive\r\n\r\n') +def test_conn_keep_alive_1_0(parser) -> None: + text = b"GET /test HTTP/1.0\r\n" b"connection: keep-alive\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close -def test_conn_keep_alive_1_1(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'connection: keep-alive\r\n\r\n') +def test_conn_keep_alive_1_1(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"connection: keep-alive\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close -def test_conn_other_1_0(parser): - text = (b'GET /test HTTP/1.0\r\n' - b'connection: test\r\n\r\n') +def test_conn_other_1_0(parser) -> None: + text = b"GET /test HTTP/1.0\r\n" b"connection: test\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close -def test_conn_other_1_1(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'connection: test\r\n\r\n') +def test_conn_other_1_1(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"connection: test\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close -def test_request_chunked(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'transfer-encoding: chunked\r\n\r\n') +def test_request_chunked(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"transfer-encoding: chunked\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg, payload = messages[0] assert msg.chunked assert not upgrade - assert isinstance(payload, streams.FlowControlStreamReader) + assert isinstance(payload, streams.StreamReader) -def test_conn_upgrade(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'connection: upgrade\r\n' - b'upgrade: websocket\r\n\r\n') +def test_conn_upgrade(parser) -> None: + text = ( + b"GET /test HTTP/1.1\r\n" + b"connection: upgrade\r\n" + b"upgrade: websocket\r\n\r\n" + ) messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close @@ -230,180 +287,294 @@ def test_conn_upgrade(parser): assert upgrade -def test_compression_deflate(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'content-encoding: deflate\r\n\r\n') +def test_compression_empty(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"content-encoding: \r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] - assert msg.compression == 'deflate' + assert msg.compression is None -def test_compression_gzip(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'content-encoding: gzip\r\n\r\n') +def test_compression_deflate(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"content-encoding: deflate\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] - assert msg.compression == 'gzip' + assert msg.compression == "deflate" -def test_compression_unknown(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'content-encoding: compress\r\n\r\n') +def test_compression_gzip(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"content-encoding: gzip\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] - assert not msg.compression + assert msg.compression == "gzip" -def test_headers_connect(parser): - text = (b'CONNECT www.google.com HTTP/1.1\r\n' - b'content-length: 0\r\n\r\n') +@pytest.mark.skipif(brotli is None, reason="brotli is not installed") +def test_compression_brotli(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"content-encoding: br\r\n\r\n" + messages, upgrade, tail = parser.feed_data(text) + msg = messages[0][0] + assert msg.compression == "br" + + +def test_compression_unknown(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"content-encoding: compress\r\n\r\n" + messages, upgrade, tail = parser.feed_data(text) + msg = messages[0][0] + assert msg.compression is None + + +def test_headers_connect(parser) -> None: + text = b"CONNECT www.google.com HTTP/1.1\r\n" b"content-length: 0\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg, payload = messages[0] assert upgrade - assert isinstance(payload, streams.FlowControlStreamReader) + assert isinstance(payload, streams.StreamReader) -def test_headers_old_websocket_key1(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'SEC-WEBSOCKET-KEY1: line\r\n\r\n') +def test_headers_old_websocket_key1(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"SEC-WEBSOCKET-KEY1: line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) -def test_headers_content_length_err_1(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'content-length: line\r\n\r\n') +def test_headers_content_length_err_1(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"content-length: line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) -def test_headers_content_length_err_2(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'content-length: -1\r\n\r\n') +def test_headers_content_length_err_2(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"content-length: -1\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) -def test_invalid_header(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'test line\r\n\r\n') +def test_invalid_header(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"test line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) -def test_invalid_name(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'test[]: line\r\n\r\n') +def test_invalid_name(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"test[]: line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) -def test_max_header_field_size(parser): - name = b'test' * 10 * 1024 - text = (b'GET /test HTTP/1.1\r\n' + name + b':data\r\n\r\n') +@pytest.mark.parametrize("size", [40960, 8191]) +def test_max_header_field_size(parser, size) -> None: + name = b"t" * size + text = b"GET /test HTTP/1.1\r\n" + name + b":data\r\n\r\n" - with pytest.raises(http_exceptions.LineTooLong): + match = f"400, message='Got more than 8190 bytes \\({size}\\) when reading" + with pytest.raises(http_exceptions.LineTooLong, match=match): parser.feed_data(text) -def test_max_header_value_size(parser): - name = b'test' * 10 * 1024 - text = (b'GET /test HTTP/1.1\r\n' - b'data:' + name + b'\r\n\r\n') +def test_max_header_field_size_under_limit(parser) -> None: + name = b"t" * 8190 + text = b"GET /test HTTP/1.1\r\n" + name + b":data\r\n\r\n" - with pytest.raises(http_exceptions.LineTooLong): + messages, upgrade, tail = parser.feed_data(text) + msg = messages[0][0] + assert msg.method == "GET" + assert msg.path == "/test" + assert msg.version == (1, 1) + assert msg.headers == CIMultiDict({name.decode(): "data"}) + assert msg.raw_headers == ((name, b"data"),) + assert not msg.should_close + assert msg.compression is None + assert not msg.upgrade + assert not msg.chunked + assert msg.url == URL("/test") + + +@pytest.mark.parametrize("size", [40960, 8191]) +def test_max_header_value_size(parser, size) -> None: + name = b"t" * size + text = b"GET /test HTTP/1.1\r\n" b"data:" + name + b"\r\n\r\n" + + match = f"400, message='Got more than 8190 bytes \\({size}\\) when reading" + with pytest.raises(http_exceptions.LineTooLong, match=match): parser.feed_data(text) -def test_max_header_value_size_continuation(parser): - name = b'test' * 10 * 1024 - text = (b'GET /test HTTP/1.1\r\n' - b'data: test\r\n ' + name + b'\r\n\r\n') +def test_max_header_value_size_under_limit(parser) -> None: + value = b"A" * 8190 + text = b"GET /test HTTP/1.1\r\n" b"data:" + value + b"\r\n\r\n" - with pytest.raises(http_exceptions.LineTooLong): + messages, upgrade, tail = parser.feed_data(text) + msg = messages[0][0] + assert msg.method == "GET" + assert msg.path == "/test" + assert msg.version == (1, 1) + assert msg.headers == CIMultiDict({"data": value.decode()}) + assert msg.raw_headers == ((b"data", value),) + assert not msg.should_close + assert msg.compression is None + assert not msg.upgrade + assert not msg.chunked + assert msg.url == URL("/test") + + +@pytest.mark.parametrize("size", [40965, 8191]) +def test_max_header_value_size_continuation(parser, size) -> None: + name = b"T" * (size - 5) + text = b"GET /test HTTP/1.1\r\n" b"data: test\r\n " + name + b"\r\n\r\n" + + match = f"400, message='Got more than 8190 bytes \\({size}\\) when reading" + with pytest.raises(http_exceptions.LineTooLong, match=match): parser.feed_data(text) -def test_http_request_parser(parser): - text = b'GET /path HTTP/1.1\r\n\r\n' +def test_max_header_value_size_continuation_under_limit(parser) -> None: + value = b"A" * 8185 + text = b"GET /test HTTP/1.1\r\n" b"data: test\r\n " + value + b"\r\n\r\n" + + messages, upgrade, tail = parser.feed_data(text) + msg = messages[0][0] + assert msg.method == "GET" + assert msg.path == "/test" + assert msg.version == (1, 1) + assert msg.headers == CIMultiDict({"data": "test " + value.decode()}) + assert msg.raw_headers == ((b"data", b"test " + value),) + assert not msg.should_close + assert msg.compression is None + assert not msg.upgrade + assert not msg.chunked + assert msg.url == URL("/test") + + +def test_http_request_parser(parser) -> None: + text = b"GET /path HTTP/1.1\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] - assert msg == ('GET', '/path', (1, 1), CIMultiDict(), (), - False, None, False, False, URL('/path')) + assert msg.method == "GET" + assert msg.path == "/path" + assert msg.version == (1, 1) + assert msg.headers == CIMultiDict() + assert msg.raw_headers == () + assert not msg.should_close + assert msg.compression is None + assert not msg.upgrade + assert not msg.chunked + assert msg.url == URL("/path") -def test_http_request_bad_status_line(parser): - text = b'getpath \r\n\r\n' +def test_http_request_bad_status_line(parser) -> None: + text = b"getpath \r\n\r\n" with pytest.raises(http_exceptions.BadStatusLine): parser.feed_data(text) -def test_http_request_upgrade(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'connection: upgrade\r\n' - b'upgrade: websocket\r\n\r\n' - b'some raw data') +def test_http_request_upgrade(parser) -> None: + text = ( + b"GET /test HTTP/1.1\r\n" + b"connection: upgrade\r\n" + b"upgrade: websocket\r\n\r\n" + b"some raw data" + ) messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close assert msg.upgrade assert upgrade - assert tail == b'some raw data' + assert tail == b"some raw data" -def test_http_request_parser_utf8(parser): - text = 'GET /path HTTP/1.1\r\nx-test:тест\r\n\r\n'.encode('utf-8') +def test_http_request_parser_utf8(parser) -> None: + text = "GET /path HTTP/1.1\r\nx-test:тест\r\n\r\n".encode() messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] - assert msg == ('GET', '/path', (1, 1), - CIMultiDict([('X-TEST', 'тест')]), - ((b'x-test', 'тест'.encode('utf-8')),), - False, None, False, False, URL('/path')) + assert msg.method == "GET" + assert msg.path == "/path" + assert msg.version == (1, 1) + assert msg.headers == CIMultiDict([("X-TEST", "тест")]) + assert msg.raw_headers == ((b"x-test", "тест".encode()),) + assert not msg.should_close + assert msg.compression is None + assert not msg.upgrade + assert not msg.chunked + assert msg.url == URL("/path") -def test_http_request_parser_non_utf8(parser): - text = 'GET /path HTTP/1.1\r\nx-test:тест\r\n\r\n'.encode('cp1251') +def test_http_request_parser_non_utf8(parser) -> None: + text = "GET /path HTTP/1.1\r\nx-test:тест\r\n\r\n".encode("cp1251") msg = parser.feed_data(text)[0][0][0] - assert msg == ('GET', '/path', (1, 1), - CIMultiDict([('X-TEST', 'тест'.encode('cp1251').decode( - 'utf-8', 'surrogateescape'))]), - ((b'x-test', 'тест'.encode('cp1251')),), - False, None, False, False, URL('/path')) + assert msg.method == "GET" + assert msg.path == "/path" + assert msg.version == (1, 1) + assert msg.headers == CIMultiDict( + [("X-TEST", "тест".encode("cp1251").decode("utf8", "surrogateescape"))] + ) + assert msg.raw_headers == ((b"x-test", "тест".encode("cp1251")),) + assert not msg.should_close + assert msg.compression is None + assert not msg.upgrade + assert not msg.chunked + assert msg.url == URL("/path") -def test_http_request_parser_two_slashes(parser): - text = b'GET //path HTTP/1.1\r\n\r\n' +def test_http_request_parser_two_slashes(parser) -> None: + text = b"GET //path HTTP/1.1\r\n\r\n" msg = parser.feed_data(text)[0][0][0] - assert msg[:-1] == ('GET', '//path', (1, 1), CIMultiDict(), (), - False, None, False, False) + assert msg.method == "GET" + assert msg.path == "//path" + assert msg.url.path == "//path" + assert msg.version == (1, 1) + assert not msg.should_close + assert msg.compression is None + assert not msg.upgrade + assert not msg.chunked -def test_http_request_parser_bad_method(parser): +def test_http_request_parser_bad_method(parser) -> None: with pytest.raises(http_exceptions.BadStatusLine): - parser.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n\r\n') + parser.feed_data(b'=":(e),[T];?" /get HTTP/1.1\r\n\r\n') -def test_http_request_parser_bad_version(parser): +def test_http_request_parser_bad_version(parser) -> None: with pytest.raises(http_exceptions.BadHttpMessage): - parser.feed_data(b'GET //get HT/11\r\n\r\n') + parser.feed_data(b"GET //get HT/11\r\n\r\n") + +@pytest.mark.parametrize("size", [40965, 8191]) +def test_http_request_max_status_line(parser, size) -> None: + path = b"t" * (size - 5) + match = f"400, message='Got more than 8190 bytes \\({size}\\) when reading" + with pytest.raises(http_exceptions.LineTooLong, match=match): + parser.feed_data(b"GET /path" + path + b" HTTP/1.1\r\n\r\n") -def test_http_request_max_status_line(parser): - with pytest.raises(http_exceptions.LineTooLong): - parser.feed_data( - b'GET /path' + b'test' * 10 * 1024 + b' HTTP/1.1\r\n\r\n') +def test_http_request_max_status_line_under_limit(parser) -> None: + path = b"t" * (8190 - 5) + messages, upgraded, tail = parser.feed_data( + b"GET /path" + path + b" HTTP/1.1\r\n\r\n" + ) + msg = messages[0][0] + + assert msg.method == "GET" + assert msg.path == "/path" + path.decode() + assert msg.version == (1, 1) + assert msg.headers == CIMultiDict() + assert msg.raw_headers == () + assert not msg.should_close + assert msg.compression is None + assert not msg.upgrade + assert not msg.chunked + assert msg.url == URL("/path" + path.decode()) -def test_http_response_parser_utf8(response): - text = 'HTTP/1.1 200 Ok\r\nx-test:тест\r\n\r\n'.encode('utf-8') + +def test_http_response_parser_utf8(response) -> None: + text = "HTTP/1.1 200 Ok\r\nx-test:тест\r\n\r\n".encode() messages, upgraded, tail = response.feed_data(text) assert len(messages) == 1 @@ -411,262 +582,543 @@ def test_http_response_parser_utf8(response): assert msg.version == (1, 1) assert msg.code == 200 - assert msg.reason == 'Ok' - assert msg.headers == CIMultiDict([('X-TEST', 'тест')]) - assert msg.raw_headers == ((b'x-test', 'тест'.encode('utf-8')),) + assert msg.reason == "Ok" + assert msg.headers == CIMultiDict([("X-TEST", "тест")]) + assert msg.raw_headers == ((b"x-test", "тест".encode()),) assert not upgraded assert not tail -def test_http_response_parser_bad_status_line_too_long(response): - with pytest.raises(http_exceptions.LineTooLong): - response.feed_data( - b'HTTP/1.1 200 Ok' + b'test' * 10 * 1024 + b'\r\n\r\n') +@pytest.mark.parametrize("size", [40962, 8191]) +def test_http_response_parser_bad_status_line_too_long(response, size) -> None: + reason = b"t" * (size - 2) + match = f"400, message='Got more than 8190 bytes \\({size}\\) when reading" + with pytest.raises(http_exceptions.LineTooLong, match=match): + response.feed_data(b"HTTP/1.1 200 Ok" + reason + b"\r\n\r\n") + +def test_http_response_parser_status_line_under_limit(response) -> None: + reason = b"O" * 8190 + messages, upgraded, tail = response.feed_data( + b"HTTP/1.1 200 " + reason + b"\r\n\r\n" + ) + msg = messages[0][0] + assert msg.version == (1, 1) + assert msg.code == 200 + assert msg.reason == reason.decode() -def test_http_response_parser_bad_version(response): + +def test_http_response_parser_bad_version(response) -> None: with pytest.raises(http_exceptions.BadHttpMessage): - response.feed_data(b'HT/11 200 Ok\r\n\r\n') + response.feed_data(b"HT/11 200 Ok\r\n\r\n") -def test_http_response_parser_no_reason(response): - msg = response.feed_data(b'HTTP/1.1 200\r\n\r\n')[0][0][0] +def test_http_response_parser_no_reason(response) -> None: + msg = response.feed_data(b"HTTP/1.1 200\r\n\r\n")[0][0][0] assert msg.version == (1, 1) assert msg.code == 200 - assert not msg.reason + assert msg.reason == "" -def test_http_response_parser_bad(response): +def test_http_response_parser_bad(response) -> None: with pytest.raises(http_exceptions.BadHttpMessage): - response.feed_data(b'HTT/1\r\n\r\n') + response.feed_data(b"HTT/1\r\n\r\n") -def test_http_response_parser_code_under_100(response): - msg = response.feed_data(b'HTTP/1.1 99 test\r\n\r\n')[0][0][0] +def test_http_response_parser_code_under_100(response) -> None: + msg = response.feed_data(b"HTTP/1.1 99 test\r\n\r\n")[0][0][0] assert msg.code == 99 -def test_http_response_parser_code_above_999(response): +def test_http_response_parser_code_above_999(response) -> None: with pytest.raises(http_exceptions.BadHttpMessage): - response.feed_data(b'HTTP/1.1 9999 test\r\n\r\n') + response.feed_data(b"HTTP/1.1 9999 test\r\n\r\n") -def test_http_response_parser_code_not_int(response): +def test_http_response_parser_code_not_int(response) -> None: with pytest.raises(http_exceptions.BadHttpMessage): - response.feed_data(b'HTTP/1.1 ttt test\r\n\r\n') + response.feed_data(b"HTTP/1.1 ttt test\r\n\r\n") -def test_http_request_chunked_payload(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'transfer-encoding: chunked\r\n\r\n') +def test_http_request_chunked_payload(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"transfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] assert msg.chunked assert not payload.is_eof() - assert isinstance(payload, streams.FlowControlStreamReader) + assert isinstance(payload, streams.StreamReader) - parser.feed_data(b'4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n') + parser.feed_data(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n") - assert b'dataline' == b''.join(d for d in payload._buffer) + assert b"dataline" == b"".join(d for d in payload._buffer) + assert [4, 8] == payload._http_chunk_splits assert payload.is_eof() -def test_http_request_chunked_payload_and_next_message(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'transfer-encoding: chunked\r\n\r\n') +def test_http_request_chunked_payload_and_next_message(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"transfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] messages, upgraded, tail = parser.feed_data( - b'4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n' - b'POST /test2 HTTP/1.1\r\n' - b'transfer-encoding: chunked\r\n\r\n') + b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n" + b"POST /test2 HTTP/1.1\r\n" + b"transfer-encoding: chunked\r\n\r\n" + ) - assert b'dataline' == b''.join(d for d in payload._buffer) + assert b"dataline" == b"".join(d for d in payload._buffer) + assert [4, 8] == payload._http_chunk_splits assert payload.is_eof() assert len(messages) == 1 msg2, payload2 = messages[0] - assert msg2.method == 'POST' + assert msg2.method == "POST" assert msg2.chunked assert not payload2.is_eof() -def test_http_request_chunked_payload_chunks(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'transfer-encoding: chunked\r\n\r\n') +def test_http_request_chunked_payload_chunks(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"transfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] - parser.feed_data(b'4\r\ndata\r') - parser.feed_data(b'\n4') - parser.feed_data(b'\r') - parser.feed_data(b'\n') - parser.feed_data(b'line\r\n0\r\n') - parser.feed_data(b'test: test\r\n') + parser.feed_data(b"4\r\ndata\r") + parser.feed_data(b"\n4") + parser.feed_data(b"\r") + parser.feed_data(b"\n") + parser.feed_data(b"li") + parser.feed_data(b"ne\r\n0\r\n") + parser.feed_data(b"test: test\r\n") - assert b'dataline' == b''.join(d for d in payload._buffer) + assert b"dataline" == b"".join(d for d in payload._buffer) + assert [4, 8] == payload._http_chunk_splits assert not payload.is_eof() - parser.feed_data(b'\r\n') - assert b'dataline' == b''.join(d for d in payload._buffer) + parser.feed_data(b"\r\n") + assert b"dataline" == b"".join(d for d in payload._buffer) + assert [4, 8] == payload._http_chunk_splits assert payload.is_eof() -def test_parse_chunked_payload_chunk_extension(parser): - text = (b'GET /test HTTP/1.1\r\n' - b'transfer-encoding: chunked\r\n\r\n') +def test_parse_chunked_payload_chunk_extension(parser) -> None: + text = b"GET /test HTTP/1.1\r\n" b"transfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] - parser.feed_data( - b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest: test\r\n\r\n') + parser.feed_data(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest: test\r\n\r\n") - assert b'dataline' == b''.join(d for d in payload._buffer) + assert b"dataline" == b"".join(d for d in payload._buffer) + assert [4, 8] == payload._http_chunk_splits assert payload.is_eof() def _test_parse_no_length_or_te_on_post(loop, protocol, request_cls): parser = request_cls(protocol, loop, readall=True) - text = b'POST /test HTTP/1.1\r\n\r\n' + text = b"POST /test HTTP/1.1\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] assert payload.is_eof() -def test_parse_payload_response_without_body(loop, protocol, response_cls): - parser = response_cls(protocol, loop, response_with_body=False) - text = (b'HTTP/1.1 200 Ok\r\n' - b'content-length: 10\r\n\r\n') +def test_parse_payload_response_without_body(loop, protocol, response_cls) -> None: + parser = response_cls(protocol, loop, 2 ** 16, response_with_body=False) + text = b"HTTP/1.1 200 Ok\r\n" b"content-length: 10\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] assert payload.is_eof() -def test_parse_length_payload(response): - text = (b'HTTP/1.1 200 Ok\r\n' - b'content-length: 4\r\n\r\n') +def test_parse_length_payload(response) -> None: + text = b"HTTP/1.1 200 Ok\r\n" b"content-length: 4\r\n\r\n" msg, payload = response.feed_data(text)[0][0] assert not payload.is_eof() - response.feed_data(b'da') - response.feed_data(b't') - response.feed_data(b'aHT') + response.feed_data(b"da") + response.feed_data(b"t") + response.feed_data(b"aHT") + + assert payload.is_eof() + assert b"data" == b"".join(d for d in payload._buffer) + +def test_parse_no_length_payload(parser) -> None: + text = b"PUT / HTTP/1.1\r\n\r\n" + msg, payload = parser.feed_data(text)[0][0] + assert payload.is_eof() + + +def test_partial_url(parser) -> None: + messages, upgrade, tail = parser.feed_data(b"GET /te") + assert len(messages) == 0 + messages, upgrade, tail = parser.feed_data(b"st HTTP/1.1\r\n\r\n") + assert len(messages) == 1 + + msg, payload = messages[0] + + assert msg.method == "GET" + assert msg.path == "/test" + assert msg.version == (1, 1) assert payload.is_eof() - assert b'data' == b''.join(d for d in payload._buffer) -class TestParsePayload(unittest.TestCase): +def test_url_parse_non_strict_mode(parser) -> None: + payload = "GET /test/тест HTTP/1.1\r\n\r\n".encode() + messages, upgrade, tail = parser.feed_data(payload) + assert len(messages) == 1 + + msg, payload = messages[0] + + assert msg.method == "GET" + assert msg.path == "/test/тест" + assert msg.version == (1, 1) + assert payload.is_eof() + + +@pytest.mark.parametrize( + ("uri", "path", "query", "fragment"), + [ + ("/path%23frag", "/path#frag", {}, ""), + ("/path%2523frag", "/path%23frag", {}, ""), + ("/path?key=value%23frag", "/path", {"key": "value#frag"}, ""), + ("/path?key=value%2523frag", "/path", {"key": "value%23frag"}, ""), + ("/path#frag%20", "/path", {}, "frag "), + ("/path#frag%2520", "/path", {}, "frag%20"), + ], +) +def test_parse_uri_percent_encoded(parser, uri, path, query, fragment) -> None: + text = (f"GET {uri} HTTP/1.1\r\n\r\n").encode() + messages, upgrade, tail = parser.feed_data(text) + msg = messages[0][0] + + assert msg.path == uri + assert msg.url == URL(uri) + assert msg.url.path == path + assert msg.url.query == query + assert msg.url.fragment == fragment + + +def test_parse_uri_utf8(parser) -> None: + text = ("GET /путь?ключ=знач#фраг HTTP/1.1\r\n\r\n").encode() + messages, upgrade, tail = parser.feed_data(text) + msg = messages[0][0] + + assert msg.path == "/путь?ключ=знач#фраг" + assert msg.url.path == "/путь" + assert msg.url.query == {"ключ": "знач"} + assert msg.url.fragment == "фраг" + - def setUp(self): - self.stream = mock.Mock() - asyncio.set_event_loop(None) +def test_parse_uri_utf8_percent_encoded(parser) -> None: + text = ( + "GET %s HTTP/1.1\r\n\r\n" % quote("/путь?ключ=знач#фраг", safe="/?=#") + ).encode() + messages, upgrade, tail = parser.feed_data(text) + msg = messages[0][0] - def test_parse_eof_payload(self): - out = aiohttp.FlowControlDataQueue(self.stream) + assert msg.path == quote("/путь?ключ=знач#фраг", safe="/?=#") + assert msg.url == URL("/путь?ключ=знач#фраг") + assert msg.url.path == "/путь" + assert msg.url.query == {"ключ": "знач"} + assert msg.url.fragment == "фраг" + + +@pytest.mark.skipif( + "HttpRequestParserC" not in dir(aiohttp.http_parser), + reason="C based HTTP parser not available", +) +def test_parse_bad_method_for_c_parser_raises(loop, protocol): + payload = b"GET1 /test HTTP/1.1\r\n\r\n" + parser = HttpRequestParserC( + protocol, + loop, + 2 ** 16, + max_line_size=8190, + max_headers=32768, + max_field_size=8190, + ) + + with pytest.raises(aiohttp.http_exceptions.BadStatusLine): + messages, upgrade, tail = parser.feed_data(payload) + + +class TestParsePayload: + async def test_parse_eof_payload(self, stream) -> None: + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) p = HttpPayloadParser(out, readall=True) - p.feed_data(b'data') + p.feed_data(b"data") p.feed_eof() - self.assertTrue(out.is_eof()) - self.assertEqual([(bytearray(b'data'), 4)], list(out._buffer)) + assert out.is_eof() + assert [(bytearray(b"data"), 4)] == list(out._buffer) - def test_parse_length_payload_eof(self): - out = aiohttp.FlowControlDataQueue(self.stream) + async def test_parse_no_body(self, stream) -> None: + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + p = HttpPayloadParser(out, method="PUT") + + assert out.is_eof() + assert p.done + + async def test_parse_length_payload_eof(self, stream) -> None: + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) p = HttpPayloadParser(out, length=4) - p.feed_data(b'da') + p.feed_data(b"da") with pytest.raises(http_exceptions.ContentLengthError): p.feed_eof() - def test_parse_chunked_payload_size_error(self): - out = aiohttp.FlowControlDataQueue(self.stream) + async def test_parse_chunked_payload_size_error(self, stream) -> None: + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) p = HttpPayloadParser(out, chunked=True) - self.assertRaises( - http_exceptions.TransferEncodingError, p.feed_data, b'blah\r\n') - self.assertIsInstance( - out.exception(), http_exceptions.TransferEncodingError) + with pytest.raises(http_exceptions.TransferEncodingError): + p.feed_data(b"blah\r\n") + assert isinstance(out.exception(), http_exceptions.TransferEncodingError) - def test_http_payload_parser_length(self): - out = aiohttp.FlowControlDataQueue(self.stream) - p = HttpPayloadParser(out, length=2) - eof, tail = p.feed_data(b'1245') - self.assertTrue(eof) - - self.assertEqual(b'12', b''.join(d for d, _ in out._buffer)) - self.assertEqual(b'45', tail) - - _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) - _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) - - def test_http_payload_parser_deflate(self): - length = len(self._COMPRESSED) - out = aiohttp.FlowControlDataQueue(self.stream) - p = HttpPayloadParser( - out, length=length, compression='deflate') - p.feed_data(self._COMPRESSED) - self.assertEqual(b'data', b''.join(d for d, _ in out._buffer)) - self.assertTrue(out.is_eof()) - - def test_http_payload_parser_length_zero(self): - out = aiohttp.FlowControlDataQueue(self.stream) - p = HttpPayloadParser(out, length=0) - self.assertTrue(p.done) - self.assertTrue(out.is_eof()) + async def test_parse_chunked_payload_split_end(self, protocol) -> None: + out = aiohttp.StreamReader(protocol, 2 ** 16, loop=None) + p = HttpPayloadParser(out, chunked=True) + p.feed_data(b"4\r\nasdf\r\n0\r\n") + p.feed_data(b"\r\n") + + assert out.is_eof() + assert b"asdf" == b"".join(out._buffer) + + async def test_parse_chunked_payload_split_end2(self, protocol) -> None: + out = aiohttp.StreamReader(protocol, 2 ** 16, loop=None) + p = HttpPayloadParser(out, chunked=True) + p.feed_data(b"4\r\nasdf\r\n0\r\n\r") + p.feed_data(b"\n") + + assert out.is_eof() + assert b"asdf" == b"".join(out._buffer) + async def test_parse_chunked_payload_split_end_trailers(self, protocol) -> None: + out = aiohttp.StreamReader(protocol, 2 ** 16, loop=None) + p = HttpPayloadParser(out, chunked=True) + p.feed_data(b"4\r\nasdf\r\n0\r\n") + p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n") + p.feed_data(b"\r\n") + + assert out.is_eof() + assert b"asdf" == b"".join(out._buffer) -class TestDeflateBuffer(unittest.TestCase): + async def test_parse_chunked_payload_split_end_trailers2(self, protocol) -> None: + out = aiohttp.StreamReader(protocol, 2 ** 16, loop=None) + p = HttpPayloadParser(out, chunked=True) + p.feed_data(b"4\r\nasdf\r\n0\r\n") + p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n\r") + p.feed_data(b"\n") - def setUp(self): - self.stream = mock.Mock() - asyncio.set_event_loop(None) + assert out.is_eof() + assert b"asdf" == b"".join(out._buffer) + + async def test_parse_chunked_payload_split_end_trailers3(self, protocol) -> None: + out = aiohttp.StreamReader(protocol, 2 ** 16, loop=None) + p = HttpPayloadParser(out, chunked=True) + p.feed_data(b"4\r\nasdf\r\n0\r\nContent-MD5: ") + p.feed_data(b"912ec803b2ce49e4a541068d495ab570\r\n\r\n") - def test_feed_data(self): - buf = aiohttp.FlowControlDataQueue(self.stream) - dbuf = DeflateBuffer(buf, 'deflate') + assert out.is_eof() + assert b"asdf" == b"".join(out._buffer) - dbuf.zlib = mock.Mock() - dbuf.zlib.decompress.return_value = b'line' + async def test_parse_chunked_payload_split_end_trailers4(self, protocol) -> None: + out = aiohttp.StreamReader(protocol, 2 ** 16, loop=None) + p = HttpPayloadParser(out, chunked=True) + p.feed_data(b"4\r\nasdf\r\n0\r\n" b"C") + p.feed_data(b"ontent-MD5: 912ec803b2ce49e4a541068d495ab570\r\n\r\n") - dbuf.feed_data(b'data', 4) - self.assertEqual([b'line'], list(d for d, _ in buf._buffer)) + assert out.is_eof() + assert b"asdf" == b"".join(out._buffer) - def test_feed_data_err(self): - buf = aiohttp.FlowControlDataQueue(self.stream) - dbuf = DeflateBuffer(buf, 'deflate') + async def test_http_payload_parser_length(self, stream) -> None: + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + p = HttpPayloadParser(out, length=2) + eof, tail = p.feed_data(b"1245") + assert eof + + assert b"12" == b"".join(d for d, _ in out._buffer) + assert b"45" == tail + + async def test_http_payload_parser_deflate(self, stream) -> None: + # c=compressobj(wbits=15); b''.join([c.compress(b'data'), c.flush()]) + COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b" + + length = len(COMPRESSED) + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + p = HttpPayloadParser(out, length=length, compression="deflate") + p.feed_data(COMPRESSED) + assert b"data" == b"".join(d for d, _ in out._buffer) + assert out.is_eof() + + async def test_http_payload_parser_deflate_no_hdrs(self, stream) -> None: + """Tests incorrectly formed data (no zlib headers) """ + + # c=compressobj(wbits=-15); b''.join([c.compress(b'data'), c.flush()]) + COMPRESSED = b"KI,I\x04\x00" + + length = len(COMPRESSED) + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + p = HttpPayloadParser(out, length=length, compression="deflate") + p.feed_data(COMPRESSED) + assert b"data" == b"".join(d for d, _ in out._buffer) + assert out.is_eof() + + async def test_http_payload_parser_deflate_light(self, stream) -> None: + # c=compressobj(wbits=9); b''.join([c.compress(b'data'), c.flush()]) + COMPRESSED = b"\x18\x95KI,I\x04\x00\x04\x00\x01\x9b" + + length = len(COMPRESSED) + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + p = HttpPayloadParser(out, length=length, compression="deflate") + p.feed_data(COMPRESSED) + assert b"data" == b"".join(d for d, _ in out._buffer) + assert out.is_eof() + + async def test_http_payload_parser_deflate_split(self, stream) -> None: + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + p = HttpPayloadParser(out, compression="deflate", readall=True) + # Feeding one correct byte should be enough to choose exact + # deflate decompressor + p.feed_data(b"x", 1) + p.feed_data(b"\x9cKI,I\x04\x00\x04\x00\x01\x9b", 11) + p.feed_eof() + assert b"data" == b"".join(d for d, _ in out._buffer) + + async def test_http_payload_parser_deflate_split_err(self, stream) -> None: + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + p = HttpPayloadParser(out, compression="deflate", readall=True) + # Feeding one wrong byte should be enough to choose exact + # deflate decompressor + p.feed_data(b"K", 1) + p.feed_data(b"I,I\x04\x00", 5) + p.feed_eof() + assert b"data" == b"".join(d for d, _ in out._buffer) + + async def test_http_payload_parser_length_zero(self, stream) -> None: + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + p = HttpPayloadParser(out, length=0) + assert p.done + assert out.is_eof() + + @pytest.mark.skipif(brotli is None, reason="brotli is not installed") + async def test_http_payload_brotli(self, stream) -> None: + compressed = brotli.compress(b"brotli data") + out = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + p = HttpPayloadParser(out, length=len(compressed), compression="br") + p.feed_data(compressed) + assert b"brotli data" == b"".join(d for d, _ in out._buffer) + assert out.is_eof() + + +class TestDeflateBuffer: + async def test_feed_data(self, stream) -> None: + buf = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + dbuf = DeflateBuffer(buf, "deflate") + + dbuf.decompressor = mock.Mock() + dbuf.decompressor.decompress.return_value = b"line" + + # First byte should be b'x' in order code not to change the decoder. + dbuf.feed_data(b"xxxx", 4) + assert [b"line"] == list(d for d, _ in buf._buffer) + + async def test_feed_data_err(self, stream) -> None: + buf = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + dbuf = DeflateBuffer(buf, "deflate") exc = ValueError() - dbuf.zlib = mock.Mock() - dbuf.zlib.decompress.side_effect = exc + dbuf.decompressor = mock.Mock() + dbuf.decompressor.decompress.side_effect = exc - self.assertRaises( - http_exceptions.ContentEncodingError, dbuf.feed_data, b'data', 4) + with pytest.raises(http_exceptions.ContentEncodingError): + # Should be more than 4 bytes to trigger deflate FSM error. + # Should start with b'x', otherwise code switch mocked decoder. + dbuf.feed_data(b"xsomedata", 9) - def test_feed_eof(self): - buf = aiohttp.FlowControlDataQueue(self.stream) - dbuf = DeflateBuffer(buf, 'deflate') + async def test_feed_eof(self, stream) -> None: + buf = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + dbuf = DeflateBuffer(buf, "deflate") - dbuf.zlib = mock.Mock() - dbuf.zlib.flush.return_value = b'line' + dbuf.decompressor = mock.Mock() + dbuf.decompressor.flush.return_value = b"line" dbuf.feed_eof() - self.assertEqual([b'line'], list(d for d, _ in buf._buffer)) - self.assertTrue(buf._eof) + assert [b"line"] == list(d for d, _ in buf._buffer) + assert buf._eof + + async def test_feed_eof_err_deflate(self, stream) -> None: + buf = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + dbuf = DeflateBuffer(buf, "deflate") + + dbuf.decompressor = mock.Mock() + dbuf.decompressor.flush.return_value = b"line" + dbuf.decompressor.eof = False - def test_feed_eof_err(self): - buf = aiohttp.FlowControlDataQueue(self.stream) - dbuf = DeflateBuffer(buf, 'deflate') + with pytest.raises(http_exceptions.ContentEncodingError): + dbuf.feed_eof() - dbuf.zlib = mock.Mock() - dbuf.zlib.flush.return_value = b'line' - dbuf.zlib.eof = False + async def test_feed_eof_no_err_gzip(self, stream) -> None: + buf = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + dbuf = DeflateBuffer(buf, "gzip") - self.assertRaises(http_exceptions.ContentEncodingError, dbuf.feed_eof) + dbuf.decompressor = mock.Mock() + dbuf.decompressor.flush.return_value = b"line" + dbuf.decompressor.eof = False + + dbuf.feed_eof() + assert [b"line"] == list(d for d, _ in buf._buffer) + + async def test_feed_eof_no_err_brotli(self, stream) -> None: + buf = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + dbuf = DeflateBuffer(buf, "br") + + dbuf.decompressor = mock.Mock() + dbuf.decompressor.flush.return_value = b"line" + dbuf.decompressor.eof = False + + dbuf.feed_eof() + assert [b"line"] == list(d for d, _ in buf._buffer) - def test_empty_body(self): - buf = aiohttp.FlowControlDataQueue(self.stream) - dbuf = DeflateBuffer(buf, 'deflate') + async def test_empty_body(self, stream) -> None: + buf = aiohttp.FlowControlDataQueue( + stream, 2 ** 16, loop=asyncio.get_event_loop() + ) + dbuf = DeflateBuffer(buf, "deflate") dbuf.feed_eof() - self.assertTrue(buf.at_eof()) + assert buf.at_eof() diff --git a/tests/test_http_stream_writer.py b/tests/test_http_stream_writer.py deleted file mode 100644 index d7b962cbc32..00000000000 --- a/tests/test_http_stream_writer.py +++ /dev/null @@ -1,318 +0,0 @@ -import socket -from unittest import mock - -import pytest - -from aiohttp.http_writer import CORK, PayloadWriter, StreamWriter - -has_ipv6 = socket.has_ipv6 -if has_ipv6: - # The socket.has_ipv6 flag may be True if Python was built with IPv6 - # support, but the target system still may not have it. - # So let's ensure that we really have IPv6 support. - try: - socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - except OSError: - has_ipv6 = False - - -# nodelay - -def test_nodelay_default(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - assert not writer.tcp_nodelay - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -def test_set_nodelay_no_change(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(False) - assert not writer.tcp_nodelay - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -def test_set_nodelay_exception(loop): - transport = mock.Mock() - s = mock.Mock() - s.setsockopt = mock.Mock() - s.family = socket.AF_INET - s.setsockopt.side_effect = OSError - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert not writer.tcp_nodelay - - -def test_set_nodelay_enable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert writer.tcp_nodelay - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -def test_set_nodelay_enable_and_disable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - writer.set_tcp_nodelay(False) - assert not writer.tcp_nodelay - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") -def test_set_nodelay_enable_ipv6(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert writer.tcp_nodelay - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), - reason="requires unix sockets") -def test_set_nodelay_enable_unix(loop): - # do not set nodelay for unix socket - transport = mock.Mock() - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert not writer.tcp_nodelay - - -def test_set_nodelay_enable_no_socket(loop): - transport = mock.Mock() - transport.get_extra_info.return_value = None - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert not writer.tcp_nodelay - assert writer._socket is None - - -# cork - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_cork_default(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_no_change(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(False) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert writer.tcp_cork - assert s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_and_disable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - writer.set_tcp_cork(False) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_ipv6(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert writer.tcp_cork - assert s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), - reason="requires unix sockets") -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_unix(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert not writer.tcp_cork - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_no_socket(loop): - transport = mock.Mock() - transport.get_extra_info.return_value = None - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert not writer.tcp_cork - assert writer._socket is None - - -def test_set_cork_exception(loop): - transport = mock.Mock() - s = mock.Mock() - s.setsockopt = mock.Mock() - s.family = socket.AF_INET - s.setsockopt.side_effect = OSError - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert not writer.tcp_cork - - -# cork and nodelay interference - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_enabling_cork_disables_nodelay(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - writer.set_tcp_cork(True) - assert not writer.tcp_nodelay - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - assert writer.tcp_cork - assert s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_enabling_nodelay_disables_cork(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - writer.set_tcp_nodelay(True) - assert writer.tcp_nodelay - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) - - -# payload writers management - -def test_acquire(loop): - transport = mock.Mock() - stream = StreamWriter(mock.Mock(), transport, loop) - assert stream.available - - payload = PayloadWriter(stream, loop) - assert not stream.available - assert payload._transport is transport - - payload2 = PayloadWriter(stream, loop) - assert payload2._transport is None - assert payload2 in stream._waiters - - -def test_acquire2(loop): - transport = mock.Mock() - stream = StreamWriter(mock.Mock(), transport, loop) - - payload = PayloadWriter(stream, loop) - stream.release() - assert stream.available - - stream.acquire(payload) - assert not stream.available - assert payload._transport is transport - - -def test_release(loop): - transport = mock.Mock() - stream = StreamWriter(mock.Mock(), transport, loop) - stream.available = False - - writer = mock.Mock() - - stream.acquire(writer) - assert not stream.available - assert not writer.set_transport.called - - stream.release() - assert not stream.available - writer.set_transport.assert_called_with(transport) - - stream.release() - assert stream.available - - -def test_replace(loop): - transport = mock.Mock() - stream = StreamWriter(mock.Mock(), transport, loop) - stream.available = False - - payload = PayloadWriter(stream, loop) - assert payload._transport is None - assert payload in stream._waiters - - payload2 = stream.replace(payload, PayloadWriter) - assert payload2._transport is None - assert payload2 in stream._waiters - assert payload not in stream._waiters - - stream.release() - assert payload2._transport is transport - assert not stream._waiters - - -def test_replace_available(loop): - transport = mock.Mock() - stream = StreamWriter(mock.Mock(), transport, loop) - - payload = PayloadWriter(stream, loop, False) - payload2 = stream.replace(payload, PayloadWriter) - assert payload2._transport is transport - assert payload2 not in stream._waiters diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 5c5d69b6245..6aca2ea2d9a 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -1,12 +1,11 @@ -"""Tests for aiohttp/http_writer.py""" - -import asyncio -import zlib +# Tests for aiohttp/http_writer.py +import array from unittest import mock import pytest from aiohttp import http +from aiohttp.test_utils import make_mocked_coro @pytest.fixture @@ -22,136 +21,228 @@ def write(chunk): buf.extend(chunk) transport.write.side_effect = write + transport.is_closing.return_value = False return transport @pytest.fixture -def stream(transport): - stream = mock.Mock(transport=transport) +def protocol(loop, transport): + protocol = mock.Mock(transport=transport) + protocol._drain_helper = make_mocked_coro() + return protocol - def acquire(writer): - writer.set_transport(transport) - stream.acquire = acquire - stream.drain.return_value = () - return stream +def test_payloadwriter_properties(transport, protocol, loop) -> None: + writer = http.StreamWriter(protocol, loop) + assert writer.protocol == protocol + assert writer.transport == transport -def test_write_payload_eof(stream, loop): - write = stream.transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_eof(transport, protocol, loop) -> None: + write = transport.write = mock.Mock() + msg = http.StreamWriter(protocol, loop) - msg.write(b'data1') - msg.write(b'data2') - msg.write_eof() + await msg.write(b"data1") + await msg.write(b"data2") + await msg.write_eof() - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - assert b'data1data2' == content.split(b'\r\n\r\n', 1)[-1] + content = b"".join([c[1][0] for c in list(write.mock_calls)]) + assert b"data1data2" == content.split(b"\r\n\r\n", 1)[-1] -@asyncio.coroutine -def test_write_payload_chunked(buf, stream, loop): - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_chunked(buf, protocol, transport, loop) -> None: + msg = http.StreamWriter(protocol, loop) msg.enable_chunking() - msg.write(b'data') - yield from msg.write_eof() + await msg.write(b"data") + await msg.write_eof() - assert b'4\r\ndata\r\n0\r\n\r\n' == buf + assert b"4\r\ndata\r\n0\r\n\r\n" == buf -@asyncio.coroutine -def test_write_payload_chunked_multiple(buf, stream, loop): - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_chunked_multiple(buf, protocol, transport, loop) -> None: + msg = http.StreamWriter(protocol, loop) msg.enable_chunking() - msg.write(b'data1') - msg.write(b'data2') - yield from msg.write_eof() + await msg.write(b"data1") + await msg.write(b"data2") + await msg.write_eof() - assert b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n' == buf + assert b"5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n" == buf -@asyncio.coroutine -def test_write_payload_length(stream, loop): - write = stream.transport.write = mock.Mock() +async def test_write_payload_length(protocol, transport, loop) -> None: + write = transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) + msg = http.StreamWriter(protocol, loop) msg.length = 2 - msg.write(b'd') - msg.write(b'ata') - yield from msg.write_eof() + await msg.write(b"d") + await msg.write(b"ata") + await msg.write_eof() - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - assert b'da' == content.split(b'\r\n\r\n', 1)[-1] + content = b"".join([c[1][0] for c in list(write.mock_calls)]) + assert b"da" == content.split(b"\r\n\r\n", 1)[-1] -@asyncio.coroutine -def test_write_payload_chunked_filter(stream, loop): - write = stream.transport.write = mock.Mock() +async def test_write_payload_chunked_filter(protocol, transport, loop) -> None: + write = transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) + msg = http.StreamWriter(protocol, loop) msg.enable_chunking() - msg.write(b'da') - msg.write(b'ta') - yield from msg.write_eof() + await msg.write(b"da") + await msg.write(b"ta") + await msg.write_eof() - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - assert content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n') + content = b"".join([c[1][0] for c in list(write.mock_calls)]) + assert content.endswith(b"2\r\nda\r\n2\r\nta\r\n0\r\n\r\n") -@asyncio.coroutine -def test_write_payload_chunked_filter_mutiple_chunks(stream, loop): - write = stream.transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_chunked_filter_mutiple_chunks(protocol, transport, loop): + write = transport.write = mock.Mock() + msg = http.StreamWriter(protocol, loop) msg.enable_chunking() - msg.write(b'da') - msg.write(b'ta') - msg.write(b'1d') - msg.write(b'at') - msg.write(b'a2') - yield from msg.write_eof() - content = b''.join([c[1][0] for c in list(write.mock_calls)]) + await msg.write(b"da") + await msg.write(b"ta") + await msg.write(b"1d") + await msg.write(b"at") + await msg.write(b"a2") + await msg.write_eof() + content = b"".join([c[1][0] for c in list(write.mock_calls)]) assert content.endswith( - b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' - b'2\r\na2\r\n0\r\n\r\n') - + b"2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n" b"2\r\na2\r\n0\r\n\r\n" + ) -compressor = zlib.compressobj(wbits=-zlib.MAX_WBITS) -COMPRESSED = b''.join([compressor.compress(b'data'), compressor.flush()]) +async def test_write_payload_deflate_compression(protocol, transport, loop) -> None: -@asyncio.coroutine -def test_write_payload_deflate_compression(stream, loop): - write = stream.transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) - msg.enable_compression('deflate') - msg.write(b'data') - yield from msg.write_eof() + COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b" + write = transport.write = mock.Mock() + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + await msg.write(b"data") + await msg.write_eof() chunks = [c[1][0] for c in list(write.mock_calls)] assert all(chunks) - content = b''.join(chunks) - assert COMPRESSED == content.split(b'\r\n\r\n', 1)[-1] + content = b"".join(chunks) + assert COMPRESSED == content.split(b"\r\n\r\n", 1)[-1] + + +async def test_write_payload_deflate_and_chunked(buf, protocol, transport, loop): + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + + await msg.write(b"da") + await msg.write(b"ta") + await msg.write_eof() + + thing = b"2\r\nx\x9c\r\n" b"a\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n" b"0\r\n\r\n" + assert thing == buf + + +async def test_write_payload_bytes_memoryview(buf, protocol, transport, loop): + + msg = http.StreamWriter(protocol, loop) + + mv = memoryview(b"abcd") + + await msg.write(mv) + await msg.write_eof() + + thing = b"abcd" + assert thing == buf + + +async def test_write_payload_short_ints_memoryview(buf, protocol, transport, loop): + msg = http.StreamWriter(protocol, loop) + msg.enable_chunking() + + payload = memoryview(array.array("H", [65, 66, 67])) + + await msg.write(payload) + await msg.write_eof() + + endians = ( + (b"6\r\n" b"\x00A\x00B\x00C\r\n" b"0\r\n\r\n"), + (b"6\r\n" b"A\x00B\x00C\x00\r\n" b"0\r\n\r\n"), + ) + assert buf in endians -@asyncio.coroutine -def test_write_payload_deflate_and_chunked(buf, stream, loop): - msg = http.PayloadWriter(stream, loop) - msg.enable_compression('deflate') +async def test_write_payload_2d_shape_memoryview(buf, protocol, transport, loop): + msg = http.StreamWriter(protocol, loop) msg.enable_chunking() - msg.write(b'da') - msg.write(b'ta') - yield from msg.write_eof() + mv = memoryview(b"ABCDEF") + payload = mv.cast("c", [3, 2]) - assert b'6\r\nKI,I\x04\x00\r\n0\r\n\r\n' == buf + await msg.write(payload) + await msg.write_eof() + thing = b"6\r\n" b"ABCDEF\r\n" b"0\r\n\r\n" + assert thing == buf -def test_write_drain(stream, loop): - msg = http.PayloadWriter(stream, loop) - msg.drain = mock.Mock() - msg.write(b'1' * (64 * 1024 * 2), drain=False) + +async def test_write_payload_slicing_long_memoryview(buf, protocol, transport, loop): + msg = http.StreamWriter(protocol, loop) + msg.length = 4 + + mv = memoryview(b"ABCDEF") + payload = mv.cast("c", [3, 2]) + + await msg.write(payload) + await msg.write_eof() + + thing = b"ABCD" + assert thing == buf + + +async def test_write_drain(protocol, transport, loop) -> None: + msg = http.StreamWriter(protocol, loop) + msg.drain = make_mocked_coro() + await msg.write(b"1" * (64 * 1024 * 2), drain=False) assert not msg.drain.called - msg.write(b'1', drain=True) + await msg.write(b"1", drain=True) assert msg.drain.called assert msg.buffer_size == 0 + + +async def test_write_calls_callback(protocol, transport, loop) -> None: + on_chunk_sent = make_mocked_coro() + msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent) + chunk = b"1" + await msg.write(chunk) + assert on_chunk_sent.called + assert on_chunk_sent.call_args == mock.call(chunk) + + +async def test_write_eof_calls_callback(protocol, transport, loop) -> None: + on_chunk_sent = make_mocked_coro() + msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent) + chunk = b"1" + await msg.write_eof(chunk=chunk) + assert on_chunk_sent.called + assert on_chunk_sent.call_args == mock.call(chunk) + + +async def test_write_to_closing_transport(protocol, transport, loop) -> None: + msg = http.StreamWriter(protocol, loop) + + await msg.write(b"Before closing") + transport.is_closing.return_value = True + + with pytest.raises(ConnectionResetError): + await msg.write(b"After closing") + + +async def test_drain(protocol, transport, loop) -> None: + msg = http.StreamWriter(protocol, loop) + await msg.drain() + assert protocol._drain_helper.called + + +async def test_drain_no_transport(protocol, transport, loop) -> None: + msg = http.StreamWriter(protocol, loop) + msg._protocol.transport = None + await msg.drain() + assert not protocol._drain_helper.called diff --git a/tests/test_locks.py b/tests/test_locks.py new file mode 100644 index 00000000000..55fd2330ec4 --- /dev/null +++ b/tests/test_locks.py @@ -0,0 +1,54 @@ +# Tests of custom aiohttp locks implementations +import asyncio + +import pytest + +from aiohttp.locks import EventResultOrError + + +class TestEventResultOrError: + async def test_set_exception(self, loop) -> None: + ev = EventResultOrError(loop=loop) + + async def c(): + try: + await ev.wait() + except Exception as e: + return e + return 1 + + t = loop.create_task(c()) + await asyncio.sleep(0, loop=loop) + e = Exception() + ev.set(exc=e) + assert (await t) == e + + async def test_set(self, loop) -> None: + ev = EventResultOrError(loop=loop) + + async def c(): + await ev.wait() + return 1 + + t = loop.create_task(c()) + await asyncio.sleep(0, loop=loop) + ev.set() + assert (await t) == 1 + + async def test_cancel_waiters(self, loop) -> None: + ev = EventResultOrError(loop=loop) + + async def c(): + await ev.wait() + + t1 = loop.create_task(c()) + t2 = loop.create_task(c()) + await asyncio.sleep(0, loop=loop) + ev.cancel() + ev.set() + + with pytest.raises(asyncio.CancelledError): + await t1 + + with pytest.raises(asyncio.CancelledError): + await t2 diff --git a/tests/test_loop.py b/tests/test_loop.py new file mode 100644 index 00000000000..24c979ebd55 --- /dev/null +++ b/tests/test_loop.py @@ -0,0 +1,43 @@ +import asyncio +import platform +import threading + +import pytest + +from aiohttp import web +from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop + + +@pytest.mark.skipif( + platform.system() == "Windows", reason="the test is not valid for Windows" +) +async def test_subprocess_co(loop) -> None: + assert isinstance(threading.current_thread(), threading._MainThread) + proc = await asyncio.create_subprocess_shell( + "exit 0", + stdin=asyncio.subprocess.DEVNULL, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + await proc.wait() + + +class TestCase(AioHTTPTestCase): + async def get_application(self): + app = web.Application() + app.on_startup.append(self.on_startup_hook) + return app + + async def on_startup_hook(self, app): + self.on_startup_called = True + + @unittest_run_loop + async def test_on_startup_hook(self) -> None: + self.assertTrue(self.on_startup_called) + + def test_default_loop(self) -> None: + self.assertIs(self.loop, asyncio.get_event_loop()) + + +def test_default_loop(loop) -> None: + assert asyncio.get_event_loop() is loop diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 8dcc6aadc4e..6c3f1214d9e 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1,21 +1,26 @@ import asyncio -import functools import io -import unittest +import json +import sys import zlib from unittest import mock import pytest -import aiohttp.multipart -from aiohttp import helpers, payload -from aiohttp.hdrs import (CONTENT_DISPOSITION, CONTENT_ENCODING, - CONTENT_TRANSFER_ENCODING, CONTENT_TYPE) +import aiohttp +from aiohttp import payload +from aiohttp.hdrs import ( + CONTENT_DISPOSITION, + CONTENT_ENCODING, + CONTENT_TRANSFER_ENCODING, + CONTENT_TYPE, +) from aiohttp.helpers import parse_mimetype -from aiohttp.multipart import (content_disposition_filename, - parse_content_disposition) -from aiohttp.streams import DEFAULT_LIMIT as stream_reader_default_limit +from aiohttp.multipart import MultipartResponseWrapper from aiohttp.streams import StreamReader +from aiohttp.test_utils import make_mocked_coro + +BOUNDARY = b"--:" @pytest.fixture @@ -27,9 +32,8 @@ def buf(): def stream(buf): writer = mock.Mock() - def write(chunk): + async def write(chunk): buf.extend(chunk) - return () writer.write.side_effect = write return writer @@ -37,63 +41,26 @@ def write(chunk): @pytest.fixture def writer(): - return aiohttp.multipart.MultipartWriter(boundary=':') - - -def run_in_loop(f): - @functools.wraps(f) - def wrapper(testcase, *args, **kwargs): - coro = asyncio.coroutine(f) - future = asyncio.wait_for(coro(testcase, *args, **kwargs), timeout=5) - return testcase.loop.run_until_complete(future) - return wrapper - - -class MetaAioTestCase(type): - - def __new__(cls, name, bases, attrs): - for key, obj in attrs.items(): - if key.startswith('test_'): - attrs[key] = run_in_loop(obj) - return super().__new__(cls, name, bases, attrs) - - -class TestCase(unittest.TestCase, metaclass=MetaAioTestCase): + return aiohttp.MultipartWriter(boundary=":") - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - - def future(self, obj): - fut = helpers.create_future(self.loop) - fut.set_result(obj) - return fut - - -class Response(object): +class Response: def __init__(self, headers, content): self.headers = headers self.content = content -class Stream(object): - +class Stream: def __init__(self, content): self.content = io.BytesIO(content) - @asyncio.coroutine - def read(self, size=None): + async def read(self, size=None): return self.content.read(size) def at_eof(self): return self.content.tell() == len(self.content.getbuffer()) - @asyncio.coroutine - def readline(self): + async def readline(self): return self.content.readline() def unread_data(self, data): @@ -101,1595 +68,1194 @@ def unread_data(self, data): class StreamWithShortenRead(Stream): - def __init__(self, content): self._first = True super().__init__(content) - @asyncio.coroutine - def read(self, size=None): + async def read(self, size=None): if size is not None and self._first: self._first = False size = size // 2 - return (yield from super().read(size)) - - -class MultipartResponseWrapperTestCase(TestCase): - - def setUp(self): - super().setUp() - wrapper = aiohttp.multipart.MultipartResponseWrapper(mock.Mock(), - mock.Mock()) - self.wrapper = wrapper - - def test_at_eof(self): - self.wrapper.at_eof() - self.assertTrue(self.wrapper.resp.content.at_eof.called) - - def test_next(self): - self.wrapper.stream.next.return_value = self.future(b'') - self.wrapper.stream.at_eof.return_value = False - yield from self.wrapper.next() - self.assertTrue(self.wrapper.stream.next.called) - - def test_release(self): - self.wrapper.resp.release.return_value = self.future(None) - yield from self.wrapper.release() - self.assertTrue(self.wrapper.resp.release.called) - - def test_release_when_stream_at_eof(self): - self.wrapper.resp.release.return_value = self.future(None) - self.wrapper.stream.next.return_value = self.future(b'') - self.wrapper.stream.at_eof.return_value = True - yield from self.wrapper.next() - self.assertTrue(self.wrapper.stream.next.called) - self.assertTrue(self.wrapper.resp.release.called) - - -class PartReaderTestCase(TestCase): - - def setUp(self): - super().setUp() - self.boundary = b'--:' - - def test_next(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, Stream(b'Hello, world!\r\n--:')) - result = yield from obj.next() - self.assertEqual(b'Hello, world!', result) - self.assertTrue(obj.at_eof()) - - def test_next_next(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, Stream(b'Hello, world!\r\n--:')) - result = yield from obj.next() - self.assertEqual(b'Hello, world!', result) - self.assertTrue(obj.at_eof()) - result = yield from obj.next() - self.assertIsNone(result) - - def test_read(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, Stream(b'Hello, world!\r\n--:')) - result = yield from obj.read() - self.assertEqual(b'Hello, world!', result) - self.assertTrue(obj.at_eof()) - - def test_read_chunk_at_eof(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, Stream(b'--:')) + return await super().read(size) + + +class TestMultipartResponseWrapper: + def test_at_eof(self) -> None: + wrapper = MultipartResponseWrapper(mock.Mock(), mock.Mock()) + wrapper.at_eof() + assert wrapper.resp.content.at_eof.called + + async def test_next(self) -> None: + wrapper = MultipartResponseWrapper(mock.Mock(), mock.Mock()) + wrapper.stream.next = make_mocked_coro(b"") + wrapper.stream.at_eof.return_value = False + await wrapper.next() + assert wrapper.stream.next.called + + async def test_release(self) -> None: + wrapper = MultipartResponseWrapper(mock.Mock(), mock.Mock()) + wrapper.resp.release = make_mocked_coro(None) + await wrapper.release() + assert wrapper.resp.release.called + + async def test_release_when_stream_at_eof(self) -> None: + wrapper = MultipartResponseWrapper(mock.Mock(), mock.Mock()) + wrapper.resp.release = make_mocked_coro(None) + wrapper.stream.next = make_mocked_coro(b"") + wrapper.stream.at_eof.return_value = True + await wrapper.next() + assert wrapper.stream.next.called + assert wrapper.resp.release.called + + +class TestPartReader: + async def test_next(self) -> None: + obj = aiohttp.BodyPartReader(BOUNDARY, {}, Stream(b"Hello, world!\r\n--:")) + result = await obj.next() + assert b"Hello, world!" == result + assert obj.at_eof() + + async def test_next_next(self) -> None: + obj = aiohttp.BodyPartReader(BOUNDARY, {}, Stream(b"Hello, world!\r\n--:")) + result = await obj.next() + assert b"Hello, world!" == result + assert obj.at_eof() + result = await obj.next() + assert result is None + + async def test_read(self) -> None: + obj = aiohttp.BodyPartReader(BOUNDARY, {}, Stream(b"Hello, world!\r\n--:")) + result = await obj.read() + assert b"Hello, world!" == result + assert obj.at_eof() + + async def test_read_chunk_at_eof(self) -> None: + obj = aiohttp.BodyPartReader(BOUNDARY, {}, Stream(b"--:")) obj._at_eof = True - result = yield from obj.read_chunk() - self.assertEqual(b'', result) - - def test_read_chunk_without_content_length(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, Stream(b'Hello, world!\r\n--:')) - c1 = yield from obj.read_chunk(8) - c2 = yield from obj.read_chunk(8) - c3 = yield from obj.read_chunk(8) - self.assertEqual(c1 + c2, b'Hello, world!') - self.assertEqual(c3, b'') - - def test_read_incomplete_chunk(self): - stream = Stream(b'') - - def prepare(data): - f = helpers.create_future(self.loop) - f.set_result(data) - return f - - with mock.patch.object(stream, 'read', side_effect=[ - prepare(b'Hello, '), - prepare(b'World'), - prepare(b'!\r\n--:'), - prepare(b'') - ]): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, stream) - c1 = yield from obj.read_chunk(8) - self.assertEqual(c1, b'Hello, ') - c2 = yield from obj.read_chunk(8) - self.assertEqual(c2, b'World') - c3 = yield from obj.read_chunk(8) - self.assertEqual(c3, b'!') - - def test_read_all_at_once(self): - stream = Stream(b'Hello, World!\r\n--:--\r\n') - obj = aiohttp.multipart.BodyPartReader(self.boundary, {}, stream) - result = yield from obj.read_chunk() - self.assertEqual(b'Hello, World!', result) - result = yield from obj.read_chunk() - self.assertEqual(b'', result) - self.assertTrue(obj.at_eof()) - - def test_read_incomplete_body_chunked(self): - stream = Stream(b'Hello, World!\r\n-') - obj = aiohttp.multipart.BodyPartReader(self.boundary, {}, stream) - result = b'' - with self.assertRaises(AssertionError): + result = await obj.read_chunk() + assert b"" == result + + async def test_read_chunk_without_content_length(self) -> None: + obj = aiohttp.BodyPartReader(BOUNDARY, {}, Stream(b"Hello, world!\r\n--:")) + c1 = await obj.read_chunk(8) + c2 = await obj.read_chunk(8) + c3 = await obj.read_chunk(8) + assert c1 + c2 == b"Hello, world!" + assert c3 == b"" + + async def test_read_incomplete_chunk(self) -> None: + stream = Stream(b"") + + if sys.version_info >= (3, 8, 1): + # Workaround for a weird behavior of patch.object + def prepare(data): + return data + + else: + + async def prepare(data): + return data + + with mock.patch.object( + stream, + "read", + side_effect=[ + prepare(b"Hello, "), + prepare(b"World"), + prepare(b"!\r\n--:"), + prepare(b""), + ], + ): + obj = aiohttp.BodyPartReader(BOUNDARY, {}, stream) + c1 = await obj.read_chunk(8) + assert c1 == b"Hello, " + c2 = await obj.read_chunk(8) + assert c2 == b"World" + c3 = await obj.read_chunk(8) + assert c3 == b"!" + + async def test_read_all_at_once(self) -> None: + stream = Stream(b"Hello, World!\r\n--:--\r\n") + obj = aiohttp.BodyPartReader(BOUNDARY, {}, stream) + result = await obj.read_chunk() + assert b"Hello, World!" == result + result = await obj.read_chunk() + assert b"" == result + assert obj.at_eof() + + async def test_read_incomplete_body_chunked(self) -> None: + stream = Stream(b"Hello, World!\r\n-") + obj = aiohttp.BodyPartReader(BOUNDARY, {}, stream) + result = b"" + with pytest.raises(AssertionError): for _ in range(4): - result += yield from obj.read_chunk(7) - self.assertEqual(b'Hello, World!\r\n-', result) - - def test_read_boundary_with_incomplete_chunk(self): - stream = Stream(b'') - - def prepare(data): - f = helpers.create_future(self.loop) - f.set_result(data) - return f - - with mock.patch.object(stream, 'read', side_effect=[ - prepare(b'Hello, World'), - prepare(b'!\r\n'), - prepare(b'--:'), - prepare(b'') - ]): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, stream) - c1 = yield from obj.read_chunk(12) - self.assertEqual(c1, b'Hello, World') - c2 = yield from obj.read_chunk(8) - self.assertEqual(c2, b'!') - c3 = yield from obj.read_chunk(8) - self.assertEqual(c3, b'') - - def test_multi_read_chunk(self): - stream = Stream(b'Hello,\r\n--:\r\n\r\nworld!\r\n--:--') - obj = aiohttp.multipart.BodyPartReader(self.boundary, {}, stream) - result = yield from obj.read_chunk(8) - self.assertEqual(b'Hello,', result) - result = yield from obj.read_chunk(8) - self.assertEqual(b'', result) - self.assertTrue(obj.at_eof()) - - def test_read_chunk_properly_counts_read_bytes(self): - expected = b'.' * 10 + result += await obj.read_chunk(7) + assert b"Hello, World!\r\n-" == result + + async def test_read_boundary_with_incomplete_chunk(self) -> None: + stream = Stream(b"") + + if sys.version_info >= (3, 8, 1): + # Workaround for weird 3.8.1 patch.object() behavior + def prepare(data): + return data + + else: + + async def prepare(data): + return data + + with mock.patch.object( + stream, + "read", + side_effect=[ + prepare(b"Hello, World"), + prepare(b"!\r\n"), + prepare(b"--:"), + prepare(b""), + ], + ): + obj = aiohttp.BodyPartReader(BOUNDARY, {}, stream) + c1 = await obj.read_chunk(12) + assert c1 == b"Hello, World" + c2 = await obj.read_chunk(8) + assert c2 == b"!" + c3 = await obj.read_chunk(8) + assert c3 == b"" + + async def test_multi_read_chunk(self) -> None: + stream = Stream(b"Hello,\r\n--:\r\n\r\nworld!\r\n--:--") + obj = aiohttp.BodyPartReader(BOUNDARY, {}, stream) + result = await obj.read_chunk(8) + assert b"Hello," == result + result = await obj.read_chunk(8) + assert b"" == result + assert obj.at_eof() + + async def test_read_chunk_properly_counts_read_bytes(self) -> None: + expected = b"." * 10 size = len(expected) - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {'CONTENT-LENGTH': size}, - StreamWithShortenRead(expected + b'\r\n--:--')) + obj = aiohttp.BodyPartReader( + BOUNDARY, + {"CONTENT-LENGTH": size}, + StreamWithShortenRead(expected + b"\r\n--:--"), + ) result = bytearray() while True: - chunk = yield from obj.read_chunk() + chunk = await obj.read_chunk() if not chunk: break result.extend(chunk) - self.assertEqual(size, len(result)) - self.assertEqual(b'.' * size, result) - self.assertTrue(obj.at_eof()) - - def test_read_does_not_read_boundary(self): - stream = Stream(b'Hello, world!\r\n--:') - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, stream) - result = yield from obj.read() - self.assertEqual(b'Hello, world!', result) - self.assertEqual(b'--:', (yield from stream.read())) - - def test_multiread(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, Stream(b'Hello,\r\n--:\r\n\r\nworld!\r\n--:--')) - result = yield from obj.read() - self.assertEqual(b'Hello,', result) - result = yield from obj.read() - self.assertEqual(b'', result) - self.assertTrue(obj.at_eof()) - - def test_read_multiline(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, Stream(b'Hello\n,\r\nworld!\r\n--:--')) - result = yield from obj.read() - self.assertEqual(b'Hello\n,\r\nworld!', result) - result = yield from obj.read() - self.assertEqual(b'', result) - self.assertTrue(obj.at_eof()) - - def test_read_respects_content_length(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {'CONTENT-LENGTH': 100500}, - Stream(b'.' * 100500 + b'\r\n--:--')) - result = yield from obj.read() - self.assertEqual(b'.' * 100500, result) - self.assertTrue(obj.at_eof()) - - def test_read_with_content_encoding_gzip(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_ENCODING: 'gzip'}, - Stream(b'\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03\x0b\xc9\xccMU' - b'(\xc9W\x08J\xcdI\xacP\x04\x00$\xfb\x9eV\x0e\x00\x00\x00' - b'\r\n--:--')) - result = yield from obj.read(decode=True) - self.assertEqual(b'Time to Relax!', result) - - def test_read_with_content_encoding_deflate(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_ENCODING: 'deflate'}, - Stream(b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--')) - result = yield from obj.read(decode=True) - self.assertEqual(b'Time to Relax!', result) - - def test_read_with_content_encoding_identity(self): - thing = (b'\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03\x0b\xc9\xccMU' - b'(\xc9W\x08J\xcdI\xacP\x04\x00$\xfb\x9eV\x0e\x00\x00\x00' - b'\r\n') - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_ENCODING: 'identity'}, - Stream(thing + b'--:--')) - result = yield from obj.read(decode=True) - self.assertEqual(thing[:-2], result) - - def test_read_with_content_encoding_unknown(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_ENCODING: 'snappy'}, - Stream(b'\x0e4Time to Relax!\r\n--:--')) - with self.assertRaises(RuntimeError): - yield from obj.read(decode=True) - - def test_read_with_content_transfer_encoding_base64(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TRANSFER_ENCODING: 'base64'}, - Stream(b'VGltZSB0byBSZWxheCE=\r\n--:--')) - result = yield from obj.read(decode=True) - self.assertEqual(b'Time to Relax!', result) - - def test_read_with_content_transfer_encoding_quoted_printable(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TRANSFER_ENCODING: 'quoted-printable'}, - Stream(b'=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82,' - b' =D0=BC=D0=B8=D1=80!\r\n--:--')) - result = yield from obj.read(decode=True) - self.assertEqual(b'\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82,' - b' \xd0\xbc\xd0\xb8\xd1\x80!', result) - - @pytest.mark.parametrize('encoding', []) - def test_read_with_content_transfer_encoding_binary(self): - data = b'\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82,' \ - b' \xd0\xbc\xd0\xb8\xd1\x80!' - for encoding in ('binary', '8bit', '7bit'): - with self.subTest(encoding): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TRANSFER_ENCODING: encoding}, - Stream(data + b'\r\n--:--')) - result = yield from obj.read(decode=True) - self.assertEqual(data, result) - - def test_read_with_content_transfer_encoding_unknown(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TRANSFER_ENCODING: 'unknown'}, - Stream(b'\x0e4Time to Relax!\r\n--:--')) - with self.assertRaises(RuntimeError): - yield from obj.read(decode=True) - - def test_read_text(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, Stream(b'Hello, world!\r\n--:--')) - result = yield from obj.text() - self.assertEqual('Hello, world!', result) - - def test_read_text_default_encoding(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, - Stream('Привет, Мир!\r\n--:--'.encode('utf-8'))) - result = yield from obj.text() - self.assertEqual('Привет, Мир!', result) - - def test_read_text_encoding(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, - Stream('Привет, Мир!\r\n--:--'.encode('cp1251'))) - result = yield from obj.text(encoding='cp1251') - self.assertEqual('Привет, Мир!', result) - - def test_read_text_guess_encoding(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TYPE: 'text/plain;charset=cp1251'}, - Stream('Привет, Мир!\r\n--:--'.encode('cp1251'))) - result = yield from obj.text() - self.assertEqual('Привет, Мир!', result) - - def test_read_text_compressed(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_ENCODING: 'deflate', - CONTENT_TYPE: 'text/plain'}, - Stream(b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--')) - result = yield from obj.text() - self.assertEqual('Time to Relax!', result) - - def test_read_text_while_closed(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TYPE: 'text/plain'}, Stream(b'')) + assert size == len(result) + assert b"." * size == result + assert obj.at_eof() + + async def test_read_does_not_read_boundary(self) -> None: + stream = Stream(b"Hello, world!\r\n--:") + obj = aiohttp.BodyPartReader(BOUNDARY, {}, stream) + result = await obj.read() + assert b"Hello, world!" == result + assert b"--:" == (await stream.read()) + + async def test_multiread(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, {}, Stream(b"Hello,\r\n--:\r\n\r\nworld!\r\n--:--") + ) + result = await obj.read() + assert b"Hello," == result + result = await obj.read() + assert b"" == result + assert obj.at_eof() + + async def test_read_multiline(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, {}, Stream(b"Hello\n,\r\nworld!\r\n--:--") + ) + result = await obj.read() + assert b"Hello\n,\r\nworld!" == result + result = await obj.read() + assert b"" == result + assert obj.at_eof() + + async def test_read_respects_content_length(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, {"CONTENT-LENGTH": 100500}, Stream(b"." * 100500 + b"\r\n--:--") + ) + result = await obj.read() + assert b"." * 100500 == result + assert obj.at_eof() + + async def test_read_with_content_encoding_gzip(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_ENCODING: "gzip"}, + Stream( + b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03\x0b\xc9\xccMU" + b"(\xc9W\x08J\xcdI\xacP\x04\x00$\xfb\x9eV\x0e\x00\x00\x00" + b"\r\n--:--" + ), + ) + result = await obj.read(decode=True) + assert b"Time to Relax!" == result + + async def test_read_with_content_encoding_deflate(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_ENCODING: "deflate"}, + Stream(b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--"), + ) + result = await obj.read(decode=True) + assert b"Time to Relax!" == result + + async def test_read_with_content_encoding_identity(self) -> None: + thing = ( + b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03\x0b\xc9\xccMU" + b"(\xc9W\x08J\xcdI\xacP\x04\x00$\xfb\x9eV\x0e\x00\x00\x00" + b"\r\n" + ) + obj = aiohttp.BodyPartReader( + BOUNDARY, {CONTENT_ENCODING: "identity"}, Stream(thing + b"--:--") + ) + result = await obj.read(decode=True) + assert thing[:-2] == result + + async def test_read_with_content_encoding_unknown(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_ENCODING: "snappy"}, + Stream(b"\x0e4Time to Relax!\r\n--:--"), + ) + with pytest.raises(RuntimeError): + await obj.read(decode=True) + + async def test_read_with_content_transfer_encoding_base64(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_TRANSFER_ENCODING: "base64"}, + Stream(b"VGltZSB0byBSZWxheCE=\r\n--:--"), + ) + result = await obj.read(decode=True) + assert b"Time to Relax!" == result + + async def test_read_with_content_transfer_encoding_quoted_printable(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_TRANSFER_ENCODING: "quoted-printable"}, + Stream( + b"=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82," + b" =D0=BC=D0=B8=D1=80!\r\n--:--" + ), + ) + result = await obj.read(decode=True) + expected = ( + b"\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82," + b" \xd0\xbc\xd0\xb8\xd1\x80!" + ) + assert result == expected + + @pytest.mark.parametrize("encoding", ("binary", "8bit", "7bit")) + async def test_read_with_content_transfer_encoding_binary(self, encoding) -> None: + data = ( + b"\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82," + b" \xd0\xbc\xd0\xb8\xd1\x80!" + ) + obj = aiohttp.BodyPartReader( + BOUNDARY, {CONTENT_TRANSFER_ENCODING: encoding}, Stream(data + b"\r\n--:--") + ) + result = await obj.read(decode=True) + assert data == result + + async def test_read_with_content_transfer_encoding_unknown(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_TRANSFER_ENCODING: "unknown"}, + Stream(b"\x0e4Time to Relax!\r\n--:--"), + ) + with pytest.raises(RuntimeError): + await obj.read(decode=True) + + async def test_read_text(self) -> None: + obj = aiohttp.BodyPartReader(BOUNDARY, {}, Stream(b"Hello, world!\r\n--:--")) + result = await obj.text() + assert "Hello, world!" == result + + async def test_read_text_default_encoding(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, {}, Stream("Привет, Мир!\r\n--:--".encode()) + ) + result = await obj.text() + assert "Привет, Мир!" == result + + async def test_read_text_encoding(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, {}, Stream("Привет, Мир!\r\n--:--".encode("cp1251")) + ) + result = await obj.text(encoding="cp1251") + assert "Привет, Мир!" == result + + async def test_read_text_guess_encoding(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_TYPE: "text/plain;charset=cp1251"}, + Stream("Привет, Мир!\r\n--:--".encode("cp1251")), + ) + result = await obj.text() + assert "Привет, Мир!" == result + + async def test_read_text_compressed(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_ENCODING: "deflate", CONTENT_TYPE: "text/plain"}, + Stream(b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--"), + ) + result = await obj.text() + assert "Time to Relax!" == result + + async def test_read_text_while_closed(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, {CONTENT_TYPE: "text/plain"}, Stream(b"") + ) obj._at_eof = True - result = yield from obj.text() - self.assertEqual('', result) - - def test_read_json(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TYPE: 'application/json'}, - Stream(b'{"test": "passed"}\r\n--:--')) - result = yield from obj.json() - self.assertEqual({'test': 'passed'}, result) - - def test_read_json_encoding(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TYPE: 'application/json'}, - Stream('{"тест": "пассед"}\r\n--:--'.encode('cp1251'))) - result = yield from obj.json(encoding='cp1251') - self.assertEqual({'тест': 'пассед'}, result) - - def test_read_json_guess_encoding(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TYPE: 'application/json; charset=cp1251'}, - Stream('{"тест": "пассед"}\r\n--:--'.encode('cp1251'))) - result = yield from obj.json() - self.assertEqual({'тест': 'пассед'}, result) - - def test_read_json_compressed(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_ENCODING: 'deflate', - CONTENT_TYPE: 'application/json'}, - Stream(b'\xabV*I-.Q\xb2RP*H,.NMQ\xaa\x05\x00\r\n--:--')) - result = yield from obj.json() - self.assertEqual({'test': 'passed'}, result) - - def test_read_json_while_closed(self): - stream = Stream(b'') - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TYPE: 'application/json'}, stream) + result = await obj.text() + assert "" == result + + async def test_read_json(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_TYPE: "application/json"}, + Stream(b'{"test": "passed"}\r\n--:--'), + ) + result = await obj.json() + assert {"test": "passed"} == result + + async def test_read_json_encoding(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_TYPE: "application/json"}, + Stream('{"тест": "пассед"}\r\n--:--'.encode("cp1251")), + ) + result = await obj.json(encoding="cp1251") + assert {"тест": "пассед"} == result + + async def test_read_json_guess_encoding(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_TYPE: "application/json; charset=cp1251"}, + Stream('{"тест": "пассед"}\r\n--:--'.encode("cp1251")), + ) + result = await obj.json() + assert {"тест": "пассед"} == result + + async def test_read_json_compressed(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_ENCODING: "deflate", CONTENT_TYPE: "application/json"}, + Stream(b"\xabV*I-.Q\xb2RP*H,.NMQ\xaa\x05\x00\r\n--:--"), + ) + result = await obj.json() + assert {"test": "passed"} == result + + async def test_read_json_while_closed(self) -> None: + stream = Stream(b"") + obj = aiohttp.BodyPartReader( + BOUNDARY, {CONTENT_TYPE: "application/json"}, stream + ) obj._at_eof = True - result = yield from obj.json() - self.assertEqual(None, result) - - def test_read_form(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TYPE: 'application/x-www-form-urlencoded'}, - Stream(b'foo=bar&foo=baz&boo=\r\n--:--')) - result = yield from obj.form() - self.assertEqual([('foo', 'bar'), ('foo', 'baz'), ('boo', '')], - result) - - def test_read_form_encoding(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {CONTENT_TYPE: 'application/x-www-form-urlencoded'}, - Stream('foo=bar&foo=baz&boo=\r\n--:--'.encode('cp1251'))) - result = yield from obj.form(encoding='cp1251') - self.assertEqual([('foo', 'bar'), ('foo', 'baz'), ('boo', '')], - result) - - def test_read_form_guess_encoding(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, - {CONTENT_TYPE: 'application/x-www-form-urlencoded; charset=utf-8'}, - Stream('foo=bar&foo=baz&boo=\r\n--:--'.encode('utf-8'))) - result = yield from obj.form() - self.assertEqual([('foo', 'bar'), ('foo', 'baz'), ('boo', '')], - result) - - def test_read_form_while_closed(self): - stream = Stream(b'') - obj = aiohttp.multipart.BodyPartReader( - self.boundary, - {CONTENT_TYPE: 'application/x-www-form-urlencoded'}, stream) + result = await obj.json() + assert result is None + + async def test_read_form(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_TYPE: "application/x-www-form-urlencoded"}, + Stream(b"foo=bar&foo=baz&boo=\r\n--:--"), + ) + result = await obj.form() + assert [("foo", "bar"), ("foo", "baz"), ("boo", "")] == result + + async def test_read_form_encoding(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_TYPE: "application/x-www-form-urlencoded"}, + Stream("foo=bar&foo=baz&boo=\r\n--:--".encode("cp1251")), + ) + result = await obj.form(encoding="cp1251") + assert [("foo", "bar"), ("foo", "baz"), ("boo", "")] == result + + async def test_read_form_guess_encoding(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, + {CONTENT_TYPE: "application/x-www-form-urlencoded; charset=utf-8"}, + Stream(b"foo=bar&foo=baz&boo=\r\n--:--"), + ) + result = await obj.form() + assert [("foo", "bar"), ("foo", "baz"), ("boo", "")] == result + + async def test_read_form_while_closed(self) -> None: + stream = Stream(b"") + obj = aiohttp.BodyPartReader( + BOUNDARY, {CONTENT_TYPE: "application/x-www-form-urlencoded"}, stream + ) obj._at_eof = True - result = yield from obj.form() - self.assertEqual(None, result) - - def test_readline(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, Stream(b'Hello\n,\r\nworld!\r\n--:--')) - result = yield from obj.readline() - self.assertEqual(b'Hello\n', result) - result = yield from obj.readline() - self.assertEqual(b',\r\n', result) - result = yield from obj.readline() - self.assertEqual(b'world!', result) - result = yield from obj.readline() - self.assertEqual(b'', result) - self.assertTrue(obj.at_eof()) - - def test_release(self): - stream = Stream(b'Hello,\r\n--:\r\n\r\nworld!\r\n--:--') - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, stream) - yield from obj.release() - self.assertTrue(obj.at_eof()) - self.assertEqual(b'--:\r\n\r\nworld!\r\n--:--', stream.content.read()) - - def test_release_respects_content_length(self): - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {'CONTENT-LENGTH': 100500}, - Stream(b'.' * 100500 + b'\r\n--:--')) - result = yield from obj.release() - self.assertIsNone(result) - self.assertTrue(obj.at_eof()) - - def test_release_release(self): - stream = Stream(b'Hello,\r\n--:\r\n\r\nworld!\r\n--:--') - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, stream) - yield from obj.release() - yield from obj.release() - self.assertEqual(b'--:\r\n\r\nworld!\r\n--:--', stream.content.read()) - - def test_filename(self): - part = aiohttp.multipart.BodyPartReader( - self.boundary, - {CONTENT_DISPOSITION: 'attachment; filename=foo.html'}, - None) - self.assertEqual('foo.html', part.filename) - - def test_reading_long_part(self): - size = 2 * stream_reader_default_limit - stream = StreamReader() - stream.feed_data(b'0' * size + b'\r\n--:--') + result = await obj.form() + assert not result + + async def test_readline(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, {}, Stream(b"Hello\n,\r\nworld!\r\n--:--") + ) + result = await obj.readline() + assert b"Hello\n" == result + result = await obj.readline() + assert b",\r\n" == result + result = await obj.readline() + assert b"world!" == result + result = await obj.readline() + assert b"" == result + assert obj.at_eof() + + async def test_release(self) -> None: + stream = Stream(b"Hello,\r\n--:\r\n\r\nworld!\r\n--:--") + obj = aiohttp.BodyPartReader(BOUNDARY, {}, stream) + await obj.release() + assert obj.at_eof() + assert b"--:\r\n\r\nworld!\r\n--:--" == stream.content.read() + + async def test_release_respects_content_length(self) -> None: + obj = aiohttp.BodyPartReader( + BOUNDARY, {"CONTENT-LENGTH": 100500}, Stream(b"." * 100500 + b"\r\n--:--") + ) + result = await obj.release() + assert result is None + assert obj.at_eof() + + async def test_release_release(self) -> None: + stream = Stream(b"Hello,\r\n--:\r\n\r\nworld!\r\n--:--") + obj = aiohttp.BodyPartReader(BOUNDARY, {}, stream) + await obj.release() + await obj.release() + assert b"--:\r\n\r\nworld!\r\n--:--" == stream.content.read() + + async def test_filename(self) -> None: + part = aiohttp.BodyPartReader( + BOUNDARY, {CONTENT_DISPOSITION: "attachment; filename=foo.html"}, None + ) + assert "foo.html" == part.filename + + async def test_reading_long_part(self) -> None: + size = 2 * 2 ** 16 + protocol = mock.Mock(_reading_paused=False) + stream = StreamReader(protocol, 2 ** 16, loop=asyncio.get_event_loop()) + stream.feed_data(b"0" * size + b"\r\n--:--") stream.feed_eof() - obj = aiohttp.multipart.BodyPartReader( - self.boundary, {}, stream) - data = yield from obj.read() - self.assertEqual(len(data), size) + obj = aiohttp.BodyPartReader(BOUNDARY, {}, stream) + data = await obj.read() + assert len(data) == size -class MultipartReaderTestCase(TestCase): - - def test_from_response(self): - resp = Response({CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'--:\r\n\r\nhello\r\n--:--')) - res = aiohttp.multipart.MultipartReader.from_response(resp) - self.assertIsInstance(res, - aiohttp.multipart.MultipartResponseWrapper) - self.assertIsInstance(res.stream, - aiohttp.multipart.MultipartReader) +class TestMultipartReader: + def test_from_response(self) -> None: + resp = Response( + {CONTENT_TYPE: 'multipart/related;boundary=":"'}, + Stream(b"--:\r\n\r\nhello\r\n--:--"), + ) + res = aiohttp.MultipartReader.from_response(resp) + assert isinstance(res, MultipartResponseWrapper) + assert isinstance(res.stream, aiohttp.MultipartReader) - def test_bad_boundary(self): + def test_bad_boundary(self) -> None: resp = Response( - {CONTENT_TYPE: 'multipart/related;boundary=' + 'a' * 80}, - Stream(b'')) - with self.assertRaises(ValueError): - aiohttp.multipart.MultipartReader.from_response(resp) + {CONTENT_TYPE: "multipart/related;boundary=" + "a" * 80}, Stream(b"") + ) + with pytest.raises(ValueError): + aiohttp.MultipartReader.from_response(resp) - def test_dispatch(self): - reader = aiohttp.multipart.MultipartReader( + def test_dispatch(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'--:\r\n\r\necho\r\n--:--')) - res = reader._get_part_reader({CONTENT_TYPE: 'text/plain'}) - self.assertIsInstance(res, reader.part_reader_cls) + Stream(b"--:\r\n\r\necho\r\n--:--"), + ) + res = reader._get_part_reader({CONTENT_TYPE: "text/plain"}) + assert isinstance(res, reader.part_reader_cls) - def test_dispatch_bodypart(self): - reader = aiohttp.multipart.MultipartReader( + def test_dispatch_bodypart(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'--:\r\n\r\necho\r\n--:--')) - res = reader._get_part_reader({CONTENT_TYPE: 'text/plain'}) - self.assertIsInstance(res, reader.part_reader_cls) + Stream(b"--:\r\n\r\necho\r\n--:--"), + ) + res = reader._get_part_reader({CONTENT_TYPE: "text/plain"}) + assert isinstance(res, reader.part_reader_cls) - def test_dispatch_multipart(self): - reader = aiohttp.multipart.MultipartReader( + def test_dispatch_multipart(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'----:--\r\n' - b'\r\n' - b'test\r\n' - b'----:--\r\n' - b'\r\n' - b'passed\r\n' - b'----:----\r\n' - b'--:--')) + Stream( + b"----:--\r\n" + b"\r\n" + b"test\r\n" + b"----:--\r\n" + b"\r\n" + b"passed\r\n" + b"----:----\r\n" + b"--:--" + ), + ) res = reader._get_part_reader( - {CONTENT_TYPE: 'multipart/related;boundary=--:--'}) - self.assertIsInstance(res, reader.__class__) + {CONTENT_TYPE: "multipart/related;boundary=--:--"} + ) + assert isinstance(res, reader.__class__) - def test_dispatch_custom_multipart_reader(self): - class CustomReader(aiohttp.multipart.MultipartReader): + def test_dispatch_custom_multipart_reader(self) -> None: + class CustomReader(aiohttp.MultipartReader): pass - reader = aiohttp.multipart.MultipartReader( + + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'----:--\r\n' - b'\r\n' - b'test\r\n' - b'----:--\r\n' - b'\r\n' - b'passed\r\n' - b'----:----\r\n' - b'--:--')) + Stream( + b"----:--\r\n" + b"\r\n" + b"test\r\n" + b"----:--\r\n" + b"\r\n" + b"passed\r\n" + b"----:----\r\n" + b"--:--" + ), + ) reader.multipart_reader_cls = CustomReader res = reader._get_part_reader( - {CONTENT_TYPE: 'multipart/related;boundary=--:--'}) - self.assertIsInstance(res, CustomReader) + {CONTENT_TYPE: "multipart/related;boundary=--:--"} + ) + assert isinstance(res, CustomReader) - def test_emit_next(self): - reader = aiohttp.multipart.MultipartReader( + async def test_emit_next(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'--:\r\n\r\necho\r\n--:--')) - res = yield from reader.next() - self.assertIsInstance(res, reader.part_reader_cls) + Stream(b"--:\r\n\r\necho\r\n--:--"), + ) + res = await reader.next() + assert isinstance(res, reader.part_reader_cls) - def test_invalid_boundary(self): - reader = aiohttp.multipart.MultipartReader( + async def test_invalid_boundary(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'---:\r\n\r\necho\r\n---:--')) - with self.assertRaises(ValueError): - yield from reader.next() + Stream(b"---:\r\n\r\necho\r\n---:--"), + ) + with pytest.raises(ValueError): + await reader.next() - def test_release(self): - reader = aiohttp.multipart.MultipartReader( + async def test_release(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/mixed;boundary=":"'}, - Stream(b'--:\r\n' - b'Content-Type: multipart/related;boundary=--:--\r\n' - b'\r\n' - b'----:--\r\n' - b'\r\n' - b'test\r\n' - b'----:--\r\n' - b'\r\n' - b'passed\r\n' - b'----:----\r\n' - b'\r\n' - b'--:--')) - yield from reader.release() - self.assertTrue(reader.at_eof()) - - def test_release_release(self): - reader = aiohttp.multipart.MultipartReader( + Stream( + b"--:\r\n" + b"Content-Type: multipart/related;boundary=--:--\r\n" + b"\r\n" + b"----:--\r\n" + b"\r\n" + b"test\r\n" + b"----:--\r\n" + b"\r\n" + b"passed\r\n" + b"----:----\r\n" + b"\r\n" + b"--:--" + ), + ) + await reader.release() + assert reader.at_eof() + + async def test_release_release(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'--:\r\n\r\necho\r\n--:--')) - yield from reader.release() - self.assertTrue(reader.at_eof()) - yield from reader.release() - self.assertTrue(reader.at_eof()) - - def test_release_next(self): - reader = aiohttp.multipart.MultipartReader( + Stream(b"--:\r\n\r\necho\r\n--:--"), + ) + await reader.release() + assert reader.at_eof() + await reader.release() + assert reader.at_eof() + + async def test_release_next(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'--:\r\n\r\necho\r\n--:--')) - yield from reader.release() - self.assertTrue(reader.at_eof()) - res = yield from reader.next() - self.assertIsNone(res) - - def test_second_next_releases_previous_object(self): - reader = aiohttp.multipart.MultipartReader( + Stream(b"--:\r\n\r\necho\r\n--:--"), + ) + await reader.release() + assert reader.at_eof() + res = await reader.next() + assert res is None + + async def test_second_next_releases_previous_object(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'--:\r\n' - b'\r\n' - b'test\r\n' - b'--:\r\n' - b'\r\n' - b'passed\r\n' - b'--:--')) - first = yield from reader.next() - self.assertIsInstance(first, aiohttp.multipart.BodyPartReader) - second = yield from reader.next() - self.assertTrue(first.at_eof()) - self.assertFalse(second.at_eof()) - - def test_release_without_read_the_last_object(self): - reader = aiohttp.multipart.MultipartReader( + Stream( + b"--:\r\n" b"\r\n" b"test\r\n" b"--:\r\n" b"\r\n" b"passed\r\n" b"--:--" + ), + ) + first = await reader.next() + assert isinstance(first, aiohttp.BodyPartReader) + second = await reader.next() + assert first.at_eof() + assert not second.at_eof() + + async def test_release_without_read_the_last_object(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'--:\r\n' - b'\r\n' - b'test\r\n' - b'--:\r\n' - b'\r\n' - b'passed\r\n' - b'--:--')) - first = yield from reader.next() - second = yield from reader.next() - third = yield from reader.next() - self.assertTrue(first.at_eof()) - self.assertTrue(second.at_eof()) - self.assertTrue(second.at_eof()) - self.assertIsNone(third) - - def test_read_chunk_by_length_doesnt_breaks_reader(self): - reader = aiohttp.multipart.MultipartReader( + Stream( + b"--:\r\n" b"\r\n" b"test\r\n" b"--:\r\n" b"\r\n" b"passed\r\n" b"--:--" + ), + ) + first = await reader.next() + second = await reader.next() + third = await reader.next() + assert first.at_eof() + assert second.at_eof() + assert second.at_eof() + assert third is None + + async def test_read_chunk_by_length_doesnt_breaks_reader(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'--:\r\n' - b'Content-Length: 4\r\n\r\n' - b'test' - b'\r\n--:\r\n' - b'Content-Length: 6\r\n\r\n' - b'passed' - b'\r\n--:--')) + Stream( + b"--:\r\n" + b"Content-Length: 4\r\n\r\n" + b"test" + b"\r\n--:\r\n" + b"Content-Length: 6\r\n\r\n" + b"passed" + b"\r\n--:--" + ), + ) body_parts = [] while True: - read_part = b'' - part = yield from reader.next() + read_part = b"" + part = await reader.next() if part is None: break while not part.at_eof(): - read_part += yield from part.read_chunk(3) + read_part += await part.read_chunk(3) body_parts.append(read_part) - self.assertListEqual(body_parts, [b'test', b'passed']) + assert body_parts == [b"test", b"passed"] - def test_read_chunk_from_stream_doesnt_breaks_reader(self): - reader = aiohttp.multipart.MultipartReader( + async def test_read_chunk_from_stream_doesnt_breaks_reader(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'--:\r\n' - b'\r\n' - b'chunk' - b'\r\n--:\r\n' - b'\r\n' - b'two_chunks' - b'\r\n--:--')) + Stream( + b"--:\r\n" + b"\r\n" + b"chunk" + b"\r\n--:\r\n" + b"\r\n" + b"two_chunks" + b"\r\n--:--" + ), + ) body_parts = [] while True: - read_part = b'' - part = yield from reader.next() + read_part = b"" + part = await reader.next() if part is None: break while not part.at_eof(): - chunk = yield from part.read_chunk(5) - self.assertTrue(chunk) + chunk = await part.read_chunk(5) + assert chunk read_part += chunk body_parts.append(read_part) - self.assertListEqual(body_parts, [b'chunk', b'two_chunks']) + assert body_parts == [b"chunk", b"two_chunks"] - def test_reading_skips_prelude(self): - reader = aiohttp.multipart.MultipartReader( + async def test_reading_skips_prelude(self) -> None: + reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, - Stream(b'Multi-part data is not supported.\r\n' - b'\r\n' - b'--:\r\n' - b'\r\n' - b'test\r\n' - b'--:\r\n' - b'\r\n' - b'passed\r\n' - b'--:--')) - first = yield from reader.next() - self.assertIsInstance(first, aiohttp.multipart.BodyPartReader) - second = yield from reader.next() - self.assertTrue(first.at_eof()) - self.assertFalse(second.at_eof()) - - -@asyncio.coroutine -def test_writer(writer): - assert writer.size == 0 - assert writer.boundary == b':' - - -@asyncio.coroutine -def test_writer_serialize_io_chunk(buf, stream, writer): - flo = io.BytesIO(b'foobarbaz') + Stream( + b"Multi-part data is not supported.\r\n" + b"\r\n" + b"--:\r\n" + b"\r\n" + b"test\r\n" + b"--:\r\n" + b"\r\n" + b"passed\r\n" + b"--:--" + ), + ) + first = await reader.next() + assert isinstance(first, aiohttp.BodyPartReader) + second = await reader.next() + assert first.at_eof() + assert not second.at_eof() + + +async def test_writer(writer) -> None: + assert writer.size == 7 + assert writer.boundary == ":" + + +async def test_writer_serialize_io_chunk(buf, stream, writer) -> None: + flo = io.BytesIO(b"foobarbaz") writer.append(flo) - yield from writer.write(stream) - assert (buf == b'--:\r\nContent-Type: application/octet-stream' - b'\r\nContent-Length: 9\r\n\r\nfoobarbaz\r\n--:--\r\n') + await writer.write(stream) + assert ( + buf == b"--:\r\nContent-Type: application/octet-stream" + b"\r\nContent-Length: 9\r\n\r\nfoobarbaz\r\n--:--\r\n" + ) -@asyncio.coroutine -def test_writer_serialize_json(buf, stream, writer): - writer.append_json({'привет': 'мир'}) - yield from writer.write(stream) - assert (b'{"\\u043f\\u0440\\u0438\\u0432\\u0435\\u0442":' - b' "\\u043c\\u0438\\u0440"}' in buf) +async def test_writer_serialize_json(buf, stream, writer) -> None: + writer.append_json({"привет": "мир"}) + await writer.write(stream) + assert ( + b'{"\\u043f\\u0440\\u0438\\u0432\\u0435\\u0442":' + b' "\\u043c\\u0438\\u0440"}' in buf + ) -@asyncio.coroutine -def test_writer_serialize_form(buf, stream, writer): - data = [('foo', 'bar'), ('foo', 'baz'), ('boo', 'zoo')] +async def test_writer_serialize_form(buf, stream, writer) -> None: + data = [("foo", "bar"), ("foo", "baz"), ("boo", "zoo")] writer.append_form(data) - yield from writer.write(stream) + await writer.write(stream) - assert (b'foo=bar&foo=baz&boo=zoo' in buf) + assert b"foo=bar&foo=baz&boo=zoo" in buf -@asyncio.coroutine -def test_writer_serialize_form_dict(buf, stream, writer): - data = {'hello': 'мир'} +async def test_writer_serialize_form_dict(buf, stream, writer) -> None: + data = {"hello": "мир"} writer.append_form(data) - yield from writer.write(stream) + await writer.write(stream) - assert (b'hello=%D0%BC%D0%B8%D1%80' in buf) + assert b"hello=%D0%BC%D0%B8%D1%80" in buf -@asyncio.coroutine -def test_writer_write(buf, stream, writer): - writer.append('foo-bar-baz') - writer.append_json({'test': 'passed'}) - writer.append_form({'test': 'passed'}) - writer.append_form([('one', 1), ('two', 2)]) +async def test_writer_write(buf, stream, writer) -> None: + writer.append("foo-bar-baz") + writer.append_json({"test": "passed"}) + writer.append_form({"test": "passed"}) + writer.append_form([("one", 1), ("two", 2)]) - sub_multipart = aiohttp.multipart.MultipartWriter(boundary='::') - sub_multipart.append('nested content') - sub_multipart.headers['X-CUSTOM'] = 'test' + sub_multipart = aiohttp.MultipartWriter(boundary="::") + sub_multipart.append("nested content") + sub_multipart.headers["X-CUSTOM"] = "test" writer.append(sub_multipart) - yield from writer.write(stream) + await writer.write(stream) + + assert ( + b"--:\r\n" + b"Content-Type: text/plain; charset=utf-8\r\n" + b"Content-Length: 11\r\n\r\n" + b"foo-bar-baz" + b"\r\n" + b"--:\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 18\r\n\r\n" + b'{"test": "passed"}' + b"\r\n" + b"--:\r\n" + b"Content-Type: application/x-www-form-urlencoded\r\n" + b"Content-Length: 11\r\n\r\n" + b"test=passed" + b"\r\n" + b"--:\r\n" + b"Content-Type: application/x-www-form-urlencoded\r\n" + b"Content-Length: 11\r\n\r\n" + b"one=1&two=2" + b"\r\n" + b"--:\r\n" + b'Content-Type: multipart/mixed; boundary="::"\r\n' + b"X-CUSTOM: test\r\nContent-Length: 93\r\n\r\n" + b"--::\r\n" + b"Content-Type: text/plain; charset=utf-8\r\n" + b"Content-Length: 14\r\n\r\n" + b"nested content\r\n" + b"--::--\r\n" + b"\r\n" + b"--:--\r\n" + ) == bytes(buf) + + +async def test_writer_write_no_close_boundary(buf, stream) -> None: + writer = aiohttp.MultipartWriter(boundary=":") + writer.append("foo-bar-baz") + writer.append_json({"test": "passed"}) + writer.append_form({"test": "passed"}) + writer.append_form([("one", 1), ("two", 2)]) + await writer.write(stream, close_boundary=False) + + assert ( + b"--:\r\n" + b"Content-Type: text/plain; charset=utf-8\r\n" + b"Content-Length: 11\r\n\r\n" + b"foo-bar-baz" + b"\r\n" + b"--:\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 18\r\n\r\n" + b'{"test": "passed"}' + b"\r\n" + b"--:\r\n" + b"Content-Type: application/x-www-form-urlencoded\r\n" + b"Content-Length: 11\r\n\r\n" + b"test=passed" + b"\r\n" + b"--:\r\n" + b"Content-Type: application/x-www-form-urlencoded\r\n" + b"Content-Length: 11\r\n\r\n" + b"one=1&two=2" + b"\r\n" + ) == bytes(buf) + + +async def test_writer_write_no_parts(buf, stream, writer) -> None: + await writer.write(stream) + assert b"--:--\r\n" == bytes(buf) + + +async def test_writer_serialize_with_content_encoding_gzip(buf, stream, writer): + writer.append("Time to Relax!", {CONTENT_ENCODING: "gzip"}) + await writer.write(stream) + headers, message = bytes(buf).split(b"\r\n\r\n", 1) assert ( - (b'--:\r\n' - b'Content-Type: text/plain; charset=utf-8\r\n' - b'Content-Length: 11\r\n\r\n' - b'foo-bar-baz' - b'\r\n' - - b'--:\r\n' - b'Content-Type: application/json\r\n' - b'Content-Length: 18\r\n\r\n' - b'{"test": "passed"}' - b'\r\n' - - b'--:\r\n' - b'Content-Type: application/x-www-form-urlencoded\r\n' - b'Content-Length: 11\r\n\r\n' - b'test=passed' - b'\r\n' - - b'--:\r\n' - b'Content-Type: application/x-www-form-urlencoded\r\n' - b'Content-Length: 11\r\n\r\n' - b'one=1&two=2' - b'\r\n' - - b'--:\r\n' - b'Content-Type: multipart/mixed; boundary="::"\r\n' - b'X-Custom: test\r\nContent-Length: 93\r\n\r\n' - b'--::\r\n' - b'Content-Type: text/plain; charset=utf-8\r\n' - b'Content-Length: 14\r\n\r\n' - b'nested content\r\n' - b'--::--\r\n' - b'\r\n' - b'--:--\r\n') == bytes(buf)) - - -@asyncio.coroutine -def test_writer_serialize_with_content_encoding_gzip(buf, stream, writer): - writer.append('Time to Relax!', {CONTENT_ENCODING: 'gzip'}) - yield from writer.write(stream) - headers, message = bytes(buf).split(b'\r\n\r\n', 1) - - assert (b'--:\r\nContent-Encoding: gzip\r\n' - b'Content-Type: text/plain; charset=utf-8' == headers) - - decompressor = zlib.decompressobj(wbits=16+zlib.MAX_WBITS) - data = decompressor.decompress(message.split(b'\r\n')[0]) + b"--:\r\nContent-Type: text/plain; charset=utf-8\r\n" + b"Content-Encoding: gzip" == headers + ) + + decompressor = zlib.decompressobj(wbits=16 + zlib.MAX_WBITS) + data = decompressor.decompress(message.split(b"\r\n")[0]) data += decompressor.flush() - assert b'Time to Relax!' == data + assert b"Time to Relax!" == data -@asyncio.coroutine -def test_writer_serialize_with_content_encoding_deflate(buf, stream, writer): - writer.append('Time to Relax!', {CONTENT_ENCODING: 'deflate'}) - yield from writer.write(stream) - headers, message = bytes(buf).split(b'\r\n\r\n', 1) +async def test_writer_serialize_with_content_encoding_deflate(buf, stream, writer): + writer.append("Time to Relax!", {CONTENT_ENCODING: "deflate"}) + await writer.write(stream) + headers, message = bytes(buf).split(b"\r\n\r\n", 1) - assert (b'--:\r\nContent-Encoding: deflate\r\n' - b'Content-Type: text/plain; charset=utf-8' == headers) + assert ( + b"--:\r\nContent-Type: text/plain; charset=utf-8\r\n" + b"Content-Encoding: deflate" == headers + ) - thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--\r\n' + thing = b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--\r\n" assert thing == message -@asyncio.coroutine -def test_writer_serialize_with_content_encoding_identity(buf, stream, writer): - thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00' - writer.append(thing, {CONTENT_ENCODING: 'identity'}) - yield from writer.write(stream) - headers, message = bytes(buf).split(b'\r\n\r\n', 1) +async def test_writer_serialize_with_content_encoding_identity(buf, stream, writer): + thing = b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00" + writer.append(thing, {CONTENT_ENCODING: "identity"}) + await writer.write(stream) + headers, message = bytes(buf).split(b"\r\n\r\n", 1) - assert (b'--:\r\nContent-Encoding: identity\r\n' - b'Content-Type: application/octet-stream\r\n' - b'Content-Length: 16' == headers) + assert ( + b"--:\r\nContent-Type: application/octet-stream\r\n" + b"Content-Encoding: identity\r\n" + b"Content-Length: 16" == headers + ) - assert thing == message.split(b'\r\n')[0] + assert thing == message.split(b"\r\n")[0] def test_writer_serialize_with_content_encoding_unknown(buf, stream, writer): with pytest.raises(RuntimeError): - writer.append('Time to Relax!', {CONTENT_ENCODING: 'snappy'}) + writer.append("Time to Relax!", {CONTENT_ENCODING: "snappy"}) -@asyncio.coroutine -def test_writer_with_content_transfer_encoding_base64(buf, stream, writer): - writer.append('Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'base64'}) - yield from writer.write(stream) - headers, message = bytes(buf).split(b'\r\n\r\n', 1) +async def test_writer_with_content_transfer_encoding_base64(buf, stream, writer): + writer.append("Time to Relax!", {CONTENT_TRANSFER_ENCODING: "base64"}) + await writer.write(stream) + headers, message = bytes(buf).split(b"\r\n\r\n", 1) - assert (b'--:\r\nContent-Transfer-Encoding: base64\r\n' - b'Content-Type: text/plain; charset=utf-8' == - headers) + assert ( + b"--:\r\nContent-Type: text/plain; charset=utf-8\r\n" + b"Content-Transfer-Encoding: base64" == headers + ) - assert b'VGltZSB0byBSZWxheCE=' == message.split(b'\r\n')[0] + assert b"VGltZSB0byBSZWxheCE=" == message.split(b"\r\n")[0] -@asyncio.coroutine -def test_writer_content_transfer_encoding_quote_printable(buf, stream, writer): - writer.append('Привет, мир!', - {CONTENT_TRANSFER_ENCODING: 'quoted-printable'}) - yield from writer.write(stream) - headers, message = bytes(buf).split(b'\r\n\r\n', 1) +async def test_writer_content_transfer_encoding_quote_printable(buf, stream, writer): + writer.append("Привет, мир!", {CONTENT_TRANSFER_ENCODING: "quoted-printable"}) + await writer.write(stream) + headers, message = bytes(buf).split(b"\r\n\r\n", 1) - assert (b'--:\r\nContent-Transfer-Encoding: quoted-printable\r\n' - b'Content-Type: text/plain; charset=utf-8' == headers) + assert ( + b"--:\r\nContent-Type: text/plain; charset=utf-8\r\n" + b"Content-Transfer-Encoding: quoted-printable" == headers + ) - assert (b'=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82,' - b' =D0=BC=D0=B8=D1=80!' == message.split(b'\r\n')[0]) + assert ( + b"=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82," + b" =D0=BC=D0=B8=D1=80!" == message.split(b"\r\n")[0] + ) -def test_writer_content_transfer_encoding_unknown(buf, stream, writer): +def test_writer_content_transfer_encoding_unknown(buf, stream, writer) -> None: with pytest.raises(RuntimeError): - writer.append('Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'unknown'}) - - -class MultipartWriterTestCase(unittest.TestCase): - - def setUp(self): - self.buf = bytearray() - self.stream = mock.Mock() - - def write(chunk): - self.buf.extend(chunk) - return () - - self.stream.write.side_effect = write - - self.writer = aiohttp.multipart.MultipartWriter(boundary=':') - - def test_default_subtype(self): - mtype, stype, *_ = parse_mimetype( - self.writer.headers.get(CONTENT_TYPE)) - self.assertEqual('multipart', mtype) - self.assertEqual('mixed', stype) - - def test_bad_boundary(self): - with self.assertRaises(ValueError): - aiohttp.multipart.MultipartWriter(boundary='тест') - - def test_default_headers(self): - self.assertEqual({CONTENT_TYPE: 'multipart/mixed; boundary=":"'}, - self.writer.headers) - - def test_iter_parts(self): - self.writer.append('foo') - self.writer.append('bar') - self.writer.append('baz') - self.assertEqual(3, len(list(self.writer))) - - def test_append(self): - self.assertEqual(0, len(self.writer)) - self.writer.append('hello, world!') - self.assertEqual(1, len(self.writer)) - self.assertIsInstance(self.writer._parts[0][0], payload.Payload) - - def test_append_with_headers(self): - self.writer.append('hello, world!', {'x-foo': 'bar'}) - self.assertEqual(1, len(self.writer)) - self.assertIn('x-foo', self.writer._parts[0][0].headers) - self.assertEqual(self.writer._parts[0][0].headers['x-foo'], 'bar') - - def test_append_json(self): - self.writer.append_json({'foo': 'bar'}) - self.assertEqual(1, len(self.writer)) - part = self.writer._parts[0][0] - self.assertEqual(part.headers[CONTENT_TYPE], 'application/json') - - def test_append_part(self): - part = payload.get_payload( - 'test', headers={CONTENT_TYPE: 'text/plain'}) - self.writer.append(part, {CONTENT_TYPE: 'test/passed'}) - self.assertEqual(1, len(self.writer)) - part = self.writer._parts[0][0] - self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') - - def test_append_json_overrides_content_type(self): - self.writer.append_json({'foo': 'bar'}, {CONTENT_TYPE: 'test/passed'}) - self.assertEqual(1, len(self.writer)) - part = self.writer._parts[0][0] - self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') - - def test_append_form(self): - self.writer.append_form({'foo': 'bar'}, {CONTENT_TYPE: 'test/passed'}) - self.assertEqual(1, len(self.writer)) - part = self.writer._parts[0][0] - self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') - - def test_append_multipart(self): - subwriter = aiohttp.multipart.MultipartWriter(boundary=':') - subwriter.append_json({'foo': 'bar'}) - self.writer.append(subwriter, {CONTENT_TYPE: 'test/passed'}) - self.assertEqual(1, len(self.writer)) - part = self.writer._parts[0][0] - self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') - - def test_write(self): - self.assertEqual([], list(self.writer.write(self.stream))) - - def test_with(self): - with aiohttp.multipart.MultipartWriter(boundary=':') as writer: - writer.append('foo') - writer.append(b'bar') - writer.append_json({'baz': True}) - self.assertEqual(3, len(writer)) - - def test_append_int_not_allowed(self): - with self.assertRaises(TypeError): - with aiohttp.multipart.MultipartWriter(boundary=':') as writer: + writer.append("Time to Relax!", {CONTENT_TRANSFER_ENCODING: "unknown"}) + + +class TestMultipartWriter: + def test_default_subtype(self, writer) -> None: + mimetype = parse_mimetype(writer.headers.get(CONTENT_TYPE)) + + assert "multipart" == mimetype.type + assert "mixed" == mimetype.subtype + + def test_unquoted_boundary(self) -> None: + writer = aiohttp.MultipartWriter(boundary="abc123") + expected = {CONTENT_TYPE: "multipart/mixed; boundary=abc123"} + assert expected == writer.headers + + def test_quoted_boundary(self) -> None: + writer = aiohttp.MultipartWriter(boundary=R"\"") + expected = {CONTENT_TYPE: R'multipart/mixed; boundary="\\\""'} + assert expected == writer.headers + + def test_bad_boundary(self) -> None: + with pytest.raises(ValueError): + aiohttp.MultipartWriter(boundary="тест") + with pytest.raises(ValueError): + aiohttp.MultipartWriter(boundary="test\n") + + def test_default_headers(self, writer) -> None: + expected = {CONTENT_TYPE: 'multipart/mixed; boundary=":"'} + assert expected == writer.headers + + def test_iter_parts(self, writer) -> None: + writer.append("foo") + writer.append("bar") + writer.append("baz") + assert 3 == len(list(writer)) + + def test_append(self, writer) -> None: + assert 0 == len(writer) + writer.append("hello, world!") + assert 1 == len(writer) + assert isinstance(writer._parts[0][0], payload.Payload) + + def test_append_with_headers(self, writer) -> None: + writer.append("hello, world!", {"x-foo": "bar"}) + assert 1 == len(writer) + assert "x-foo" in writer._parts[0][0].headers + assert writer._parts[0][0].headers["x-foo"] == "bar" + + def test_append_json(self, writer) -> None: + writer.append_json({"foo": "bar"}) + assert 1 == len(writer) + part = writer._parts[0][0] + assert part.headers[CONTENT_TYPE] == "application/json" + + def test_append_part(self, writer) -> None: + part = payload.get_payload("test", headers={CONTENT_TYPE: "text/plain"}) + writer.append(part, {CONTENT_TYPE: "test/passed"}) + assert 1 == len(writer) + part = writer._parts[0][0] + assert part.headers[CONTENT_TYPE] == "test/passed" + + def test_append_json_overrides_content_type(self, writer) -> None: + writer.append_json({"foo": "bar"}, {CONTENT_TYPE: "test/passed"}) + assert 1 == len(writer) + part = writer._parts[0][0] + assert part.headers[CONTENT_TYPE] == "test/passed" + + def test_append_form(self, writer) -> None: + writer.append_form({"foo": "bar"}, {CONTENT_TYPE: "test/passed"}) + assert 1 == len(writer) + part = writer._parts[0][0] + assert part.headers[CONTENT_TYPE] == "test/passed" + + def test_append_multipart(self, writer) -> None: + subwriter = aiohttp.MultipartWriter(boundary=":") + subwriter.append_json({"foo": "bar"}) + writer.append(subwriter, {CONTENT_TYPE: "test/passed"}) + assert 1 == len(writer) + part = writer._parts[0][0] + assert part.headers[CONTENT_TYPE] == "test/passed" + + def test_with(self) -> None: + with aiohttp.MultipartWriter(boundary=":") as writer: + writer.append("foo") + writer.append(b"bar") + writer.append_json({"baz": True}) + assert 3 == len(writer) + + def test_append_int_not_allowed(self) -> None: + with pytest.raises(TypeError): + with aiohttp.MultipartWriter(boundary=":") as writer: writer.append(1) - def test_append_float_not_allowed(self): - with self.assertRaises(TypeError): - with aiohttp.multipart.MultipartWriter(boundary=':') as writer: + def test_append_float_not_allowed(self) -> None: + with pytest.raises(TypeError): + with aiohttp.MultipartWriter(boundary=":") as writer: writer.append(1.1) - def test_append_none_not_allowed(self): - with self.assertRaises(TypeError): - with aiohttp.multipart.MultipartWriter(boundary=':') as writer: + def test_append_none_not_allowed(self) -> None: + with pytest.raises(TypeError): + with aiohttp.MultipartWriter(boundary=":") as writer: writer.append(None) - -class ParseContentDispositionTestCase(unittest.TestCase): - # http://greenbytes.de/tech/tc2231/ - - def test_parse_empty(self): - disptype, params = parse_content_disposition(None) - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_inlonly(self): - disptype, params = parse_content_disposition('inline') - self.assertEqual('inline', disptype) - self.assertEqual({}, params) - - def test_inlonlyquoted(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition('"inline"') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_inlwithasciifilename(self): - disptype, params = parse_content_disposition( - 'inline; filename="foo.html"') - self.assertEqual('inline', disptype) - self.assertEqual({'filename': 'foo.html'}, params) - - def test_inlwithfnattach(self): - disptype, params = parse_content_disposition( - 'inline; filename="Not an attachment!"') - self.assertEqual('inline', disptype) - self.assertEqual({'filename': 'Not an attachment!'}, params) - - def test_attonly(self): - disptype, params = parse_content_disposition('attachment') - self.assertEqual('attachment', disptype) - self.assertEqual({}, params) - - def test_attonlyquoted(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition('"attachment"') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attonlyucase(self): - disptype, params = parse_content_disposition('ATTACHMENT') - self.assertEqual('attachment', disptype) - self.assertEqual({}, params) - - def test_attwithasciifilename(self): - disptype, params = parse_content_disposition( - 'attachment; filename="foo.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo.html'}, params) - - def test_inlwithasciifilenamepdf(self): - disptype, params = parse_content_disposition( - 'attachment; filename="foo.pdf"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo.pdf'}, params) - - def test_attwithasciifilename25(self): - disptype, params = parse_content_disposition( - 'attachment; filename="0000000000111111111122222"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': '0000000000111111111122222'}, params) - - def test_attwithasciifilename35(self): - disptype, params = parse_content_disposition( - 'attachment; filename="00000000001111111111222222222233333"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': '00000000001111111111222222222233333'}, - params) - - def test_attwithasciifnescapedchar(self): - disptype, params = parse_content_disposition( - r'attachment; filename="f\oo.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo.html'}, params) - - def test_attwithasciifnescapedquote(self): - disptype, params = parse_content_disposition( - 'attachment; filename="\"quoting\" tested.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': '"quoting" tested.html'}, params) - - @unittest.skip('need more smart parser which respects quoted text') - def test_attwithquotedsemicolon(self): - disptype, params = parse_content_disposition( - 'attachment; filename="Here\'s a semicolon;.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'Here\'s a semicolon;.html'}, params) - - def test_attwithfilenameandextparam(self): - disptype, params = parse_content_disposition( - 'attachment; foo="bar"; filename="foo.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo.html', 'foo': 'bar'}, params) - - def test_attwithfilenameandextparamescaped(self): - disptype, params = parse_content_disposition( - 'attachment; foo="\"\\";filename="foo.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo.html', 'foo': '"\\'}, params) - - def test_attwithasciifilenameucase(self): - disptype, params = parse_content_disposition( - 'attachment; FILENAME="foo.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo.html'}, params) - - def test_attwithasciifilenamenq(self): - disptype, params = parse_content_disposition( - 'attachment; filename=foo.html') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo.html'}, params) - - def test_attwithtokfncommanq(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename=foo,bar.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attwithasciifilenamenqs(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename=foo.html ;') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attemptyparam(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; ;filename=foo') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attwithasciifilenamenqws(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename=foo bar.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attwithfntokensq(self): - disptype, params = parse_content_disposition( - "attachment; filename='foo.html'") - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': "'foo.html'"}, params) - - def test_attwithisofnplain(self): - disptype, params = parse_content_disposition( - 'attachment; filename="foo-ä.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo-ä.html'}, params) - - def test_attwithutf8fnplain(self): - disptype, params = parse_content_disposition( - 'attachment; filename="foo-ä.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo-ä.html'}, params) - - def test_attwithfnrawpctenca(self): - disptype, params = parse_content_disposition( - 'attachment; filename="foo-%41.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo-%41.html'}, params) - - def test_attwithfnusingpct(self): - disptype, params = parse_content_disposition( - 'attachment; filename="50%.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': '50%.html'}, params) - - def test_attwithfnrawpctencaq(self): - disptype, params = parse_content_disposition( - r'attachment; filename="foo-%\41.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': r'foo-%41.html'}, params) - - def test_attwithnamepct(self): - disptype, params = parse_content_disposition( - 'attachment; filename="foo-%41.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo-%41.html'}, params) - - def test_attwithfilenamepctandiso(self): - disptype, params = parse_content_disposition( - 'attachment; filename="ä-%41.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'ä-%41.html'}, params) - - def test_attwithfnrawpctenclong(self): - disptype, params = parse_content_disposition( - 'attachment; filename="foo-%c3%a4-%e2%82%ac.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo-%c3%a4-%e2%82%ac.html'}, params) - - def test_attwithasciifilenamews1(self): - disptype, params = parse_content_disposition( - 'attachment; filename ="foo.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo.html'}, params) - - def test_attwith2filenames(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename="foo.html"; filename="bar.html"') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attfnbrokentoken(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename=foo[1](2).html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attfnbrokentokeniso(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename=foo-ä.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attfnbrokentokenutf(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename=foo-ä.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attmissingdisposition(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'filename=foo.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attmissingdisposition2(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'x=y; filename=foo.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attmissingdisposition3(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - '"foo; filename=bar;baz"; filename=qux') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attmissingdisposition4(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'filename=foo.html, filename=bar.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_emptydisposition(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - '; filename=foo.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_doublecolon(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - ': inline; attachment; filename=foo.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attandinline(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'inline; attachment; filename=foo.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attandinline2(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; inline; filename=foo.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attbrokenquotedfn(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename="foo.html".txt') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attbrokenquotedfn2(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename="bar') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attbrokenquotedfn3(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename=foo"bar;baz"qux') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attmultinstances(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename=foo.html, attachment; filename=bar.html') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attmissingdelim(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; foo=foo filename=bar') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attmissingdelim2(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename=bar foo=foo') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attmissingdelim3(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment filename=bar') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attreversed(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'filename=foo.html; attachment') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attconfusedparam(self): - disptype, params = parse_content_disposition( - 'attachment; xfilename=foo.html') - self.assertEqual('attachment', disptype) - self.assertEqual({'xfilename': 'foo.html'}, params) - - def test_attabspath(self): - disptype, params = parse_content_disposition( - 'attachment; filename="/foo.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo.html'}, params) - - def test_attabspathwin(self): - disptype, params = parse_content_disposition( - 'attachment; filename="\\foo.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo.html'}, params) - - def test_attcdate(self): - disptype, params = parse_content_disposition( - 'attachment; creation-date="Wed, 12 Feb 1997 16:29:51 -0500"') - self.assertEqual('attachment', disptype) - self.assertEqual({'creation-date': 'Wed, 12 Feb 1997 16:29:51 -0500'}, - params) - - def test_attmdate(self): - disptype, params = parse_content_disposition( - 'attachment; modification-date="Wed, 12 Feb 1997 16:29:51 -0500"') - self.assertEqual('attachment', disptype) - self.assertEqual( - {'modification-date': 'Wed, 12 Feb 1997 16:29:51 -0500'}, - params) - - def test_dispext(self): - disptype, params = parse_content_disposition('foobar') - self.assertEqual('foobar', disptype) - self.assertEqual({}, params) - - def test_dispextbadfn(self): - disptype, params = parse_content_disposition( - 'attachment; example="filename=example.txt"') - self.assertEqual('attachment', disptype) - self.assertEqual({'example': 'filename=example.txt'}, params) - - def test_attwithisofn2231iso(self): - disptype, params = parse_content_disposition( - "attachment; filename*=iso-8859-1''foo-%E4.html") - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*': 'foo-ä.html'}, params) - - def test_attwithfn2231utf8(self): - disptype, params = parse_content_disposition( - "attachment; filename*=UTF-8''foo-%c3%a4-%e2%82%ac.html") - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*': 'foo-ä-€.html'}, params) - - def test_attwithfn2231noc(self): - disptype, params = parse_content_disposition( - "attachment; filename*=''foo-%c3%a4-%e2%82%ac.html") - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*': 'foo-ä-€.html'}, params) - - def test_attwithfn2231utf8comp(self): - disptype, params = parse_content_disposition( - "attachment; filename*=UTF-8''foo-a%cc%88.html") - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*': 'foo-ä.html'}, params) - - @unittest.skip('should raise decoding error: %82 is invalid for latin1') - def test_attwithfn2231utf8_bad(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionParam): - disptype, params = parse_content_disposition( - "attachment; filename*=iso-8859-1''foo-%c3%a4-%e2%82%ac.html") - self.assertEqual('attachment', disptype) - self.assertEqual({}, params) - - @unittest.skip('should raise decoding error: %E4 is invalid for utf-8') - def test_attwithfn2231iso_bad(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionParam): - disptype, params = parse_content_disposition( - "attachment; filename*=utf-8''foo-%E4.html") - self.assertEqual('attachment', disptype) - self.assertEqual({}, params) - - def test_attwithfn2231ws1(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionParam): - disptype, params = parse_content_disposition( - "attachment; filename *=UTF-8''foo-%c3%a4.html") - self.assertEqual('attachment', disptype) - self.assertEqual({}, params) - - def test_attwithfn2231ws2(self): - disptype, params = parse_content_disposition( - "attachment; filename*= UTF-8''foo-%c3%a4.html") - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*': 'foo-ä.html'}, params) - - def test_attwithfn2231ws3(self): - disptype, params = parse_content_disposition( - "attachment; filename* =UTF-8''foo-%c3%a4.html") - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*': 'foo-ä.html'}, params) - - def test_attwithfn2231quot(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionParam): - disptype, params = parse_content_disposition( - "attachment; filename*=\"UTF-8''foo-%c3%a4.html\"") - self.assertEqual('attachment', disptype) - self.assertEqual({}, params) - - def test_attwithfn2231quot2(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionParam): - disptype, params = parse_content_disposition( - "attachment; filename*=\"foo%20bar.html\"") - self.assertEqual('attachment', disptype) - self.assertEqual({}, params) - - def test_attwithfn2231singleqmissing(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionParam): - disptype, params = parse_content_disposition( - "attachment; filename*=UTF-8'foo-%c3%a4.html") - self.assertEqual('attachment', disptype) - self.assertEqual({}, params) - - @unittest.skip('urllib.parse.unquote is tolerate to standalone % chars') - def test_attwithfn2231nbadpct1(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionParam): - disptype, params = parse_content_disposition( - "attachment; filename*=UTF-8''foo%") - self.assertEqual('attachment', disptype) - self.assertEqual({}, params) - - @unittest.skip('urllib.parse.unquote is tolerate to standalone % chars') - def test_attwithfn2231nbadpct2(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionParam): - disptype, params = parse_content_disposition( - "attachment; filename*=UTF-8''f%oo.html") - self.assertEqual('attachment', disptype) - self.assertEqual({}, params) - - def test_attwithfn2231dpct(self): - disptype, params = parse_content_disposition( - "attachment; filename*=UTF-8''A-%2541.html") - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*': 'A-%41.html'}, params) - - def test_attwithfn2231abspathdisguised(self): - disptype, params = parse_content_disposition( - "attachment; filename*=UTF-8''%5cfoo.html") - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*': '\\foo.html'}, params) - - def test_attfncont(self): - disptype, params = parse_content_disposition( - 'attachment; filename*0="foo."; filename*1="html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*0': 'foo.', - 'filename*1': 'html'}, params) - - def test_attfncontqs(self): - disptype, params = parse_content_disposition( - r'attachment; filename*0="foo"; filename*1="\b\a\r.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*0': 'foo', - 'filename*1': 'bar.html'}, params) - - def test_attfncontenc(self): - disptype, params = parse_content_disposition( - 'attachment; filename*0*=UTF-8''foo-%c3%a4; filename*1=".html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*0*': 'UTF-8''foo-%c3%a4', - 'filename*1': '.html'}, params) - - def test_attfncontlz(self): - disptype, params = parse_content_disposition( - 'attachment; filename*0="foo"; filename*01="bar"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*0': 'foo', - 'filename*01': 'bar'}, params) - - def test_attfncontnc(self): - disptype, params = parse_content_disposition( - 'attachment; filename*0="foo"; filename*2="bar"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*0': 'foo', - 'filename*2': 'bar'}, params) - - def test_attfnconts1(self): - disptype, params = parse_content_disposition( - 'attachment; filename*0="foo."; filename*2="html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*0': 'foo.', - 'filename*2': 'html'}, params) - - def test_attfncontord(self): - disptype, params = parse_content_disposition( - 'attachment; filename*1="bar"; filename*0="foo"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*0': 'foo', - 'filename*1': 'bar'}, params) - - def test_attfnboth(self): - disptype, params = parse_content_disposition( - 'attachment; filename="foo-ae.html";' - " filename*=UTF-8''foo-%c3%a4.html") - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo-ae.html', - 'filename*': 'foo-ä.html'}, params) - - def test_attfnboth2(self): - disptype, params = parse_content_disposition( - "attachment; filename*=UTF-8''foo-%c3%a4.html;" - ' filename="foo-ae.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': 'foo-ae.html', - 'filename*': 'foo-ä.html'}, params) - - def test_attfnboth3(self): - disptype, params = parse_content_disposition( - "attachment; filename*0*=ISO-8859-15''euro-sign%3d%a4;" - " filename*=ISO-8859-1''currency-sign%3d%a4") - self.assertEqual('attachment', disptype) - self.assertEqual({'filename*': 'currency-sign=¤', - 'filename*0*': "ISO-8859-15''euro-sign%3d%a4"}, - params) - - def test_attnewandfn(self): - disptype, params = parse_content_disposition( - 'attachment; foobar=x; filename="foo.html"') - self.assertEqual('attachment', disptype) - self.assertEqual({'foobar': 'x', - 'filename': 'foo.html'}, params) - - def test_attrfc2047token(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionHeader): - disptype, params = parse_content_disposition( - 'attachment; filename==?ISO-8859-1?Q?foo-=E4.html?=') - self.assertEqual(None, disptype) - self.assertEqual({}, params) - - def test_attrfc2047quoted(self): - disptype, params = parse_content_disposition( - 'attachment; filename="=?ISO-8859-1?Q?foo-=E4.html?="') - self.assertEqual('attachment', disptype) - self.assertEqual({'filename': '=?ISO-8859-1?Q?foo-=E4.html?='}, params) - - def test_bad_continuous_param(self): - with self.assertWarns(aiohttp.multipart.BadContentDispositionParam): - disptype, params = parse_content_disposition( - 'attachment; filename*0=foo bar') - self.assertEqual('attachment', disptype) - self.assertEqual({}, params) - - -class ContentDispositionFilenameTestCase(unittest.TestCase): - # http://greenbytes.de/tech/tc2231/ - - def test_no_filename(self): - self.assertIsNone(content_disposition_filename({})) - self.assertIsNone(content_disposition_filename({'foo': 'bar'})) - - def test_filename(self): - params = {'filename': 'foo.html'} - self.assertEqual('foo.html', content_disposition_filename(params)) - - def test_filename_ext(self): - params = {'filename*': 'файл.html'} - self.assertEqual('файл.html', content_disposition_filename(params)) - - def test_attfncont(self): - params = {'filename*0': 'foo.', 'filename*1': 'html'} - self.assertEqual('foo.html', content_disposition_filename(params)) - - def test_attfncontqs(self): - params = {'filename*0': 'foo', 'filename*1': 'bar.html'} - self.assertEqual('foobar.html', content_disposition_filename(params)) - - def test_attfncontenc(self): - params = {'filename*0*': "UTF-8''foo-%c3%a4", - 'filename*1': '.html'} - self.assertEqual('foo-ä.html', content_disposition_filename(params)) - - def test_attfncontlz(self): - params = {'filename*0': 'foo', - 'filename*01': 'bar'} - self.assertEqual('foo', content_disposition_filename(params)) - - def test_attfncontnc(self): - params = {'filename*0': 'foo', - 'filename*2': 'bar'} - self.assertEqual('foo', content_disposition_filename(params)) - - def test_attfnconts1(self): - params = {'filename*1': 'foo', - 'filename*2': 'bar'} - self.assertEqual(None, content_disposition_filename(params)) - - def test_attfnboth(self): - params = {'filename': 'foo-ae.html', - 'filename*': 'foo-ä.html'} - self.assertEqual('foo-ä.html', content_disposition_filename(params)) - - def test_attfnboth3(self): - params = {'filename*0*': "ISO-8859-15''euro-sign%3d%a4", - 'filename*': 'currency-sign=¤'} - self.assertEqual('currency-sign=¤', - content_disposition_filename(params)) - - def test_attrfc2047quoted(self): - params = {'filename': '=?ISO-8859-1?Q?foo-=E4.html?='} - self.assertEqual('=?ISO-8859-1?Q?foo-=E4.html?=', - content_disposition_filename(params)) + async def test_write_preserves_content_disposition(self, buf, stream) -> None: + with aiohttp.MultipartWriter(boundary=":") as writer: + part = writer.append(b"foo", headers={CONTENT_TYPE: "test/passed"}) + part.set_content_disposition("form-data", filename="bug") + await writer.write(stream) + + headers, message = bytes(buf).split(b"\r\n\r\n", 1) + + assert headers == ( + b"--:\r\n" + b"Content-Type: test/passed\r\n" + b"Content-Length: 3\r\n" + b"Content-Disposition:" + b" form-data; filename=\"bug\"; filename*=utf-8''bug" + ) + assert message == b"foo\r\n--:--\r\n" + + async def test_preserve_content_disposition_header(self, buf, stream): + # https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381 + with open(__file__, "rb") as fobj: + with aiohttp.MultipartWriter("form-data", boundary=":") as writer: + part = writer.append( + fobj, + headers={ + CONTENT_DISPOSITION: 'attachments; filename="bug.py"', + CONTENT_TYPE: "text/python", + }, + ) + content_length = part.size + await writer.write(stream) + + assert part.headers[CONTENT_TYPE] == "text/python" + assert part.headers[CONTENT_DISPOSITION] == ('attachments; filename="bug.py"') + + headers, _ = bytes(buf).split(b"\r\n\r\n", 1) + + assert headers == ( + b"--:\r\n" + b"Content-Type: text/python\r\n" + b'Content-Disposition: attachments; filename="bug.py"\r\n' + b"Content-Length: %s" + b"" % (str(content_length).encode(),) + ) + + async def test_set_content_disposition_override(self, buf, stream): + # https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381 + with open(__file__, "rb") as fobj: + with aiohttp.MultipartWriter("form-data", boundary=":") as writer: + part = writer.append( + fobj, + headers={ + CONTENT_DISPOSITION: 'attachments; filename="bug.py"', + CONTENT_TYPE: "text/python", + }, + ) + content_length = part.size + await writer.write(stream) + + assert part.headers[CONTENT_TYPE] == "text/python" + assert part.headers[CONTENT_DISPOSITION] == ('attachments; filename="bug.py"') + + headers, _ = bytes(buf).split(b"\r\n\r\n", 1) + + assert headers == ( + b"--:\r\n" + b"Content-Type: text/python\r\n" + b'Content-Disposition: attachments; filename="bug.py"\r\n' + b"Content-Length: %s" + b"" % (str(content_length).encode(),) + ) + + async def test_reset_content_disposition_header(self, buf, stream): + # https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381 + with open(__file__, "rb") as fobj: + with aiohttp.MultipartWriter("form-data", boundary=":") as writer: + part = writer.append( + fobj, + headers={CONTENT_TYPE: "text/plain"}, + ) + + content_length = part.size + + assert CONTENT_DISPOSITION in part.headers + + part.set_content_disposition("attachments", filename="bug.py") + + await writer.write(stream) + + headers, _ = bytes(buf).split(b"\r\n\r\n", 1) + + assert headers == ( + b"--:\r\n" + b"Content-Type: text/plain\r\n" + b"Content-Disposition:" + b" attachments; filename=\"bug.py\"; filename*=utf-8''bug.py\r\n" + b"Content-Length: %s" + b"" % (str(content_length).encode(),) + ) + + +async def test_async_for_reader() -> None: + data = [{"test": "passed"}, 42, b"plain text", b"aiohttp\n", b"no epilogue"] + reader = aiohttp.MultipartReader( + headers={CONTENT_TYPE: 'multipart/mixed; boundary=":"'}, + content=Stream( + b"\r\n".join( + [ + b"--:", + b"Content-Type: application/json", + b"", + json.dumps(data[0]).encode(), + b"--:", + b"Content-Type: application/json", + b"", + json.dumps(data[1]).encode(), + b"--:", + b'Content-Type: multipart/related; boundary="::"', + b"", + b"--::", + b"Content-Type: text/plain", + b"", + data[2], + b"--::", + b'Content-Disposition: attachment; filename="aiohttp"', + b"Content-Type: text/plain", + b"Content-Length: 28", + b"Content-Encoding: gzip", + b"", + b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03K\xcc\xcc\xcf())" + b"\xe0\x02\x00\xd6\x90\xe2O\x08\x00\x00\x00", + b"--::", + b'Content-Type: multipart/related; boundary=":::"', + b"", + b"--:::", + b"Content-Type: text/plain", + b"", + data[4], + b"--:::--", + b"--::--", + b"", + b"--:--", + b"", + ] + ) + ), + ) + idata = iter(data) + + async def check(reader): + async for part in reader: + if isinstance(part, aiohttp.BodyPartReader): + if part.headers[CONTENT_TYPE] == "application/json": + assert next(idata) == (await part.json()) + else: + assert next(idata) == await part.read(decode=True) + else: + await check(part) + + await check(reader) + + +async def test_async_for_bodypart() -> None: + part = aiohttp.BodyPartReader( + boundary=b"--:", headers={}, content=Stream(b"foobarbaz\r\n--:--") + ) + async for data in part: + assert data == b"foobarbaz" diff --git a/tests/test_multipart_helpers.py b/tests/test_multipart_helpers.py new file mode 100644 index 00000000000..9516751cba9 --- /dev/null +++ b/tests/test_multipart_helpers.py @@ -0,0 +1,699 @@ +import pytest + +import aiohttp +from aiohttp import content_disposition_filename, parse_content_disposition + + +class TestParseContentDisposition: + # http://greenbytes.de/tech/tc2231/ + + def test_parse_empty(self) -> None: + disptype, params = parse_content_disposition(None) + assert disptype is None + assert {} == params + + def test_inlonly(self) -> None: + disptype, params = parse_content_disposition("inline") + assert "inline" == disptype + assert {} == params + + def test_inlonlyquoted(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition('"inline"') + assert disptype is None + assert {} == params + + def test_semicolon(self) -> None: + disptype, params = parse_content_disposition( + 'form-data; name="data"; filename="file ; name.mp4"' + ) + assert disptype == "form-data" + assert params == {"name": "data", "filename": "file ; name.mp4"} + + def test_inlwithasciifilename(self) -> None: + disptype, params = parse_content_disposition('inline; filename="foo.html"') + assert "inline" == disptype + assert {"filename": "foo.html"} == params + + def test_inlwithfnattach(self) -> None: + disptype, params = parse_content_disposition( + 'inline; filename="Not an attachment!"' + ) + assert "inline" == disptype + assert {"filename": "Not an attachment!"} == params + + def test_attonly(self) -> None: + disptype, params = parse_content_disposition("attachment") + assert "attachment" == disptype + assert {} == params + + def test_attonlyquoted(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition('"attachment"') + assert disptype is None + assert {} == params + + def test_attonlyucase(self) -> None: + disptype, params = parse_content_disposition("ATTACHMENT") + assert "attachment" == disptype + assert {} == params + + def test_attwithasciifilename(self) -> None: + disptype, params = parse_content_disposition('attachment; filename="foo.html"') + assert "attachment" == disptype + assert {"filename": "foo.html"} == params + + def test_inlwithasciifilenamepdf(self) -> None: + disptype, params = parse_content_disposition('attachment; filename="foo.pdf"') + assert "attachment" == disptype + assert {"filename": "foo.pdf"} == params + + def test_attwithasciifilename25(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="0000000000111111111122222"' + ) + assert "attachment" == disptype + assert {"filename": "0000000000111111111122222"} == params + + def test_attwithasciifilename35(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="00000000001111111111222222222233333"' + ) + assert "attachment" == disptype + assert {"filename": "00000000001111111111222222222233333"} == params + + def test_attwithasciifnescapedchar(self) -> None: + disptype, params = parse_content_disposition( + r'attachment; filename="f\oo.html"' + ) + assert "attachment" == disptype + assert {"filename": "foo.html"} == params + + def test_attwithasciifnescapedquote(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename=""quoting" tested.html"' + ) + assert "attachment" == disptype + assert {"filename": '"quoting" tested.html'} == params + + @pytest.mark.skip("need more smart parser which respects quoted text") + def test_attwithquotedsemicolon(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="Here\'s a semicolon;.html"' + ) + assert "attachment" == disptype + assert {"filename": "Here's a semicolon;.html"} == params + + def test_attwithfilenameandextparam(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; foo="bar"; filename="foo.html"' + ) + assert "attachment" == disptype + assert {"filename": "foo.html", "foo": "bar"} == params + + def test_attwithfilenameandextparamescaped(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; foo=""\\";filename="foo.html"' + ) + assert "attachment" == disptype + assert {"filename": "foo.html", "foo": '"\\'} == params + + def test_attwithasciifilenameucase(self) -> None: + disptype, params = parse_content_disposition('attachment; FILENAME="foo.html"') + assert "attachment" == disptype + assert {"filename": "foo.html"} == params + + def test_attwithasciifilenamenq(self) -> None: + disptype, params = parse_content_disposition("attachment; filename=foo.html") + assert "attachment" == disptype + assert {"filename": "foo.html"} == params + + def test_attwithtokfncommanq(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "attachment; filename=foo,bar.html" + ) + assert disptype is None + assert {} == params + + def test_attwithasciifilenamenqs(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "attachment; filename=foo.html ;" + ) + assert disptype is None + assert {} == params + + def test_attemptyparam(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition("attachment; ;filename=foo") + assert disptype is None + assert {} == params + + def test_attwithasciifilenamenqws(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "attachment; filename=foo bar.html" + ) + assert disptype is None + assert {} == params + + def test_attwithfntokensq(self) -> None: + disptype, params = parse_content_disposition("attachment; filename='foo.html'") + assert "attachment" == disptype + assert {"filename": "'foo.html'"} == params + + def test_attwithisofnplain(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="foo-ä.html"' + ) + assert "attachment" == disptype + assert {"filename": "foo-ä.html"} == params + + def test_attwithutf8fnplain(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="foo-ä.html"' + ) + assert "attachment" == disptype + assert {"filename": "foo-ä.html"} == params + + def test_attwithfnrawpctenca(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="foo-%41.html"' + ) + assert "attachment" == disptype + assert {"filename": "foo-%41.html"} == params + + def test_attwithfnusingpct(self) -> None: + disptype, params = parse_content_disposition('attachment; filename="50%.html"') + assert "attachment" == disptype + assert {"filename": "50%.html"} == params + + def test_attwithfnrawpctencaq(self) -> None: + disptype, params = parse_content_disposition( + r'attachment; filename="foo-%\41.html"' + ) + assert "attachment" == disptype + assert {"filename": r"foo-%41.html"} == params + + def test_attwithnamepct(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="foo-%41.html"' + ) + assert "attachment" == disptype + assert {"filename": "foo-%41.html"} == params + + def test_attwithfilenamepctandiso(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="ä-%41.html"' + ) + assert "attachment" == disptype + assert {"filename": "ä-%41.html"} == params + + def test_attwithfnrawpctenclong(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="foo-%c3%a4-%e2%82%ac.html"' + ) + assert "attachment" == disptype + assert {"filename": "foo-%c3%a4-%e2%82%ac.html"} == params + + def test_attwithasciifilenamews1(self) -> None: + disptype, params = parse_content_disposition('attachment; filename ="foo.html"') + assert "attachment" == disptype + assert {"filename": "foo.html"} == params + + def test_attwith2filenames(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + 'attachment; filename="foo.html"; filename="bar.html"' + ) + assert disptype is None + assert {} == params + + def test_attfnbrokentoken(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "attachment; filename=foo[1](2).html" + ) + assert disptype is None + assert {} == params + + def test_attfnbrokentokeniso(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "attachment; filename=foo-ä.html" + ) + assert disptype is None + assert {} == params + + def test_attfnbrokentokenutf(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "attachment; filename=foo-ä.html" + ) + assert disptype is None + assert {} == params + + def test_attmissingdisposition(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition("filename=foo.html") + assert disptype is None + assert {} == params + + def test_attmissingdisposition2(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition("x=y; filename=foo.html") + assert disptype is None + assert {} == params + + def test_attmissingdisposition3(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + '"foo; filename=bar;baz"; filename=qux' + ) + assert disptype is None + assert {} == params + + def test_attmissingdisposition4(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "filename=foo.html, filename=bar.html" + ) + assert disptype is None + assert {} == params + + def test_emptydisposition(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition("; filename=foo.html") + assert disptype is None + assert {} == params + + def test_doublecolon(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + ": inline; attachment; filename=foo.html" + ) + assert disptype is None + assert {} == params + + def test_attandinline(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "inline; attachment; filename=foo.html" + ) + assert disptype is None + assert {} == params + + def test_attandinline2(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "attachment; inline; filename=foo.html" + ) + assert disptype is None + assert {} == params + + def test_attbrokenquotedfn(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + 'attachment; filename="foo.html".txt' + ) + assert disptype is None + assert {} == params + + def test_attbrokenquotedfn2(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition('attachment; filename="bar') + assert disptype is None + assert {} == params + + def test_attbrokenquotedfn3(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + 'attachment; filename=foo"bar;baz"qux' + ) + assert disptype is None + assert {} == params + + def test_attmultinstances(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "attachment; filename=foo.html, attachment; filename=bar.html" + ) + assert disptype is None + assert {} == params + + def test_attmissingdelim(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "attachment; foo=foo filename=bar" + ) + assert disptype is None + assert {} == params + + def test_attmissingdelim2(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "attachment; filename=bar foo=foo" + ) + assert disptype is None + assert {} == params + + def test_attmissingdelim3(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition("attachment filename=bar") + assert disptype is None + assert {} == params + + def test_attreversed(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "filename=foo.html; attachment" + ) + assert disptype is None + assert {} == params + + def test_attconfusedparam(self) -> None: + disptype, params = parse_content_disposition("attachment; xfilename=foo.html") + assert "attachment" == disptype + assert {"xfilename": "foo.html"} == params + + def test_attabspath(self) -> None: + disptype, params = parse_content_disposition('attachment; filename="/foo.html"') + assert "attachment" == disptype + assert {"filename": "foo.html"} == params + + def test_attabspathwin(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="\\foo.html"' + ) + assert "attachment" == disptype + assert {"filename": "foo.html"} == params + + def test_attcdate(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; creation-date="Wed, 12 Feb 1997 16:29:51 -0500"' + ) + assert "attachment" == disptype + assert {"creation-date": "Wed, 12 Feb 1997 16:29:51 -0500"} == params + + def test_attmdate(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; modification-date="Wed, 12 Feb 1997 16:29:51 -0500"' + ) + assert "attachment" == disptype + assert {"modification-date": "Wed, 12 Feb 1997 16:29:51 -0500"} == params + + def test_dispext(self) -> None: + disptype, params = parse_content_disposition("foobar") + assert "foobar" == disptype + assert {} == params + + def test_dispextbadfn(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; example="filename=example.txt"' + ) + assert "attachment" == disptype + assert {"example": "filename=example.txt"} == params + + def test_attwithisofn2231iso(self) -> None: + disptype, params = parse_content_disposition( + "attachment; filename*=iso-8859-1''foo-%E4.html" + ) + assert "attachment" == disptype + assert {"filename*": "foo-ä.html"} == params + + def test_attwithfn2231utf8(self) -> None: + disptype, params = parse_content_disposition( + "attachment; filename*=UTF-8''foo-%c3%a4-%e2%82%ac.html" + ) + assert "attachment" == disptype + assert {"filename*": "foo-ä-€.html"} == params + + def test_attwithfn2231noc(self) -> None: + disptype, params = parse_content_disposition( + "attachment; filename*=''foo-%c3%a4-%e2%82%ac.html" + ) + assert "attachment" == disptype + assert {"filename*": "foo-ä-€.html"} == params + + def test_attwithfn2231utf8comp(self) -> None: + disptype, params = parse_content_disposition( + "attachment; filename*=UTF-8''foo-a%cc%88.html" + ) + assert "attachment" == disptype + assert {"filename*": "foo-ä.html"} == params + + @pytest.mark.skip("should raise decoding error: %82 is invalid for latin1") + def test_attwithfn2231utf8_bad(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionParam): + disptype, params = parse_content_disposition( + "attachment; filename*=iso-8859-1''foo-%c3%a4-%e2%82%ac.html" + ) + assert "attachment" == disptype + assert {} == params + + @pytest.mark.skip("should raise decoding error: %E4 is invalid for utf-8") + def test_attwithfn2231iso_bad(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionParam): + disptype, params = parse_content_disposition( + "attachment; filename*=utf-8''foo-%E4.html" + ) + assert "attachment" == disptype + assert {} == params + + def test_attwithfn2231ws1(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionParam): + disptype, params = parse_content_disposition( + "attachment; filename *=UTF-8''foo-%c3%a4.html" + ) + assert "attachment" == disptype + assert {} == params + + def test_attwithfn2231ws2(self) -> None: + disptype, params = parse_content_disposition( + "attachment; filename*= UTF-8''foo-%c3%a4.html" + ) + assert "attachment" == disptype + assert {"filename*": "foo-ä.html"} == params + + def test_attwithfn2231ws3(self) -> None: + disptype, params = parse_content_disposition( + "attachment; filename* =UTF-8''foo-%c3%a4.html" + ) + assert "attachment" == disptype + assert {"filename*": "foo-ä.html"} == params + + def test_attwithfn2231quot(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionParam): + disptype, params = parse_content_disposition( + "attachment; filename*=\"UTF-8''foo-%c3%a4.html\"" + ) + assert "attachment" == disptype + assert {} == params + + def test_attwithfn2231quot2(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionParam): + disptype, params = parse_content_disposition( + 'attachment; filename*="foo%20bar.html"' + ) + assert "attachment" == disptype + assert {} == params + + def test_attwithfn2231singleqmissing(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionParam): + disptype, params = parse_content_disposition( + "attachment; filename*=UTF-8'foo-%c3%a4.html" + ) + assert "attachment" == disptype + assert {} == params + + @pytest.mark.skip("urllib.parse.unquote is tolerate to standalone % chars") + def test_attwithfn2231nbadpct1(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionParam): + disptype, params = parse_content_disposition( + "attachment; filename*=UTF-8''foo%" + ) + assert "attachment" == disptype + assert {} == params + + @pytest.mark.skip("urllib.parse.unquote is tolerate to standalone % chars") + def test_attwithfn2231nbadpct2(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionParam): + disptype, params = parse_content_disposition( + "attachment; filename*=UTF-8''f%oo.html" + ) + assert "attachment" == disptype + assert {} == params + + def test_attwithfn2231dpct(self) -> None: + disptype, params = parse_content_disposition( + "attachment; filename*=UTF-8''A-%2541.html" + ) + assert "attachment" == disptype + assert {"filename*": "A-%41.html"} == params + + def test_attwithfn2231abspathdisguised(self) -> None: + disptype, params = parse_content_disposition( + "attachment; filename*=UTF-8''%5cfoo.html" + ) + assert "attachment" == disptype + assert {"filename*": "\\foo.html"} == params + + def test_attfncont(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename*0="foo."; filename*1="html"' + ) + assert "attachment" == disptype + assert {"filename*0": "foo.", "filename*1": "html"} == params + + def test_attfncontqs(self) -> None: + disptype, params = parse_content_disposition( + r'attachment; filename*0="foo"; filename*1="\b\a\r.html"' + ) + assert "attachment" == disptype + assert {"filename*0": "foo", "filename*1": "bar.html"} == params + + def test_attfncontenc(self) -> None: + disptype, params = parse_content_disposition( + "attachment; filename*0*=UTF-8" 'foo-%c3%a4; filename*1=".html"' + ) + assert "attachment" == disptype + assert {"filename*0*": "UTF-8" "foo-%c3%a4", "filename*1": ".html"} == params + + def test_attfncontlz(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename*0="foo"; filename*01="bar"' + ) + assert "attachment" == disptype + assert {"filename*0": "foo", "filename*01": "bar"} == params + + def test_attfncontnc(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename*0="foo"; filename*2="bar"' + ) + assert "attachment" == disptype + assert {"filename*0": "foo", "filename*2": "bar"} == params + + def test_attfnconts1(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename*0="foo."; filename*2="html"' + ) + assert "attachment" == disptype + assert {"filename*0": "foo.", "filename*2": "html"} == params + + def test_attfncontord(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename*1="bar"; filename*0="foo"' + ) + assert "attachment" == disptype + assert {"filename*0": "foo", "filename*1": "bar"} == params + + def test_attfnboth(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="foo-ae.html";' " filename*=UTF-8''foo-%c3%a4.html" + ) + assert "attachment" == disptype + assert {"filename": "foo-ae.html", "filename*": "foo-ä.html"} == params + + def test_attfnboth2(self) -> None: + disptype, params = parse_content_disposition( + "attachment; filename*=UTF-8''foo-%c3%a4.html;" ' filename="foo-ae.html"' + ) + assert "attachment" == disptype + assert {"filename": "foo-ae.html", "filename*": "foo-ä.html"} == params + + def test_attfnboth3(self) -> None: + disptype, params = parse_content_disposition( + "attachment; filename*0*=ISO-8859-15''euro-sign%3d%a4;" + " filename*=ISO-8859-1''currency-sign%3d%a4" + ) + assert "attachment" == disptype + assert { + "filename*": "currency-sign=¤", + "filename*0*": "ISO-8859-15''euro-sign%3d%a4", + } == params + + def test_attnewandfn(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; foobar=x; filename="foo.html"' + ) + assert "attachment" == disptype + assert {"foobar": "x", "filename": "foo.html"} == params + + def test_attrfc2047token(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionHeader): + disptype, params = parse_content_disposition( + "attachment; filename==?ISO-8859-1?Q?foo-=E4.html?=" + ) + assert disptype is None + assert {} == params + + def test_attrfc2047quoted(self) -> None: + disptype, params = parse_content_disposition( + 'attachment; filename="=?ISO-8859-1?Q?foo-=E4.html?="' + ) + assert "attachment" == disptype + assert {"filename": "=?ISO-8859-1?Q?foo-=E4.html?="} == params + + def test_bad_continuous_param(self) -> None: + with pytest.warns(aiohttp.BadContentDispositionParam): + disptype, params = parse_content_disposition( + "attachment; filename*0=foo bar" + ) + assert "attachment" == disptype + assert {} == params + + +class TestContentDispositionFilename: + # http://greenbytes.de/tech/tc2231/ + + def test_no_filename(self) -> None: + assert content_disposition_filename({}) is None + assert content_disposition_filename({"foo": "bar"}) is None + + def test_filename(self) -> None: + params = {"filename": "foo.html"} + assert "foo.html" == content_disposition_filename(params) + + def test_filename_ext(self) -> None: + params = {"filename*": "файл.html"} + assert "файл.html" == content_disposition_filename(params) + + def test_attfncont(self) -> None: + params = {"filename*0": "foo.", "filename*1": "html"} + assert "foo.html" == content_disposition_filename(params) + + def test_attfncontqs(self) -> None: + params = {"filename*0": "foo", "filename*1": "bar.html"} + assert "foobar.html" == content_disposition_filename(params) + + def test_attfncontenc(self) -> None: + params = {"filename*0*": "UTF-8''foo-%c3%a4", "filename*1": ".html"} + assert "foo-ä.html" == content_disposition_filename(params) + + def test_attfncontlz(self) -> None: + params = {"filename*0": "foo", "filename*01": "bar"} + assert "foo" == content_disposition_filename(params) + + def test_attfncontnc(self) -> None: + params = {"filename*0": "foo", "filename*2": "bar"} + assert "foo" == content_disposition_filename(params) + + def test_attfnconts1(self) -> None: + params = {"filename*1": "foo", "filename*2": "bar"} + assert content_disposition_filename(params) is None + + def test_attfnboth(self) -> None: + params = {"filename": "foo-ae.html", "filename*": "foo-ä.html"} + assert "foo-ä.html" == content_disposition_filename(params) + + def test_attfnboth3(self) -> None: + params = { + "filename*0*": "ISO-8859-15''euro-sign%3d%a4", + "filename*": "currency-sign=¤", + } + assert "currency-sign=¤" == content_disposition_filename(params) + + def test_attrfc2047quoted(self) -> None: + params = {"filename": "=?ISO-8859-1?Q?foo-=E4.html?="} + assert "=?ISO-8859-1?Q?foo-=E4.html?=" == content_disposition_filename(params) diff --git a/tests/test_payload.py b/tests/test_payload.py index faf8d6de2b0..c075dba3cd3 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -1,8 +1,12 @@ +import array import asyncio +from io import StringIO +from unittest import mock import pytest +from async_generator import async_generator -from aiohttp import payload +from aiohttp import payload, streams @pytest.fixture @@ -14,13 +18,11 @@ def registry(): class Payload(payload.Payload): - - @asyncio.coroutine - def write(self, writer): + async def write(self, writer): pass -def test_register_type(registry): +def test_register_type(registry) -> None: class TestProvider: pass @@ -29,30 +31,108 @@ class TestProvider: assert isinstance(p, Payload) -def test_payload_ctor(): - p = Payload('test', encoding='utf-8', filename='test.txt') - assert p._value == 'test' - assert p._encoding == 'utf-8' +def test_register_unsupported_order(registry) -> None: + class TestProvider: + pass + + with pytest.raises(ValueError): + payload.register_payload(Payload, TestProvider, order=object()) + + +def test_payload_ctor() -> None: + p = Payload("test", encoding="utf-8", filename="test.txt") + assert p._value == "test" + assert p._encoding == "utf-8" assert p.size is None - assert p.filename == 'test.txt' - assert p.content_type == 'text/plain' + assert p.filename == "test.txt" + assert p.content_type == "text/plain" + + +def test_payload_content_type() -> None: + p = Payload("test", headers={"content-type": "application/json"}) + assert p.content_type == "application/json" + + +def test_bytes_payload_default_content_type() -> None: + p = payload.BytesPayload(b"data") + assert p.content_type == "application/octet-stream" + + +def test_bytes_payload_explicit_content_type() -> None: + p = payload.BytesPayload(b"data", content_type="application/custom") + assert p.content_type == "application/custom" + + +def test_bytes_payload_bad_type() -> None: + with pytest.raises(TypeError): + payload.BytesPayload(object()) + + +def test_bytes_payload_memoryview_correct_size() -> None: + mv = memoryview(array.array("H", [1, 2, 3])) + p = payload.BytesPayload(mv) + assert p.size == 6 + + +def test_string_payload() -> None: + p = payload.StringPayload("test") + assert p.encoding == "utf-8" + assert p.content_type == "text/plain; charset=utf-8" + + p = payload.StringPayload("test", encoding="koi8-r") + assert p.encoding == "koi8-r" + assert p.content_type == "text/plain; charset=koi8-r" + + p = payload.StringPayload("test", content_type="text/plain; charset=koi8-r") + assert p.encoding == "koi8-r" + assert p.content_type == "text/plain; charset=koi8-r" + + +def test_string_io_payload() -> None: + s = StringIO("ű" * 5000) + p = payload.StringIOPayload(s) + assert p.encoding == "utf-8" + assert p.content_type == "text/plain; charset=utf-8" + assert p.size == 10000 + + +def test_async_iterable_payload_default_content_type() -> None: + @async_generator + async def gen(): + pass + + p = payload.AsyncIterablePayload(gen()) + assert p.content_type == "application/octet-stream" + + +def test_async_iterable_payload_explicit_content_type() -> None: + @async_generator + async def gen(): + pass + + p = payload.AsyncIterablePayload(gen(), content_type="application/custom") + assert p.content_type == "application/custom" + +def test_async_iterable_payload_not_async_iterable() -> None: -def test_payload_content_type(): - p = Payload('test', headers={'content-type': 'application/json'}) - assert p.content_type == 'application/json' + with pytest.raises(TypeError): + payload.AsyncIterablePayload(object()) -def test_string_payload(): - p = payload.StringPayload('test') - assert p.encoding == 'utf-8' - assert p.content_type == 'text/plain; charset=utf-8' +async def test_stream_reader_long_lines() -> None: + loop = asyncio.get_event_loop() + DATA = b"0" * 1024 ** 3 - p = payload.StringPayload('test', encoding='koi8-r') - assert p.encoding == 'koi8-r' - assert p.content_type == 'text/plain; charset=koi8-r' + stream = streams.StreamReader(mock.Mock(), 2 ** 16, loop=loop) + stream.feed_data(DATA) + stream.feed_eof() + body = payload.get_payload(stream) - p = payload.StringPayload( - 'test', content_type='text/plain; charset=koi8-r') - assert p.encoding == 'koi8-r' - assert p.content_type == 'text/plain; charset=koi8-r' + writer = mock.Mock() + writer.write.return_value = loop.create_future() + writer.write.return_value.set_result(None) + await body.write(writer) + writer.write.assert_called_once_with(mock.ANY) + (chunk,), _ = writer.write.call_args + assert len(chunk) == len(DATA) diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 6be01d3ce37..3b1bf0c052a 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -1,6 +1,7 @@ import asyncio import gc import socket +import ssl import unittest from unittest import mock @@ -8,10 +9,18 @@ import aiohttp from aiohttp.client_reqrep import ClientRequest, ClientResponse +from aiohttp.helpers import TimerNoop from aiohttp.test_utils import make_mocked_coro class TestProxy(unittest.TestCase): + response_mock_attrs = { + "status": 200, + } + mocked_response = mock.Mock(**response_mock_attrs) + clientrequest_mock_attrs = { + "return_value.send.return_value.start": make_mocked_coro(mocked_response), + } def setUp(self): self.loop = asyncio.new_event_loop() @@ -24,440 +33,712 @@ def tearDown(self): self.loop.close() gc.collect() - @mock.patch('aiohttp.connector.ClientRequest') - def test_connect(self, ClientRequestMock): + @mock.patch("aiohttp.connector.ClientRequest") + def test_connect(self, ClientRequestMock) -> None: req = ClientRequest( - 'GET', URL('http://www.python.org'), - proxy=URL('http://proxy.example.com'), - loop=self.loop + "GET", + URL("http://www.python.org"), + proxy=URL("http://proxy.example.com"), + loop=self.loop, + ) + self.assertEqual(str(req.proxy), "http://proxy.example.com") + + # mock all the things! + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) + connector._resolve_host = make_mocked_coro([mock.MagicMock()]) + + proto = mock.Mock( + **{ + "transport.get_extra_info.return_value": False, + } + ) + self.loop.create_connection = make_mocked_coro((proto.transport, proto)) + conn = self.loop.run_until_complete( + connector.connect(req, None, aiohttp.ClientTimeout()) + ) + self.assertEqual(req.url, URL("http://www.python.org")) + self.assertIs(conn._protocol, proto) + self.assertIs(conn.transport, proto.transport) + + ClientRequestMock.assert_called_with( + "GET", + URL("http://proxy.example.com"), + auth=None, + headers={"Host": "www.python.org"}, + loop=self.loop, + ssl=None, + ) + + @mock.patch("aiohttp.connector.ClientRequest") + def test_proxy_headers(self, ClientRequestMock) -> None: + req = ClientRequest( + "GET", + URL("http://www.python.org"), + proxy=URL("http://proxy.example.com"), + proxy_headers={"Foo": "Bar"}, + loop=self.loop, ) - self.assertEqual(str(req.proxy), 'http://proxy.example.com') + self.assertEqual(str(req.proxy), "http://proxy.example.com") # mock all the things! - connector = aiohttp.TCPConnector(loop=self.loop) + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) connector._resolve_host = make_mocked_coro([mock.MagicMock()]) - proto = mock.Mock() - self.loop.create_connection = make_mocked_coro( - (proto.transport, proto)) - conn = self.loop.run_until_complete(connector.connect(req)) - self.assertEqual(req.url, URL('http://www.python.org')) + proto = mock.Mock( + **{ + "transport.get_extra_info.return_value": False, + } + ) + self.loop.create_connection = make_mocked_coro((proto.transport, proto)) + conn = self.loop.run_until_complete( + connector.connect(req, None, aiohttp.ClientTimeout()) + ) + self.assertEqual(req.url, URL("http://www.python.org")) self.assertIs(conn._protocol, proto) self.assertIs(conn.transport, proto.transport) ClientRequestMock.assert_called_with( - 'GET', URL('http://proxy.example.com'), + "GET", + URL("http://proxy.example.com"), auth=None, - headers={'Host': 'www.python.org'}, - loop=self.loop) + headers={"Host": "www.python.org", "Foo": "Bar"}, + loop=self.loop, + ssl=None, + ) - def test_proxy_auth(self): + def test_proxy_auth(self) -> None: with self.assertRaises(ValueError) as ctx: ClientRequest( - 'GET', URL('http://python.org'), - proxy=URL('http://proxy.example.com'), - proxy_auth=('user', 'pass'), - loop=mock.Mock()) + "GET", + URL("http://python.org"), + proxy=URL("http://proxy.example.com"), + proxy_auth=("user", "pass"), + loop=mock.Mock(), + ) self.assertEqual( ctx.exception.args[0], "proxy_auth must be None or BasicAuth() tuple", ) - @mock.patch('aiohttp.client_reqrep.PayloadWriter') - def _test_connect_request_with_unicode_host(self, Request_mock): - loop = mock.Mock() - request = ClientRequest("CONNECT", URL("http://éé.com/"), - loop=loop) + def test_proxy_dns_error(self) -> None: + async def make_conn(): + return aiohttp.TCPConnector() - request.response_class = mock.Mock() - request.write_bytes = mock.Mock() - request.write_bytes.return_value = asyncio.Future(loop=loop) - request.write_bytes.return_value.set_result(None) - request.send(mock.Mock()) - - Request_mock.assert_called_with(mock.ANY, mock.ANY, "xn--9caa.com:80", - mock.ANY, loop=loop) - - def test_proxy_connection_error(self): - connector = aiohttp.TCPConnector(loop=self.loop) + connector = self.loop.run_until_complete(make_conn()) connector._resolve_host = make_mocked_coro( - raise_exception=OSError('dont take it serious')) + raise_exception=OSError("dont take it serious") + ) req = ClientRequest( - 'GET', URL('http://www.python.org'), - proxy=URL('http://proxy.example.com'), + "GET", + URL("http://www.python.org"), + proxy=URL("http://proxy.example.com"), loop=self.loop, ) expected_headers = dict(req.headers) with self.assertRaises(aiohttp.ClientConnectorError): - self.loop.run_until_complete(connector.connect(req)) - self.assertEqual(req.url.path, '/') + self.loop.run_until_complete( + connector.connect(req, None, aiohttp.ClientTimeout()) + ) + self.assertEqual(req.url.path, "/") self.assertEqual(dict(req.headers), expected_headers) - @mock.patch('aiohttp.connector.ClientRequest') - def test_auth(self, ClientRequestMock): + def test_proxy_connection_error(self) -> None: + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) + connector._resolve_host = make_mocked_coro( + [ + { + "hostname": "www.python.org", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + ] + ) + connector._loop.create_connection = make_mocked_coro( + raise_exception=OSError("dont take it serious") + ) + + req = ClientRequest( + "GET", + URL("http://www.python.org"), + proxy=URL("http://proxy.example.com"), + loop=self.loop, + ) + with self.assertRaises(aiohttp.ClientProxyConnectionError): + self.loop.run_until_complete( + connector.connect(req, None, aiohttp.ClientTimeout()) + ) + + @mock.patch("aiohttp.connector.ClientRequest") + def test_https_connect(self, ClientRequestMock) -> None: proxy_req = ClientRequest( - 'GET', URL('http://proxy.example.com'), - auth=aiohttp.helpers.BasicAuth('user', 'pass'), - loop=self.loop + "GET", URL("http://proxy.example.com"), loop=self.loop ) ClientRequestMock.return_value = proxy_req - self.assertIn('AUTHORIZATION', proxy_req.headers) - self.assertNotIn('PROXY-AUTHORIZATION', proxy_req.headers) - connector = aiohttp.TCPConnector(loop=self.loop) - connector._resolve_host = make_mocked_coro([mock.MagicMock()]) + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) + proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) + + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) + connector._resolve_host = make_mocked_coro( + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) tr, proto = mock.Mock(), mock.Mock() self.loop.create_connection = make_mocked_coro((tr, proto)) req = ClientRequest( - 'GET', URL('http://www.python.org'), - proxy=URL('http://proxy.example.com'), - proxy_auth=aiohttp.helpers.BasicAuth('user', 'pass'), + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), loop=self.loop, ) - self.assertNotIn('AUTHORIZATION', req.headers) - self.assertNotIn('PROXY-AUTHORIZATION', req.headers) - conn = self.loop.run_until_complete(connector.connect(req)) + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) - self.assertEqual(req.url, URL('http://www.python.org')) - self.assertNotIn('AUTHORIZATION', req.headers) - self.assertIn('PROXY-AUTHORIZATION', req.headers) - self.assertNotIn('AUTHORIZATION', proxy_req.headers) - self.assertNotIn('PROXY-AUTHORIZATION', proxy_req.headers) + self.assertEqual(req.url.path, "/") + self.assertEqual(proxy_req.method, "CONNECT") + self.assertEqual(proxy_req.url, URL("https://www.python.org")) + tr.close.assert_called_once_with() + tr.get_extra_info.assert_called_with("socket", default=None) - ClientRequestMock.assert_called_with( - 'GET', URL('http://proxy.example.com'), - auth=aiohttp.helpers.BasicAuth('user', 'pass'), - loop=mock.ANY, headers=mock.ANY) - conn.close() + self.loop.run_until_complete(proxy_req.close()) + proxy_resp.close() + self.loop.run_until_complete(req.close()) - def test_auth_utf8(self): + @mock.patch("aiohttp.connector.ClientRequest") + def test_https_connect_certificate_error(self, ClientRequestMock) -> None: proxy_req = ClientRequest( - 'GET', URL('http://proxy.example.com'), - auth=aiohttp.helpers.BasicAuth('юзер', 'пасс', 'utf-8'), - loop=self.loop) - self.assertIn('AUTHORIZATION', proxy_req.headers) - - @mock.patch('aiohttp.connector.ClientRequest') - def test_auth_from_url(self, ClientRequestMock): - proxy_req = ClientRequest('GET', - URL('http://user:pass@proxy.example.com'), - loop=self.loop) + "GET", URL("http://proxy.example.com"), loop=self.loop + ) ClientRequestMock.return_value = proxy_req - self.assertIn('AUTHORIZATION', proxy_req.headers) - self.assertNotIn('PROXY-AUTHORIZATION', proxy_req.headers) - connector = aiohttp.TCPConnector(loop=self.loop) - connector._resolve_host = make_mocked_coro([mock.MagicMock()]) + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) + proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) - tr, proto = mock.Mock(), mock.Mock() - self.loop.create_connection = make_mocked_coro((tr, proto)) + async def make_conn(): + return aiohttp.TCPConnector() - req = ClientRequest( - 'GET', URL('http://www.python.org'), - proxy=URL('http://user:pass@proxy.example.com'), - loop=self.loop, + connector = self.loop.run_until_complete(make_conn()) + connector._resolve_host = make_mocked_coro( + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] ) - self.assertNotIn('AUTHORIZATION', req.headers) - self.assertNotIn('PROXY-AUTHORIZATION', req.headers) - conn = self.loop.run_until_complete(connector.connect(req)) - self.assertEqual(req.url, URL('http://www.python.org')) - self.assertNotIn('AUTHORIZATION', req.headers) - self.assertIn('PROXY-AUTHORIZATION', req.headers) - self.assertNotIn('AUTHORIZATION', proxy_req.headers) - self.assertNotIn('PROXY-AUTHORIZATION', proxy_req.headers) + seq = 0 - ClientRequestMock.assert_called_with( - 'GET', URL('http://user:pass@proxy.example.com'), - auth=None, loop=mock.ANY, headers=mock.ANY) - conn.close() - - @mock.patch('aiohttp.connector.ClientRequest') - def test_auth__not_modifying_request(self, ClientRequestMock): - proxy_req = ClientRequest('GET', - URL('http://user:pass@proxy.example.com'), - loop=self.loop) - ClientRequestMock.return_value = proxy_req - proxy_req_headers = dict(proxy_req.headers) + async def create_connection(*args, **kwargs): + nonlocal seq + seq += 1 - connector = aiohttp.TCPConnector(loop=self.loop) - connector._resolve_host = make_mocked_coro( - raise_exception=OSError('nothing personal')) + # connection to http://proxy.example.com + if seq == 1: + return mock.Mock(), mock.Mock() + # connection to https://www.python.org + elif seq == 2: + raise ssl.CertificateError + else: + assert False + + self.loop.create_connection = create_connection req = ClientRequest( - 'GET', URL('http://www.python.org'), - proxy=URL('http://user:pass@proxy.example.com'), + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), loop=self.loop, ) - req_headers = dict(req.headers) - with self.assertRaises(aiohttp.ClientConnectorError): - self.loop.run_until_complete(connector.connect(req)) - self.assertEqual(req.headers, req_headers) - self.assertEqual(req.url.path, '/') - self.assertEqual(proxy_req.headers, proxy_req_headers) - - @mock.patch('aiohttp.connector.ClientRequest') - def test_https_connect(self, ClientRequestMock): - proxy_req = ClientRequest('GET', URL('http://proxy.example.com'), - loop=self.loop) + with self.assertRaises(aiohttp.ClientConnectorCertificateError): + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) + + @mock.patch("aiohttp.connector.ClientRequest") + def test_https_connect_ssl_error(self, ClientRequestMock) -> None: + proxy_req = ClientRequest( + "GET", URL("http://proxy.example.com"), loop=self.loop + ) ClientRequestMock.return_value = proxy_req - proxy_resp = ClientResponse('get', URL('http://proxy.example.com')) - proxy_resp._loop = self.loop - proxy_req.send = send_mock = mock.Mock() - send_mock.return_value = proxy_resp + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) - connector = aiohttp.TCPConnector(loop=self.loop) + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) connector._resolve_host = make_mocked_coro( - [{'hostname': 'hostname', 'host': '127.0.0.1', 'port': 80, - 'family': socket.AF_INET, 'proto': 0, 'flags': 0}]) + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) - tr, proto = mock.Mock(), mock.Mock() - self.loop.create_connection = make_mocked_coro((tr, proto)) + seq = 0 + + async def create_connection(*args, **kwargs): + nonlocal seq + seq += 1 + + # connection to http://proxy.example.com + if seq == 1: + return mock.Mock(), mock.Mock() + # connection to https://www.python.org + elif seq == 2: + raise ssl.SSLError + else: + assert False + + self.loop.create_connection = create_connection req = ClientRequest( - 'GET', URL('https://www.python.org'), - proxy=URL('http://proxy.example.com'), + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), loop=self.loop, ) - self.loop.run_until_complete(connector._create_connection(req)) - - self.assertEqual(req.url.path, '/') - self.assertEqual(proxy_req.method, 'CONNECT') - self.assertEqual(proxy_req.url, URL('https://www.python.org')) - tr.close.assert_called_once_with() - tr.get_extra_info.assert_called_with('socket', default=None) - - self.loop.run_until_complete(proxy_req.close()) - proxy_resp.close() - self.loop.run_until_complete(req.close()) + with self.assertRaises(aiohttp.ClientConnectorSSLError): + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) - @mock.patch('aiohttp.connector.ClientRequest') - def test_https_connect_runtime_error(self, ClientRequestMock): - proxy_req = ClientRequest('GET', URL('http://proxy.example.com'), - loop=self.loop) + @mock.patch("aiohttp.connector.ClientRequest") + def test_https_connect_runtime_error(self, ClientRequestMock) -> None: + proxy_req = ClientRequest( + "GET", URL("http://proxy.example.com"), loop=self.loop + ) ClientRequestMock.return_value = proxy_req - proxy_resp = ClientResponse('get', URL('http://proxy.example.com')) - proxy_resp._loop = self.loop - proxy_req.send = send_mock = mock.Mock() - send_mock.return_value = proxy_resp + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) - connector = aiohttp.TCPConnector(loop=self.loop) + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) connector._resolve_host = make_mocked_coro( - [{'hostname': 'hostname', 'host': '127.0.0.1', 'port': 80, - 'family': socket.AF_INET, 'proto': 0, 'flags': 0}]) + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) tr, proto = mock.Mock(), mock.Mock() tr.get_extra_info.return_value = None self.loop.create_connection = make_mocked_coro((tr, proto)) req = ClientRequest( - 'GET', URL('https://www.python.org'), - proxy=URL('http://proxy.example.com'), + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), loop=self.loop, ) with self.assertRaisesRegex( - RuntimeError, "Transport does not expose socket instance"): - self.loop.run_until_complete(connector._create_connection(req)) + RuntimeError, "Transport does not expose socket instance" + ): + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) self.loop.run_until_complete(proxy_req.close()) proxy_resp.close() self.loop.run_until_complete(req.close()) - @mock.patch('aiohttp.connector.ClientRequest') - def test_https_connect_http_proxy_error(self, ClientRequestMock): - proxy_req = ClientRequest('GET', URL('http://proxy.example.com'), - loop=self.loop) + @mock.patch("aiohttp.connector.ClientRequest") + def test_https_connect_http_proxy_error(self, ClientRequestMock) -> None: + proxy_req = ClientRequest( + "GET", URL("http://proxy.example.com"), loop=self.loop + ) ClientRequestMock.return_value = proxy_req - proxy_resp = ClientResponse('get', URL('http://proxy.example.com')) - proxy_resp._loop = self.loop - proxy_req.send = send_mock = mock.Mock() - send_mock.return_value = proxy_resp - proxy_resp.start = make_mocked_coro( - mock.Mock(status=400, reason='bad request')) + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) + proxy_resp.start = make_mocked_coro(mock.Mock(status=400, reason="bad request")) + + async def make_conn(): + return aiohttp.TCPConnector() - connector = aiohttp.TCPConnector(loop=self.loop) - connector = aiohttp.TCPConnector(loop=self.loop) + connector = self.loop.run_until_complete(make_conn()) connector._resolve_host = make_mocked_coro( - [{'hostname': 'hostname', 'host': '127.0.0.1', 'port': 80, - 'family': socket.AF_INET, 'proto': 0, 'flags': 0}]) + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) tr, proto = mock.Mock(), mock.Mock() tr.get_extra_info.return_value = None self.loop.create_connection = make_mocked_coro((tr, proto)) req = ClientRequest( - 'GET', URL('https://www.python.org'), - proxy=URL('http://proxy.example.com'), + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), loop=self.loop, ) with self.assertRaisesRegex( - aiohttp.ClientHttpProxyError, "400, message='bad request'"): - self.loop.run_until_complete(connector._create_connection(req)) + aiohttp.ClientHttpProxyError, "400, message='bad request'" + ): + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) self.loop.run_until_complete(proxy_req.close()) proxy_resp.close() self.loop.run_until_complete(req.close()) - @mock.patch('aiohttp.connector.ClientRequest') - def test_https_connect_resp_start_error(self, ClientRequestMock): - proxy_req = ClientRequest('GET', URL('http://proxy.example.com'), - loop=self.loop) + @mock.patch("aiohttp.connector.ClientRequest") + def test_https_connect_resp_start_error(self, ClientRequestMock) -> None: + proxy_req = ClientRequest( + "GET", URL("http://proxy.example.com"), loop=self.loop + ) ClientRequestMock.return_value = proxy_req - proxy_resp = ClientResponse('get', URL('http://proxy.example.com')) - proxy_resp._loop = self.loop - proxy_req.send = send_mock = mock.Mock() - send_mock.return_value = proxy_resp - proxy_resp.start = make_mocked_coro( - raise_exception=OSError("error message")) + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) + proxy_resp.start = make_mocked_coro(raise_exception=OSError("error message")) + + async def make_conn(): + return aiohttp.TCPConnector() - connector = aiohttp.TCPConnector(loop=self.loop) + connector = self.loop.run_until_complete(make_conn()) connector._resolve_host = make_mocked_coro( - [{'hostname': 'hostname', 'host': '127.0.0.1', 'port': 80, - 'family': socket.AF_INET, 'proto': 0, 'flags': 0}]) + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) tr, proto = mock.Mock(), mock.Mock() tr.get_extra_info.return_value = None self.loop.create_connection = make_mocked_coro((tr, proto)) req = ClientRequest( - 'GET', URL('https://www.python.org'), - proxy=URL('http://proxy.example.com'), + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), loop=self.loop, ) with self.assertRaisesRegex(OSError, "error message"): - self.loop.run_until_complete(connector._create_connection(req)) + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) - @mock.patch('aiohttp.connector.ClientRequest') - def test_request_port(self, ClientRequestMock): - proxy_req = ClientRequest('GET', URL('http://proxy.example.com'), - loop=self.loop) + @mock.patch("aiohttp.connector.ClientRequest") + def test_request_port(self, ClientRequestMock) -> None: + proxy_req = ClientRequest( + "GET", URL("http://proxy.example.com"), loop=self.loop + ) ClientRequestMock.return_value = proxy_req - connector = aiohttp.TCPConnector(loop=self.loop) + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) connector._resolve_host = make_mocked_coro( - [{'hostname': 'hostname', 'host': '127.0.0.1', 'port': 80, - 'family': socket.AF_INET, 'proto': 0, 'flags': 0}]) + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) tr, proto = mock.Mock(), mock.Mock() tr.get_extra_info.return_value = None self.loop.create_connection = make_mocked_coro((tr, proto)) req = ClientRequest( - 'GET', URL('http://localhost:1234/path'), - proxy=URL('http://proxy.example.com'), + "GET", + URL("http://localhost:1234/path"), + proxy=URL("http://proxy.example.com"), loop=self.loop, ) - self.loop.run_until_complete(connector._create_connection(req)) - self.assertEqual(req.url, URL('http://localhost:1234/path')) + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) + self.assertEqual(req.url, URL("http://localhost:1234/path")) - def test_proxy_auth_property(self): + def test_proxy_auth_property(self) -> None: req = aiohttp.ClientRequest( - 'GET', URL('http://localhost:1234/path'), - proxy=URL('http://proxy.example.com'), - proxy_auth=aiohttp.helpers.BasicAuth('user', 'pass'), - loop=self.loop) - self.assertEqual(('user', 'pass', 'latin1'), req.proxy_auth) + "GET", + URL("http://localhost:1234/path"), + proxy=URL("http://proxy.example.com"), + proxy_auth=aiohttp.helpers.BasicAuth("user", "pass"), + loop=self.loop, + ) + self.assertEqual(("user", "pass", "latin1"), req.proxy_auth) - def test_proxy_auth_property_default(self): + def test_proxy_auth_property_default(self) -> None: req = aiohttp.ClientRequest( - 'GET', URL('http://localhost:1234/path'), - proxy=URL('http://proxy.example.com'), - loop=self.loop) + "GET", + URL("http://localhost:1234/path"), + proxy=URL("http://proxy.example.com"), + loop=self.loop, + ) self.assertIsNone(req.proxy_auth) - @mock.patch('aiohttp.connector.ClientRequest') - def test_https_connect_pass_ssl_context(self, ClientRequestMock): - proxy_req = ClientRequest('GET', URL('http://proxy.example.com'), - loop=self.loop) + @mock.patch("aiohttp.connector.ClientRequest") + def test_https_connect_pass_ssl_context(self, ClientRequestMock) -> None: + proxy_req = ClientRequest( + "GET", URL("http://proxy.example.com"), loop=self.loop + ) ClientRequestMock.return_value = proxy_req - proxy_resp = ClientResponse('get', URL('http://proxy.example.com')) - proxy_resp._loop = self.loop - proxy_req.send = send_mock = mock.Mock() - send_mock.return_value = proxy_resp + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) - connector = aiohttp.TCPConnector(loop=self.loop) + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) connector._resolve_host = make_mocked_coro( - [{'hostname': 'hostname', 'host': '127.0.0.1', 'port': 80, - 'family': socket.AF_INET, 'proto': 0, 'flags': 0}]) + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) tr, proto = mock.Mock(), mock.Mock() self.loop.create_connection = make_mocked_coro((tr, proto)) req = ClientRequest( - 'GET', URL('https://www.python.org'), - proxy=URL('http://proxy.example.com'), + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), loop=self.loop, ) - self.loop.run_until_complete(connector._create_connection(req)) + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) self.loop.create_connection.assert_called_with( mock.ANY, - ssl=connector.ssl_context, + ssl=connector._make_ssl_context(True), sock=mock.ANY, - server_hostname='www.python.org') + server_hostname="www.python.org", + ) - self.assertEqual(req.url.path, '/') - self.assertEqual(proxy_req.method, 'CONNECT') - self.assertEqual(proxy_req.url, URL('https://www.python.org')) + self.assertEqual(req.url.path, "/") + self.assertEqual(proxy_req.method, "CONNECT") + self.assertEqual(proxy_req.url, URL("https://www.python.org")) tr.close.assert_called_once_with() - tr.get_extra_info.assert_called_with('socket', default=None) + tr.get_extra_info.assert_called_with("socket", default=None) self.loop.run_until_complete(proxy_req.close()) proxy_resp.close() self.loop.run_until_complete(req.close()) - @mock.patch('aiohttp.connector.ClientRequest') - def test_https_auth(self, ClientRequestMock): - proxy_req = ClientRequest('GET', URL('http://proxy.example.com'), - auth=aiohttp.helpers.BasicAuth('user', - 'pass'), - loop=self.loop) + @mock.patch("aiohttp.connector.ClientRequest") + def test_https_auth(self, ClientRequestMock) -> None: + proxy_req = ClientRequest( + "GET", + URL("http://proxy.example.com"), + auth=aiohttp.helpers.BasicAuth("user", "pass"), + loop=self.loop, + ) ClientRequestMock.return_value = proxy_req - proxy_resp = ClientResponse('get', URL('http://proxy.example.com')) - proxy_resp._loop = self.loop - proxy_req.send = send_mock = mock.Mock() - send_mock.return_value = proxy_resp + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) - connector = aiohttp.TCPConnector(loop=self.loop) + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) connector._resolve_host = make_mocked_coro( - [{'hostname': 'hostname', 'host': '127.0.0.1', 'port': 80, - 'family': socket.AF_INET, 'proto': 0, 'flags': 0}]) + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) tr, proto = mock.Mock(), mock.Mock() self.loop.create_connection = make_mocked_coro((tr, proto)) - self.assertIn('AUTHORIZATION', proxy_req.headers) - self.assertNotIn('PROXY-AUTHORIZATION', proxy_req.headers) + self.assertIn("AUTHORIZATION", proxy_req.headers) + self.assertNotIn("PROXY-AUTHORIZATION", proxy_req.headers) req = ClientRequest( - 'GET', URL('https://www.python.org'), - proxy=URL('http://proxy.example.com'), - loop=self.loop + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), + loop=self.loop, + ) + self.assertNotIn("AUTHORIZATION", req.headers) + self.assertNotIn("PROXY-AUTHORIZATION", req.headers) + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) ) - self.assertNotIn('AUTHORIZATION', req.headers) - self.assertNotIn('PROXY-AUTHORIZATION', req.headers) - self.loop.run_until_complete(connector._create_connection(req)) - self.assertEqual(req.url.path, '/') - self.assertNotIn('AUTHORIZATION', req.headers) - self.assertNotIn('PROXY-AUTHORIZATION', req.headers) - self.assertNotIn('AUTHORIZATION', proxy_req.headers) - self.assertIn('PROXY-AUTHORIZATION', proxy_req.headers) + self.assertEqual(req.url.path, "/") + self.assertNotIn("AUTHORIZATION", req.headers) + self.assertNotIn("PROXY-AUTHORIZATION", req.headers) + self.assertNotIn("AUTHORIZATION", proxy_req.headers) + self.assertIn("PROXY-AUTHORIZATION", proxy_req.headers) - connector._resolve_host.assert_called_with('proxy.example.com', 80) + connector._resolve_host.assert_called_with( + "proxy.example.com", 80, traces=mock.ANY + ) self.loop.run_until_complete(proxy_req.close()) proxy_resp.close() diff --git a/tests/test_proxy_functional.py b/tests/test_proxy_functional.py index ee8992ffbef..68763cd446e 100644 --- a/tests/test_proxy_functional.py +++ b/tests/test_proxy_functional.py @@ -1,28 +1,26 @@ import asyncio -from functools import partial +import os +import pathlib from unittest import mock import pytest from yarl import URL import aiohttp -import aiohttp.helpers -import aiohttp.web +from aiohttp import web @pytest.fixture -def proxy_test_server(raw_test_server, loop, monkeypatch): - """Handle all proxy requests and imitate remote server response.""" +def proxy_test_server(aiohttp_raw_server, loop, monkeypatch): + # Handle all proxy requests and imitate remote server response. _patch_ssl_transport(monkeypatch) - default_response = dict( - status=200, - headers=None, - body=None) + default_response = dict(status=200, headers=None, body=None) - @asyncio.coroutine - def proxy_handler(request, proxy_mock): + proxy_mock = mock.Mock() + + async def proxy_handler(request): proxy_mock.request = request proxy_mock.requests_list.append(request) @@ -30,475 +28,632 @@ def proxy_handler(request, proxy_mock): if isinstance(proxy_mock.return_value, dict): response.update(proxy_mock.return_value) - headers = response['headers'] + headers = response["headers"] if not headers: headers = {} - if request.method == 'CONNECT': - response['body'] = None + if request.method == "CONNECT": + response["body"] = None - response['headers'] = headers + response["headers"] = headers - resp = aiohttp.web.Response(**response) - yield from resp.prepare(request) - yield from resp.drain() + resp = web.Response(**response) + await resp.prepare(request) + await resp.write_eof() return resp - @asyncio.coroutine - def proxy_server(): - proxy_mock = mock.Mock() + async def proxy_server(): proxy_mock.request = None + proxy_mock.auth = None proxy_mock.requests_list = [] - handler = partial(proxy_handler, proxy_mock=proxy_mock) - server = yield from raw_test_server(handler) + server = await aiohttp_raw_server(proxy_handler) proxy_mock.server = server - proxy_mock.url = server.make_url('/') + proxy_mock.url = server.make_url("/") return proxy_mock return proxy_server -@asyncio.coroutine -def _request(method, url, loop=None, **kwargs): - with aiohttp.ClientSession(loop=loop) as client: - resp = yield from client.request(method, url, **kwargs) - yield from resp.release() - return resp - - @pytest.fixture() def get_request(loop): - return partial(_request, method='GET', loop=loop) + async def _request(method="GET", *, url, trust_env=False, **kwargs): + connector = aiohttp.TCPConnector(ssl=False, loop=loop) + client = aiohttp.ClientSession(connector=connector, trust_env=trust_env) + try: + resp = await client.request(method, url, **kwargs) + await resp.release() + return resp + finally: + await client.close() + return _request -@asyncio.coroutine -def test_proxy_http_absolute_path(proxy_test_server, get_request): - url = 'http://aiohttp.io/path?query=yes' - proxy = yield from proxy_test_server() - yield from get_request(url=url, proxy=proxy.url) +async def test_proxy_http_absolute_path(proxy_test_server, get_request) -> None: + url = "http://aiohttp.io/path?query=yes" + proxy = await proxy_test_server() + + await get_request(url=url, proxy=proxy.url) assert len(proxy.requests_list) == 1 - assert proxy.request.method == 'GET' - assert proxy.request.host == 'aiohttp.io' - assert proxy.request.path_qs == 'http://aiohttp.io/path?query=yes' + assert proxy.request.method == "GET" + assert proxy.request.host == "aiohttp.io" + assert proxy.request.path_qs == "http://aiohttp.io/path?query=yes" -@asyncio.coroutine -def test_proxy_http_raw_path(proxy_test_server, get_request): - url = 'http://aiohttp.io:2561/space sheep?q=can:fly' - raw_url = 'http://aiohttp.io:2561/space%20sheep?q=can:fly' - proxy = yield from proxy_test_server() +async def test_proxy_http_raw_path(proxy_test_server, get_request) -> None: + url = "http://aiohttp.io:2561/space sheep?q=can:fly" + raw_url = "http://aiohttp.io:2561/space%20sheep?q=can:fly" + proxy = await proxy_test_server() - yield from get_request(url=url, proxy=proxy.url) + await get_request(url=url, proxy=proxy.url) - assert proxy.request.host == 'aiohttp.io:2561' + assert proxy.request.host == "aiohttp.io:2561" assert proxy.request.path_qs == raw_url -@asyncio.coroutine -def test_proxy_http_idna_support(proxy_test_server, get_request): - url = 'http://éé.com/' - raw_url = 'http://xn--9caa.com/' - proxy = yield from proxy_test_server() +async def test_proxy_http_idna_support(proxy_test_server, get_request) -> None: + url = "http://éé.com/" + raw_url = "http://xn--9caa.com/" + proxy = await proxy_test_server() - yield from get_request(url=url, proxy=proxy.url) + await get_request(url=url, proxy=proxy.url) - assert proxy.request.host == 'xn--9caa.com' + assert proxy.request.host == "xn--9caa.com" assert proxy.request.path_qs == raw_url -@asyncio.coroutine -def test_proxy_http_connection_error(get_request): - url = 'http://aiohttp.io/path' - proxy_url = 'http://localhost:2242/' +async def test_proxy_http_connection_error(get_request) -> None: + url = "http://aiohttp.io/path" + proxy_url = "http://localhost:2242/" with pytest.raises(aiohttp.ClientConnectorError): - yield from get_request(url=url, proxy=proxy_url) + await get_request(url=url, proxy=proxy_url) -@asyncio.coroutine -def test_proxy_http_bad_response(proxy_test_server, get_request): - url = 'http://aiohttp.io/path' - proxy = yield from proxy_test_server() - proxy.return_value = dict( - status=502, - headers={'Proxy-Agent': 'TestProxy'}) +async def test_proxy_http_bad_response(proxy_test_server, get_request) -> None: + url = "http://aiohttp.io/path" + proxy = await proxy_test_server() + proxy.return_value = dict(status=502, headers={"Proxy-Agent": "TestProxy"}) - resp = yield from get_request(url=url, proxy=proxy.url) + resp = await get_request(url=url, proxy=proxy.url) assert resp.status == 502 - assert resp.headers['Proxy-Agent'] == 'TestProxy' + assert resp.headers["Proxy-Agent"] == "TestProxy" -@asyncio.coroutine -def test_proxy_http_auth(proxy_test_server, get_request): - url = 'http://aiohttp.io/path' - proxy = yield from proxy_test_server() +async def test_proxy_http_auth(proxy_test_server, get_request) -> None: + url = "http://aiohttp.io/path" + proxy = await proxy_test_server() - yield from get_request(url=url, proxy=proxy.url) + await get_request(url=url, proxy=proxy.url) - assert 'Authorization' not in proxy.request.headers - assert 'Proxy-Authorization' not in proxy.request.headers + assert "Authorization" not in proxy.request.headers + assert "Proxy-Authorization" not in proxy.request.headers - auth = aiohttp.helpers.BasicAuth('user', 'pass') - yield from get_request(url=url, auth=auth, proxy=proxy.url) + auth = aiohttp.BasicAuth("user", "pass") + await get_request(url=url, auth=auth, proxy=proxy.url) - assert 'Authorization' in proxy.request.headers - assert 'Proxy-Authorization' not in proxy.request.headers + assert "Authorization" in proxy.request.headers + assert "Proxy-Authorization" not in proxy.request.headers - yield from get_request(url=url, proxy_auth=auth, proxy=proxy.url) + await get_request(url=url, proxy_auth=auth, proxy=proxy.url) - assert 'Authorization' not in proxy.request.headers - assert 'Proxy-Authorization' in proxy.request.headers + assert "Authorization" not in proxy.request.headers + assert "Proxy-Authorization" in proxy.request.headers - yield from get_request(url=url, auth=auth, - proxy_auth=auth, proxy=proxy.url) + await get_request(url=url, auth=auth, proxy_auth=auth, proxy=proxy.url) - assert 'Authorization' in proxy.request.headers - assert 'Proxy-Authorization' in proxy.request.headers + assert "Authorization" in proxy.request.headers + assert "Proxy-Authorization" in proxy.request.headers -@asyncio.coroutine -def test_proxy_http_auth_utf8(proxy_test_server, get_request): - url = 'http://aiohttp.io/path' - auth = aiohttp.helpers.BasicAuth('юзер', 'пасс', 'utf-8') - proxy = yield from proxy_test_server() +async def test_proxy_http_auth_utf8(proxy_test_server, get_request) -> None: + url = "http://aiohttp.io/path" + auth = aiohttp.BasicAuth("юзер", "пасс", "utf-8") + proxy = await proxy_test_server() - yield from get_request(url=url, auth=auth, proxy=proxy.url) + await get_request(url=url, auth=auth, proxy=proxy.url) - assert 'Authorization' in proxy.request.headers - assert 'Proxy-Authorization' not in proxy.request.headers + assert "Authorization" in proxy.request.headers + assert "Proxy-Authorization" not in proxy.request.headers -@asyncio.coroutine -def test_proxy_http_auth_from_url(proxy_test_server, get_request): - url = 'http://aiohttp.io/path' - proxy = yield from proxy_test_server() +async def test_proxy_http_auth_from_url(proxy_test_server, get_request) -> None: + url = "http://aiohttp.io/path" + proxy = await proxy_test_server() - auth_url = URL(url).with_user('user').with_password('pass') - yield from get_request(url=auth_url, proxy=proxy.url) + auth_url = URL(url).with_user("user").with_password("pass") + await get_request(url=auth_url, proxy=proxy.url) - assert 'Authorization' in proxy.request.headers - assert 'Proxy-Authorization' not in proxy.request.headers + assert "Authorization" in proxy.request.headers + assert "Proxy-Authorization" not in proxy.request.headers - proxy_url = URL(proxy.url).with_user('user').with_password('pass') - yield from get_request(url=url, proxy=proxy_url) + proxy_url = URL(proxy.url).with_user("user").with_password("pass") + await get_request(url=url, proxy=proxy_url) - assert 'Authorization' not in proxy.request.headers - assert 'Proxy-Authorization' in proxy.request.headers + assert "Authorization" not in proxy.request.headers + assert "Proxy-Authorization" in proxy.request.headers -@asyncio.coroutine -def test_proxy_http_acquired_cleanup(proxy_test_server, loop): - url = 'http://aiohttp.io/path' +async def test_proxy_http_acquired_cleanup(proxy_test_server, loop) -> None: + url = "http://aiohttp.io/path" conn = aiohttp.TCPConnector(loop=loop) sess = aiohttp.ClientSession(connector=conn, loop=loop) - proxy = yield from proxy_test_server() + proxy = await proxy_test_server() assert 0 == len(conn._acquired) - resp = yield from sess.get(url, proxy=proxy.url) + resp = await sess.get(url, proxy=proxy.url) assert resp.closed assert 0 == len(conn._acquired) - sess.close() + await sess.close() -@pytest.mark.skip('we need to reconsider how we test this') -@asyncio.coroutine -def test_proxy_http_acquired_cleanup_force(proxy_test_server, loop): - url = 'http://aiohttp.io/path' +@pytest.mark.skip("we need to reconsider how we test this") +async def test_proxy_http_acquired_cleanup_force(proxy_test_server, loop) -> None: + url = "http://aiohttp.io/path" conn = aiohttp.TCPConnector(force_close=True, loop=loop) sess = aiohttp.ClientSession(connector=conn, loop=loop) - proxy = yield from proxy_test_server() + proxy = await proxy_test_server() assert 0 == len(conn._acquired) - @asyncio.coroutine - def request(): - resp = yield from sess.get(url, proxy=proxy.url) + async def request(): + resp = await sess.get(url, proxy=proxy.url) assert 1 == len(conn._acquired) - yield from resp.release() + await resp.release() - yield from request() + await request() assert 0 == len(conn._acquired) - yield from sess.close() + await sess.close() -@pytest.mark.skip('we need to reconsider how we test this') -@asyncio.coroutine -def test_proxy_http_multi_conn_limit(proxy_test_server, loop): - url = 'http://aiohttp.io/path' +@pytest.mark.skip("we need to reconsider how we test this") +async def test_proxy_http_multi_conn_limit(proxy_test_server, loop) -> None: + url = "http://aiohttp.io/path" limit, multi_conn_num = 1, 5 conn = aiohttp.TCPConnector(limit=limit, loop=loop) sess = aiohttp.ClientSession(connector=conn, loop=loop) - proxy = yield from proxy_test_server() + proxy = await proxy_test_server() current_pid = None - @asyncio.coroutine - def request(pid): + async def request(pid): # process requests only one by one nonlocal current_pid - resp = yield from sess.get(url, proxy=proxy.url) + resp = await sess.get(url, proxy=proxy.url) current_pid = pid - yield from asyncio.sleep(0.2, loop=loop) + await asyncio.sleep(0.2, loop=loop) assert current_pid == pid - yield from resp.release() + await resp.release() return resp requests = [request(pid) for pid in range(multi_conn_num)] - responses = yield from asyncio.gather(*requests, loop=loop) + responses = await asyncio.gather(*requests, loop=loop) assert len(responses) == multi_conn_num - assert set(resp.status for resp in responses) == {200} + assert {resp.status for resp in responses} == {200} - yield from sess.close() + await sess.close() -# @pytest.mark.xfail -@asyncio.coroutine -def _test_proxy_https_connect(proxy_test_server, get_request): - proxy = yield from proxy_test_server() - url = 'https://www.google.com.ua/search?q=aiohttp proxy' +@pytest.mark.xfail +async def xtest_proxy_https_connect(proxy_test_server, get_request): + proxy = await proxy_test_server() + url = "https://www.google.com.ua/search?q=aiohttp proxy" - yield from get_request(url=url, proxy=proxy.url) + await get_request(url=url, proxy=proxy.url) connect = proxy.requests_list[0] - assert connect.method == 'CONNECT' - assert connect.path == 'www.google.com.ua:443' - assert connect.host == 'www.google.com.ua' + assert connect.method == "CONNECT" + assert connect.path == "www.google.com.ua:443" + assert connect.host == "www.google.com.ua" - assert proxy.request.host == 'www.google.com.ua' - assert proxy.request.path_qs == '/search?q=aiohttp+proxy' + assert proxy.request.host == "www.google.com.ua" + assert proxy.request.path_qs == "/search?q=aiohttp+proxy" -# @pytest.mark.xfail -@asyncio.coroutine -def _test_proxy_https_connect_with_port(proxy_test_server, get_request): - proxy = yield from proxy_test_server() - url = 'https://secure.aiohttp.io:2242/path' +@pytest.mark.xfail +async def xtest_proxy_https_connect_with_port(proxy_test_server, get_request): + proxy = await proxy_test_server() + url = "https://secure.aiohttp.io:2242/path" - yield from get_request(url=url, proxy=proxy.url) + await get_request(url=url, proxy=proxy.url) connect = proxy.requests_list[0] - assert connect.method == 'CONNECT' - assert connect.path == 'secure.aiohttp.io:2242' - assert connect.host == 'secure.aiohttp.io:2242' + assert connect.method == "CONNECT" + assert connect.path == "secure.aiohttp.io:2242" + assert connect.host == "secure.aiohttp.io:2242" - assert proxy.request.host == 'secure.aiohttp.io:2242' - assert proxy.request.path_qs == '/path' + assert proxy.request.host == "secure.aiohttp.io:2242" + assert proxy.request.path_qs == "/path" -# @pytest.mark.xfail -@asyncio.coroutine -def _test_proxy_https_send_body(proxy_test_server, loop): +@pytest.mark.xfail +async def xtest_proxy_https_send_body(proxy_test_server, loop): sess = aiohttp.ClientSession(loop=loop) - proxy = yield from proxy_test_server() - proxy.return_value = {'status': 200, 'body': b'1'*(2**20)} - url = 'https://www.google.com.ua/search?q=aiohttp proxy' + proxy = await proxy_test_server() + proxy.return_value = {"status": 200, "body": b"1" * (2 ** 20)} + url = "https://www.google.com.ua/search?q=aiohttp proxy" - resp = yield from sess.get(url, proxy=proxy.url) - body = yield from resp.read() - yield from resp.release() - yield from sess.close() + resp = await sess.get(url, proxy=proxy.url) + body = await resp.read() + await resp.release() + await sess.close() - assert body == b'1'*(2**20) + assert body == b"1" * (2 ** 20) -# @pytest.mark.xfail -@asyncio.coroutine -def _test_proxy_https_idna_support(proxy_test_server, get_request): - url = 'https://éé.com/' - proxy = yield from proxy_test_server() +@pytest.mark.xfail +async def xtest_proxy_https_idna_support(proxy_test_server, get_request): + url = "https://éé.com/" + proxy = await proxy_test_server() - yield from get_request(url=url, proxy=proxy.url) + await get_request(url=url, proxy=proxy.url) connect = proxy.requests_list[0] - assert connect.method == 'CONNECT' - assert connect.path == 'xn--9caa.com:443' - assert connect.host == 'xn--9caa.com' + assert connect.method == "CONNECT" + assert connect.path == "xn--9caa.com:443" + assert connect.host == "xn--9caa.com" -@asyncio.coroutine -def test_proxy_https_connection_error(get_request): - url = 'https://secure.aiohttp.io/path' - proxy_url = 'http://localhost:2242/' +async def test_proxy_https_connection_error(get_request) -> None: + url = "https://secure.aiohttp.io/path" + proxy_url = "http://localhost:2242/" with pytest.raises(aiohttp.ClientConnectorError): - yield from get_request(url=url, proxy=proxy_url) + await get_request(url=url, proxy=proxy_url) -@asyncio.coroutine -def test_proxy_https_bad_response(proxy_test_server, get_request): - url = 'https://secure.aiohttp.io/path' - proxy = yield from proxy_test_server() - proxy.return_value = dict( - status=502, - headers={'Proxy-Agent': 'TestProxy'}) +async def test_proxy_https_bad_response(proxy_test_server, get_request) -> None: + url = "https://secure.aiohttp.io/path" + proxy = await proxy_test_server() + proxy.return_value = dict(status=502, headers={"Proxy-Agent": "TestProxy"}) with pytest.raises(aiohttp.ClientHttpProxyError): - yield from get_request(url=url, proxy=proxy.url) + await get_request(url=url, proxy=proxy.url) assert len(proxy.requests_list) == 1 - assert proxy.request.method == 'CONNECT' - assert proxy.request.path == 'secure.aiohttp.io:443' + assert proxy.request.method == "CONNECT" + # The following check fails on MacOS + # assert proxy.request.path == 'secure.aiohttp.io:443' -# @pytest.mark.xfail -@asyncio.coroutine -def _test_proxy_https_auth(proxy_test_server, get_request): - url = 'https://secure.aiohttp.io/path' - auth = aiohttp.helpers.BasicAuth('user', 'pass') +@pytest.mark.xfail +async def xtest_proxy_https_auth(proxy_test_server, get_request): + url = "https://secure.aiohttp.io/path" + auth = aiohttp.BasicAuth("user", "pass") - proxy = yield from proxy_test_server() - yield from get_request(url=url, proxy=proxy.url) + proxy = await proxy_test_server() + await get_request(url=url, proxy=proxy.url) connect = proxy.requests_list[0] - assert 'Authorization' not in connect.headers - assert 'Proxy-Authorization' not in connect.headers - assert 'Authorization' not in proxy.request.headers - assert 'Proxy-Authorization' not in proxy.request.headers + assert "Authorization" not in connect.headers + assert "Proxy-Authorization" not in connect.headers + assert "Authorization" not in proxy.request.headers + assert "Proxy-Authorization" not in proxy.request.headers - proxy = yield from proxy_test_server() - yield from get_request(url=url, auth=auth, proxy=proxy.url) + proxy = await proxy_test_server() + await get_request(url=url, auth=auth, proxy=proxy.url) connect = proxy.requests_list[0] - assert 'Authorization' not in connect.headers - assert 'Proxy-Authorization' not in connect.headers - assert 'Authorization' in proxy.request.headers - assert 'Proxy-Authorization' not in proxy.request.headers + assert "Authorization" not in connect.headers + assert "Proxy-Authorization" not in connect.headers + assert "Authorization" in proxy.request.headers + assert "Proxy-Authorization" not in proxy.request.headers - proxy = yield from proxy_test_server() - yield from get_request(url=url, proxy_auth=auth, proxy=proxy.url) + proxy = await proxy_test_server() + await get_request(url=url, proxy_auth=auth, proxy=proxy.url) connect = proxy.requests_list[0] - assert 'Authorization' not in connect.headers - assert 'Proxy-Authorization' in connect.headers - assert 'Authorization' not in proxy.request.headers - assert 'Proxy-Authorization' not in proxy.request.headers + assert "Authorization" not in connect.headers + assert "Proxy-Authorization" in connect.headers + assert "Authorization" not in proxy.request.headers + assert "Proxy-Authorization" not in proxy.request.headers - proxy = yield from proxy_test_server() - yield from get_request(url=url, auth=auth, - proxy_auth=auth, proxy=proxy.url) + proxy = await proxy_test_server() + await get_request(url=url, auth=auth, proxy_auth=auth, proxy=proxy.url) connect = proxy.requests_list[0] - assert 'Authorization' not in connect.headers - assert 'Proxy-Authorization' in connect.headers - assert 'Authorization' in proxy.request.headers - assert 'Proxy-Authorization' not in proxy.request.headers + assert "Authorization" not in connect.headers + assert "Proxy-Authorization" in connect.headers + assert "Authorization" in proxy.request.headers + assert "Proxy-Authorization" not in proxy.request.headers -# @pytest.mark.xfail -@asyncio.coroutine -def _test_proxy_https_acquired_cleanup(proxy_test_server, loop): - url = 'https://secure.aiohttp.io/path' +@pytest.mark.xfail +async def xtest_proxy_https_acquired_cleanup(proxy_test_server, loop): + url = "https://secure.aiohttp.io/path" conn = aiohttp.TCPConnector(loop=loop) sess = aiohttp.ClientSession(connector=conn, loop=loop) - proxy = yield from proxy_test_server() + proxy = await proxy_test_server() assert 0 == len(conn._acquired) - @asyncio.coroutine - def request(): - resp = yield from sess.get(url, proxy=proxy.url) + async def request(): + resp = await sess.get(url, proxy=proxy.url) assert 1 == len(conn._acquired) - yield from resp.release() + await resp.release() - yield from request() + await request() assert 0 == len(conn._acquired) - yield from sess.close() + await sess.close() -# @pytest.mark.xfail -@asyncio.coroutine -def _test_proxy_https_acquired_cleanup_force(proxy_test_server, loop): - url = 'https://secure.aiohttp.io/path' +@pytest.mark.xfail +async def xtest_proxy_https_acquired_cleanup_force(proxy_test_server, loop): + url = "https://secure.aiohttp.io/path" conn = aiohttp.TCPConnector(force_close=True, loop=loop) sess = aiohttp.ClientSession(connector=conn, loop=loop) - proxy = yield from proxy_test_server() + proxy = await proxy_test_server() assert 0 == len(conn._acquired) - @asyncio.coroutine - def request(): - resp = yield from sess.get(url, proxy=proxy.url) + async def request(): + resp = await sess.get(url, proxy=proxy.url) assert 1 == len(conn._acquired) - yield from resp.release() + await resp.release() - yield from request() + await request() assert 0 == len(conn._acquired) - yield from sess.close() + await sess.close() -# @pytest.mark.xfail -@asyncio.coroutine -def _test_proxy_https_multi_conn_limit(proxy_test_server, loop): - url = 'https://secure.aiohttp.io/path' +@pytest.mark.xfail +async def xtest_proxy_https_multi_conn_limit(proxy_test_server, loop): + url = "https://secure.aiohttp.io/path" limit, multi_conn_num = 1, 5 conn = aiohttp.TCPConnector(limit=limit, loop=loop) sess = aiohttp.ClientSession(connector=conn, loop=loop) - proxy = yield from proxy_test_server() + proxy = await proxy_test_server() current_pid = None - @asyncio.coroutine - def request(pid): + async def request(pid): # process requests only one by one nonlocal current_pid - resp = yield from sess.get(url, proxy=proxy.url) + resp = await sess.get(url, proxy=proxy.url) current_pid = pid - yield from asyncio.sleep(0.2, loop=loop) + await asyncio.sleep(0.2, loop=loop) assert current_pid == pid - yield from resp.release() + await resp.release() return resp requests = [request(pid) for pid in range(multi_conn_num)] - responses = yield from asyncio.gather(*requests, loop=loop) + responses = await asyncio.gather(*requests, loop=loop) assert len(responses) == multi_conn_num - assert set(resp.status for resp in responses) == {200} + assert {resp.status for resp in responses} == {200} - yield from sess.close() + await sess.close() def _patch_ssl_transport(monkeypatch): - """Make ssl transport substitution to prevent ssl handshake.""" - def _make_ssl_transport_dummy(self, rawsock, protocol, sslcontext, - waiter=None, **kwargs): - return self._make_socket_transport(rawsock, protocol, waiter, - extra=kwargs.get('extra'), - server=kwargs.get('server')) + # Make ssl transport substitution to prevent ssl handshake. + def _make_ssl_transport_dummy( + self, rawsock, protocol, sslcontext, waiter=None, **kwargs + ): + return self._make_socket_transport( + rawsock, + protocol, + waiter, + extra=kwargs.get("extra"), + server=kwargs.get("server"), + ) monkeypatch.setattr( "asyncio.selector_events.BaseSelectorEventLoop._make_ssl_transport", - _make_ssl_transport_dummy) + _make_ssl_transport_dummy, + ) + + +original_is_file = pathlib.Path.is_file + + +def mock_is_file(self): + # make real netrc file invisible in home dir + if self.name in ["_netrc", ".netrc"] and self.parent == self.home(): + return False + else: + return original_is_file(self) + + +async def test_proxy_from_env_http(proxy_test_server, get_request, mocker) -> None: + url = "http://aiohttp.io/path" + proxy = await proxy_test_server() + mocker.patch.dict(os.environ, {"http_proxy": str(proxy.url)}) + mocker.patch("pathlib.Path.is_file", mock_is_file) + + await get_request(url=url, trust_env=True) + + assert len(proxy.requests_list) == 1 + assert proxy.request.method == "GET" + assert proxy.request.host == "aiohttp.io" + assert proxy.request.path_qs == "http://aiohttp.io/path" + assert "Proxy-Authorization" not in proxy.request.headers + + +async def test_proxy_from_env_http_with_auth(proxy_test_server, get_request, mocker): + url = "http://aiohttp.io/path" + proxy = await proxy_test_server() + auth = aiohttp.BasicAuth("user", "pass") + mocker.patch.dict( + os.environ, + { + "http_proxy": str( + proxy.url.with_user(auth.login).with_password(auth.password) + ) + }, + ) + + await get_request(url=url, trust_env=True) + + assert len(proxy.requests_list) == 1 + assert proxy.request.method == "GET" + assert proxy.request.host == "aiohttp.io" + assert proxy.request.path_qs == "http://aiohttp.io/path" + assert proxy.request.headers["Proxy-Authorization"] == auth.encode() + + +async def test_proxy_from_env_http_with_auth_from_netrc( + proxy_test_server, get_request, tmpdir, mocker +): + url = "http://aiohttp.io/path" + proxy = await proxy_test_server() + auth = aiohttp.BasicAuth("user", "pass") + netrc_file = tmpdir.join("test_netrc") + netrc_file_data = "machine 127.0.0.1 login {} password {}".format( + auth.login, + auth.password, + ) + with open(str(netrc_file), "w") as f: + f.write(netrc_file_data) + mocker.patch.dict( + os.environ, {"http_proxy": str(proxy.url), "NETRC": str(netrc_file)} + ) + + await get_request(url=url, trust_env=True) + + assert len(proxy.requests_list) == 1 + assert proxy.request.method == "GET" + assert proxy.request.host == "aiohttp.io" + assert proxy.request.path_qs == "http://aiohttp.io/path" + assert proxy.request.headers["Proxy-Authorization"] == auth.encode() + + +async def test_proxy_from_env_http_without_auth_from_netrc( + proxy_test_server, get_request, tmpdir, mocker +): + url = "http://aiohttp.io/path" + proxy = await proxy_test_server() + auth = aiohttp.BasicAuth("user", "pass") + netrc_file = tmpdir.join("test_netrc") + netrc_file_data = "machine 127.0.0.2 login {} password {}".format( + auth.login, + auth.password, + ) + with open(str(netrc_file), "w") as f: + f.write(netrc_file_data) + mocker.patch.dict( + os.environ, {"http_proxy": str(proxy.url), "NETRC": str(netrc_file)} + ) + + await get_request(url=url, trust_env=True) + + assert len(proxy.requests_list) == 1 + assert proxy.request.method == "GET" + assert proxy.request.host == "aiohttp.io" + assert proxy.request.path_qs == "http://aiohttp.io/path" + assert "Proxy-Authorization" not in proxy.request.headers + + +async def test_proxy_from_env_http_without_auth_from_wrong_netrc( + proxy_test_server, get_request, tmpdir, mocker +): + url = "http://aiohttp.io/path" + proxy = await proxy_test_server() + auth = aiohttp.BasicAuth("user", "pass") + netrc_file = tmpdir.join("test_netrc") + invalid_data = f"machine 127.0.0.1 {auth.login} pass {auth.password}" + with open(str(netrc_file), "w") as f: + f.write(invalid_data) + + mocker.patch.dict( + os.environ, {"http_proxy": str(proxy.url), "NETRC": str(netrc_file)} + ) + + await get_request(url=url, trust_env=True) + + assert len(proxy.requests_list) == 1 + assert proxy.request.method == "GET" + assert proxy.request.host == "aiohttp.io" + assert proxy.request.path_qs == "http://aiohttp.io/path" + assert "Proxy-Authorization" not in proxy.request.headers + + +@pytest.mark.xfail +async def xtest_proxy_from_env_https(proxy_test_server, get_request, mocker): + url = "https://aiohttp.io/path" + proxy = await proxy_test_server() + mocker.patch.dict(os.environ, {"https_proxy": str(proxy.url)}) + mock.patch("pathlib.Path.is_file", mock_is_file) + + await get_request(url=url, trust_env=True) + + assert len(proxy.requests_list) == 2 + assert proxy.request.method == "GET" + assert proxy.request.host == "aiohttp.io" + assert proxy.request.path_qs == "https://aiohttp.io/path" + assert "Proxy-Authorization" not in proxy.request.headers + + +@pytest.mark.xfail +async def xtest_proxy_from_env_https_with_auth(proxy_test_server, get_request, mocker): + url = "https://aiohttp.io/path" + proxy = await proxy_test_server() + auth = aiohttp.BasicAuth("user", "pass") + mocker.patch.dict( + os.environ, + { + "https_proxy": str( + proxy.url.with_user(auth.login).with_password(auth.password) + ) + }, + ) + + await get_request(url=url, trust_env=True) + + assert len(proxy.requests_list) == 2 + + assert proxy.request.method == "GET" + assert proxy.request.host == "aiohttp.io" + assert proxy.request.path_qs == "/path" + assert "Proxy-Authorization" not in proxy.request.headers + + r2 = proxy.requests_list[0] + assert r2.method == "CONNECT" + assert r2.host == "aiohttp.io" + assert r2.path_qs == "/path" + assert r2.headers["Proxy-Authorization"] == auth.encode() + + +async def test_proxy_auth() -> None: + async with aiohttp.ClientSession() as session: + with pytest.raises( + ValueError, match=r"proxy_auth must be None or BasicAuth\(\) tuple" + ): + await session.get( + "http://python.org", + proxy="http://proxy.example.com", + proxy_auth=("user", "pass"), + ) diff --git a/tests/test_py35/test_cbv35.py b/tests/test_py35/test_cbv35.py deleted file mode 100644 index 1a8dd65a6a9..00000000000 --- a/tests/test_py35/test_cbv35.py +++ /dev/null @@ -1,17 +0,0 @@ -from unittest import mock - -from aiohttp import web -from aiohttp.web_urldispatcher import View - - -async def test_render_ok(): - resp = web.Response(text='OK') - - class MyView(View): - async def get(self): - return resp - - request = mock.Mock() - request._method = 'GET' - resp2 = await MyView(request) - assert resp is resp2 diff --git a/tests/test_py35/test_client.py b/tests/test_py35/test_client.py deleted file mode 100644 index e33f657eb4a..00000000000 --- a/tests/test_py35/test_client.py +++ /dev/null @@ -1,101 +0,0 @@ -import asyncio - -import pytest - -import aiohttp -from aiohttp import web - - -async def test_async_with_session(loop): - async with aiohttp.ClientSession(loop=loop) as session: - pass - - assert session.closed - - -async def test_close_resp_on_error_async_with_session(loop, test_server): - async def handler(request): - resp = web.StreamResponse(headers={'content-length': '100'}) - await resp.prepare(request) - await resp.drain() - await asyncio.sleep(0.1, loop=request.app.loop) - return resp - - app = web.Application() - app.router.add_get('/', handler) - server = await test_server(app) - - async with aiohttp.ClientSession(loop=loop) as session: - with pytest.raises(RuntimeError): - async with session.get(server.make_url('/')) as resp: - resp.content.set_exception(RuntimeError()) - await resp.read() - - assert len(session._connector._conns) == 0 - - -async def test_release_resp_on_normal_exit_from_cm(loop, test_server): - async def handler(request): - return web.Response() - - app = web.Application() - app.router.add_get('/', handler) - server = await test_server(app) - - async with aiohttp.ClientSession(loop=loop) as session: - async with session.get(server.make_url('/')) as resp: - await resp.read() - - assert len(session._connector._conns) == 1 - - -async def test_non_close_detached_session_on_error_cm(loop, test_server): - async def handler(request): - resp = web.StreamResponse(headers={'content-length': '100'}) - await resp.prepare(request) - await resp.drain() - await asyncio.sleep(0.1, loop=request.app.loop) - return resp - - app = web.Application() - app.router.add_get('/', handler) - server = await test_server(app) - - session = aiohttp.ClientSession(loop=loop) - cm = session.get(server.make_url('/')) - assert not session.closed - with pytest.raises(RuntimeError): - async with cm as resp: - resp.content.set_exception(RuntimeError()) - await resp.read() - assert not session.closed - - -async def test_close_detached_session_on_non_existing_addr(loop): - session = aiohttp.ClientSession(loop=loop) - - async with session: - cm = session.get('http://non-existing.example.com') - assert not session.closed - with pytest.raises(Exception): - await cm - - assert session.closed - - -async def test_aiohttp_request(loop, test_server): - async def handler(request): - return web.Response() - - app = web.Application() - app.router.add_get('/', handler) - server = await test_server(app) - - async with aiohttp.request('GET', server.make_url('/'), loop=loop) as resp: - await resp.read() - assert resp.status == 200 - - resp = await aiohttp.request('GET', server.make_url('/'), loop=loop) - await resp.read() - assert resp.status == 200 - assert resp.connection is None diff --git a/tests/test_py35/test_client_websocket_35.py b/tests/test_py35/test_client_websocket_35.py deleted file mode 100644 index 240fbfc3b46..00000000000 --- a/tests/test_py35/test_client_websocket_35.py +++ /dev/null @@ -1,137 +0,0 @@ -import pytest - -import aiohttp -from aiohttp import helpers, web - - -async def test_client_ws_async_for(loop, test_client): - items = ['q1', 'q2', 'q3'] - - async def handler(request): - ws = web.WebSocketResponse() - await ws.prepare(request) - for i in items: - ws.send_str(i) - await ws.close() - return ws - - app = web.Application() - app.router.add_route('GET', '/', handler) - - client = await test_client(app) - resp = await client.ws_connect('/') - it = iter(items) - async for msg in resp: - assert msg.data == next(it) - - with pytest.raises(StopIteration): - next(it) - - assert resp.closed - - -async def test_client_ws_async_with(loop, test_server): - - async def handler(request): - ws = web.WebSocketResponse() - await ws.prepare(request) - msg = await ws.receive() - ws.send_str(msg.data + '/answer') - await ws.close() - return ws - - app = web.Application() - app.router.add_route('GET', '/', handler) - - server = await test_server(app) - - async with aiohttp.ClientSession(loop=loop) as client: - async with client.ws_connect(server.make_url('/')) as ws: - ws.send_str('request') - msg = await ws.receive() - assert msg.data == 'request/answer' - - assert ws.closed - - -async def test_client_ws_async_with_send(loop, test_server): - # send_xxx methods have to return awaitable objects - - async def handler(request): - ws = web.WebSocketResponse() - await ws.prepare(request) - msg = await ws.receive() - ws.send_str(msg.data + '/answer') - await ws.close() - return ws - - app = web.Application() - app.router.add_route('GET', '/', handler) - - server = await test_server(app) - - async with aiohttp.ClientSession(loop=loop) as client: - async with client.ws_connect(server.make_url('/')) as ws: - await ws.send_str('request') - msg = await ws.receive() - assert msg.data == 'request/answer' - - assert ws.closed - - -async def test_client_ws_async_with_shortcut(loop, test_server): - - async def handler(request): - ws = web.WebSocketResponse() - await ws.prepare(request) - msg = await ws.receive() - ws.send_str(msg.data + '/answer') - await ws.close() - return ws - - app = web.Application() - app.router.add_route('GET', '/', handler) - server = await test_server(app) - - async with aiohttp.ClientSession(loop=loop) as client: - async with client.ws_connect(server.make_url('/')) as ws: - ws.send_str('request') - msg = await ws.receive() - assert msg.data == 'request/answer' - - assert ws.closed - - -async def test_closed_async_for(loop, test_client): - - closed = helpers.create_future(loop) - - async def handler(request): - ws = web.WebSocketResponse() - await ws.prepare(request) - - try: - ws.send_bytes(b'started') - await ws.receive_bytes() - finally: - closed.set_result(1) - return ws - - app = web.Application() - app.router.add_route('GET', '/', handler) - client = await test_client(app) - resp = await client.ws_connect('/') - - messages = [] - async for msg in resp: - messages.append(msg) - if b'started' == msg.data: - resp.send_bytes(b'ask') - await resp.close() - - assert 1 == len(messages) - assert messages[0].type == aiohttp.WSMsgType.BINARY - assert messages[0].data == b'started' - assert resp.closed - - await closed diff --git a/tests/test_py35/test_multipart_35.py b/tests/test_py35/test_multipart_35.py deleted file mode 100644 index a9a24229fbb..00000000000 --- a/tests/test_py35/test_multipart_35.py +++ /dev/null @@ -1,92 +0,0 @@ -import io -import json - -import aiohttp -import aiohttp.hdrs as h - - -class Stream(object): - - def __init__(self, content): - self.content = io.BytesIO(content) - - async def read(self, size=None): - return self.content.read(size) - - def at_eof(self): - return self.content.tell() == len(self.content.getbuffer()) - - async def readline(self): - return self.content.readline() - - def unread_data(self, data): - self.content = io.BytesIO(data + self.content.read()) - - -async def test_async_for_reader(loop): - data = [ - {"test": "passed"}, - 42, - b'plain text', - b'aiohttp\n', - b'no epilogue'] - reader = aiohttp.MultipartReader( - headers={h.CONTENT_TYPE: 'multipart/mixed; boundary=":"'}, - content=Stream(b'\r\n'.join([ - b'--:', - b'Content-Type: application/json', - b'', - json.dumps(data[0]).encode(), - b'--:', - b'Content-Type: application/json', - b'', - json.dumps(data[1]).encode(), - b'--:', - b'Content-Type: multipart/related; boundary="::"', - b'', - b'--::', - b'Content-Type: text/plain', - b'', - data[2], - b'--::', - b'Content-Disposition: attachment; filename="aiohttp"', - b'Content-Type: text/plain', - b'Content-Length: 28', - b'Content-Encoding: gzip', - b'', - b'\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03K\xcc\xcc\xcf())' - b'\xe0\x02\x00\xd6\x90\xe2O\x08\x00\x00\x00', - b'--::', - b'Content-Type: multipart/related; boundary=":::"', - b'', - b'--:::', - b'Content-Type: text/plain', - b'', - data[4], - b'--:::--', - b'--::--', - b'', - b'--:--', - b'']))) - idata = iter(data) - - async def check(reader): - async for part in reader: - if isinstance(part, aiohttp.BodyPartReader): - if part.headers[h.CONTENT_TYPE] == 'application/json': - assert next(idata) == (await part.json()) - else: - assert next(idata) == await part.read(decode=True) - else: - await check(part) - - await check(reader) - - -async def test_async_for_bodypart(loop): - part = aiohttp.BodyPartReader( - boundary=b'--:', - headers={}, - content=Stream(b'foobarbaz\r\n--:--')) - async for data in part: - assert data == b'foobarbaz' diff --git a/tests/test_py35/test_resp.py b/tests/test_py35/test_resp.py deleted file mode 100644 index f98fc2f0d7f..00000000000 --- a/tests/test_py35/test_resp.py +++ /dev/null @@ -1,133 +0,0 @@ -import asyncio -from collections.abc import Coroutine - -import pytest - -import aiohttp -from aiohttp import web -from aiohttp.client import _RequestContextManager - - -async def test_await(test_server, loop): - - async def handler(request): - resp = web.StreamResponse(headers={'content-length': str(4)}) - await resp.prepare(request) - await resp.drain() - await asyncio.sleep(0.01, loop=loop) - resp.write(b'test') - await asyncio.sleep(0.01, loop=loop) - await resp.write_eof() - return resp - - app = web.Application() - app.router.add_route('GET', '/', handler) - server = await test_server(app) - - with aiohttp.ClientSession(loop=loop) as session: - resp = await session.get(server.make_url('/')) - assert resp.status == 200 - assert resp.connection is not None - await resp.read() - await resp.release() - assert resp.connection is None - - -async def test_response_context_manager(test_server, loop): - - async def handler(request): - return web.HTTPOk() - - app = web.Application() - app.router.add_route('GET', '/', handler) - server = await test_server(app) - resp = await aiohttp.ClientSession(loop=loop).get(server.make_url('/')) - async with resp: - assert resp.status == 200 - assert resp.connection is None - assert resp.connection is None - - -async def test_response_context_manager_error(test_server, loop): - - async def handler(request): - return web.HTTPOk() - - app = web.Application() - app.router.add_route('GET', '/', handler) - server = await test_server(app) - session = aiohttp.ClientSession(loop=loop) - cm = session.get(server.make_url('/')) - resp = await cm - with pytest.raises(RuntimeError): - async with resp: - assert resp.status == 200 - resp.content.set_exception(RuntimeError()) - await resp.read() - assert resp.closed - - assert len(session._connector._conns) == 1 - - -async def test_client_api_context_manager(test_server, loop): - - async def handler(request): - return web.HTTPOk() - - app = web.Application() - app.router.add_route('GET', '/', handler) - server = await test_server(app) - - async with aiohttp.ClientSession(loop=loop) as session: - async with session.get(server.make_url('/')) as resp: - assert resp.status == 200 - assert resp.connection is None - assert resp.connection is None - - -def test_ctx_manager_is_coroutine(): - assert issubclass(_RequestContextManager, Coroutine) - - -async def test_context_manager_close_on_release(test_server, loop, mocker): - - async def handler(request): - resp = web.StreamResponse() - await resp.prepare(request) - await resp.drain() - await asyncio.sleep(10, loop=loop) - return resp - - app = web.Application() - app.router.add_route('GET', '/', handler) - server = await test_server(app) - - with aiohttp.ClientSession(loop=loop) as session: - resp = await session.get(server.make_url('/')) - proto = resp.connection._protocol - mocker.spy(proto, 'close') - async with resp: - assert resp.status == 200 - assert resp.connection is not None - assert resp.connection is None - assert proto.close.called - - -async def test_iter_any(test_server, loop): - - data = b'0123456789' * 1024 - - async def handler(request): - buf = [] - async for raw in request.content.iter_any(): - buf.append(raw) - assert b''.join(buf) == data - return web.Response() - - app = web.Application() - app.router.add_route('POST', '/', handler) - server = await test_server(app) - - with aiohttp.ClientSession(loop=loop) as session: - async with await session.post(server.make_url('/'), data=data) as resp: - assert resp.status == 200 diff --git a/tests/test_py35/test_streams_35.py b/tests/test_py35/test_streams_35.py deleted file mode 100644 index ef25bcca52d..00000000000 --- a/tests/test_py35/test_streams_35.py +++ /dev/null @@ -1,82 +0,0 @@ -import pytest - -from aiohttp import streams - -DATA = b'line1\nline2\nline3\n' - - -def chunkify(seq, n): - for i in range(0, len(seq), n): - yield seq[i:i+n] - - -def create_stream(loop): - stream = streams.StreamReader(loop=loop) - stream.feed_data(DATA) - stream.feed_eof() - return stream - - -async def test_stream_reader_lines(loop): - line_iter = iter(DATA.splitlines(keepends=True)) - async for line in create_stream(loop): - assert line == next(line_iter, None) - pytest.raises(StopIteration, next, line_iter) - - -async def test_stream_reader_chunks_complete(loop): - """Tests if chunked iteration works if the chunking works out - (i.e. the data is divisible by the chunk size) - """ - chunk_iter = chunkify(DATA, 9) - async for data in create_stream(loop).iter_chunked(9): - assert data == next(chunk_iter, None) - pytest.raises(StopIteration, next, chunk_iter) - - -async def test_stream_reader_chunks_incomplete(loop): - """Tests if chunked iteration works if the last chunk is incomplete""" - chunk_iter = chunkify(DATA, 8) - async for data in create_stream(loop).iter_chunked(8): - assert data == next(chunk_iter, None) - pytest.raises(StopIteration, next, chunk_iter) - - -async def test_data_queue_empty(loop): - """Tests that async looping yields nothing if nothing is there""" - buffer = streams.DataQueue(loop=loop) - buffer.feed_eof() - - async for _ in buffer: # NOQA - assert False - - -async def test_data_queue_items(loop): - """Tests that async looping yields objects identically""" - buffer = streams.DataQueue(loop=loop) - - items = [object(), object()] - buffer.feed_data(items[0], 1) - buffer.feed_data(items[1], 1) - buffer.feed_eof() - - item_iter = iter(items) - async for item in buffer: - assert item is next(item_iter, None) - pytest.raises(StopIteration, next, item_iter) - - -async def test_stream_reader_iter_any(loop): - it = iter([b'line1\nline2\nline3\n']) - async for raw in create_stream(loop).iter_any(): - assert raw == next(it) - pytest.raises(StopIteration, next, it) - - -async def test_stream_reader_iter(loop): - it = iter([b'line1\n', - b'line2\n', - b'line3\n']) - async for raw in create_stream(loop): - assert raw == next(it) - pytest.raises(StopIteration, next, it) diff --git a/tests/test_py35/test_test_utils_35.py b/tests/test_py35/test_test_utils_35.py deleted file mode 100644 index 73cd98248bc..00000000000 --- a/tests/test_py35/test_test_utils_35.py +++ /dev/null @@ -1,35 +0,0 @@ -import pytest - -import aiohttp -from aiohttp import web -from aiohttp.test_utils import TestClient as _TestClient -from aiohttp.test_utils import TestServer as _TestServer - - -@pytest.fixture -def app(): - async def handler(request): - return web.Response(body=b"OK") - - app = web.Application() - app.router.add_route('*', '/', handler) - return app - - -async def test_server_context_manager(app, loop): - async with _TestServer(app, loop=loop) as server: - async with aiohttp.ClientSession(loop=loop) as client: - async with client.head(server.make_url('/')) as resp: - assert resp.status == 200 - - -@pytest.mark.parametrize("method", [ - "head", "get", "post", "options", "post", "put", "patch", "delete" -]) -async def test_client_context_manager_response(method, app, loop): - async with _TestClient(app, loop=loop) as client: - async with getattr(client, method)('/') as resp: - assert resp.status == 200 - if method != 'head': - text = await resp.text() - assert "OK" in text diff --git a/tests/test_py35/test_web_websocket_35.py b/tests/test_py35/test_web_websocket_35.py deleted file mode 100644 index c0d54fa74db..00000000000 --- a/tests/test_py35/test_web_websocket_35.py +++ /dev/null @@ -1,71 +0,0 @@ -import aiohttp -from aiohttp import helpers, web -from aiohttp.http import WSMsgType - - -async def test_server_ws_async_for(loop, test_server): - closed = helpers.create_future(loop) - - async def handler(request): - ws = web.WebSocketResponse() - await ws.prepare(request) - async for msg in ws: - assert msg.type == aiohttp.WSMsgType.TEXT - s = msg.data - await ws.send_str(s + '/answer') - await ws.close() - closed.set_result(1) - return ws - - app = web.Application() - app.router.add_route('GET', '/', handler) - server = await test_server(app) - - async with aiohttp.ClientSession(loop=loop) as sm: - async with sm.ws_connect(server.make_url('/')) as resp: - - items = ['q1', 'q2', 'q3'] - for item in items: - resp.send_str(item) - msg = await resp.receive() - assert msg.type == aiohttp.WSMsgType.TEXT - assert item + '/answer' == msg.data - - await resp.close() - await closed - - -async def test_closed_async_for(loop, test_client): - - closed = helpers.create_future(loop) - - async def handler(request): - ws = web.WebSocketResponse() - await ws.prepare(request) - - messages = [] - async for msg in ws: - messages.append(msg) - if 'stop' == msg.data: - ws.send_str('stopping') - await ws.close() - - assert 1 == len(messages) - assert messages[0].type == WSMsgType.TEXT - assert messages[0].data == 'stop' - - closed.set_result(None) - return ws - - app = web.Application() - app.router.add_get('/', handler) - client = await test_client(app) - - ws = await client.ws_connect('/') - ws.send_str('stop') - msg = await ws.receive() - assert msg.type == WSMsgType.TEXT - assert msg.data == 'stopping' - - await ws.close() - await closed diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 606a4f3a87e..0d2641525ab 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -1,97 +1,84 @@ -pytest_plugins = 'pytester' +import os +import platform +import sys + +import pytest + +pytest_plugins = "pytester" + +CONFTEST = """ +pytest_plugins = 'aiohttp.pytest_plugin' +""" -def test_myplugin(testdir): - testdir.makepyfile("""\ -import asyncio +IS_PYPY = platform.python_implementation() == "PyPy" + + +def test_aiohttp_plugin(testdir) -> None: + testdir.makepyfile( + """\ import pytest from unittest import mock from aiohttp import web -pytest_plugins = 'aiohttp.pytest_plugin' - - -@asyncio.coroutine -def hello(request): +async def hello(request): return web.Response(body=b'Hello, world') -def create_app(loop): +def create_app(loop=None): app = web.Application() app.router.add_route('GET', '/', hello) return app -@asyncio.coroutine -def test_hello(test_client): - client = yield from test_client(create_app) - resp = yield from client.get('/') +async def test_hello(aiohttp_client) -> None: + client = await aiohttp_client(create_app) + resp = await client.get('/') assert resp.status == 200 - text = yield from resp.text() + text = await resp.text() assert 'Hello, world' in text -@asyncio.coroutine -def test_hello_from_app(test_client, loop): +async def test_hello_from_app(aiohttp_client, loop) -> None: app = web.Application() app.router.add_get('/', hello) - client = yield from test_client(app) - resp = yield from client.get('/') + client = await aiohttp_client(app) + resp = await client.get('/') assert resp.status == 200 - text = yield from resp.text() + text = await resp.text() assert 'Hello, world' in text -@asyncio.coroutine -def test_hello_with_loop(test_client, loop): - client = yield from test_client(create_app) - resp = yield from client.get('/') +async def test_hello_with_loop(aiohttp_client, loop) -> None: + client = await aiohttp_client(create_app) + resp = await client.get('/') assert resp.status == 200 - text = yield from resp.text() + text = await resp.text() assert 'Hello, world' in text -@asyncio.coroutine -def test_hello_fails(test_client): - client = yield from test_client(create_app) - resp = yield from client.get('/') - assert resp.status == 200 - text = yield from resp.text() - assert 'Hello, wield' in text - - -@asyncio.coroutine -def test_hello_with_fake_loop(test_client): - with pytest.raises(AssertionError): - fake_loop = mock.Mock() - yield from test_client(web.Application(loop=fake_loop)) - - -@asyncio.coroutine -def test_set_args(test_client, loop): +async def test_set_args(aiohttp_client, loop) -> None: with pytest.raises(AssertionError): app = web.Application() - yield from test_client(app, 1, 2, 3) + await aiohttp_client(app, 1, 2, 3) -@asyncio.coroutine -def test_set_keyword_args(test_client, loop): +async def test_set_keyword_args(aiohttp_client, loop) -> None: app = web.Application() with pytest.raises(TypeError): - yield from test_client(app, param=1) + await aiohttp_client(app, param=1) -@asyncio.coroutine -def test_noop(): +async def test_noop() -> None: pass -@asyncio.coroutine -def previous(request): +async def previous(request): if request.method == 'POST': - request.app['value'] = (yield from request.post())['value'] + with pytest.warns(DeprecationWarning): + request.app['value'] = (await request.post())['value'] return web.Response(body=b'thanks for the data') else: v = request.app.get('value', 'unknown') @@ -99,54 +86,201 @@ def previous(request): def create_stateful_app(loop): - app = web.Application(loop=loop) + app = web.Application() app.router.add_route('*', '/', previous) return app @pytest.fixture -def cli(loop, test_client): - return loop.run_until_complete(test_client(create_stateful_app)) +def cli(loop, aiohttp_client): + return loop.run_until_complete(aiohttp_client(create_stateful_app)) -@asyncio.coroutine -def test_set_value(cli): - resp = yield from cli.post('/', data={'value': 'foo'}) +async def test_set_value(cli) -> None: + resp = await cli.post('/', data={'value': 'foo'}) assert resp.status == 200 - text = yield from resp.text() + text = await resp.text() assert text == 'thanks for the data' assert cli.server.app['value'] == 'foo' -@asyncio.coroutine -def test_get_value(cli): - resp = yield from cli.get('/') +async def test_get_value(cli) -> None: + resp = await cli.get('/') assert resp.status == 200 - text = yield from resp.text() + text = await resp.text() assert text == 'value: unknown' - cli.server.app['value'] = 'bar' - resp = yield from cli.get('/') + with pytest.warns(DeprecationWarning): + cli.server.app['value'] = 'bar' + resp = await cli.get('/') assert resp.status == 200 - text = yield from resp.text() + text = await resp.text() assert text == 'value: bar' -def test_noncoro(): +def test_noncoro() -> None: assert True -@asyncio.coroutine -def test_client_failed_to_create(test_client): +async def test_failed_to_create_client(aiohttp_client) -> None: def make_app(loop): raise RuntimeError() with pytest.raises(RuntimeError): - yield from test_client(make_app) + await aiohttp_client(make_app) + + +async def test_custom_port_aiohttp_client(aiohttp_client, aiohttp_unused_port): + port = aiohttp_unused_port() + client = await aiohttp_client(create_app, server_kwargs={'port': port}) + assert client.port == port + resp = await client.get('/') + assert resp.status == 200 + text = await resp.text() + assert 'Hello, world' in text + + +async def test_custom_port_test_server(aiohttp_server, aiohttp_unused_port): + app = create_app() + port = aiohttp_unused_port() + server = await aiohttp_server(app, port=port) + assert server.port == port + +""" + ) + testdir.makeconftest(CONFTEST) + result = testdir.runpytest("-p", "no:sugar", "--aiohttp-loop=pyloop") + result.assert_outcomes(passed=12) + + +def test_warning_checks(testdir) -> None: + testdir.makepyfile( + """\ + +async def foobar(): + return 123 + +async def test_good() -> None: + v = await foobar() + assert v == 123 + +async def test_bad() -> None: + foobar() +""" + ) + testdir.makeconftest(CONFTEST) + result = testdir.runpytest( + "-p", "no:sugar", "-s", "-W", "default", "--aiohttp-loop=pyloop" + ) + expected_outcomes = ( + {"failed": 0, "passed": 2} + if IS_PYPY and bool(os.environ.get("PYTHONASYNCIODEBUG")) + else {"failed": 1, "passed": 1} + ) + # Under PyPy "coroutine 'foobar' was never awaited" does not happen. + result.assert_outcomes(**expected_outcomes) + + +def test_aiohttp_plugin_async_fixture(testdir, capsys) -> None: + testdir.makepyfile( + """\ +import pytest + +from aiohttp import web + + +async def hello(request): + return web.Response(body=b'Hello, world') + + +def create_app(): + app = web.Application() + app.router.add_route('GET', '/', hello) + return app + + +@pytest.fixture +async def cli(aiohttp_client, loop): + client = await aiohttp_client(create_app()) + return client + + +@pytest.fixture +async def foo(): + return 42 + + +@pytest.fixture +async def bar(request): + # request should be accessible in async fixtures if needed + return request.function + + +async def test_hello(cli, loop) -> None: + resp = await cli.get('/') + assert resp.status == 200 + + +def test_foo(loop, foo) -> None: + assert foo == 42 + + +def test_foo_without_loop(foo) -> None: + # will raise an error because there is no loop + pass + + +def test_bar(loop, bar) -> None: + assert bar is test_bar +""" + ) + testdir.makeconftest(CONFTEST) + result = testdir.runpytest("-p", "no:sugar", "--aiohttp-loop=pyloop") + result.assert_outcomes(passed=3, errors=1) + result.stdout.fnmatch_lines( + "*Asynchronous fixtures must depend on the 'loop' fixture " + "or be used in tests depending from it." + ) + + +@pytest.mark.skipif(sys.version_info < (3, 6), reason="old python") +def test_aiohttp_plugin_async_gen_fixture(testdir) -> None: + testdir.makepyfile( + """\ +import pytest +from unittest import mock + +from aiohttp import web + + +canary = mock.Mock() + + +async def hello(request): + return web.Response(body=b'Hello, world') + + +def create_app(loop): + app = web.Application() + app.router.add_route('GET', '/', hello) + return app + + +@pytest.fixture +async def cli(aiohttp_client): + yield await aiohttp_client(create_app) + canary() + + +async def test_hello(cli) -> None: + resp = await cli.get('/') + assert resp.status == 200 -""") - testdir.runpytest('-p', 'no:sugar') - # i dont know how to fix this - # result = testdir.runpytest('-p', 'no:sugar') - # result.assert_outcomes(passed=11, failed=1) +def test_finalized() -> None: + assert canary.called is True +""" + ) + testdir.makeconftest(CONFTEST) + result = testdir.runpytest("-p", "no:sugar", "--aiohttp-loop=pyloop") + result.assert_outcomes(passed=2) diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 32245098f89..199707e7a42 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -1,7 +1,7 @@ import asyncio import ipaddress import socket -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest @@ -9,7 +9,8 @@ try: import aiodns - gethostbyname = hasattr(aiodns.DNSResolver, 'gethostbyname') + + gethostbyname = hasattr(aiodns.DNSResolver, "gethostbyname") except ImportError: aiodns = None gethostbyname = False @@ -25,184 +26,185 @@ def __init__(self, host): self.host = host -@asyncio.coroutine -def fake_result(addresses): +async def fake_result(addresses): return FakeResult(addresses=tuple(addresses)) -@asyncio.coroutine -def fake_query_result(result): - return [FakeQueryResult(host=h) - for h in result] +async def fake_query_result(result): + return [FakeQueryResult(host=h) for h in result] def fake_addrinfo(hosts): - @asyncio.coroutine - def fake(*args, **kwargs): + async def fake(*args, **kwargs): if not hosts: raise socket.gaierror - return list([(None, None, None, None, [h, 0]) - for h in hosts]) + return list([(None, None, None, None, [h, 0]) for h in hosts]) return fake @pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") -@asyncio.coroutine -def test_async_resolver_positive_lookup(loop): - with patch('aiodns.DNSResolver') as mock: - mock().gethostbyname.return_value = fake_result(['127.0.0.1']) +async def test_async_resolver_positive_lookup(loop) -> None: + with patch("aiodns.DNSResolver") as mock: + mock().gethostbyname.return_value = fake_result(["127.0.0.1"]) resolver = AsyncResolver(loop=loop) - real = yield from resolver.resolve('www.python.org') - ipaddress.ip_address(real[0]['host']) - mock().gethostbyname.assert_called_with('www.python.org', - socket.AF_INET) + real = await resolver.resolve("www.python.org") + ipaddress.ip_address(real[0]["host"]) + mock().gethostbyname.assert_called_with("www.python.org", socket.AF_INET) @pytest.mark.skipif(aiodns is None, reason="aiodns required") -@asyncio.coroutine -def test_async_resolver_query_positive_lookup(loop): - with patch('aiodns.DNSResolver') as mock: +async def test_async_resolver_query_positive_lookup(loop) -> None: + with patch("aiodns.DNSResolver") as mock: del mock().gethostbyname - mock().query.return_value = fake_query_result(['127.0.0.1']) + mock().query.return_value = fake_query_result(["127.0.0.1"]) resolver = AsyncResolver(loop=loop) - real = yield from resolver.resolve('www.python.org') - ipaddress.ip_address(real[0]['host']) - mock().query.assert_called_with('www.python.org', 'A') + real = await resolver.resolve("www.python.org") + ipaddress.ip_address(real[0]["host"]) + mock().query.assert_called_with("www.python.org", "A") @pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") -@asyncio.coroutine -def test_async_resolver_multiple_replies(loop): - with patch('aiodns.DNSResolver') as mock: - ips = ['127.0.0.1', '127.0.0.2', '127.0.0.3', '127.0.0.4'] +async def test_async_resolver_multiple_replies(loop) -> None: + with patch("aiodns.DNSResolver") as mock: + ips = ["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4"] mock().gethostbyname.return_value = fake_result(ips) resolver = AsyncResolver(loop=loop) - real = yield from resolver.resolve('www.google.com') - ips = [ipaddress.ip_address(x['host']) for x in real] + real = await resolver.resolve("www.google.com") + ips = [ipaddress.ip_address(x["host"]) for x in real] assert len(ips) > 3, "Expecting multiple addresses" @pytest.mark.skipif(aiodns is None, reason="aiodns required") -@asyncio.coroutine -def test_async_resolver_query_multiple_replies(loop): - with patch('aiodns.DNSResolver') as mock: +async def test_async_resolver_query_multiple_replies(loop) -> None: + with patch("aiodns.DNSResolver") as mock: del mock().gethostbyname - ips = ['127.0.0.1', '127.0.0.2', '127.0.0.3', '127.0.0.4'] + ips = ["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4"] mock().query.return_value = fake_query_result(ips) resolver = AsyncResolver(loop=loop) - real = yield from resolver.resolve('www.google.com') - ips = [ipaddress.ip_address(x['host']) for x in real] + real = await resolver.resolve("www.google.com") + ips = [ipaddress.ip_address(x["host"]) for x in real] @pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") -@asyncio.coroutine -def test_async_resolver_negative_lookup(loop): - with patch('aiodns.DNSResolver') as mock: +async def test_async_resolver_negative_lookup(loop) -> None: + with patch("aiodns.DNSResolver") as mock: mock().gethostbyname.side_effect = aiodns.error.DNSError() resolver = AsyncResolver(loop=loop) - with pytest.raises(aiodns.error.DNSError): - yield from resolver.resolve('doesnotexist.bla') + with pytest.raises(OSError): + await resolver.resolve("doesnotexist.bla") @pytest.mark.skipif(aiodns is None, reason="aiodns required") -@asyncio.coroutine -def test_async_resolver_query_negative_lookup(loop): - with patch('aiodns.DNSResolver') as mock: +async def test_async_resolver_query_negative_lookup(loop) -> None: + with patch("aiodns.DNSResolver") as mock: del mock().gethostbyname mock().query.side_effect = aiodns.error.DNSError() resolver = AsyncResolver(loop=loop) - with pytest.raises(aiodns.error.DNSError): - yield from resolver.resolve('doesnotexist.bla') + with pytest.raises(OSError): + await resolver.resolve("doesnotexist.bla") + + +@pytest.mark.skipif(aiodns is None, reason="aiodns required") +async def test_async_resolver_no_hosts_in_query(loop) -> None: + with patch("aiodns.DNSResolver") as mock: + del mock().gethostbyname + mock().query.return_value = fake_query_result([]) + resolver = AsyncResolver(loop=loop) + with pytest.raises(OSError): + await resolver.resolve("doesnotexist.bla") + + +@pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") +async def test_async_resolver_no_hosts_in_gethostbyname(loop) -> None: + with patch("aiodns.DNSResolver") as mock: + mock().gethostbyname.return_value = fake_result([]) + resolver = AsyncResolver(loop=loop) + with pytest.raises(OSError): + await resolver.resolve("doesnotexist.bla") -@asyncio.coroutine -def test_threaded_resolver_positive_lookup(loop): +async def test_threaded_resolver_positive_lookup() -> None: + loop = Mock() loop.getaddrinfo = fake_addrinfo(["127.0.0.1"]) resolver = ThreadedResolver(loop=loop) - real = yield from resolver.resolve('www.python.org') - ipaddress.ip_address(real[0]['host']) + real = await resolver.resolve("www.python.org") + assert real[0]["hostname"] == "www.python.org" + ipaddress.ip_address(real[0]["host"]) -@asyncio.coroutine -def test_threaded_resolver_multiple_replies(loop): - ips = ['127.0.0.1', '127.0.0.2', '127.0.0.3', '127.0.0.4'] +async def test_threaded_resolver_multiple_replies() -> None: + loop = Mock() + ips = ["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4"] loop.getaddrinfo = fake_addrinfo(ips) resolver = ThreadedResolver(loop=loop) - real = yield from resolver.resolve('www.google.com') - ips = [ipaddress.ip_address(x['host']) for x in real] + real = await resolver.resolve("www.google.com") + ips = [ipaddress.ip_address(x["host"]) for x in real] assert len(ips) > 3, "Expecting multiple addresses" -@asyncio.coroutine -def test_threaded_negative_lookup(loop): +async def test_threaded_negative_lookup() -> None: + loop = Mock() ips = [] loop.getaddrinfo = fake_addrinfo(ips) resolver = ThreadedResolver(loop=loop) with pytest.raises(socket.gaierror): - yield from resolver.resolve('doesnotexist.bla') + await resolver.resolve("doesnotexist.bla") -@asyncio.coroutine -def test_close_for_threaded_resolver(loop): +async def test_close_for_threaded_resolver(loop) -> None: resolver = ThreadedResolver(loop=loop) - yield from resolver.close() + await resolver.close() @pytest.mark.skipif(aiodns is None, reason="aiodns required") -@asyncio.coroutine -def test_close_for_async_resolver(loop): +async def test_close_for_async_resolver(loop) -> None: resolver = AsyncResolver(loop=loop) - yield from resolver.close() + await resolver.close() -def test_default_loop_for_threaded_resolver(loop): +async def test_default_loop_for_threaded_resolver(loop) -> None: asyncio.set_event_loop(loop) resolver = ThreadedResolver() assert resolver._loop is loop @pytest.mark.skipif(aiodns is None, reason="aiodns required") -def test_default_loop_for_async_resolver(loop): +async def test_default_loop_for_async_resolver(loop) -> None: asyncio.set_event_loop(loop) resolver = AsyncResolver() assert resolver._loop is loop @pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") -@asyncio.coroutine -def test_async_resolver_ipv6_positive_lookup(loop): - with patch('aiodns.DNSResolver') as mock: - mock().gethostbyname.return_value = fake_result(['::1']) +async def test_async_resolver_ipv6_positive_lookup(loop) -> None: + with patch("aiodns.DNSResolver") as mock: + mock().gethostbyname.return_value = fake_result(["::1"]) resolver = AsyncResolver(loop=loop) - real = yield from resolver.resolve('www.python.org', - family=socket.AF_INET6) - ipaddress.ip_address(real[0]['host']) - mock().gethostbyname.assert_called_with('www.python.org', - socket.AF_INET6) + real = await resolver.resolve("www.python.org", family=socket.AF_INET6) + ipaddress.ip_address(real[0]["host"]) + mock().gethostbyname.assert_called_with("www.python.org", socket.AF_INET6) @pytest.mark.skipif(aiodns is None, reason="aiodns required") -@asyncio.coroutine -def test_async_resolver_query_ipv6_positive_lookup(loop): - with patch('aiodns.DNSResolver') as mock: +async def test_async_resolver_query_ipv6_positive_lookup(loop) -> None: + with patch("aiodns.DNSResolver") as mock: del mock().gethostbyname - mock().query.return_value = fake_query_result(['::1']) + mock().query.return_value = fake_query_result(["::1"]) resolver = AsyncResolver(loop=loop) - real = yield from resolver.resolve('www.python.org', - family=socket.AF_INET6) - ipaddress.ip_address(real[0]['host']) - mock().query.assert_called_with('www.python.org', 'AAAA') + real = await resolver.resolve("www.python.org", family=socket.AF_INET6) + ipaddress.ip_address(real[0]["host"]) + mock().query.assert_called_with("www.python.org", "AAAA") -def test_async_resolver_aiodns_not_present(loop, monkeypatch): +async def test_async_resolver_aiodns_not_present(loop, monkeypatch) -> None: monkeypatch.setattr("aiohttp.resolver.aiodns", None) with pytest.raises(RuntimeError): AsyncResolver(loop=loop) -def test_default_resolver(): +def test_default_resolver() -> None: # if gethostbyname: # assert DefaultResolver is AsyncResolver # else: diff --git a/tests/test_route_def.py b/tests/test_route_def.py new file mode 100644 index 00000000000..49c6c4cb68f --- /dev/null +++ b/tests/test_route_def.py @@ -0,0 +1,304 @@ +import pathlib + +import pytest +from yarl import URL + +from aiohttp import web +from aiohttp.web_urldispatcher import UrlDispatcher + + +@pytest.fixture +def router(): + return UrlDispatcher() + + +def test_get(router) -> None: + async def handler(request): + pass + + router.add_routes([web.get("/", handler)]) + assert len(router.routes()) == 2 # GET and HEAD + + route = list(router.routes())[1] + assert route.handler is handler + assert route.method == "GET" + assert str(route.url_for()) == "/" + + route2 = list(router.routes())[0] + assert route2.handler is handler + assert route2.method == "HEAD" + + +def test_head(router) -> None: + async def handler(request): + pass + + router.add_routes([web.head("/", handler)]) + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.handler is handler + assert route.method == "HEAD" + assert str(route.url_for()) == "/" + + +def test_options(router) -> None: + async def handler(request): + pass + + router.add_routes([web.options("/", handler)]) + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.handler is handler + assert route.method == "OPTIONS" + assert str(route.url_for()) == "/" + + +def test_post(router) -> None: + async def handler(request): + pass + + router.add_routes([web.post("/", handler)]) + + route = list(router.routes())[0] + assert route.handler is handler + assert route.method == "POST" + assert str(route.url_for()) == "/" + + +def test_put(router) -> None: + async def handler(request): + pass + + router.add_routes([web.put("/", handler)]) + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.handler is handler + assert route.method == "PUT" + assert str(route.url_for()) == "/" + + +def test_patch(router) -> None: + async def handler(request): + pass + + router.add_routes([web.patch("/", handler)]) + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.handler is handler + assert route.method == "PATCH" + assert str(route.url_for()) == "/" + + +def test_delete(router) -> None: + async def handler(request): + pass + + router.add_routes([web.delete("/", handler)]) + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.handler is handler + assert route.method == "DELETE" + assert str(route.url_for()) == "/" + + +def test_route(router) -> None: + async def handler(request): + pass + + router.add_routes([web.route("OTHER", "/", handler)]) + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.handler is handler + assert route.method == "OTHER" + assert str(route.url_for()) == "/" + + +def test_static(router) -> None: + folder = pathlib.Path(__file__).parent + router.add_routes([web.static("/prefix", folder)]) + assert len(router.resources()) == 1 # 2 routes: for HEAD and GET + + resource = list(router.resources())[0] + info = resource.get_info() + assert info["prefix"] == "/prefix" + assert info["directory"] == folder + url = resource.url_for(filename="aiohttp.png") + assert url == URL("/prefix/aiohttp.png") + + +def test_head_deco(router) -> None: + routes = web.RouteTableDef() + + @routes.head("/path") + async def handler(request): + pass + + router.add_routes(routes) + + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.method == "HEAD" + assert str(route.url_for()) == "/path" + + +def test_get_deco(router) -> None: + routes = web.RouteTableDef() + + @routes.get("/path") + async def handler(request): + pass + + router.add_routes(routes) + + assert len(router.routes()) == 2 + + route1 = list(router.routes())[0] + assert route1.method == "HEAD" + assert str(route1.url_for()) == "/path" + + route2 = list(router.routes())[1] + assert route2.method == "GET" + assert str(route2.url_for()) == "/path" + + +def test_post_deco(router) -> None: + routes = web.RouteTableDef() + + @routes.post("/path") + async def handler(request): + pass + + router.add_routes(routes) + + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.method == "POST" + assert str(route.url_for()) == "/path" + + +def test_put_deco(router) -> None: + routes = web.RouteTableDef() + + @routes.put("/path") + async def handler(request): + pass + + router.add_routes(routes) + + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.method == "PUT" + assert str(route.url_for()) == "/path" + + +def test_patch_deco(router) -> None: + routes = web.RouteTableDef() + + @routes.patch("/path") + async def handler(request): + pass + + router.add_routes(routes) + + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.method == "PATCH" + assert str(route.url_for()) == "/path" + + +def test_delete_deco(router) -> None: + routes = web.RouteTableDef() + + @routes.delete("/path") + async def handler(request): + pass + + router.add_routes(routes) + + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.method == "DELETE" + assert str(route.url_for()) == "/path" + + +def test_route_deco(router) -> None: + routes = web.RouteTableDef() + + @routes.route("OTHER", "/path") + async def handler(request): + pass + + router.add_routes(routes) + + assert len(router.routes()) == 1 + + route = list(router.routes())[0] + assert route.method == "OTHER" + assert str(route.url_for()) == "/path" + + +def test_routedef_sequence_protocol() -> None: + routes = web.RouteTableDef() + + @routes.delete("/path") + async def handler(request): + pass + + assert len(routes) == 1 + + info = routes[0] + assert isinstance(info, web.RouteDef) + assert info in routes + assert list(routes)[0] is info + + +def test_repr_route_def() -> None: + routes = web.RouteTableDef() + + @routes.get("/path") + async def handler(request): + pass + + rd = routes[0] + assert repr(rd) == " 'handler'>" + + +def test_repr_route_def_with_extra_info() -> None: + routes = web.RouteTableDef() + + @routes.get("/path", extra="info") + async def handler(request): + pass + + rd = routes[0] + assert repr(rd) == " 'handler', extra='info'>" + + +def test_repr_static_def() -> None: + routes = web.RouteTableDef() + + routes.static("/prefix", "/path", name="name") + + rd = routes[0] + assert repr(rd) == " /path, name='name'>" + + +def test_repr_route_table_def() -> None: + routes = web.RouteTableDef() + + @routes.get("/path") + async def handler(request): + pass + + assert repr(routes) == "" diff --git a/tests/test_run_app.py b/tests/test_run_app.py index c79a55a542d..d2ba2262ac2 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -1,24 +1,28 @@ +import asyncio import contextlib +import logging import os +import platform +import signal import socket import ssl - -from io import StringIO +import subprocess +import sys from unittest import mock from uuid import uuid4 import pytest from aiohttp import web -from aiohttp.test_utils import loop_context - +from aiohttp.helpers import PY_37 +from aiohttp.test_utils import make_mocked_coro # Test for features of OS' socket support -_has_unix_domain_socks = hasattr(socket, 'AF_UNIX') +_has_unix_domain_socks = hasattr(socket, "AF_UNIX") if _has_unix_domain_socks: _abstract_path_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: - _abstract_path_sock.bind(b"\x00" + uuid4().hex.encode('ascii')) + _abstract_path_sock.bind(b"\x00" + uuid4().hex.encode("ascii")) # type: ignore except FileNotFoundError: _abstract_path_failed = True else: @@ -30,113 +34,211 @@ _abstract_path_failed = True skip_if_no_abstract_paths = pytest.mark.skipif( - _abstract_path_failed, - reason="Linux-style abstract paths are not supported." + _abstract_path_failed, reason="Linux-style abstract paths are not supported." ) skip_if_no_unix_socks = pytest.mark.skipif( - not _has_unix_domain_socks, - reason="Unix domain sockets are not supported" + not _has_unix_domain_socks, reason="Unix domain sockets are not supported" ) del _has_unix_domain_socks, _abstract_path_failed +HAS_IPV6 = socket.has_ipv6 +if HAS_IPV6: + # The socket.has_ipv6 flag may be True if Python was built with IPv6 + # support, but the target system still may not have it. + # So let's ensure that we really have IPv6 support. + try: + socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + except OSError: + HAS_IPV6 = False + + +# tokio event loop does not allow to override attributes +def skip_if_no_dict(loop): + if not hasattr(loop, "__dict__"): + pytest.skip("can not override loop attributes") + + +def skip_if_on_windows(): + if platform.system() == "Windows": + pytest.skip("the test is not valid for Windows") + + +@pytest.fixture +def patched_loop(loop): + skip_if_no_dict(loop) + server = mock.Mock() + server.wait_closed = make_mocked_coro(None) + loop.create_server = make_mocked_coro(server) + unix_server = mock.Mock() + unix_server.wait_closed = make_mocked_coro(None) + loop.create_unix_server = make_mocked_coro(unix_server) + asyncio.set_event_loop(loop) + return loop + + +def stopper(loop): + def raiser(): + raise KeyboardInterrupt -def test_run_app_http(loop, mocker): - mocker.spy(loop, 'create_server') - loop.call_later(0.05, loop.stop) + def f(*args): + loop.call_soon(raiser) + return f + + +def test_run_app_http(patched_loop) -> None: app = web.Application() - mocker.spy(app, 'startup') + startup_handler = make_mocked_coro() + app.on_startup.append(startup_handler) + cleanup_handler = make_mocked_coro() + app.on_cleanup.append(cleanup_handler) + + web.run_app(app, print=stopper(patched_loop)) - web.run_app(app, loop=loop, print=lambda *args: None) + patched_loop.create_server.assert_called_with( + mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None + ) + startup_handler.assert_called_once_with(app) + cleanup_handler.assert_called_once_with(app) + + +def test_run_app_close_loop(patched_loop) -> None: + app = web.Application() + web.run_app(app, print=stopper(patched_loop)) - assert loop.is_closed() - loop.create_server.assert_called_with(mock.ANY, '0.0.0.0', 8080, - ssl=None, backlog=128) - app.startup.assert_called_once_with() + patched_loop.create_server.assert_called_with( + mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None + ) + assert patched_loop.is_closed() mock_unix_server_single = [ - mock.call(mock.ANY, '/tmp/testsock1.sock', ssl=None, backlog=128), + mock.call(mock.ANY, "/tmp/testsock1.sock", ssl=None, backlog=128), ] mock_unix_server_multi = [ - mock.call(mock.ANY, '/tmp/testsock1.sock', ssl=None, backlog=128), - mock.call(mock.ANY, '/tmp/testsock2.sock', ssl=None, backlog=128), + mock.call(mock.ANY, "/tmp/testsock1.sock", ssl=None, backlog=128), + mock.call(mock.ANY, "/tmp/testsock2.sock", ssl=None, backlog=128), ] mock_server_single = [ - mock.call(mock.ANY, '127.0.0.1', 8080, ssl=None, backlog=128), + mock.call( + mock.ANY, + "127.0.0.1", + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, + ), ] mock_server_multi = [ - mock.call(mock.ANY, ('127.0.0.1', '192.168.1.1'), 8080, ssl=None, - backlog=128), + mock.call( + mock.ANY, + "127.0.0.1", + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, + ), + mock.call( + mock.ANY, + "192.168.1.1", + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, + ), ] mock_server_default_8989 = [ - mock.call(mock.ANY, '0.0.0.0', 8989, ssl=None, backlog=128) + mock.call( + mock.ANY, None, 8989, ssl=None, backlog=128, reuse_address=None, reuse_port=None + ) ] -mock_socket = mock.Mock(getsockname=lambda: ('mock-socket', 123)) +mock_socket = mock.Mock(getsockname=lambda: ("mock-socket", 123)) mixed_bindings_tests = ( - ( + ( # type: ignore "Nothing Specified", {}, - [mock.call(mock.ANY, '0.0.0.0', 8080, ssl=None, backlog=128)], - [] - ), - ( - "Port Only", - {'port': 8989}, - mock_server_default_8989, - [] - ), - ( - "Multiple Hosts", - {'host': ('127.0.0.1', '192.168.1.1')}, - mock_server_multi, - [] + [ + mock.call( + mock.ANY, + None, + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, + ) + ], + [], ), + ("Port Only", {"port": 8989}, mock_server_default_8989, []), + ("Multiple Hosts", {"host": ("127.0.0.1", "192.168.1.1")}, mock_server_multi, []), ( "Multiple Paths", - {'path': ('/tmp/testsock1.sock', '/tmp/testsock2.sock')}, + {"path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock")}, [], - mock_unix_server_multi + mock_unix_server_multi, ), ( "Multiple Paths, Port", - {'path': ('/tmp/testsock1.sock', '/tmp/testsock2.sock'), - 'port': 8989}, + {"path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock"), "port": 8989}, mock_server_default_8989, mock_unix_server_multi, ), ( "Multiple Paths, Single Host", - {'path': ('/tmp/testsock1.sock', '/tmp/testsock2.sock'), - 'host': '127.0.0.1'}, + {"path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock"), "host": "127.0.0.1"}, mock_server_single, - mock_unix_server_multi + mock_unix_server_multi, ), ( "Single Path, Single Host", - {'path': '/tmp/testsock1.sock', 'host': '127.0.0.1'}, + {"path": "/tmp/testsock1.sock", "host": "127.0.0.1"}, mock_server_single, - mock_unix_server_single + mock_unix_server_single, ), ( "Single Path, Multiple Hosts", - {'path': '/tmp/testsock1.sock', 'host': ('127.0.0.1', '192.168.1.1')}, + {"path": "/tmp/testsock1.sock", "host": ("127.0.0.1", "192.168.1.1")}, mock_server_multi, - mock_unix_server_single + mock_unix_server_single, ), ( "Single Path, Port", - {'path': '/tmp/testsock1.sock', 'port': 8989}, + {"path": "/tmp/testsock1.sock", "port": 8989}, mock_server_default_8989, - mock_unix_server_single + mock_unix_server_single, ), ( "Multiple Paths, Multiple Hosts, Port", - {'path': ('/tmp/testsock1.sock', '/tmp/testsock2.sock'), - 'host': ('127.0.0.1', '192.168.1.1'), 'port': 8000}, - [mock.call(mock.ANY, ('127.0.0.1', '192.168.1.1'), 8000, ssl=None, - backlog=128)], - mock_unix_server_multi + { + "path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock"), + "host": ("127.0.0.1", "192.168.1.1"), + "port": 8000, + }, + [ + mock.call( + mock.ANY, + "127.0.0.1", + 8000, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, + ), + mock.call( + mock.ANY, + "192.168.1.1", + 8000, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, + ), + ], + mock_unix_server_multi, ), ( "Only socket", @@ -147,315 +249,613 @@ def test_run_app_http(loop, mocker): ( "Socket, port", {"sock": [mock_socket], "port": 8765}, - [mock.call(mock.ANY, '0.0.0.0', 8765, ssl=None, backlog=128), - mock.call(mock.ANY, sock=mock_socket, ssl=None, backlog=128)], + [ + mock.call( + mock.ANY, + None, + 8765, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, + ), + mock.call(mock.ANY, sock=mock_socket, ssl=None, backlog=128), + ], [], ), ( "Socket, Host, No port", - {"sock": [mock_socket], "host": 'localhost'}, - [mock.call(mock.ANY, 'localhost', 8080, ssl=None, backlog=128), - mock.call(mock.ANY, sock=mock_socket, ssl=None, backlog=128)], + {"sock": [mock_socket], "host": "localhost"}, + [ + mock.call( + mock.ANY, + "localhost", + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, + ), + mock.call(mock.ANY, sock=mock_socket, ssl=None, backlog=128), + ], [], ), + ( + "reuse_port", + {"reuse_port": True}, + [ + mock.call( + mock.ANY, + None, + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=True, + ) + ], + [], + ), + ( + "reuse_address", + {"reuse_address": False}, + [ + mock.call( + mock.ANY, + None, + 8080, + ssl=None, + backlog=128, + reuse_address=False, + reuse_port=None, + ) + ], + [], + ), + ( + "reuse_port, reuse_address", + {"reuse_address": True, "reuse_port": True}, + [ + mock.call( + mock.ANY, + None, + 8080, + ssl=None, + backlog=128, + reuse_address=True, + reuse_port=True, + ) + ], + [], + ), + ( + "Port, reuse_port", + {"port": 8989, "reuse_port": True}, + [ + mock.call( + mock.ANY, + None, + 8989, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=True, + ) + ], + [], + ), + ( + "Multiple Hosts, reuse_port", + {"host": ("127.0.0.1", "192.168.1.1"), "reuse_port": True}, + [ + mock.call( + mock.ANY, + "127.0.0.1", + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=True, + ), + mock.call( + mock.ANY, + "192.168.1.1", + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=True, + ), + ], + [], + ), + ( + "Multiple Paths, Port, reuse_address", + { + "path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock"), + "port": 8989, + "reuse_address": False, + }, + [ + mock.call( + mock.ANY, + None, + 8989, + ssl=None, + backlog=128, + reuse_address=False, + reuse_port=None, + ) + ], + mock_unix_server_multi, + ), + ( + "Multiple Paths, Single Host, reuse_address, reuse_port", + { + "path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock"), + "host": "127.0.0.1", + "reuse_address": True, + "reuse_port": True, + }, + [ + mock.call( + mock.ANY, + "127.0.0.1", + 8080, + ssl=None, + backlog=128, + reuse_address=True, + reuse_port=True, + ), + ], + mock_unix_server_multi, + ), ) mixed_bindings_test_ids = [test[0] for test in mixed_bindings_tests] mixed_bindings_test_params = [test[1:] for test in mixed_bindings_tests] @pytest.mark.parametrize( - 'run_app_kwargs, expected_server_calls, expected_unix_server_calls', + "run_app_kwargs, expected_server_calls, expected_unix_server_calls", mixed_bindings_test_params, - ids=mixed_bindings_test_ids + ids=mixed_bindings_test_ids, ) -def test_run_app_mixed_bindings(mocker, run_app_kwargs, expected_server_calls, - expected_unix_server_calls): - app = mocker.MagicMock() - loop = mocker.MagicMock() - mocker.patch('asyncio.gather') - - web.run_app(app, loop=loop, print=lambda *args: None, **run_app_kwargs) - - assert loop.create_unix_server.mock_calls == expected_unix_server_calls - assert loop.create_server.mock_calls == expected_server_calls +def test_run_app_mixed_bindings( + run_app_kwargs, expected_server_calls, expected_unix_server_calls, patched_loop +): + app = web.Application() + web.run_app(app, print=stopper(patched_loop), **run_app_kwargs) + assert patched_loop.create_unix_server.mock_calls == expected_unix_server_calls + assert patched_loop.create_server.mock_calls == expected_server_calls -def test_run_app_http_access_format(loop, mocker): - mocker.spy(loop, 'create_server') - loop.call_later(0.05, loop.stop) +def test_run_app_https(patched_loop) -> None: app = web.Application() - mocker.spy(app, 'startup') - web.run_app(app, loop=loop, - print=lambda *args: None, access_log_format='%a') + ssl_context = ssl.create_default_context() + web.run_app(app, ssl_context=ssl_context, print=stopper(patched_loop)) - assert loop.is_closed() - cs = loop.create_server - cs.assert_called_with(mock.ANY, '0.0.0.0', 8080, ssl=None, backlog=128) - assert cs.call_args[0][0]._kwargs['access_log_format'] == '%a' - app.startup.assert_called_once_with() + patched_loop.create_server.assert_called_with( + mock.ANY, + None, + 8443, + ssl=ssl_context, + backlog=128, + reuse_address=None, + reuse_port=None, + ) -def test_run_app_https(loop, mocker): - mocker.spy(loop, 'create_server') - loop.call_later(0.05, loop.stop) +def test_run_app_nondefault_host_port(patched_loop, aiohttp_unused_port) -> None: + port = aiohttp_unused_port() + host = "127.0.0.1" app = web.Application() - mocker.spy(app, 'startup') + web.run_app(app, host=host, port=port, print=stopper(patched_loop)) - ssl_context = ssl.create_default_context() + patched_loop.create_server.assert_called_with( + mock.ANY, host, port, ssl=None, backlog=128, reuse_address=None, reuse_port=None + ) - web.run_app(app, loop=loop, - ssl_context=ssl_context, print=lambda *args: None) - assert loop.is_closed() - loop.create_server.assert_called_with(mock.ANY, '0.0.0.0', 8443, - ssl=ssl_context, backlog=128) - app.startup.assert_called_once_with() +def test_run_app_multiple_hosts(patched_loop) -> None: + hosts = ("127.0.0.1", "127.0.0.2") + app = web.Application() + web.run_app(app, host=hosts, print=stopper(patched_loop)) -def test_run_app_nondefault_host_port(loop, unused_port, mocker): - port = unused_port() - host = 'localhost' + calls = map( + lambda h: mock.call( + mock.ANY, + h, + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, + ), + hosts, + ) + patched_loop.create_server.assert_has_calls(calls) - mocker.spy(loop, 'create_server') - loop.call_later(0.05, loop.stop) +def test_run_app_custom_backlog(patched_loop) -> None: app = web.Application() - mocker.spy(app, 'startup') + web.run_app(app, backlog=10, print=stopper(patched_loop)) - web.run_app(app, loop=loop, - host=host, port=port, print=lambda *args: None) + patched_loop.create_server.assert_called_with( + mock.ANY, None, 8080, ssl=None, backlog=10, reuse_address=None, reuse_port=None + ) - assert loop.is_closed() - loop.create_server.assert_called_with(mock.ANY, host, port, - ssl=None, backlog=128) - app.startup.assert_called_once_with() +def test_run_app_custom_backlog_unix(patched_loop) -> None: + app = web.Application() + web.run_app(app, path="/tmp/tmpsock.sock", backlog=10, print=stopper(patched_loop)) -def test_run_app_custom_backlog(loop, mocker): - mocker.spy(loop, 'create_server') - loop.call_later(0.05, loop.stop) + patched_loop.create_unix_server.assert_called_with( + mock.ANY, "/tmp/tmpsock.sock", ssl=None, backlog=10 + ) + +@skip_if_no_unix_socks +def test_run_app_http_unix_socket(patched_loop, shorttmpdir) -> None: app = web.Application() - mocker.spy(app, 'startup') - web.run_app(app, loop=loop, backlog=10, print=lambda *args: None) + sock_path = str(shorttmpdir / "socket.sock") + printer = mock.Mock(wraps=stopper(patched_loop)) + web.run_app(app, path=sock_path, print=printer) - assert loop.is_closed() - loop.create_server.assert_called_with(mock.ANY, '0.0.0.0', 8080, - ssl=None, backlog=10) - app.startup.assert_called_once_with() + patched_loop.create_unix_server.assert_called_with( + mock.ANY, sock_path, ssl=None, backlog=128 + ) + assert f"http://unix:{sock_path}:" in printer.call_args[0][0] @skip_if_no_unix_socks -def test_run_app_http_unix_socket(loop, mocker, shorttmpdir): - mocker.spy(loop, 'create_unix_server') - loop.call_later(0.05, loop.stop) - +def test_run_app_https_unix_socket(patched_loop, shorttmpdir) -> None: app = web.Application() - mocker.spy(app, 'startup') - sock_path = str(shorttmpdir.join('socket.sock')) - printed = StringIO() - web.run_app(app, loop=loop, path=sock_path, print=printed.write) + sock_path = str(shorttmpdir / "socket.sock") + ssl_context = ssl.create_default_context() + printer = mock.Mock(wraps=stopper(patched_loop)) + web.run_app(app, path=sock_path, ssl_context=ssl_context, print=printer) - assert loop.is_closed() - loop.create_unix_server.assert_called_with(mock.ANY, sock_path, - ssl=None, backlog=128) - app.startup.assert_called_once_with() - assert "http://unix:{}:".format(sock_path) in printed.getvalue() + patched_loop.create_unix_server.assert_called_with( + mock.ANY, sock_path, ssl=ssl_context, backlog=128 + ) + assert f"https://unix:{sock_path}:" in printer.call_args[0][0] @skip_if_no_unix_socks -def test_run_app_https_unix_socket(loop, mocker, shorttmpdir): - mocker.spy(loop, 'create_unix_server') - loop.call_later(0.05, loop.stop) +@skip_if_no_abstract_paths +def test_run_app_abstract_linux_socket(patched_loop) -> None: + sock_path = b"\x00" + uuid4().hex.encode("ascii") + app = web.Application() + web.run_app( + app, path=sock_path.decode("ascii", "ignore"), print=stopper(patched_loop) + ) + + patched_loop.create_unix_server.assert_called_with( + mock.ANY, sock_path.decode("ascii"), ssl=None, backlog=128 + ) + +def test_run_app_preexisting_inet_socket(patched_loop, mocker) -> None: app = web.Application() - mocker.spy(app, 'startup') - sock_path = str(shorttmpdir.join('socket.sock')) - printed = StringIO() - ssl_context = ssl.create_default_context() - web.run_app(app, loop=loop, path=sock_path, ssl_context=ssl_context, - print=printed.write) + sock = socket.socket() + with contextlib.closing(sock): + sock.bind(("0.0.0.0", 0)) + _, port = sock.getsockname() - assert loop.is_closed() - loop.create_unix_server.assert_called_with(mock.ANY, sock_path, - ssl=ssl_context, backlog=128) - app.startup.assert_called_once_with() - assert "https://unix:{}:".format(sock_path) in printed.getvalue() + printer = mock.Mock(wraps=stopper(patched_loop)) + web.run_app(app, sock=sock, print=printer) + patched_loop.create_server.assert_called_with( + mock.ANY, sock=sock, backlog=128, ssl=None + ) + assert f"http://0.0.0.0:{port}" in printer.call_args[0][0] -@skip_if_no_unix_socks -def test_run_app_stale_unix_socket(loop, mocker, shorttmpdir): - """Older asyncio event loop implementations are known to halt server - creation when a socket path from a previous server bind still exists. - """ - loop.call_later(0.05, loop.stop) +@pytest.mark.skipif(not HAS_IPV6, reason="IPv6 is not available") +def test_run_app_preexisting_inet6_socket(patched_loop) -> None: app = web.Application() - sock_path = shorttmpdir.join('socket.sock') - sock_path_string = str(sock_path) + sock = socket.socket(socket.AF_INET6) + with contextlib.closing(sock): + sock.bind(("::", 0)) + port = sock.getsockname()[1] - web.run_app(app, loop=loop, - path=sock_path_string, print=lambda *args: None) - assert loop.is_closed() + printer = mock.Mock(wraps=stopper(patched_loop)) + web.run_app(app, sock=sock, print=printer) - if sock_path.check(): - # New app run using same socket path - with loop_context() as loop: - mocker.spy(loop, 'create_unix_server') - loop.call_later(0.05, loop.stop) + patched_loop.create_server.assert_called_with( + mock.ANY, sock=sock, backlog=128, ssl=None + ) + assert f"http://[::]:{port}" in printer.call_args[0][0] - app = web.Application() - mocker.spy(app, 'startup') - mocker.spy(os, 'remove') - printed = StringIO() +@skip_if_no_unix_socks +def test_run_app_preexisting_unix_socket(patched_loop, mocker) -> None: + app = web.Application() - web.run_app(app, loop=loop, - path=sock_path_string, print=printed.write) - os.remove.assert_called_with(sock_path_string) - loop.create_unix_server.assert_called_with( - mock.ANY, - sock_path_string, - ssl=None, - backlog=128 - ) - app.startup.assert_called_once_with() - assert "http://unix:{}:".format(sock_path) in \ - printed.getvalue() + sock_path = "/tmp/test_preexisting_sock1" + sock = socket.socket(socket.AF_UNIX) + with contextlib.closing(sock): + sock.bind(sock_path) + os.unlink(sock_path) + printer = mock.Mock(wraps=stopper(patched_loop)) + web.run_app(app, sock=sock, print=printer) + + patched_loop.create_server.assert_called_with( + mock.ANY, sock=sock, backlog=128, ssl=None + ) + assert f"http://unix:{sock_path}:" in printer.call_args[0][0] -@skip_if_no_unix_socks -@skip_if_no_abstract_paths -def test_run_app_abstract_linux_socket(loop, mocker): - sock_path = b"\x00" + uuid4().hex.encode('ascii') - loop.call_later(0.05, loop.stop) +def test_run_app_multiple_preexisting_sockets(patched_loop) -> None: app = web.Application() - web.run_app( - app, path=sock_path.decode('ascii', 'ignore'), loop=loop, - print=lambda *args: None) - assert loop.is_closed() - # New app run using same socket path - with loop_context() as loop: - mocker.spy(loop, 'create_unix_server') - loop.call_later(0.05, loop.stop) + sock1 = socket.socket() + sock2 = socket.socket() + with contextlib.closing(sock1), contextlib.closing(sock2): + sock1.bind(("0.0.0.0", 0)) + _, port1 = sock1.getsockname() + sock2.bind(("0.0.0.0", 0)) + _, port2 = sock2.getsockname() - app = web.Application() + printer = mock.Mock(wraps=stopper(patched_loop)) + web.run_app(app, sock=(sock1, sock2), print=printer) - mocker.spy(app, 'startup') - mocker.spy(os, 'remove') - printed = StringIO() + patched_loop.create_server.assert_has_calls( + [ + mock.call(mock.ANY, sock=sock1, backlog=128, ssl=None), + mock.call(mock.ANY, sock=sock2, backlog=128, ssl=None), + ] + ) + assert f"http://0.0.0.0:{port1}" in printer.call_args[0][0] + assert f"http://0.0.0.0:{port2}" in printer.call_args[0][0] - web.run_app(app, path=sock_path, print=printed.write, loop=loop) - # Abstract paths don't exist on the file system, so no attempt should - # be made to remove. - assert mock.call([sock_path]) not in os.remove.mock_calls +_script_test_signal = """ +from aiohttp import web - loop.create_unix_server.assert_called_with( - mock.ANY, - sock_path, - ssl=None, - backlog=128 - ) - app.startup.assert_called_once_with() +app = web.Application() +web.run_app(app, host=()) +""" -@skip_if_no_unix_socks -def test_run_app_existing_file_conflict(loop, mocker, shorttmpdir): +def test_sigint() -> None: + skip_if_on_windows() + + proc = subprocess.Popen( + [sys.executable, "-u", "-c", _script_test_signal], stdout=subprocess.PIPE + ) + for line in proc.stdout: + if line.startswith(b"======== Running on"): + break + proc.send_signal(signal.SIGINT) + assert proc.wait() == 0 + + +def test_sigterm() -> None: + skip_if_on_windows() + + proc = subprocess.Popen( + [sys.executable, "-u", "-c", _script_test_signal], stdout=subprocess.PIPE + ) + for line in proc.stdout: + if line.startswith(b"======== Running on"): + break + proc.terminate() + assert proc.wait() == 0 + + +def test_startup_cleanup_signals_even_on_failure(patched_loop) -> None: + patched_loop.create_server = mock.Mock(side_effect=RuntimeError()) + app = web.Application() - sock_path = shorttmpdir.join('socket.sock') - sock_path.ensure() - sock_path_str = str(sock_path) - mocker.spy(os, 'remove') + startup_handler = make_mocked_coro() + app.on_startup.append(startup_handler) + cleanup_handler = make_mocked_coro() + app.on_cleanup.append(cleanup_handler) + + with pytest.raises(RuntimeError): + web.run_app(app, print=stopper(patched_loop)) - with pytest.raises(OSError): - web.run_app(app, loop=loop, - path=sock_path_str, print=lambda *args: None) + startup_handler.assert_called_once_with(app) + cleanup_handler.assert_called_once_with(app) - # No attempt should be made to remove a non-socket file - assert mock.call([sock_path_str]) not in os.remove.mock_calls +def test_run_app_coro(patched_loop) -> None: + startup_handler = cleanup_handler = None -def test_run_app_preexisting_inet_socket(loop, mocker): - mocker.spy(loop, 'create_server') - loop.call_later(0.05, loop.stop) + async def make_app(): + nonlocal startup_handler, cleanup_handler + app = web.Application() + startup_handler = make_mocked_coro() + app.on_startup.append(startup_handler) + cleanup_handler = make_mocked_coro() + app.on_cleanup.append(cleanup_handler) + return app + + web.run_app(make_app(), print=stopper(patched_loop)) + + patched_loop.create_server.assert_called_with( + mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None + ) + startup_handler.assert_called_once_with(mock.ANY) + cleanup_handler.assert_called_once_with(mock.ANY) + + +def test_run_app_default_logger(monkeypatch, patched_loop): + patched_loop.set_debug(True) + logger = web.access_logger + attrs = { + "hasHandlers.return_value": False, + "level": logging.NOTSET, + "name": "aiohttp.access", + } + mock_logger = mock.create_autospec(logger, name="mock_access_logger") + mock_logger.configure_mock(**attrs) app = web.Application() - mocker.spy(app, 'startup') + web.run_app(app, print=stopper(patched_loop), access_log=mock_logger) + mock_logger.setLevel.assert_any_call(logging.DEBUG) + mock_logger.hasHandlers.assert_called_with() + assert isinstance(mock_logger.addHandler.call_args[0][0], logging.StreamHandler) + + +def test_run_app_default_logger_setup_requires_debug(patched_loop): + patched_loop.set_debug(False) + logger = web.access_logger + attrs = { + "hasHandlers.return_value": False, + "level": logging.NOTSET, + "name": "aiohttp.access", + } + mock_logger = mock.create_autospec(logger, name="mock_access_logger") + mock_logger.configure_mock(**attrs) - sock = socket.socket() - with contextlib.closing(sock): - sock.bind(('0.0.0.0', 0)) - _, port = sock.getsockname() + app = web.Application() + web.run_app(app, print=stopper(patched_loop), access_log=mock_logger) + mock_logger.setLevel.assert_not_called() + mock_logger.hasHandlers.assert_not_called() + mock_logger.addHandler.assert_not_called() + + +def test_run_app_default_logger_setup_requires_default_logger(patched_loop): + patched_loop.set_debug(True) + logger = web.access_logger + attrs = { + "hasHandlers.return_value": False, + "level": logging.NOTSET, + "name": None, + } + mock_logger = mock.create_autospec(logger, name="mock_access_logger") + mock_logger.configure_mock(**attrs) - printed = StringIO() - web.run_app(app, loop=loop, sock=sock, print=printed.write) + app = web.Application() + web.run_app(app, print=stopper(patched_loop), access_log=mock_logger) + mock_logger.setLevel.assert_not_called() + mock_logger.hasHandlers.assert_not_called() + mock_logger.addHandler.assert_not_called() + + +def test_run_app_default_logger_setup_only_if_unconfigured(patched_loop): + patched_loop.set_debug(True) + logger = web.access_logger + attrs = { + "hasHandlers.return_value": True, + "level": None, + "name": "aiohttp.access", + } + mock_logger = mock.create_autospec(logger, name="mock_access_logger") + mock_logger.configure_mock(**attrs) - assert loop.is_closed() - loop.create_server.assert_called_with( - mock.ANY, sock=sock, backlog=128, ssl=None - ) - app.startup.assert_called_once_with() - assert "http://0.0.0.0:{}".format(port) in printed.getvalue() + app = web.Application() + web.run_app(app, print=stopper(patched_loop), access_log=mock_logger) + mock_logger.setLevel.assert_not_called() + mock_logger.hasHandlers.assert_called_with() + mock_logger.addHandler.assert_not_called() -@skip_if_no_unix_socks -def test_run_app_preexisting_unix_socket(loop, mocker): - mocker.spy(loop, 'create_server') - loop.call_later(0.05, loop.stop) +def test_run_app_cancels_all_pending_tasks(patched_loop): + app = web.Application() + task = None + + async def on_startup(app): + nonlocal task + loop = asyncio.get_event_loop() + task = loop.create_task(asyncio.sleep(1000)) + app.on_startup.append(on_startup) + + web.run_app(app, print=stopper(patched_loop)) + assert task.cancelled() + + +def test_run_app_cancels_done_tasks(patched_loop): app = web.Application() - mocker.spy(app, 'startup') + task = None - sock_path = '/tmp/test_preexisting_sock1' - sock = socket.socket(socket.AF_UNIX) - with contextlib.closing(sock): - sock.bind(sock_path) - os.unlink(sock_path) + async def coro(): + return 123 - printed = StringIO() - web.run_app(app, loop=loop, sock=sock, print=printed.write) + async def on_startup(app): + nonlocal task + loop = asyncio.get_event_loop() + task = loop.create_task(coro()) - assert loop.is_closed() - loop.create_server.assert_called_with( - mock.ANY, sock=sock, backlog=128, ssl=None - ) - app.startup.assert_called_once_with() - assert "http://unix:{}:".format(sock_path) in printed.getvalue() + app.on_startup.append(on_startup) + web.run_app(app, print=stopper(patched_loop)) + assert task.done() -def test_run_app_multiple_preexisting_sockets(loop, mocker): - mocker.spy(loop, 'create_server') - loop.call_later(0.05, loop.stop) +def test_run_app_cancels_failed_tasks(patched_loop): app = web.Application() - mocker.spy(app, 'startup') + task = None + + exc = RuntimeError("FAIL") + + async def fail(): + try: + await asyncio.sleep(1000) + except asyncio.CancelledError: + raise exc + + async def on_startup(app): + nonlocal task + loop = asyncio.get_event_loop() + task = loop.create_task(fail()) + await asyncio.sleep(0.01) + + app.on_startup.append(on_startup) + + exc_handler = mock.Mock() + patched_loop.set_exception_handler(exc_handler) + web.run_app(app, print=stopper(patched_loop)) + assert task.done() + + msg = { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": exc, + "task": task, + } + exc_handler.assert_called_with(patched_loop, msg) + + +@pytest.mark.skipif(not PY_37, reason="contextvars support is required") +def test_run_app_context_vars(patched_loop): + from contextvars import ContextVar + + count = 0 + VAR = ContextVar("VAR", default="default") + + async def on_startup(app): + nonlocal count + assert "init" == VAR.get() + VAR.set("on_startup") + count += 1 + + async def on_cleanup(app): + nonlocal count + assert "on_startup" == VAR.get() + count += 1 + + async def init(): + nonlocal count + assert "default" == VAR.get() + VAR.set("init") + app = web.Application() - sock1 = socket.socket() - sock2 = socket.socket() - with contextlib.closing(sock1), contextlib.closing(sock2): - sock1.bind(('0.0.0.0', 0)) - _, port1 = sock1.getsockname() - sock2.bind(('0.0.0.0', 0)) - _, port2 = sock2.getsockname() + app.on_startup.append(on_startup) + app.on_cleanup.append(on_cleanup) + count += 1 + return app - printed = StringIO() - web.run_app(app, loop=loop, sock=(sock1, sock2), print=printed.write) - - assert loop.is_closed() - loop.create_server.assert_has_calls([ - mock.call(mock.ANY, sock=sock1, backlog=128, ssl=None), - mock.call(mock.ANY, sock=sock2, backlog=128, ssl=None) - ]) - app.startup.assert_called_once_with() - assert "http://0.0.0.0:{}".format(port1) in printed.getvalue() - assert "http://0.0.0.0:{}".format(port2) in printed.getvalue() + web.run_app(init(), print=stopper(patched_loop)) + assert count == 3 diff --git a/tests/test_signals.py b/tests/test_signals.py index 12f5cd6c901..971cab5c448 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,11 +1,11 @@ -import asyncio from unittest import mock import pytest from multidict import CIMultiDict +from re_assert import Matches from aiohttp.signals import Signal -from aiohttp.test_utils import make_mocked_request +from aiohttp.test_utils import make_mocked_coro, make_mocked_request from aiohttp.web import Application, Response @@ -14,108 +14,81 @@ def app(): return Application() -@pytest.fixture -def debug_app(): - return Application(debug=True) - - def make_request(app, method, path, headers=CIMultiDict()): return make_mocked_request(method, path, headers, app=app) -@asyncio.coroutine -def test_add_signal_handler_not_a_callable(app): +async def test_add_signal_handler_not_a_callable(app) -> None: callback = True app.on_response_prepare.append(callback) + app.on_response_prepare.freeze() with pytest.raises(TypeError): - yield from app.on_response_prepare(None, None) + await app.on_response_prepare(None, None) -@asyncio.coroutine -def test_function_signal_dispatch(app): +async def test_function_signal_dispatch(app) -> None: signal = Signal(app) - kwargs = {'foo': 1, 'bar': 2} + kwargs = {"foo": 1, "bar": 2} callback_mock = mock.Mock() - @asyncio.coroutine - def callback(**kwargs): + async def callback(**kwargs): callback_mock(**kwargs) signal.append(callback) + signal.freeze() - yield from signal.send(**kwargs) + await signal.send(**kwargs) callback_mock.assert_called_once_with(**kwargs) -@asyncio.coroutine -def test_function_signal_dispatch2(app): +async def test_function_signal_dispatch2(app) -> None: signal = Signal(app) - args = {'a', 'b'} - kwargs = {'foo': 1, 'bar': 2} + args = {"a", "b"} + kwargs = {"foo": 1, "bar": 2} callback_mock = mock.Mock() - @asyncio.coroutine - def callback(*args, **kwargs): + async def callback(*args, **kwargs): callback_mock(*args, **kwargs) signal.append(callback) + signal.freeze() - yield from signal.send(*args, **kwargs) + await signal.send(*args, **kwargs) callback_mock.assert_called_once_with(*args, **kwargs) -@asyncio.coroutine -def test_response_prepare(app): +async def test_response_prepare(app) -> None: callback = mock.Mock() - @asyncio.coroutine - def cb(*args, **kwargs): + async def cb(*args, **kwargs): callback(*args, **kwargs) app.on_response_prepare.append(cb) + app.on_response_prepare.freeze() - request = make_request(app, 'GET', '/') - response = Response(body=b'') - yield from response.prepare(request) + request = make_request(app, "GET", "/") + response = Response(body=b"") + await response.prepare(request) callback.assert_called_once_with(request, response) -@asyncio.coroutine -def test_non_coroutine(app): +async def test_non_coroutine(app) -> None: signal = Signal(app) - kwargs = {'foo': 1, 'bar': 2} + kwargs = {"foo": 1, "bar": 2} callback = mock.Mock() signal.append(callback) + signal.freeze() - yield from signal.send(**kwargs) - callback.assert_called_once_with(**kwargs) - - -@asyncio.coroutine -def test_debug_signal(debug_app): - assert debug_app.debug, "Should be True" - signal = Signal(debug_app) - - callback = mock.Mock() - pre = mock.Mock() - post = mock.Mock() - - signal.append(callback) - debug_app.on_pre_signal.append(pre) - debug_app.on_post_signal.append(post) - - yield from signal.send(1, a=2) - callback.assert_called_once_with(1, a=2) - pre.assert_called_once_with(1, 'aiohttp.signals:Signal', 1, a=2) - post.assert_called_once_with(1, 'aiohttp.signals:Signal', 1, a=2) + with pytest.raises(TypeError): + await signal.send(**kwargs) -def test_setitem(app): +def test_setitem(app) -> None: signal = Signal(app) m1 = mock.Mock() signal.append(m1) @@ -125,7 +98,7 @@ def test_setitem(app): assert signal[0] is m2 -def test_delitem(app): +def test_delitem(app) -> None: signal = Signal(app) m1 = mock.Mock() signal.append(m1) @@ -134,7 +107,7 @@ def test_delitem(app): assert len(signal) == 0 -def test_cannot_append_to_frozen_signal(app): +def test_cannot_append_to_frozen_signal(app) -> None: signal = Signal(app) m1 = mock.Mock() m2 = mock.Mock() @@ -146,7 +119,7 @@ def test_cannot_append_to_frozen_signal(app): assert list(signal) == [m1] -def test_cannot_setitem_in_frozen_signal(app): +def test_cannot_setitem_in_frozen_signal(app) -> None: signal = Signal(app) m1 = mock.Mock() m2 = mock.Mock() @@ -158,7 +131,7 @@ def test_cannot_setitem_in_frozen_signal(app): assert list(signal) == [m1] -def test_cannot_delitem_in_frozen_signal(app): +def test_cannot_delitem_in_frozen_signal(app) -> None: signal = Signal(app) m1 = mock.Mock() signal.append(m1) @@ -167,3 +140,28 @@ def test_cannot_delitem_in_frozen_signal(app): del signal[0] assert list(signal) == [m1] + + +async def test_cannot_send_non_frozen_signal(app) -> None: + signal = Signal(app) + + callback = make_mocked_coro() + + signal.append(callback) + + with pytest.raises(RuntimeError): + await signal.send() + + assert not callback.called + + +async def test_repr(app) -> None: + signal = Signal(app) + + callback = make_mocked_coro() + + signal.append(callback) + + assert Matches( + r", frozen=False, " r"\[\]>" + ) == repr(signal) diff --git a/tests/test_streams.py b/tests/test_streams.py index a928e0d29f4..d83941bec3e 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1,799 +1,1134 @@ -"""Tests for streams.py""" +# Tests for streams.py +import abc import asyncio -import unittest +import gc +import types +from collections import defaultdict +from itertools import groupby from unittest import mock -from aiohttp import helpers, streams, test_utils +import pytest +from re_assert import Matches +from aiohttp import streams -class TestStreamReader(unittest.TestCase): +DATA = b"line1\nline2\nline3\n" - DATA = b'line1\nline2\nline3\n' - def setUp(self): - self.time_service = None - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) +def chunkify(seq, n): + for i in range(0, len(seq), n): + yield seq[i : i + n] - def tearDown(self): - self.loop.close() + +async def create_stream(): + loop = asyncio.get_event_loop() + protocol = mock.Mock(_reading_paused=False) + stream = streams.StreamReader(protocol, 2 ** 16, loop=loop) + stream.feed_data(DATA) + stream.feed_eof() + return stream + + +@pytest.fixture +def protocol(): + return mock.Mock(_reading_paused=False) + + +MEMLEAK_SKIP_TYPES = ( + *(getattr(types, name) for name in types.__all__ if name.endswith("Type")), + mock.Mock, + abc.ABCMeta, +) + + +def get_memory_usage(obj): + objs = [obj] + # Memory leak may be caused by leaked links to same objects. + # Without link counting, [1,2,3] is indistiguishable from [1,2,3,3,3,3,3,3] + known = defaultdict(int) + known[id(obj)] += 1 + + while objs: + refs = gc.get_referents(*objs) + objs = [] + for obj in refs: + if isinstance(obj, MEMLEAK_SKIP_TYPES): + continue + i = id(obj) + known[i] += 1 + if known[i] == 1: + objs.append(obj) + + # Make list of unhashable objects uniq + objs.sort(key=id) + objs = [next(g) for (i, g) in groupby(objs, id)] + + return sum(known.values()) + + +class TestStreamReader: + + DATA = b"line1\nline2\nline3\n" def _make_one(self, *args, **kwargs): - if 'timeout' in kwargs: - self.time_service = helpers.TimeService(self.loop, interval=0.01) - self.addCleanup(self.time_service.close) - kwargs['timer'] = self.time_service.timeout(kwargs.pop('timeout')) + kwargs.setdefault("limit", 2 ** 16) + return streams.StreamReader(mock.Mock(_reading_paused=False), *args, **kwargs) - return streams.StreamReader(loop=self.loop, *args, **kwargs) + async def test_create_waiter(self) -> None: + loop = asyncio.get_event_loop() + stream = self._make_one(loop=loop) + stream._waiter = loop.create_future + with pytest.raises(RuntimeError): + await stream._wait("test") - def test_create_waiter(self): - stream = self._make_one() - stream._waiter = helpers.create_future(self.loop) - with self.assertRaises(RuntimeError): - self.loop.run_until_complete(stream._wait('test')) + def test_ctor_global_loop(self) -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + stream = streams.StreamReader(mock.Mock(_reading_paused=False), 2 ** 16) - @mock.patch('aiohttp.streams.asyncio') - def test_ctor_global_loop(self, m_asyncio): - stream = streams.StreamReader() - self.assertIs(stream._loop, m_asyncio.get_event_loop.return_value) + assert stream._loop is loop - def test_at_eof(self): + async def test_at_eof(self) -> None: stream = self._make_one() - self.assertFalse(stream.at_eof()) + assert not stream.at_eof() - stream.feed_data(b'some data\n') - self.assertFalse(stream.at_eof()) + stream.feed_data(b"some data\n") + assert not stream.at_eof() - self.loop.run_until_complete(stream.readline()) - self.assertFalse(stream.at_eof()) + await stream.readline() + assert not stream.at_eof() - stream.feed_data(b'some data\n') + stream.feed_data(b"some data\n") stream.feed_eof() - self.loop.run_until_complete(stream.readline()) - self.assertTrue(stream.at_eof()) + await stream.readline() + assert stream.at_eof() - def test_wait_eof(self): + async def test_wait_eof(self) -> None: + loop = asyncio.get_event_loop() stream = self._make_one() - wait_task = asyncio.Task(stream.wait_eof(), loop=self.loop) + wait_task = loop.create_task(stream.wait_eof()) - def cb(): - yield from asyncio.sleep(0.1, loop=self.loop) + async def cb(): + await asyncio.sleep(0.1) stream.feed_eof() - asyncio.Task(cb(), loop=self.loop) - self.loop.run_until_complete(wait_task) - self.assertTrue(stream.is_eof()) - self.assertIsNone(stream._eof_waiter) + loop.create_task(cb()) + await wait_task + assert stream.is_eof() + assert stream._eof_waiter is None - def test_wait_eof_eof(self): + async def test_wait_eof_eof(self) -> None: + loop = asyncio.get_event_loop() stream = self._make_one() stream.feed_eof() - wait_task = asyncio.Task(stream.wait_eof(), loop=self.loop) - self.loop.run_until_complete(wait_task) - self.assertTrue(stream.is_eof()) + wait_task = loop.create_task(stream.wait_eof()) + await wait_task + assert stream.is_eof() - def test_feed_empty_data(self): + async def test_feed_empty_data(self) -> None: stream = self._make_one() - stream.feed_data(b'') + stream.feed_data(b"") stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'', data) + data = await stream.read() + assert b"" == data - def test_feed_nonempty_data(self): + async def test_feed_nonempty_data(self) -> None: stream = self._make_one() stream.feed_data(self.DATA) stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(self.DATA, data) + data = await stream.read() + assert self.DATA == data - def test_read_zero(self): + async def test_read_zero(self) -> None: # Read zero bytes. stream = self._make_one() stream.feed_data(self.DATA) - data = self.loop.run_until_complete(stream.read(0)) - self.assertEqual(b'', data) + data = await stream.read(0) + assert b"" == data stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(self.DATA, data) + data = await stream.read() + assert self.DATA == data - def test_read(self): + async def test_read(self) -> None: + loop = asyncio.get_event_loop() # Read bytes. stream = self._make_one() - read_task = asyncio.Task(stream.read(30), loop=self.loop) + read_task = loop.create_task(stream.read(30)) def cb(): stream.feed_data(self.DATA) - self.loop.call_soon(cb) - data = self.loop.run_until_complete(read_task) - self.assertEqual(self.DATA, data) + loop.call_soon(cb) + + data = await read_task + assert self.DATA == data stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'', data) + data = await stream.read() + assert b"" == data - def test_read_line_breaks(self): + async def test_read_line_breaks(self) -> None: # Read bytes without line breaks. stream = self._make_one() - stream.feed_data(b'line1') - stream.feed_data(b'line2') + stream.feed_data(b"line1") + stream.feed_data(b"line2") - data = self.loop.run_until_complete(stream.read(5)) - self.assertEqual(b'line1', data) + data = await stream.read(5) + assert b"line1" == data - data = self.loop.run_until_complete(stream.read(5)) - self.assertEqual(b'line2', data) + data = await stream.read(5) + assert b"line2" == data - def test_read_all(self): - # Read all avaliable buffered bytes + async def test_read_all(self) -> None: + # Read all available buffered bytes stream = self._make_one() - stream.feed_data(b'line1') - stream.feed_data(b'line2') + stream.feed_data(b"line1") + stream.feed_data(b"line2") stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'line1line2', data) + data = await stream.read() + assert b"line1line2" == data - def test_read_up_to(self): + async def test_read_up_to(self) -> None: # Read available buffered bytes up to requested amount stream = self._make_one() - stream.feed_data(b'line1') - stream.feed_data(b'line2') + stream.feed_data(b"line1") + stream.feed_data(b"line2") - data = self.loop.run_until_complete(stream.read(8)) - self.assertEqual(b'line1lin', data) + data = await stream.read(8) + assert b"line1lin" == data - data = self.loop.run_until_complete(stream.read(8)) - self.assertEqual(b'e2', data) + data = await stream.read(8) + assert b"e2" == data - def test_read_eof(self): + async def test_read_eof(self) -> None: + loop = asyncio.get_event_loop() # Read bytes, stop at eof. stream = self._make_one() - read_task = asyncio.Task(stream.read(1024), loop=self.loop) + read_task = loop.create_task(stream.read(1024)) def cb(): stream.feed_eof() - self.loop.call_soon(cb) - data = self.loop.run_until_complete(read_task) - self.assertEqual(b'', data) + loop.call_soon(cb) - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(data, b'') + data = await read_task + assert b"" == data - @mock.patch('aiohttp.streams.internal_logger') - def test_read_eof_infinit(self, internal_logger): + data = await stream.read() + assert data == b"" + + async def test_read_eof_infinite(self) -> None: # Read bytes. stream = self._make_one() stream.feed_eof() - self.loop.run_until_complete(stream.read()) - self.loop.run_until_complete(stream.read()) - self.loop.run_until_complete(stream.read()) - self.loop.run_until_complete(stream.read()) - self.loop.run_until_complete(stream.read()) - self.loop.run_until_complete(stream.read()) - self.assertTrue(internal_logger.warning.called) + with mock.patch("aiohttp.streams.internal_logger") as internal_logger: + await stream.read() + await stream.read() + await stream.read() + await stream.read() + await stream.read() + await stream.read() + assert internal_logger.warning.called + + async def test_read_eof_unread_data_no_warning(self) -> None: + # Read bytes. + stream = self._make_one() + stream.feed_eof() - def test_read_until_eof(self): + with mock.patch("aiohttp.streams.internal_logger") as internal_logger: + await stream.read() + await stream.read() + await stream.read() + await stream.read() + await stream.read() + with pytest.warns(DeprecationWarning): + stream.unread_data(b"data") + await stream.read() + await stream.read() + assert not internal_logger.warning.called + + async def test_read_until_eof(self) -> None: + loop = asyncio.get_event_loop() # Read all bytes until eof. stream = self._make_one() - read_task = asyncio.Task(stream.read(-1), loop=self.loop) + read_task = loop.create_task(stream.read(-1)) def cb(): - stream.feed_data(b'chunk1\n') - stream.feed_data(b'chunk2') + stream.feed_data(b"chunk1\n") + stream.feed_data(b"chunk2") stream.feed_eof() - self.loop.call_soon(cb) - data = self.loop.run_until_complete(read_task) - self.assertEqual(b'chunk1\nchunk2', data) + loop.call_soon(cb) - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'', data) + data = await read_task + assert b"chunk1\nchunk2" == data - def test_read_exception(self): + data = await stream.read() + assert b"" == data + + async def test_read_exception(self) -> None: stream = self._make_one() - stream.feed_data(b'line\n') + stream.feed_data(b"line\n") - data = self.loop.run_until_complete(stream.read(2)) - self.assertEqual(b'li', data) + data = await stream.read(2) + assert b"li" == data stream.set_exception(ValueError()) - self.assertRaises( - ValueError, self.loop.run_until_complete, stream.read(2)) + with pytest.raises(ValueError): + await stream.read(2) - def test_readline(self): + async def test_readline(self) -> None: + loop = asyncio.get_event_loop() # Read one line. 'readline' will need to wait for the data # to come from 'cb' stream = self._make_one() - stream.feed_data(b'chunk1 ') - read_task = asyncio.Task(stream.readline(), loop=self.loop) + stream.feed_data(b"chunk1 ") + read_task = loop.create_task(stream.readline()) def cb(): - stream.feed_data(b'chunk2 ') - stream.feed_data(b'chunk3 ') - stream.feed_data(b'\n chunk4') - self.loop.call_soon(cb) + stream.feed_data(b"chunk2 ") + stream.feed_data(b"chunk3 ") + stream.feed_data(b"\n chunk4") + + loop.call_soon(cb) - line = self.loop.run_until_complete(read_task) - self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + line = await read_task + assert b"chunk1 chunk2 chunk3 \n" == line stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b' chunk4', data) + data = await stream.read() + assert b" chunk4" == data - def test_readline_limit_with_existing_data(self): + async def test_readline_limit_with_existing_data(self) -> None: # Read one line. The data is in StreamReader's buffer # before the event loop is run. - stream = self._make_one(limit=3) - stream.feed_data(b'li') - stream.feed_data(b'ne1\nline2\n') + stream = self._make_one(limit=2) + stream.feed_data(b"li") + stream.feed_data(b"ne1\nline2\n") - self.assertRaises( - ValueError, self.loop.run_until_complete, stream.readline()) + with pytest.raises(ValueError): + await stream.readline() # The buffer should contain the remaining data after exception stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'line2\n', data) + data = await stream.read() + assert b"line2\n" == data - def test_readline_limit(self): + async def test_readline_limit(self) -> None: + loop = asyncio.get_event_loop() # Read one line. StreamReaders are fed with data after # their 'readline' methods are called. - - stream = self._make_one(limit=7) + stream = self._make_one(limit=4) def cb(): - stream.feed_data(b'chunk1') - stream.feed_data(b'chunk2') - stream.feed_data(b'chunk3\n') + stream.feed_data(b"chunk1") + stream.feed_data(b"chunk2\n") + stream.feed_data(b"chunk3\n") stream.feed_eof() - self.loop.call_soon(cb) - - self.assertRaises( - ValueError, self.loop.run_until_complete, stream.readline()) - - stream = self._make_one(limit=7) - def cb(): - stream.feed_data(b'chunk1') - stream.feed_data(b'chunk2\n') - stream.feed_data(b'chunk3\n') - stream.feed_eof() - self.loop.call_soon(cb) + loop.call_soon(cb) - self.assertRaises( - ValueError, self.loop.run_until_complete, stream.readline()) - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'chunk3\n', data) + with pytest.raises(ValueError): + await stream.readline() + data = await stream.read() + assert b"chunk3\n" == data - def test_readline_nolimit_nowait(self): + async def test_readline_nolimit_nowait(self) -> None: # All needed data for the first 'readline' call will be # in the buffer. stream = self._make_one() stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) - line = self.loop.run_until_complete(stream.readline()) - self.assertEqual(b'line1\n', line) + line = await stream.readline() + assert b"line1\n" == line stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'line2\nline3\n', data) + data = await stream.read() + assert b"line2\nline3\n" == data - def test_readline_eof(self): + async def test_readline_eof(self) -> None: stream = self._make_one() - stream.feed_data(b'some data') + stream.feed_data(b"some data") stream.feed_eof() - line = self.loop.run_until_complete(stream.readline()) - self.assertEqual(b'some data', line) + line = await stream.readline() + assert b"some data" == line - def test_readline_empty_eof(self): + async def test_readline_empty_eof(self) -> None: stream = self._make_one() stream.feed_eof() - line = self.loop.run_until_complete(stream.readline()) - self.assertEqual(b'', line) + line = await stream.readline() + assert b"" == line - def test_readline_read_byte_count(self): + async def test_readline_read_byte_count(self) -> None: stream = self._make_one() stream.feed_data(self.DATA) - self.loop.run_until_complete(stream.readline()) + await stream.readline() - data = self.loop.run_until_complete(stream.read(7)) - self.assertEqual(b'line2\nl', data) + data = await stream.read(7) + assert b"line2\nl" == data stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'ine3\n', data) + data = await stream.read() + assert b"ine3\n" == data - def test_readline_exception(self): + async def test_readline_exception(self) -> None: stream = self._make_one() - stream.feed_data(b'line\n') + stream.feed_data(b"line\n") - data = self.loop.run_until_complete(stream.readline()) - self.assertEqual(b'line\n', data) + data = await stream.readline() + assert b"line\n" == data stream.set_exception(ValueError()) - self.assertRaises( - ValueError, self.loop.run_until_complete, stream.readline()) + with pytest.raises(ValueError): + await stream.readline() - def test_readexactly_zero_or_less(self): + async def test_readexactly_zero_or_less(self) -> None: # Read exact number of bytes (zero or less). stream = self._make_one() stream.feed_data(self.DATA) - data = self.loop.run_until_complete(stream.readexactly(0)) - self.assertEqual(b'', data) + data = await stream.readexactly(0) + assert b"" == data stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(self.DATA, data) + data = await stream.read() + assert self.DATA == data stream = self._make_one() stream.feed_data(self.DATA) - data = self.loop.run_until_complete(stream.readexactly(-1)) - self.assertEqual(b'', data) + data = await stream.readexactly(-1) + assert b"" == data stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(self.DATA, data) + data = await stream.read() + assert self.DATA == data - def test_readexactly(self): + async def test_readexactly(self) -> None: + loop = asyncio.get_event_loop() # Read exact number of bytes. stream = self._make_one() n = 2 * len(self.DATA) - read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) + read_task = loop.create_task(stream.readexactly(n)) def cb(): stream.feed_data(self.DATA) stream.feed_data(self.DATA) stream.feed_data(self.DATA) - self.loop.call_soon(cb) - data = self.loop.run_until_complete(read_task) - self.assertEqual(self.DATA + self.DATA, data) + loop.call_soon(cb) + + data = await read_task + assert self.DATA + self.DATA == data stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(self.DATA, data) + data = await stream.read() + assert self.DATA == data - def test_readexactly_eof(self): + async def test_readexactly_eof(self) -> None: + loop = asyncio.get_event_loop() # Read exact number of bytes (eof). - stream = self._make_one() + stream = self._make_one(loop=loop) n = 2 * len(self.DATA) - read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) + read_task = loop.create_task(stream.readexactly(n)) def cb(): stream.feed_data(self.DATA) stream.feed_eof() - self.loop.call_soon(cb) - with self.assertRaises(asyncio.IncompleteReadError) as cm: - self.loop.run_until_complete(read_task) - self.assertEqual(cm.exception.partial, self.DATA) - self.assertEqual(cm.exception.expected, n) - self.assertEqual(str(cm.exception), - '18 bytes read on a total of 36 expected bytes') - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'', data) + loop.call_soon(cb) + + with pytest.raises(asyncio.IncompleteReadError) as cm: + await read_task + assert cm.value.partial == self.DATA + assert cm.value.expected == n + assert str(cm.value) == "18 bytes read on a total of 36 expected bytes" + data = await stream.read() + assert b"" == data - def test_readexactly_exception(self): + async def test_readexactly_exception(self) -> None: stream = self._make_one() - stream.feed_data(b'line\n') + stream.feed_data(b"line\n") - data = self.loop.run_until_complete(stream.readexactly(2)) - self.assertEqual(b'li', data) + data = await stream.readexactly(2) + assert b"li" == data stream.set_exception(ValueError()) - self.assertRaises( - ValueError, self.loop.run_until_complete, stream.readexactly(2)) + with pytest.raises(ValueError): + await stream.readexactly(2) - def test_unread_data(self): + async def test_unread_data(self) -> None: stream = self._make_one() - stream.feed_data(b'line1') - stream.feed_data(b'line2') - stream.feed_data(b'onemoreline') + stream.feed_data(b"line1") + stream.feed_data(b"line2") + stream.feed_data(b"onemoreline") - data = self.loop.run_until_complete(stream.read(5)) - self.assertEqual(b'line1', data) + data = await stream.read(5) + assert b"line1" == data - stream.unread_data(data) + with pytest.warns(DeprecationWarning): + stream.unread_data(data) - data = self.loop.run_until_complete(stream.read(5)) - self.assertEqual(b'line1', data) + data = await stream.read(5) + assert b"line1" == data - data = self.loop.run_until_complete(stream.read(4)) - self.assertEqual(b'line', data) + data = await stream.read(4) + assert b"line" == data - stream.unread_data(b'line1line') + with pytest.warns(DeprecationWarning): + stream.unread_data(b"line1line") - data = b'' + data = b"" while len(data) < 10: - data += self.loop.run_until_complete(stream.read(10)) - self.assertEqual(b'line1line2', data) + data += await stream.read(10) + assert b"line1line2" == data - data = self.loop.run_until_complete(stream.read(7)) - self.assertEqual(b'onemore', data) + data = await stream.read(7) + assert b"onemore" == data - stream.unread_data(data) + with pytest.warns(DeprecationWarning): + stream.unread_data(data) - data = b'' + data = b"" while len(data) < 11: - data += self.loop.run_until_complete(stream.read(11)) - self.assertEqual(b'onemoreline', data) + data += await stream.read(11) + assert b"onemoreline" == data - stream.unread_data(b'line') - data = self.loop.run_until_complete(stream.read(4)) - self.assertEqual(b'line', data) + with pytest.warns(DeprecationWarning): + stream.unread_data(b"line") + data = await stream.read(4) + assert b"line" == data stream.feed_eof() - stream.unread_data(b'at_eof') - data = self.loop.run_until_complete(stream.read(6)) - self.assertEqual(b'at_eof', data) + with pytest.warns(DeprecationWarning): + stream.unread_data(b"at_eof") + data = await stream.read(6) + assert b"at_eof" == data - def test_exception(self): + async def test_exception(self) -> None: stream = self._make_one() - self.assertIsNone(stream.exception()) + assert stream.exception() is None exc = ValueError() stream.set_exception(exc) - self.assertIs(stream.exception(), exc) + assert stream.exception() is exc - def test_exception_waiter(self): + async def test_exception_waiter(self) -> None: + loop = asyncio.get_event_loop() stream = self._make_one() - @asyncio.coroutine - def set_err(): + async def set_err(): stream.set_exception(ValueError()) - t1 = asyncio.Task(stream.readline(), loop=self.loop) - t2 = asyncio.Task(set_err(), loop=self.loop) + t1 = loop.create_task(stream.readline()) + t2 = loop.create_task(set_err()) - self.loop.run_until_complete(asyncio.wait([t1, t2], loop=self.loop)) - self.assertRaises(ValueError, t1.result) + await asyncio.wait([t1, t2]) + with pytest.raises(ValueError): + t1.result() - def test_exception_cancel(self): + async def test_exception_cancel(self) -> None: + loop = asyncio.get_event_loop() stream = self._make_one() - @asyncio.coroutine - def read_a_line(): - yield from stream.readline() + async def read_a_line(): + await stream.readline() - t = asyncio.Task(read_a_line(), loop=self.loop) - test_utils.run_briefly(self.loop) + t = loop.create_task(read_a_line()) + await asyncio.sleep(0) t.cancel() - test_utils.run_briefly(self.loop) + await asyncio.sleep(0) # The following line fails if set_exception() isn't careful. - stream.set_exception(RuntimeError('message')) - test_utils.run_briefly(self.loop) - self.assertIs(stream._waiter, None) + stream.set_exception(RuntimeError("message")) + await asyncio.sleep(0) + assert stream._waiter is None - def test_readany_eof(self): + async def test_readany_eof(self) -> None: + loop = asyncio.get_event_loop() stream = self._make_one() - read_task = asyncio.Task(stream.readany(), loop=self.loop) - self.loop.call_soon(stream.feed_data, b'chunk1\n') + read_task = loop.create_task(stream.readany()) + loop.call_soon(stream.feed_data, b"chunk1\n") - data = self.loop.run_until_complete(read_task) - self.assertEqual(b'chunk1\n', data) + data = await read_task + assert b"chunk1\n" == data stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'', data) + data = await stream.read() + assert b"" == data - def test_readany_empty_eof(self): + async def test_readany_empty_eof(self) -> None: + loop = asyncio.get_event_loop() stream = self._make_one() stream.feed_eof() - read_task = asyncio.Task(stream.readany(), loop=self.loop) + read_task = loop.create_task(stream.readany()) - data = self.loop.run_until_complete(read_task) + data = await read_task - self.assertEqual(b'', data) + assert b"" == data - def test_readany_exception(self): + async def test_readany_exception(self) -> None: stream = self._make_one() - stream.feed_data(b'line\n') + stream.feed_data(b"line\n") - data = self.loop.run_until_complete(stream.readany()) - self.assertEqual(b'line\n', data) + data = await stream.readany() + assert b"line\n" == data stream.set_exception(ValueError()) - self.assertRaises( - ValueError, self.loop.run_until_complete, stream.readany()) + with pytest.raises(ValueError): + await stream.readany() - def test_read_nowait(self): + async def test_read_nowait(self) -> None: stream = self._make_one() - stream.feed_data(b'line1\nline2\n') + stream.feed_data(b"line1\nline2\n") - self.assertEqual(stream.read_nowait(), b'line1\nline2\n') - self.assertEqual(stream.read_nowait(), b'') + assert stream.read_nowait() == b"line1\nline2\n" + assert stream.read_nowait() == b"" stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'', data) + data = await stream.read() + assert b"" == data - def test_read_nowait_n(self): + async def test_read_nowait_n(self) -> None: stream = self._make_one() - stream.feed_data(b'line1\nline2\n') + stream.feed_data(b"line1\nline2\n") - self.assertEqual( - stream.read_nowait(4), b'line') - self.assertEqual( - stream.read_nowait(), b'1\nline2\n') - self.assertEqual(stream.read_nowait(), b'') + assert stream.read_nowait(4) == b"line" + assert stream.read_nowait() == b"1\nline2\n" + assert stream.read_nowait() == b"" stream.feed_eof() - data = self.loop.run_until_complete(stream.read()) - self.assertEqual(b'', data) + data = await stream.read() + assert b"" == data - def test_read_nowait_exception(self): + async def test_read_nowait_exception(self) -> None: stream = self._make_one() - stream.feed_data(b'line\n') + stream.feed_data(b"line\n") stream.set_exception(ValueError()) - self.assertRaises(ValueError, stream.read_nowait) + with pytest.raises(ValueError): + stream.read_nowait() - def test_read_nowait_waiter(self): + async def test_read_nowait_waiter(self) -> None: + loop = asyncio.get_event_loop() stream = self._make_one() - stream.feed_data(b'line\n') - stream._waiter = helpers.create_future(self.loop) + stream.feed_data(b"line\n") + stream._waiter = loop.create_future() - self.assertRaises(RuntimeError, stream.read_nowait) + with pytest.raises(RuntimeError): + stream.read_nowait() - def test___repr__(self): + async def test_readchunk(self) -> None: + loop = asyncio.get_event_loop() stream = self._make_one() - self.assertEqual("", repr(stream)) - def test___repr__nondefault_limit(self): - stream = self._make_one(limit=123) - self.assertEqual("", repr(stream)) + def cb(): + stream.feed_data(b"chunk1") + stream.feed_data(b"chunk2") + stream.feed_eof() + + loop.call_soon(cb) + + data, end_of_chunk = await stream.readchunk() + assert b"chunk1" == data + assert not end_of_chunk - def test___repr__eof(self): + data, end_of_chunk = await stream.readchunk() + assert b"chunk2" == data + assert not end_of_chunk + + data, end_of_chunk = await stream.readchunk() + assert b"" == data + assert not end_of_chunk + + async def test_readchunk_wait_eof(self) -> None: + loop = asyncio.get_event_loop() stream = self._make_one() - stream.feed_eof() - self.assertEqual("", repr(stream)) - def test___repr__data(self): + async def cb(): + await asyncio.sleep(0.1) + stream.feed_eof() + + loop.create_task(cb()) + data, end_of_chunk = await stream.readchunk() + assert b"" == data + assert not end_of_chunk + assert stream.is_eof() + + async def test_begin_and_end_chunk_receiving(self) -> None: stream = self._make_one() - stream.feed_data(b'data') - self.assertEqual("", repr(stream)) - def test___repr__exception(self): + stream.begin_http_chunk_receiving() + stream.feed_data(b"part1") + stream.feed_data(b"part2") + stream.end_http_chunk_receiving() + + data, end_of_chunk = await stream.readchunk() + assert b"part1part2" == data + assert end_of_chunk + + stream.begin_http_chunk_receiving() + stream.feed_data(b"part3") + + data, end_of_chunk = await stream.readchunk() + assert b"part3" == data + assert not end_of_chunk + + stream.end_http_chunk_receiving() + + data, end_of_chunk = await stream.readchunk() + assert b"" == data + assert end_of_chunk + + stream.feed_eof() + + data, end_of_chunk = await stream.readchunk() + assert b"" == data + assert not end_of_chunk + + async def test_readany_chunk_end_race(self) -> None: stream = self._make_one() - exc = RuntimeError() - stream.set_exception(exc) - self.assertEqual("", repr(stream)) + stream.begin_http_chunk_receiving() + stream.feed_data(b"part1") + + data = await stream.readany() + assert data == b"part1" + + loop = asyncio.get_event_loop() + task = loop.create_task(stream.readany()) + + # Give a chance for task to create waiter and start waiting for it. + await asyncio.sleep(0.1) + assert stream._waiter is not None + assert not task.done() # Just for sure. + + # This will trigger waiter, but without feeding any data. + # The stream should re-create waiter again. + stream.end_http_chunk_receiving() - def test___repr__waiter(self): + # Give a chance for task to resolve. + # If everything is OK, previous action SHOULD NOT resolve the task. + await asyncio.sleep(0.1) + assert not task.done() # The actual test. + + stream.begin_http_chunk_receiving() + # This SHOULD unblock the task actually. + stream.feed_data(b"part2") + stream.end_http_chunk_receiving() + + data = await task + assert data == b"part2" + + async def test_end_chunk_receiving_without_begin(self) -> None: stream = self._make_one() - stream._waiter = helpers.create_future(self.loop) - self.assertRegex( - repr(stream), - ">") - stream._waiter.set_result(None) - self.loop.run_until_complete(stream._waiter) - stream._waiter = None - self.assertEqual("", repr(stream)) + with pytest.raises(RuntimeError): + stream.end_http_chunk_receiving() - def test_unread_empty(self): + async def test_readchunk_with_unread(self) -> None: + # Test that stream.unread does not break controlled chunk receiving. stream = self._make_one() - stream.feed_data(b'line1') + + # Send 2 chunks + stream.begin_http_chunk_receiving() + stream.feed_data(b"part1") + stream.end_http_chunk_receiving() + stream.begin_http_chunk_receiving() + stream.feed_data(b"part2") + stream.end_http_chunk_receiving() + + # Read only one chunk + data, end_of_chunk = await stream.readchunk() + + # Try to unread a part of the first chunk + with pytest.warns(DeprecationWarning): + stream.unread_data(b"rt1") + + # The end_of_chunk signal was already received for the first chunk, + # so we receive up to the second one + data, end_of_chunk = await stream.readchunk() + assert b"rt1part2" == data + assert end_of_chunk + + # Unread a part of the second chunk + with pytest.warns(DeprecationWarning): + stream.unread_data(b"rt2") + + data, end_of_chunk = await stream.readchunk() + assert b"rt2" == data + # end_of_chunk was already received for this chunk + assert not end_of_chunk + stream.feed_eof() - stream.unread_data(b'') - - data = self.loop.run_until_complete(stream.read(5)) - self.assertEqual(b'line1', data) - self.assertTrue(stream.at_eof()) - - -class TestEmptyStreamReader(unittest.TestCase): - - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_empty_stream_reader(self): - s = streams.EmptyStreamReader() - self.assertIsNone(s.set_exception(ValueError())) - self.assertIsNone(s.exception()) - self.assertIsNone(s.feed_eof()) - self.assertIsNone(s.feed_data(b'data')) - self.assertTrue(s.at_eof()) - self.assertIsNone( - self.loop.run_until_complete(s.wait_eof())) - self.assertEqual( - self.loop.run_until_complete(s.read()), b'') - self.assertEqual( - self.loop.run_until_complete(s.readline()), b'') - self.assertEqual( - self.loop.run_until_complete(s.readany()), b'') - self.assertRaises( - asyncio.IncompleteReadError, - self.loop.run_until_complete, s.readexactly(10)) - self.assertEqual(s.read_nowait(), b'') - - -class DataQueueMixin: - - def test_is_eof(self): - self.assertFalse(self.buffer.is_eof()) - self.buffer.feed_eof() - self.assertTrue(self.buffer.is_eof()) - - def test_at_eof(self): - self.assertFalse(self.buffer.at_eof()) - self.buffer.feed_eof() - self.assertTrue(self.buffer.at_eof()) - self.buffer._buffer.append(object()) - self.assertFalse(self.buffer.at_eof()) - - def test_feed_data(self): - item = object() - self.buffer.feed_data(item, 1) - self.assertEqual([(item, 1)], list(self.buffer._buffer)) + data, end_of_chunk = await stream.readchunk() + assert b"" == data + assert not end_of_chunk - def test_feed_eof(self): - self.buffer.feed_eof() - self.assertTrue(self.buffer._eof) + async def test_readchunk_with_other_read_calls(self) -> None: + # Test that stream.readchunk works when other read calls are made on + # the stream. + stream = self._make_one() - def test_read(self): - item = object() - read_task = asyncio.Task(self.buffer.read(), loop=self.loop) + stream.begin_http_chunk_receiving() + stream.feed_data(b"part1") + stream.end_http_chunk_receiving() + stream.begin_http_chunk_receiving() + stream.feed_data(b"part2") + stream.end_http_chunk_receiving() + stream.begin_http_chunk_receiving() + stream.feed_data(b"part3") + stream.end_http_chunk_receiving() - def cb(): - self.buffer.feed_data(item, 1) - self.loop.call_soon(cb) + data = await stream.read(7) + assert b"part1pa" == data - data = self.loop.run_until_complete(read_task) - self.assertIs(item, data) + data, end_of_chunk = await stream.readchunk() + assert b"rt2" == data + assert end_of_chunk - def test_read_eof(self): - read_task = asyncio.Task(self.buffer.read(), loop=self.loop) + # Corner case between read/readchunk + data = await stream.read(5) + assert b"part3" == data - def cb(): - self.buffer.feed_eof() - self.loop.call_soon(cb) + data, end_of_chunk = await stream.readchunk() + assert b"" == data + assert end_of_chunk - self.assertRaises( - streams.EofStream, self.loop.run_until_complete, read_task) + stream.feed_eof() - def test_read_cancelled(self): - read_task = asyncio.Task(self.buffer.read(), loop=self.loop) - test_utils.run_briefly(self.loop) - waiter = self.buffer._waiter - self.assertIsInstance(waiter, asyncio.Future) + data, end_of_chunk = await stream.readchunk() + assert b"" == data + assert not end_of_chunk - read_task.cancel() - self.assertRaises( - asyncio.CancelledError, - self.loop.run_until_complete, read_task) - self.assertTrue(waiter.cancelled()) - self.assertIsNone(self.buffer._waiter) + async def test_chunksplits_memory_leak(self) -> None: + # Test for memory leak on chunksplits + stream = self._make_one() - self.buffer.feed_data(b'test', 4) - self.assertIsNone(self.buffer._waiter) + N = 500 - def test_read_until_eof(self): - item = object() - self.buffer.feed_data(item, 1) - self.buffer.feed_eof() + # Warm-up variables + stream.begin_http_chunk_receiving() + stream.feed_data(b"Y" * N) + stream.end_http_chunk_receiving() + await stream.read(N) - data = self.loop.run_until_complete(self.buffer.read()) - self.assertIs(data, item) + N = 300 - self.assertRaises( - streams.EofStream, - self.loop.run_until_complete, self.buffer.read()) + before = get_memory_usage(stream) + for _ in range(N): + stream.begin_http_chunk_receiving() + stream.feed_data(b"X") + stream.end_http_chunk_receiving() + await stream.read(N) + after = get_memory_usage(stream) - def test_read_exception(self): - self.buffer.set_exception(ValueError()) + assert abs(after - before) == 0 - self.assertRaises( - ValueError, self.loop.run_until_complete, self.buffer.read()) + async def test_read_empty_chunks(self) -> None: + # Test that feeding empty chunks does not break stream + stream = self._make_one() - def test_read_exception_with_data(self): - val = object() - self.buffer.feed_data(val, 1) - self.buffer.set_exception(ValueError()) + # Simulate empty first chunk. This is significant special case + stream.begin_http_chunk_receiving() + stream.end_http_chunk_receiving() - self.assertIs(val, self.loop.run_until_complete(self.buffer.read())) - self.assertRaises( - ValueError, self.loop.run_until_complete, self.buffer.read()) + stream.begin_http_chunk_receiving() + stream.feed_data(b"ungzipped") + stream.end_http_chunk_receiving() - def test_read_exception_on_wait(self): - read_task = asyncio.Task(self.buffer.read(), loop=self.loop) - test_utils.run_briefly(self.loop) - self.assertIsInstance(self.buffer._waiter, asyncio.Future) + # Possible when compression is enabled. + stream.begin_http_chunk_receiving() + stream.end_http_chunk_receiving() - self.buffer.feed_eof() - self.buffer.set_exception(ValueError()) + # is also possible + stream.begin_http_chunk_receiving() + stream.end_http_chunk_receiving() - self.assertRaises( - ValueError, self.loop.run_until_complete, read_task) + stream.begin_http_chunk_receiving() + stream.feed_data(b" data") + stream.end_http_chunk_receiving() - def test_exception(self): - self.assertIsNone(self.buffer.exception()) + stream.feed_eof() - exc = ValueError() - self.buffer.set_exception(exc) - self.assertIs(self.buffer.exception(), exc) + data = await stream.read() + assert data == b"ungzipped data" - def test_exception_waiter(self): - @asyncio.coroutine - def set_err(): - self.buffer.set_exception(ValueError()) + async def test_readchunk_separate_http_chunk_tail(self) -> None: + # Test that stream.readchunk returns (b'', True) when end of + # http chunk received after body + loop = asyncio.get_event_loop() + stream = self._make_one() - t1 = asyncio.Task(self.buffer.read(), loop=self.loop) - t2 = asyncio.Task(set_err(), loop=self.loop) + stream.begin_http_chunk_receiving() + stream.feed_data(b"part1") + + data, end_of_chunk = await stream.readchunk() + assert b"part1" == data + assert not end_of_chunk + + async def cb(): + await asyncio.sleep(0.1) + stream.end_http_chunk_receiving() + + loop.create_task(cb()) + data, end_of_chunk = await stream.readchunk() + assert b"" == data + assert end_of_chunk + + stream.begin_http_chunk_receiving() + stream.feed_data(b"part2") + data, end_of_chunk = await stream.readchunk() + assert b"part2" == data + assert not end_of_chunk + + stream.end_http_chunk_receiving() + stream.begin_http_chunk_receiving() + stream.feed_data(b"part3") + stream.end_http_chunk_receiving() + + data, end_of_chunk = await stream.readchunk() + assert b"" == data + assert end_of_chunk + + data, end_of_chunk = await stream.readchunk() + assert b"part3" == data + assert end_of_chunk + + stream.begin_http_chunk_receiving() + stream.feed_data(b"part4") + data, end_of_chunk = await stream.readchunk() + assert b"part4" == data + assert not end_of_chunk + + async def cb(): + await asyncio.sleep(0.1) + stream.end_http_chunk_receiving() + stream.feed_eof() - self.loop.run_until_complete(asyncio.wait([t1, t2], loop=self.loop)) + loop.create_task(cb()) + data, end_of_chunk = await stream.readchunk() + assert b"" == data + assert end_of_chunk - self.assertRaises(ValueError, t1.result) + data, end_of_chunk = await stream.readchunk() + assert b"" == data + assert not end_of_chunk + async def test___repr__(self) -> None: + stream = self._make_one() + assert "" == repr(stream) -class TestDataQueue(unittest.TestCase, DataQueueMixin): + async def test___repr__nondefault_limit(self) -> None: + stream = self._make_one(limit=123) + assert "" == repr(stream) - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - self.buffer = streams.DataQueue(loop=self.loop) + async def test___repr__eof(self) -> None: + stream = self._make_one() + stream.feed_eof() + assert "" == repr(stream) - def tearDown(self): - self.loop.close() + async def test___repr__data(self) -> None: + stream = self._make_one() + stream.feed_data(b"data") + assert "" == repr(stream) + async def test___repr__exception(self) -> None: + loop = asyncio.get_event_loop() + stream = self._make_one(loop=loop) + exc = RuntimeError() + stream.set_exception(exc) + assert "" == repr(stream) -class TestChunksQueue(unittest.TestCase, DataQueueMixin): + async def test___repr__waiter(self) -> None: + loop = asyncio.get_event_loop() + stream = self._make_one() + stream._waiter = loop.create_future() + assert Matches(r">") == repr(stream) + stream._waiter.set_result(None) + await stream._waiter + stream._waiter = None + assert "" == repr(stream) + + async def test_unread_empty(self) -> None: + stream = self._make_one() + stream.feed_data(b"line1") + stream.feed_eof() + with pytest.warns(DeprecationWarning): + stream.unread_data(b"") + + data = await stream.read(5) + assert b"line1" == data + assert stream.at_eof() + + +async def test_empty_stream_reader() -> None: + s = streams.EmptyStreamReader() + assert s.set_exception(ValueError()) is None + assert s.exception() is None + assert s.feed_eof() is None + assert s.feed_data(b"data") is None + assert s.at_eof() + assert (await s.wait_eof()) is None + assert await s.read() == b"" + assert await s.readline() == b"" + assert await s.readany() == b"" + assert await s.readchunk() == (b"", True) + with pytest.raises(asyncio.IncompleteReadError): + await s.readexactly(10) + assert s.read_nowait() == b"" + + +@pytest.fixture +async def buffer(loop): + return streams.DataQueue(loop) + + +class TestDataQueue: + def test_is_eof(self, buffer) -> None: + assert not buffer.is_eof() + buffer.feed_eof() + assert buffer.is_eof() + + def test_at_eof(self, buffer) -> None: + assert not buffer.at_eof() + buffer.feed_eof() + assert buffer.at_eof() + buffer._buffer.append(object()) + assert not buffer.at_eof() + + def test_feed_data(self, buffer) -> None: + item = object() + buffer.feed_data(item, 1) + assert [(item, 1)] == list(buffer._buffer) + + def test_feed_eof(self, buffer) -> None: + buffer.feed_eof() + assert buffer._eof + + async def test_read(self, buffer) -> None: + loop = asyncio.get_event_loop() + item = object() + + def cb(): + buffer.feed_data(item, 1) - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - self.buffer = streams.ChunksQueue(loop=self.loop) + loop.call_soon(cb) - def tearDown(self): - self.loop.close() + data = await buffer.read() + assert item is data - def test_read_eof(self): - read_task = asyncio.Task(self.buffer.read(), loop=self.loop) + async def test_read_eof(self, buffer) -> None: + loop = asyncio.get_event_loop() def cb(): - self.buffer.feed_eof() - self.loop.call_soon(cb) + buffer.feed_eof() - self.loop.run_until_complete(read_task) - self.assertTrue(self.buffer.at_eof()) + loop.call_soon(cb) - def test_read_until_eof(self): + with pytest.raises(streams.EofStream): + await buffer.read() + + async def test_read_cancelled(self, buffer) -> None: + loop = asyncio.get_event_loop() + read_task = loop.create_task(buffer.read()) + await asyncio.sleep(0) + waiter = buffer._waiter + assert asyncio.isfuture(waiter) + + read_task.cancel() + with pytest.raises(asyncio.CancelledError): + await read_task + assert waiter.cancelled() + assert buffer._waiter is None + + buffer.feed_data(b"test", 4) + assert buffer._waiter is None + + async def test_read_until_eof(self, buffer) -> None: item = object() - self.buffer.feed_data(item, 1) - self.buffer.feed_eof() + buffer.feed_data(item, 1) + buffer.feed_eof() + + data = await buffer.read() + assert data is item + + with pytest.raises(streams.EofStream): + await buffer.read() + + async def test_read_exc(self, buffer) -> None: + item = object() + buffer.feed_data(item) + buffer.set_exception(ValueError) + + data = await buffer.read() + assert item is data + + with pytest.raises(ValueError): + await buffer.read() + + async def test_read_exception(self, buffer) -> None: + buffer.set_exception(ValueError()) + + with pytest.raises(ValueError): + await buffer.read() + + async def test_read_exception_with_data(self, buffer) -> None: + val = object() + buffer.feed_data(val, 1) + buffer.set_exception(ValueError()) + + assert val is (await buffer.read()) + with pytest.raises(ValueError): + await buffer.read() + + async def test_read_exception_on_wait(self, buffer) -> None: + loop = asyncio.get_event_loop() + read_task = loop.create_task(buffer.read()) + await asyncio.sleep(0) + assert asyncio.isfuture(buffer._waiter) + + buffer.feed_eof() + buffer.set_exception(ValueError()) + + with pytest.raises(ValueError): + await read_task + + def test_exception(self, buffer) -> None: + assert buffer.exception() is None - data = self.loop.run_until_complete(self.buffer.read()) - self.assertIs(data, item) + exc = ValueError() + buffer.set_exception(exc) + assert buffer.exception() is exc + + async def test_exception_waiter(self, buffer) -> None: + loop = asyncio.get_event_loop() + + async def set_err(): + buffer.set_exception(ValueError()) - thing = self.loop.run_until_complete(self.buffer.read()) - self.assertEqual(thing, b'') - self.assertTrue(self.buffer.at_eof()) + t1 = loop.create_task(buffer.read()) + t2 = loop.create_task(set_err()) - def test_readany(self): - self.assertIs(self.buffer.read.__func__, self.buffer.readany.__func__) + await asyncio.wait([t1, t2]) + with pytest.raises(ValueError): + t1.result() -def test_feed_data_waiters(loop): - reader = streams.StreamReader(loop=loop) - waiter = reader._waiter = helpers.create_future(loop) - eof_waiter = reader._eof_waiter = helpers.create_future(loop) - reader.feed_data(b'1') - assert list(reader._buffer) == [b'1'] +async def test_feed_data_waiters(protocol) -> None: + loop = asyncio.get_event_loop() + reader = streams.StreamReader(protocol, 2 ** 16, loop=loop) + waiter = reader._waiter = loop.create_future() + eof_waiter = reader._eof_waiter = loop.create_future() + + reader.feed_data(b"1") + assert list(reader._buffer) == [b"1"] assert reader._size == 1 assert reader.total_bytes == 1 @@ -803,20 +1138,22 @@ def test_feed_data_waiters(loop): assert reader._eof_waiter is eof_waiter -def test_feed_data_completed_waiters(loop): - reader = streams.StreamReader(loop=loop) - waiter = reader._waiter = helpers.create_future(loop) +async def test_feed_data_completed_waiters(protocol) -> None: + loop = asyncio.get_event_loop() + reader = streams.StreamReader(protocol, 2 ** 16, loop=loop) + waiter = reader._waiter = loop.create_future() waiter.set_result(1) - reader.feed_data(b'1') + reader.feed_data(b"1") assert reader._waiter is None -def test_feed_eof_waiters(loop): - reader = streams.StreamReader(loop=loop) - waiter = reader._waiter = helpers.create_future(loop) - eof_waiter = reader._eof_waiter = helpers.create_future(loop) +async def test_feed_eof_waiters(protocol) -> None: + loop = asyncio.get_event_loop() + reader = streams.StreamReader(protocol, 2 ** 16, loop=loop) + waiter = reader._waiter = loop.create_future() + eof_waiter = reader._eof_waiter = loop.create_future() reader.feed_eof() assert reader._eof @@ -827,10 +1164,11 @@ def test_feed_eof_waiters(loop): assert reader._eof_waiter is None -def test_feed_eof_cancelled(loop): - reader = streams.StreamReader(loop=loop) - waiter = reader._waiter = helpers.create_future(loop) - eof_waiter = reader._eof_waiter = helpers.create_future(loop) +async def test_feed_eof_cancelled(protocol) -> None: + loop = asyncio.get_event_loop() + reader = streams.StreamReader(protocol, 2 ** 16, loop=loop) + waiter = reader._waiter = loop.create_future() + eof_waiter = reader._eof_waiter = loop.create_future() waiter.set_result(1) eof_waiter.set_result(1) @@ -843,8 +1181,9 @@ def test_feed_eof_cancelled(loop): assert reader._eof_waiter is None -def test_on_eof(loop): - reader = streams.StreamReader(loop=loop) +async def test_on_eof(protocol) -> None: + loop = asyncio.get_event_loop() + reader = streams.StreamReader(protocol, 2 ** 16, loop=loop) on_eof = mock.Mock() reader.on_eof(on_eof) @@ -854,8 +1193,18 @@ def test_on_eof(loop): assert on_eof.called -def test_on_eof_exc_in_callback(loop): - reader = streams.StreamReader(loop=loop) +async def test_on_eof_empty_reader() -> None: + reader = streams.EmptyStreamReader() + + on_eof = mock.Mock() + reader.on_eof(on_eof) + + assert on_eof.called + + +async def test_on_eof_exc_in_callback(protocol) -> None: + loop = asyncio.get_event_loop() + reader = streams.StreamReader(protocol, 2 ** 16, loop=loop) on_eof = mock.Mock() on_eof.side_effect = ValueError @@ -867,8 +1216,19 @@ def test_on_eof_exc_in_callback(loop): assert not reader._eof_callbacks -def test_on_eof_eof_is_set(loop): - reader = streams.StreamReader(loop=loop) +async def test_on_eof_exc_in_callback_empty_stream_reader() -> None: + reader = streams.EmptyStreamReader() + + on_eof = mock.Mock() + on_eof.side_effect = ValueError + + reader.on_eof(on_eof) + assert on_eof.called + + +async def test_on_eof_eof_is_set(protocol) -> None: + loop = asyncio.get_event_loop() + reader = streams.StreamReader(protocol, 2 ** 16, loop=loop) reader.feed_eof() on_eof = mock.Mock() @@ -877,8 +1237,9 @@ def test_on_eof_eof_is_set(loop): assert not reader._eof_callbacks -def test_on_eof_eof_is_set_exception(loop): - reader = streams.StreamReader(loop=loop) +async def test_on_eof_eof_is_set_exception(protocol) -> None: + loop = asyncio.get_event_loop() + reader = streams.StreamReader(protocol, 2 ** 16, loop=loop) reader.feed_eof() on_eof = mock.Mock() @@ -889,10 +1250,11 @@ def test_on_eof_eof_is_set_exception(loop): assert not reader._eof_callbacks -def test_set_exception(loop): - reader = streams.StreamReader(loop=loop) - waiter = reader._waiter = helpers.create_future(loop) - eof_waiter = reader._eof_waiter = helpers.create_future(loop) +async def test_set_exception(protocol) -> None: + loop = asyncio.get_event_loop() + reader = streams.StreamReader(protocol, 2 ** 16, loop=loop) + waiter = reader._waiter = loop.create_future() + eof_waiter = reader._eof_waiter = loop.create_future() exc = ValueError() reader.set_exception(exc) @@ -903,10 +1265,11 @@ def test_set_exception(loop): assert reader._eof_waiter is None -def test_set_exception_cancelled(loop): - reader = streams.StreamReader(loop=loop) - waiter = reader._waiter = helpers.create_future(loop) - eof_waiter = reader._eof_waiter = helpers.create_future(loop) +async def test_set_exception_cancelled(protocol) -> None: + loop = asyncio.get_event_loop() + reader = streams.StreamReader(protocol, 2 ** 16, loop=loop) + waiter = reader._waiter = loop.create_future() + eof_waiter = reader._eof_waiter = loop.create_future() waiter.set_result(1) eof_waiter.set_result(1) @@ -920,8 +1283,9 @@ def test_set_exception_cancelled(loop): assert reader._eof_waiter is None -def test_set_exception_eof_callbacks(loop): - reader = streams.StreamReader(loop=loop) +async def test_set_exception_eof_callbacks(protocol) -> None: + loop = asyncio.get_event_loop() + reader = streams.StreamReader(protocol, 2 ** 16, loop=loop) on_eof = mock.Mock() reader.on_eof(on_eof) @@ -929,3 +1293,89 @@ def test_set_exception_eof_callbacks(loop): reader.set_exception(ValueError()) assert not on_eof.called assert not reader._eof_callbacks + + +async def test_stream_reader_lines() -> None: + line_iter = iter(DATA.splitlines(keepends=True)) + async for line in await create_stream(): + assert line == next(line_iter, None) + pytest.raises(StopIteration, next, line_iter) + + +async def test_stream_reader_chunks_complete() -> None: + # Tests if chunked iteration works if the chunking works out + # (i.e. the data is divisible by the chunk size) + chunk_iter = chunkify(DATA, 9) + async for data in (await create_stream()).iter_chunked(9): + assert data == next(chunk_iter, None) + pytest.raises(StopIteration, next, chunk_iter) + + +async def test_stream_reader_chunks_incomplete() -> None: + # Tests if chunked iteration works if the last chunk is incomplete + chunk_iter = chunkify(DATA, 8) + async for data in (await create_stream()).iter_chunked(8): + assert data == next(chunk_iter, None) + pytest.raises(StopIteration, next, chunk_iter) + + +async def test_data_queue_empty() -> None: + # Tests that async looping yields nothing if nothing is there + loop = asyncio.get_event_loop() + buffer = streams.DataQueue(loop) + buffer.feed_eof() + + async for _ in buffer: + assert False + + +async def test_data_queue_items() -> None: + # Tests that async looping yields objects identically + loop = asyncio.get_event_loop() + buffer = streams.DataQueue(loop) + + items = [object(), object()] + buffer.feed_data(items[0], 1) + buffer.feed_data(items[1], 1) + buffer.feed_eof() + + item_iter = iter(items) + async for item in buffer: + assert item is next(item_iter, None) + pytest.raises(StopIteration, next, item_iter) + + +async def test_stream_reader_iter_any() -> None: + it = iter([b"line1\nline2\nline3\n"]) + async for raw in (await create_stream()).iter_any(): + assert raw == next(it) + pytest.raises(StopIteration, next, it) + + +async def test_stream_reader_iter() -> None: + it = iter([b"line1\n", b"line2\n", b"line3\n"]) + async for raw in await create_stream(): + assert raw == next(it) + pytest.raises(StopIteration, next, it) + + +async def test_stream_reader_iter_chunks_no_chunked_encoding() -> None: + it = iter([b"line1\nline2\nline3\n"]) + async for data, end_of_chunk in (await create_stream()).iter_chunks(): + assert (data, end_of_chunk) == (next(it), False) + pytest.raises(StopIteration, next, it) + + +async def test_stream_reader_iter_chunks_chunked_encoding(protocol) -> None: + loop = asyncio.get_event_loop() + stream = streams.StreamReader(protocol, 2 ** 16, loop=loop) + for line in DATA.splitlines(keepends=True): + stream.begin_http_chunk_receiving() + stream.feed_data(line) + stream.end_http_chunk_receiving() + stream.feed_eof() + + it = iter([b"line1\n", b"line2\n", b"line3\n"]) + async for data, end_of_chunk in stream.iter_chunks(): + assert (data, end_of_chunk) == (next(it), True) + pytest.raises(StopIteration, next, it) diff --git a/tests/test_tcp_helpers.py b/tests/test_tcp_helpers.py new file mode 100644 index 00000000000..18fedb93a25 --- /dev/null +++ b/tests/test_tcp_helpers.py @@ -0,0 +1,73 @@ +import socket +from unittest import mock + +import pytest + +from aiohttp.tcp_helpers import tcp_nodelay + +has_ipv6 = socket.has_ipv6 +if has_ipv6: + # The socket.has_ipv6 flag may be True if Python was built with IPv6 + # support, but the target system still may not have it. + # So let's ensure that we really have IPv6 support. + try: + socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + except OSError: + has_ipv6 = False + + +# nodelay + + +def test_tcp_nodelay_exception() -> None: + transport = mock.Mock() + s = mock.Mock() + s.setsockopt = mock.Mock() + s.family = socket.AF_INET + s.setsockopt.side_effect = OSError + transport.get_extra_info.return_value = s + tcp_nodelay(transport, True) + s.setsockopt.assert_called_with(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) + + +def test_tcp_nodelay_enable() -> None: + transport = mock.Mock() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + transport.get_extra_info.return_value = s + tcp_nodelay(transport, True) + assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + + +def test_tcp_nodelay_enable_and_disable() -> None: + transport = mock.Mock() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + transport.get_extra_info.return_value = s + tcp_nodelay(transport, True) + assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + tcp_nodelay(transport, False) + assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + + +@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") +def test_tcp_nodelay_enable_ipv6() -> None: + transport = mock.Mock() + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + transport.get_extra_info.return_value = s + tcp_nodelay(transport, True) + assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + + +@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires unix sockets") +def test_tcp_nodelay_enable_unix() -> None: + # do not set nodelay for unix socket + transport = mock.Mock() + s = mock.Mock(family=socket.AF_UNIX, type=socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + tcp_nodelay(transport, True) + assert not s.setsockopt.called + + +def test_tcp_nodelay_enable_no_socket() -> None: + transport = mock.Mock() + transport.get_extra_info.return_value = None + tcp_nodelay(transport, True) diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 8aef4e89db8..cbaed33bccd 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -1,278 +1,316 @@ -import asyncio +import gzip from unittest import mock import pytest -from multidict import CIMultiDict +from multidict import CIMultiDict, CIMultiDictProxy from yarl import URL import aiohttp from aiohttp import web -from aiohttp.test_utils import TestClient as _TestClient -from aiohttp.test_utils import TestServer as _TestServer -from aiohttp.test_utils import (AioHTTPTestCase, loop_context, - make_mocked_request, setup_test_loop, - teardown_test_loop, unittest_run_loop) +from aiohttp.test_utils import ( + AioHTTPTestCase, + RawTestServer as _RawTestServer, + TestClient as _TestClient, + TestServer as _TestServer, + loop_context, + make_mocked_request, + unittest_run_loop, +) +_hello_world_str = "Hello, world" +_hello_world_bytes = _hello_world_str.encode("utf-8") +_hello_world_gz = gzip.compress(_hello_world_bytes) -def _create_example_app(): - @asyncio.coroutine - def hello(request): - return web.Response(body=b"Hello, world") +def _create_example_app(): + async def hello(request): + return web.Response(body=_hello_world_bytes) - @asyncio.coroutine - def websocket_handler(request): + async def websocket_handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) - msg = yield from ws.receive() + await ws.prepare(request) + msg = await ws.receive() if msg.type == aiohttp.WSMsgType.TEXT: - if msg.data == 'close': - yield from ws.close() + if msg.data == "close": + await ws.close() else: - ws.send_str(msg.data + '/answer') + await ws.send_str(msg.data + "/answer") return ws - @asyncio.coroutine - def cookie_handler(request): - resp = web.Response(body=b"Hello, world") - resp.set_cookie('cookie', 'val') + async def cookie_handler(request): + resp = web.Response(body=_hello_world_bytes) + resp.set_cookie("cookie", "val") return resp app = web.Application() - app.router.add_route('*', '/', hello) - app.router.add_route('*', '/websocket', websocket_handler) - app.router.add_route('*', '/cookie', cookie_handler) + app.router.add_route("*", "/", hello) + app.router.add_route("*", "/websocket", websocket_handler) + app.router.add_route("*", "/cookie", cookie_handler) return app -def test_full_server_scenario(): +# these exist to test the pytest scenario +@pytest.fixture +def loop(): with loop_context() as loop: - app = _create_example_app() - with _TestClient(app, loop=loop) as client: + yield loop - @asyncio.coroutine - def test_get_route(): - nonlocal client - resp = yield from client.request("GET", "/") - assert resp.status == 200 - text = yield from resp.text() - assert "Hello, world" in text - loop.run_until_complete(test_get_route()) +@pytest.fixture +def app(): + return _create_example_app() -def test_server_with_create_test_teardown(): - with loop_context() as loop: - app = _create_example_app() - with _TestClient(app, loop=loop) as client: +@pytest.fixture +def test_client(loop, app) -> None: + async def make_client(): + return _TestClient(_TestServer(app, loop=loop), loop=loop) - @asyncio.coroutine - def test_get_route(): - resp = yield from client.request("GET", "/") - assert resp.status == 200 - text = yield from resp.text() - assert "Hello, world" in text + client = loop.run_until_complete(make_client()) - loop.run_until_complete(test_get_route()) + loop.run_until_complete(client.start_server()) + yield client + loop.run_until_complete(client.close()) -def test_test_client_close_is_idempotent(): - """ - a test client, called multiple times, should - not attempt to close the server again. - """ - loop = setup_test_loop() +def test_with_test_server_fails(loop) -> None: app = _create_example_app() - client = _TestClient(app, loop=loop) - loop.run_until_complete(client.close()) - loop.run_until_complete(client.close()) - teardown_test_loop(loop) + with pytest.raises(TypeError): + with _TestServer(app, loop=loop): + pass -class TestAioHTTPTestCase(AioHTTPTestCase): +async def test_with_client_fails(loop) -> None: + app = _create_example_app() + with pytest.raises(TypeError): + with _TestClient(_TestServer(app, loop=loop), loop=loop): + pass + +async def test_aiohttp_client_close_is_idempotent() -> None: + # a test client, called multiple times, should + # not attempt to close the server again. + app = _create_example_app() + client = _TestClient(_TestServer(app)) + await client.close() + await client.close() + + +class TestAioHTTPTestCase(AioHTTPTestCase): def get_app(self): return _create_example_app() @unittest_run_loop - @asyncio.coroutine - def test_example_with_loop(self): - request = yield from self.client.request("GET", "/") + async def test_example_with_loop(self) -> None: + request = await self.client.request("GET", "/") assert request.status == 200 - text = yield from request.text() - assert "Hello, world" in text + text = await request.text() + assert _hello_world_str == text - def test_example(self): - @asyncio.coroutine - def test_get_route(): - resp = yield from self.client.request("GET", "/") + def test_example(self) -> None: + async def test_get_route() -> None: + resp = await self.client.request("GET", "/") assert resp.status == 200 - text = yield from resp.text() - assert "Hello, world" in text + text = await resp.text() + assert _hello_world_str == text self.loop.run_until_complete(test_get_route()) -# these exist to test the pytest scenario -@pytest.yield_fixture -def loop(): - with loop_context() as loop: - yield loop - - -@pytest.fixture -def app(): - return _create_example_app() - - -@pytest.yield_fixture -def test_client(loop, app): - client = _TestClient(app, loop=loop) - loop.run_until_complete(client.start_server()) - yield client - loop.run_until_complete(client.close()) - - -def test_get_route(loop, test_client): - @asyncio.coroutine - def test_get_route(): - resp = yield from test_client.request("GET", "/") +def test_get_route(loop, test_client) -> None: + async def test_get_route() -> None: + resp = await test_client.request("GET", "/") assert resp.status == 200 - text = yield from resp.text() - assert "Hello, world" in text + text = await resp.text() + assert _hello_world_str == text loop.run_until_complete(test_get_route()) -@asyncio.coroutine -def test_client_websocket(loop, test_client): - resp = yield from test_client.ws_connect("/websocket") - resp.send_str("foo") - msg = yield from resp.receive() +async def test_client_websocket(loop, test_client) -> None: + resp = await test_client.ws_connect("/websocket") + await resp.send_str("foo") + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.TEXT assert "foo" in msg.data - resp.send_str("close") - msg = yield from resp.receive() + await resp.send_str("close") + msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSE -@asyncio.coroutine -def test_client_cookie(loop, test_client): +async def test_client_cookie(loop, test_client) -> None: assert not test_client.session.cookie_jar - yield from test_client.get("/cookie") + await test_client.get("/cookie") cookies = list(test_client.session.cookie_jar) - assert cookies[0].key == 'cookie' - assert cookies[0].value == 'val' + assert cookies[0].key == "cookie" + assert cookies[0].value == "val" -@asyncio.coroutine -@pytest.mark.parametrize("method", [ - "get", "post", "options", "post", "put", "patch", "delete" -]) -@asyncio.coroutine -def test_test_client_methods(method, loop, test_client): - resp = yield from getattr(test_client, method)("/") +@pytest.mark.parametrize( + "method", ["get", "post", "options", "post", "put", "patch", "delete"] +) +async def test_test_client_methods(method, loop, test_client) -> None: + resp = await getattr(test_client, method)("/") assert resp.status == 200 - text = yield from resp.text() - assert "Hello, world" in text + text = await resp.text() + assert _hello_world_str == text -@asyncio.coroutine -def test_test_client_head(loop, test_client): - resp = yield from test_client.head("/") +async def test_test_client_head(loop, test_client) -> None: + resp = await test_client.head("/") assert resp.status == 200 -@pytest.mark.parametrize( - "headers", [{'token': 'x'}, CIMultiDict({'token': 'x'}), {}]) -def test_make_mocked_request(headers): - req = make_mocked_request('GET', '/', headers=headers) +@pytest.mark.parametrize("headers", [{"token": "x"}, CIMultiDict({"token": "x"}), {}]) +def test_make_mocked_request(headers) -> None: + req = make_mocked_request("GET", "/", headers=headers) assert req.method == "GET" assert req.path == "/" assert isinstance(req, web.Request) - assert isinstance(req.headers, CIMultiDict) + assert isinstance(req.headers, CIMultiDictProxy) -def test_make_mocked_request_sslcontext(): - req = make_mocked_request('GET', '/') - assert req.transport.get_extra_info('sslcontext') is None +def test_make_mocked_request_sslcontext() -> None: + req = make_mocked_request("GET", "/") + assert req.transport.get_extra_info("sslcontext") is None -def test_make_mocked_request_unknown_extra_info(): - req = make_mocked_request('GET', '/') - assert req.transport.get_extra_info('unknown_extra_info') is None +def test_make_mocked_request_unknown_extra_info() -> None: + req = make_mocked_request("GET", "/") + assert req.transport.get_extra_info("unknown_extra_info") is None -def test_make_mocked_request_app(): +def test_make_mocked_request_app() -> None: app = mock.Mock() - req = make_mocked_request('GET', '/', app=app) + req = make_mocked_request("GET", "/", app=app) assert req.app is app -def test_make_mocked_request_content(): +def test_make_mocked_request_app_can_store_values() -> None: + req = make_mocked_request("GET", "/") + req.app["a_field"] = "a_value" + assert req.app["a_field"] == "a_value" + + +def test_make_mocked_request_match_info() -> None: + req = make_mocked_request("GET", "/", match_info={"a": "1", "b": "2"}) + assert req.match_info == {"a": "1", "b": "2"} + + +def test_make_mocked_request_content() -> None: payload = mock.Mock() - req = make_mocked_request('GET', '/', payload=payload) + req = make_mocked_request("GET", "/", payload=payload) assert req.content is payload -def test_make_mocked_request_transport(): +def test_make_mocked_request_transport() -> None: transport = mock.Mock() - req = make_mocked_request('GET', '/', transport=transport) + req = make_mocked_request("GET", "/", transport=transport) assert req.transport is transport -def test_test_client_props(loop): +async def test_test_client_props(loop) -> None: app = _create_example_app() - client = _TestClient(app, loop=loop, host='localhost') - assert client.host == 'localhost' + client = _TestClient(_TestServer(app, host="127.0.0.1", loop=loop), loop=loop) + assert client.host == "127.0.0.1" assert client.port is None - with client: + async with client: assert isinstance(client.port, int) assert client.server is not None + assert client.app is not None assert client.port is None -def test_test_server_context_manager(loop): - app = _create_example_app() - with _TestServer(app, loop=loop) as server: - @asyncio.coroutine - def go(): - client = aiohttp.ClientSession(loop=loop) - resp = yield from client.head(server.make_url('/')) - assert resp.status == 200 - resp.close() - client.close() +async def test_test_client_raw_server_props(loop) -> None: + async def hello(request): + return web.Response(body=_hello_world_bytes) - loop.run_until_complete(go()) + client = _TestClient(_RawTestServer(hello, host="127.0.0.1", loop=loop), loop=loop) + assert client.host == "127.0.0.1" + assert client.port is None + async with client: + assert isinstance(client.port, int) + assert client.server is not None + assert client.app is None + assert client.port is None -def test_client_scheme_mutually_exclusive_with_server(): +async def test_test_server_context_manager(loop) -> None: app = _create_example_app() - server = _TestServer(app) - with pytest.raises(ValueError): - _TestClient(server, scheme='http') + async with _TestServer(app, loop=loop) as server: + client = aiohttp.ClientSession(loop=loop) + resp = await client.head(server.make_url("/")) + assert resp.status == 200 + resp.close() + await client.close() -def test_client_host_mutually_exclusive_with_server(): - app = _create_example_app() - server = _TestServer(app) - with pytest.raises(ValueError): - _TestClient(server, host='127.0.0.1') +def test_client_unsupported_arg() -> None: + with pytest.raises(TypeError) as e: + _TestClient("string") + assert ( + str(e.value) == "server must be TestServer instance, found type: " + ) -def test_client_unsupported_arg(): - with pytest.raises(TypeError): - _TestClient('string') - -def test_server_make_url_yarl_compatibility(loop): +async def test_server_make_url_yarl_compatibility(loop) -> None: app = _create_example_app() - with _TestServer(app, loop=loop) as server: + async with _TestServer(app, loop=loop) as server: make_url = server.make_url - assert make_url(URL('/foo')) == make_url('/foo') + assert make_url(URL("/foo")) == make_url("/foo") with pytest.raises(AssertionError): - make_url('http://foo.com') + make_url("http://foo.com") with pytest.raises(AssertionError): - make_url(URL('http://foo.com')) + make_url(URL("http://foo.com")) + + +def test_testcase_no_app(testdir, loop) -> None: + testdir.makepyfile( + """ + from aiohttp.test_utils import AioHTTPTestCase + + + class InvalidTestCase(AioHTTPTestCase): + def test_noop(self) -> None: + pass + """ + ) + result = testdir.runpytest() + result.stdout.fnmatch_lines(["*RuntimeError*"]) + + +async def test_server_context_manager(app, loop) -> None: + async with _TestServer(app, loop=loop) as server: + async with aiohttp.ClientSession(loop=loop) as client: + async with client.head(server.make_url("/")) as resp: + assert resp.status == 200 + + +@pytest.mark.parametrize( + "method", ["head", "get", "post", "options", "post", "put", "patch", "delete"] +) +async def test_client_context_manager_response(method, app, loop) -> None: + async with _TestClient(_TestServer(app), loop=loop) as client: + async with getattr(client, method)("/") as resp: + assert resp.status == 200 + if method != "head": + text = await resp.text() + assert "Hello, world" in text + + +async def test_custom_port(loop, app, aiohttp_unused_port) -> None: + port = aiohttp_unused_port() + client = _TestClient(_TestServer(app, loop=loop, port=port), loop=loop) + await client.start_server() + + assert client.server.port == port + + resp = await client.get("/") + assert resp.status == 200 + text = await resp.text() + assert _hello_world_str == text + + await client.close() diff --git a/tests/test_tracing.py b/tests/test_tracing.py new file mode 100644 index 00000000000..5523fe9589f --- /dev/null +++ b/tests/test_tracing.py @@ -0,0 +1,122 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from aiohttp.test_utils import make_mocked_coro +from aiohttp.tracing import ( + Trace, + TraceConfig, + TraceConnectionCreateEndParams, + TraceConnectionCreateStartParams, + TraceConnectionQueuedEndParams, + TraceConnectionQueuedStartParams, + TraceConnectionReuseconnParams, + TraceDnsCacheHitParams, + TraceDnsCacheMissParams, + TraceDnsResolveHostEndParams, + TraceDnsResolveHostStartParams, + TraceRequestChunkSentParams, + TraceRequestEndParams, + TraceRequestExceptionParams, + TraceRequestRedirectParams, + TraceRequestStartParams, + TraceResponseChunkReceivedParams, +) + + +class TestTraceConfig: + def test_trace_config_ctx_default(self) -> None: + trace_config = TraceConfig() + assert isinstance(trace_config.trace_config_ctx(), SimpleNamespace) + + def test_trace_config_ctx_factory(self) -> None: + trace_config = TraceConfig(trace_config_ctx_factory=dict) + assert isinstance(trace_config.trace_config_ctx(), dict) + + def test_trace_config_ctx_request_ctx(self) -> None: + trace_request_ctx = Mock() + trace_config = TraceConfig() + trace_config_ctx = trace_config.trace_config_ctx( + trace_request_ctx=trace_request_ctx + ) + assert trace_config_ctx.trace_request_ctx is trace_request_ctx + + def test_freeze(self) -> None: + trace_config = TraceConfig() + trace_config.freeze() + + assert trace_config.on_request_start.frozen + assert trace_config.on_request_chunk_sent.frozen + assert trace_config.on_response_chunk_received.frozen + assert trace_config.on_request_end.frozen + assert trace_config.on_request_exception.frozen + assert trace_config.on_request_redirect.frozen + assert trace_config.on_connection_queued_start.frozen + assert trace_config.on_connection_queued_end.frozen + assert trace_config.on_connection_create_start.frozen + assert trace_config.on_connection_create_end.frozen + assert trace_config.on_connection_reuseconn.frozen + assert trace_config.on_dns_resolvehost_start.frozen + assert trace_config.on_dns_resolvehost_end.frozen + assert trace_config.on_dns_cache_hit.frozen + assert trace_config.on_dns_cache_miss.frozen + + +class TestTrace: + @pytest.mark.parametrize( + "signal,params,param_obj", + [ + ("request_start", (Mock(), Mock(), Mock()), TraceRequestStartParams), + ( + "request_chunk_sent", + (Mock(), Mock(), Mock()), + TraceRequestChunkSentParams, + ), + ( + "response_chunk_received", + (Mock(), Mock(), Mock()), + TraceResponseChunkReceivedParams, + ), + ("request_end", (Mock(), Mock(), Mock(), Mock()), TraceRequestEndParams), + ( + "request_exception", + (Mock(), Mock(), Mock(), Mock()), + TraceRequestExceptionParams, + ), + ( + "request_redirect", + (Mock(), Mock(), Mock(), Mock()), + TraceRequestRedirectParams, + ), + ("connection_queued_start", (), TraceConnectionQueuedStartParams), + ("connection_queued_end", (), TraceConnectionQueuedEndParams), + ("connection_create_start", (), TraceConnectionCreateStartParams), + ("connection_create_end", (), TraceConnectionCreateEndParams), + ("connection_reuseconn", (), TraceConnectionReuseconnParams), + ("dns_resolvehost_start", (Mock(),), TraceDnsResolveHostStartParams), + ("dns_resolvehost_end", (Mock(),), TraceDnsResolveHostEndParams), + ("dns_cache_hit", (Mock(),), TraceDnsCacheHitParams), + ("dns_cache_miss", (Mock(),), TraceDnsCacheMissParams), + ], + ) + async def test_send(self, signal, params, param_obj) -> None: + session = Mock() + trace_request_ctx = Mock() + callback = Mock(side_effect=make_mocked_coro(Mock())) + + trace_config = TraceConfig() + getattr(trace_config, "on_%s" % signal).append(callback) + trace_config.freeze() + trace = Trace( + session, + trace_config, + trace_config.trace_config_ctx(trace_request_ctx=trace_request_ctx), + ) + await getattr(trace, "send_%s" % signal)(*params) + + callback.assert_called_once_with( + session, + SimpleNamespace(trace_request_ctx=trace_request_ctx), + param_obj(*params), + ) diff --git a/tests/test_urldispatch.py b/tests/test_urldispatch.py index 6a0af0df40a..588daed8d40 100644 --- a/tests/test_urldispatch.py +++ b/tests/test_urldispatch.py @@ -1,4 +1,3 @@ -import asyncio import os import pathlib import re @@ -6,35 +5,38 @@ from urllib.parse import unquote import pytest +from re_assert import Matches from yarl import URL import aiohttp from aiohttp import hdrs, web from aiohttp.test_utils import make_mocked_request from aiohttp.web import HTTPMethodNotAllowed, HTTPNotFound, Response -from aiohttp.web_urldispatcher import (AbstractResource, ResourceRoute, - SystemRoute, View, - _defaultExpectHandler) - - -def make_request(method, path): - return make_mocked_request(method, path) +from aiohttp.web_urldispatcher import ( + PATH_SEP, + AbstractResource, + Domain, + DynamicResource, + MaskDomain, + PlainResource, + ResourceRoute, + StaticResource, + SystemRoute, + View, + _default_expect_handler, +) def make_handler(): - - @asyncio.coroutine - def handler(request): + async def handler(request): return Response(request) # pragma: no cover return handler @pytest.fixture -def app(loop): - app = web.Application() - app._set_loop(loop) - return app +def app(): + return web.Application() @pytest.fixture @@ -45,949 +47,1131 @@ def router(app): @pytest.fixture def fill_routes(router): def go(): - route1 = router.add_route('GET', '/plain', make_handler()) - route2 = router.add_route('GET', '/variable/{name}', - make_handler()) - resource = router.add_static('/static', - os.path.dirname(aiohttp.__file__)) + route1 = router.add_route("GET", "/plain", make_handler()) + route2 = router.add_route("GET", "/variable/{name}", make_handler()) + resource = router.add_static("/static", os.path.dirname(aiohttp.__file__)) return [route1, route2] + list(resource) + return go -def test_register_uncommon_http_methods(router): +def test_register_uncommon_http_methods(router) -> None: uncommon_http_methods = { - 'PROPFIND', - 'PROPPATCH', - 'COPY', - 'LOCK', - 'UNLOCK' - 'MOVE', - 'SUBSCRIBE', - 'UNSUBSCRIBE', - 'NOTIFY' + "PROPFIND", + "PROPPATCH", + "COPY", + "LOCK", + "UNLOCK", + "MOVE", + "SUBSCRIBE", + "UNSUBSCRIBE", + "NOTIFY", } for method in uncommon_http_methods: - router.add_route(method, '/handler/to/path', make_handler()) + router.add_route(method, "/handler/to/path", make_handler()) -@asyncio.coroutine -def test_add_route_root(router): +async def test_add_route_root(router) -> None: handler = make_handler() - router.add_route('GET', '/', handler) - req = make_request('GET', '/') - info = yield from router.resolve(req) + router.add_route("GET", "/", handler) + req = make_mocked_request("GET", "/") + info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None -@asyncio.coroutine -def test_add_route_simple(router): +async def test_add_route_simple(router) -> None: handler = make_handler() - router.add_route('GET', '/handler/to/path', handler) - req = make_request('GET', '/handler/to/path') - info = yield from router.resolve(req) + router.add_route("GET", "/handler/to/path", handler) + req = make_mocked_request("GET", "/handler/to/path") + info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None -@asyncio.coroutine -def test_add_with_matchdict(router): +async def test_add_with_matchdict(router) -> None: + handler = make_handler() + router.add_route("GET", "/handler/{to}", handler) + req = make_mocked_request("GET", "/handler/tail") + info = await router.resolve(req) + assert info is not None + assert {"to": "tail"} == info + assert handler is info.handler + assert info.route.name is None + + +async def test_add_with_matchdict_with_colon(router) -> None: handler = make_handler() - router.add_route('GET', '/handler/{to}', handler) - req = make_request('GET', '/handler/tail') - info = yield from router.resolve(req) + router.add_route("GET", "/handler/{to}", handler) + req = make_mocked_request("GET", "/handler/1:2:3") + info = await router.resolve(req) assert info is not None - assert {'to': 'tail'} == info + assert {"to": "1:2:3"} == info assert handler is info.handler assert info.route.name is None -@asyncio.coroutine -def test_add_route_with_add_get_shortcut(router): +async def test_add_route_with_add_get_shortcut(router) -> None: handler = make_handler() - router.add_get('/handler/to/path', handler) - req = make_request('GET', '/handler/to/path') - info = yield from router.resolve(req) + router.add_get("/handler/to/path", handler) + req = make_mocked_request("GET", "/handler/to/path") + info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None -@asyncio.coroutine -def test_add_route_with_add_post_shortcut(router): +async def test_add_route_with_add_post_shortcut(router) -> None: handler = make_handler() - router.add_post('/handler/to/path', handler) - req = make_request('POST', '/handler/to/path') - info = yield from router.resolve(req) + router.add_post("/handler/to/path", handler) + req = make_mocked_request("POST", "/handler/to/path") + info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None -@asyncio.coroutine -def test_add_route_with_add_put_shortcut(router): +async def test_add_route_with_add_put_shortcut(router) -> None: handler = make_handler() - router.add_put('/handler/to/path', handler) - req = make_request('PUT', '/handler/to/path') - info = yield from router.resolve(req) + router.add_put("/handler/to/path", handler) + req = make_mocked_request("PUT", "/handler/to/path") + info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None -@asyncio.coroutine -def test_add_route_with_add_patch_shortcut(router): +async def test_add_route_with_add_patch_shortcut(router) -> None: handler = make_handler() - router.add_patch('/handler/to/path', handler) - req = make_request('PATCH', '/handler/to/path') - info = yield from router.resolve(req) + router.add_patch("/handler/to/path", handler) + req = make_mocked_request("PATCH", "/handler/to/path") + info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None -@asyncio.coroutine -def test_add_route_with_add_delete_shortcut(router): +async def test_add_route_with_add_delete_shortcut(router) -> None: handler = make_handler() - router.add_delete('/handler/to/path', handler) - req = make_request('DELETE', '/handler/to/path') - info = yield from router.resolve(req) + router.add_delete("/handler/to/path", handler) + req = make_mocked_request("DELETE", "/handler/to/path") + info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None -@asyncio.coroutine -def test_add_route_with_add_head_shortcut(router): +async def test_add_route_with_add_head_shortcut(router) -> None: handler = make_handler() - router.add_head('/handler/to/path', handler) - req = make_request('HEAD', '/handler/to/path') - info = yield from router.resolve(req) + router.add_head("/handler/to/path", handler) + req = make_mocked_request("HEAD", "/handler/to/path") + info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None -@asyncio.coroutine -def test_add_with_name(router): +async def test_add_with_name(router) -> None: handler = make_handler() - router.add_route('GET', '/handler/to/path', handler, - name='name') - req = make_request('GET', '/handler/to/path') - info = yield from router.resolve(req) + router.add_route("GET", "/handler/to/path", handler, name="name") + req = make_mocked_request("GET", "/handler/to/path") + info = await router.resolve(req) assert info is not None - assert 'name' == info.route.name + assert "name" == info.route.name -@asyncio.coroutine -def test_add_with_tailing_slash(router): +async def test_add_with_tailing_slash(router) -> None: handler = make_handler() - router.add_route('GET', '/handler/to/path/', handler) - req = make_request('GET', '/handler/to/path/') - info = yield from router.resolve(req) + router.add_route("GET", "/handler/to/path/", handler) + req = make_mocked_request("GET", "/handler/to/path/") + info = await router.resolve(req) assert info is not None assert {} == info assert handler is info.handler -def test_add_invalid_path(router): +def test_add_invalid_path(router) -> None: handler = make_handler() with pytest.raises(ValueError): - router.add_route('GET', '/{/', handler) + router.add_route("GET", "/{/", handler) -def test_add_url_invalid1(router): +def test_add_url_invalid1(router) -> None: handler = make_handler() with pytest.raises(ValueError): - router.add_route('post', '/post/{id', handler) + router.add_route("post", "/post/{id", handler) -def test_add_url_invalid2(router): +def test_add_url_invalid2(router) -> None: handler = make_handler() with pytest.raises(ValueError): - router.add_route('post', '/post/{id{}}', handler) + router.add_route("post", "/post/{id{}}", handler) -def test_add_url_invalid3(router): +def test_add_url_invalid3(router) -> None: handler = make_handler() with pytest.raises(ValueError): - router.add_route('post', '/post/{id{}', handler) + router.add_route("post", "/post/{id{}", handler) -def test_add_url_invalid4(router): +def test_add_url_invalid4(router) -> None: handler = make_handler() with pytest.raises(ValueError): - router.add_route('post', '/post/{id"}', handler) + router.add_route("post", '/post/{id"}', handler) -@asyncio.coroutine -def test_add_url_escaping(router): +async def test_add_url_escaping(router) -> None: handler = make_handler() - router.add_route('GET', '/+$', handler) + router.add_route("GET", "/+$", handler) - req = make_request('GET', '/+$') - info = yield from router.resolve(req) + req = make_mocked_request("GET", "/+$") + info = await router.resolve(req) assert info is not None assert handler is info.handler -@asyncio.coroutine -def test_any_method(router): +async def test_any_method(router) -> None: handler = make_handler() - route = router.add_route(hdrs.METH_ANY, '/', handler) + route = router.add_route(hdrs.METH_ANY, "/", handler) - req = make_request('GET', '/') - info1 = yield from router.resolve(req) + req = make_mocked_request("GET", "/") + info1 = await router.resolve(req) assert info1 is not None assert route is info1.route - req = make_request('POST', '/') - info2 = yield from router.resolve(req) + req = make_mocked_request("POST", "/") + info2 = await router.resolve(req) assert info2 is not None assert info1.route is info2.route -@asyncio.coroutine -def test_match_second_result_in_table(router): +async def test_match_second_result_in_table(router) -> None: handler1 = make_handler() handler2 = make_handler() - router.add_route('GET', '/h1', handler1) - router.add_route('POST', '/h2', handler2) - req = make_request('POST', '/h2') - info = yield from router.resolve(req) + router.add_route("GET", "/h1", handler1) + router.add_route("POST", "/h2", handler2) + req = make_mocked_request("POST", "/h2") + info = await router.resolve(req) assert info is not None assert {} == info assert handler2 is info.handler -@asyncio.coroutine -def test_raise_method_not_allowed(router): +async def test_raise_method_not_allowed(router) -> None: handler1 = make_handler() handler2 = make_handler() - router.add_route('GET', '/', handler1) - router.add_route('POST', '/', handler2) - req = make_request('PUT', '/') + router.add_route("GET", "/", handler1) + router.add_route("POST", "/", handler2) + req = make_mocked_request("PUT", "/") - match_info = yield from router.resolve(req) + match_info = await router.resolve(req) assert isinstance(match_info.route, SystemRoute) assert {} == match_info with pytest.raises(HTTPMethodNotAllowed) as ctx: - yield from match_info.handler(req) + await match_info.handler(req) exc = ctx.value - assert 'PUT' == exc.method + assert "PUT" == exc.method assert 405 == exc.status - assert {'POST', 'GET'} == exc.allowed_methods + assert {"POST", "GET"} == exc.allowed_methods -@asyncio.coroutine -def test_raise_method_not_found(router): +async def test_raise_method_not_found(router) -> None: handler = make_handler() - router.add_route('GET', '/a', handler) - req = make_request('GET', '/b') + router.add_route("GET", "/a", handler) + req = make_mocked_request("GET", "/b") - match_info = yield from router.resolve(req) + match_info = await router.resolve(req) assert isinstance(match_info.route, SystemRoute) assert {} == match_info with pytest.raises(HTTPNotFound) as ctx: - yield from match_info.handler(req) + await match_info.handler(req) exc = ctx.value assert 404 == exc.status -def test_double_add_url_with_the_same_name(router): +def test_double_add_url_with_the_same_name(router) -> None: handler1 = make_handler() handler2 = make_handler() - router.add_route('GET', '/get', handler1, name='name') + router.add_route("GET", "/get", handler1, name="name") - regexp = ("Duplicate 'name', already handled by") + regexp = "Duplicate 'name', already handled by" with pytest.raises(ValueError) as ctx: - router.add_route('GET', '/get_other', handler2, name='name') - assert re.match(regexp, str(ctx.value)) + router.add_route("GET", "/get_other", handler2, name="name") + assert Matches(regexp) == str(ctx.value) -def test_route_plain(router): +def test_route_plain(router) -> None: handler = make_handler() - route = router.add_route('GET', '/get', handler, name='name') - route2 = next(iter(router['name'])) - url = route2.url() - assert '/get' == url + route = router.add_route("GET", "/get", handler, name="name") + route2 = next(iter(router["name"])) + url = route2.url_for() + assert "/get" == str(url) assert route is route2 -def test_route_unknown_route_name(router): +def test_route_unknown_route_name(router) -> None: with pytest.raises(KeyError): - router['unknown'] + router["unknown"] -def test_route_dynamic(router): +def test_route_dynamic(router) -> None: handler = make_handler() - route = router.add_route('GET', '/get/{name}', handler, - name='name') + route = router.add_route("GET", "/get/{name}", handler, name="name") - route2 = next(iter(router['name'])) - url = route2.url(parts={'name': 'John'}) - assert '/get/John' == url + route2 = next(iter(router["name"])) + url = route2.url_for(name="John") + assert "/get/John" == str(url) assert route is route2 -def test_route_with_qs(router): - handler = make_handler() - router.add_route('GET', '/get', handler, name='name') - - url = router['name'].url(query=[('a', 'b'), ('c', '1')]) - assert '/get?a=b&c=1' == url +def test_add_static(router) -> None: + resource = router.add_static( + "/st", os.path.dirname(aiohttp.__file__), name="static" + ) + assert router["static"] is resource + url = resource.url_for(filename="/dir/a.txt") + assert "/st/dir/a.txt" == str(url) + assert len(resource) == 2 -def test_add_static(router): - resource = router.add_static('/st', - os.path.dirname(aiohttp.__file__), - name='static') - assert router['static'] is resource - url = resource.url(filename='/dir/a.txt') - assert '/st/dir/a.txt' == url +def test_add_static_append_version(router) -> None: + resource = router.add_static("/st", os.path.dirname(__file__), name="static") + url = resource.url_for(filename="/data.unknown_mime_type", append_version=True) + expect_url = ( + "/st/data.unknown_mime_type?" "v=aUsn8CHEhhszc81d28QmlcBW0KQpfS2F4trgQKhOYd8%3D" + ) + assert expect_url == str(url) + + +def test_add_static_append_version_set_from_constructor(router) -> None: + resource = router.add_static( + "/st", os.path.dirname(__file__), append_version=True, name="static" + ) + url = resource.url_for(filename="/data.unknown_mime_type") + expect_url = ( + "/st/data.unknown_mime_type?" "v=aUsn8CHEhhszc81d28QmlcBW0KQpfS2F4trgQKhOYd8%3D" + ) + assert expect_url == str(url) + + +def test_add_static_append_version_override_constructor(router) -> None: + resource = router.add_static( + "/st", os.path.dirname(__file__), append_version=True, name="static" + ) + url = resource.url_for(filename="/data.unknown_mime_type", append_version=False) + expect_url = "/st/data.unknown_mime_type" + assert expect_url == str(url) + + +def test_add_static_append_version_filename_without_slash(router) -> None: + resource = router.add_static("/st", os.path.dirname(__file__), name="static") + url = resource.url_for(filename="data.unknown_mime_type", append_version=True) + expect_url = ( + "/st/data.unknown_mime_type?" "v=aUsn8CHEhhszc81d28QmlcBW0KQpfS2F4trgQKhOYd8%3D" + ) + assert expect_url == str(url) + + +def test_add_static_append_version_non_exists_file(router) -> None: + resource = router.add_static("/st", os.path.dirname(__file__), name="static") + url = resource.url_for(filename="/non_exists_file", append_version=True) + assert "/st/non_exists_file" == str(url) + + +def test_add_static_append_version_non_exists_file_without_slash(router) -> None: + resource = router.add_static("/st", os.path.dirname(__file__), name="static") + url = resource.url_for(filename="non_exists_file", append_version=True) + assert "/st/non_exists_file" == str(url) + + +def test_add_static_append_version_follow_symlink(router, tmpdir) -> None: + # Tests the access to a symlink, in static folder with apeend_version + tmp_dir_path = str(tmpdir) + symlink_path = os.path.join(tmp_dir_path, "append_version_symlink") + symlink_target_path = os.path.dirname(__file__) + os.symlink(symlink_target_path, symlink_path, True) + + # Register global static route: + resource = router.add_static( + "/st", tmp_dir_path, follow_symlinks=True, append_version=True + ) + + url = resource.url_for(filename="/append_version_symlink/data.unknown_mime_type") + + expect_url = ( + "/st/append_version_symlink/data.unknown_mime_type?" + "v=aUsn8CHEhhszc81d28QmlcBW0KQpfS2F4trgQKhOYd8%3D" + ) + assert expect_url == str(url) + + +def test_add_static_append_version_not_follow_symlink(router, tmpdir) -> None: + # Tests the access to a symlink, in static folder with apeend_version + tmp_dir_path = str(tmpdir) + symlink_path = os.path.join(tmp_dir_path, "append_version_symlink") + symlink_target_path = os.path.dirname(__file__) + os.symlink(symlink_target_path, symlink_path, True) + + # Register global static route: + resource = router.add_static( + "/st", tmp_dir_path, follow_symlinks=False, append_version=True + ) + + filename = "/append_version_symlink/data.unknown_mime_type" + url = resource.url_for(filename=filename) + assert "/st/append_version_symlink/data.unknown_mime_type" == str(url) + + +def test_add_static_quoting(router) -> None: + resource = router.add_static( + "/пре %2Fфикс", pathlib.Path(aiohttp.__file__).parent, name="static" + ) + assert router["static"] is resource + url = resource.url_for(filename="/1 2/файл%2F.txt") + assert url.path == "/пре /фикс/1 2/файл%2F.txt" + assert str(url) == ( + "/%D0%BF%D1%80%D0%B5%20%2F%D1%84%D0%B8%D0%BA%D1%81" + "/1%202/%D1%84%D0%B0%D0%B9%D0%BB%252F.txt" + ) assert len(resource) == 2 -def test_plain_not_match(router): +def test_plain_not_match(router) -> None: handler = make_handler() - router.add_route('GET', '/get/path', handler, name='name') - route = router['name'] - assert route._match('/another/path') is None + router.add_route("GET", "/get/path", handler, name="name") + route = router["name"] + assert route._match("/another/path") is None -def test_dynamic_not_match(router): +def test_dynamic_not_match(router) -> None: handler = make_handler() - router.add_route('GET', '/get/{name}', handler, name='name') - route = router['name'] - assert route._match('/another/path') is None + router.add_route("GET", "/get/{name}", handler, name="name") + route = router["name"] + assert route._match("/another/path") is None -@asyncio.coroutine -def test_static_not_match(router): - router.add_static('/pre', os.path.dirname(aiohttp.__file__), - name='name') - resource = router['name'] - ret = yield from resource.resolve( - make_mocked_request('GET', '/another/path')) +async def test_static_not_match(router) -> None: + router.add_static("/pre", os.path.dirname(aiohttp.__file__), name="name") + resource = router["name"] + ret = await resource.resolve(make_mocked_request("GET", "/another/path")) assert (None, set()) == ret -def test_dynamic_with_trailing_slash(router): +def test_dynamic_with_trailing_slash(router) -> None: handler = make_handler() - router.add_route('GET', '/get/{name}/', handler, name='name') - route = router['name'] - assert {'name': 'John'} == route._match('/get/John/') + router.add_route("GET", "/get/{name}/", handler, name="name") + route = router["name"] + assert {"name": "John"} == route._match("/get/John/") -def test_len(router): +def test_len(router) -> None: handler = make_handler() - router.add_route('GET', '/get1', handler, name='name1') - router.add_route('GET', '/get2', handler, name='name2') + router.add_route("GET", "/get1", handler, name="name1") + router.add_route("GET", "/get2", handler, name="name2") assert 2 == len(router) -def test_iter(router): +def test_iter(router) -> None: handler = make_handler() - router.add_route('GET', '/get1', handler, name='name1') - router.add_route('GET', '/get2', handler, name='name2') - assert {'name1', 'name2'} == set(iter(router)) + router.add_route("GET", "/get1", handler, name="name1") + router.add_route("GET", "/get2", handler, name="name2") + assert {"name1", "name2"} == set(iter(router)) -def test_contains(router): +def test_contains(router) -> None: handler = make_handler() - router.add_route('GET', '/get1', handler, name='name1') - router.add_route('GET', '/get2', handler, name='name2') - assert 'name1' in router - assert 'name3' not in router + router.add_route("GET", "/get1", handler, name="name1") + router.add_route("GET", "/get2", handler, name="name2") + assert "name1" in router + assert "name3" not in router -def test_static_repr(router): - router.add_static('/get', os.path.dirname(aiohttp.__file__), - name='name') - assert re.match(r" None: + router.add_static("/get", os.path.dirname(aiohttp.__file__), name="name") + assert Matches(r" None: + route = router.add_static("/prefix", os.path.dirname(aiohttp.__file__)) + assert "/prefix" == route._prefix -def test_static_remove_trailing_slash(router): - route = router.add_static('/prefix/', - os.path.dirname(aiohttp.__file__)) - assert '/prefix' == route._prefix +def test_static_remove_trailing_slash(router) -> None: + route = router.add_static("/prefix/", os.path.dirname(aiohttp.__file__)) + assert "/prefix" == route._prefix -@asyncio.coroutine -def test_add_route_with_re(router): +async def test_add_route_with_re(router) -> None: handler = make_handler() - router.add_route('GET', r'/handler/{to:\d+}', handler) + router.add_route("GET", r"/handler/{to:\d+}", handler) - req = make_request('GET', '/handler/1234') - info = yield from router.resolve(req) + req = make_mocked_request("GET", "/handler/1234") + info = await router.resolve(req) assert info is not None - assert {'to': '1234'} == info + assert {"to": "1234"} == info - router.add_route('GET', r'/handler/{name}.html', handler) - req = make_request('GET', '/handler/test.html') - info = yield from router.resolve(req) - assert {'name': 'test'} == info + router.add_route("GET", r"/handler/{name}.html", handler) + req = make_mocked_request("GET", "/handler/test.html") + info = await router.resolve(req) + assert {"name": "test"} == info -@asyncio.coroutine -def test_add_route_with_re_and_slashes(router): +async def test_add_route_with_re_and_slashes(router) -> None: handler = make_handler() - router.add_route('GET', r'/handler/{to:[^/]+/?}', handler) - req = make_request('GET', '/handler/1234/') - info = yield from router.resolve(req) + router.add_route("GET", r"/handler/{to:[^/]+/?}", handler) + req = make_mocked_request("GET", "/handler/1234/") + info = await router.resolve(req) assert info is not None - assert {'to': '1234/'} == info + assert {"to": "1234/"} == info - router.add_route('GET', r'/handler/{to:.+}', handler) - req = make_request('GET', '/handler/1234/5/6/7') - info = yield from router.resolve(req) + router.add_route("GET", r"/handler/{to:.+}", handler) + req = make_mocked_request("GET", "/handler/1234/5/6/7") + info = await router.resolve(req) assert info is not None - assert {'to': '1234/5/6/7'} == info + assert {"to": "1234/5/6/7"} == info -@asyncio.coroutine -def test_add_route_with_re_not_match(router): +async def test_add_route_with_re_not_match(router) -> None: handler = make_handler() - router.add_route('GET', r'/handler/{to:\d+}', handler) + router.add_route("GET", r"/handler/{to:\d+}", handler) - req = make_request('GET', '/handler/tail') - match_info = yield from router.resolve(req) + req = make_mocked_request("GET", "/handler/tail") + match_info = await router.resolve(req) assert isinstance(match_info.route, SystemRoute) assert {} == match_info with pytest.raises(HTTPNotFound): - yield from match_info.handler(req) + await match_info.handler(req) -@asyncio.coroutine -def test_add_route_with_re_including_slashes(router): +async def test_add_route_with_re_including_slashes(router) -> None: handler = make_handler() - router.add_route('GET', r'/handler/{to:.+}/tail', handler) - req = make_request('GET', '/handler/re/with/slashes/tail') - info = yield from router.resolve(req) + router.add_route("GET", r"/handler/{to:.+}/tail", handler) + req = make_mocked_request("GET", "/handler/re/with/slashes/tail") + info = await router.resolve(req) assert info is not None - assert {'to': 're/with/slashes'} == info + assert {"to": "re/with/slashes"} == info -def test_add_route_with_invalid_re(router): +def test_add_route_with_invalid_re(router) -> None: handler = make_handler() with pytest.raises(ValueError) as ctx: - router.add_route('GET', r'/handler/{to:+++}', handler) + router.add_route("GET", r"/handler/{to:+++}", handler) s = str(ctx.value) assert s.startswith( - "Bad pattern '\/handler\/(?P+++)': nothing to repeat") + "Bad pattern '" + + PATH_SEP + + "handler" + + PATH_SEP + + "(?P+++)': nothing to repeat" + ) assert ctx.value.__cause__ is None -def test_route_dynamic_with_regex_spec(router): +def test_route_dynamic_with_regex_spec(router) -> None: + handler = make_handler() + route = router.add_route("GET", r"/get/{num:^\d+}", handler, name="name") + + url = route.url_for(num="123") + assert "/get/123" == str(url) + + +def test_route_dynamic_with_regex_spec_and_trailing_slash(router) -> None: + handler = make_handler() + route = router.add_route("GET", r"/get/{num:^\d+}/", handler, name="name") + + url = route.url_for(num="123") + assert "/get/123/" == str(url) + + +def test_route_dynamic_with_regex(router) -> None: handler = make_handler() - route = router.add_route('GET', '/get/{num:^\d+}', handler, - name='name') + route = router.add_route("GET", r"/{one}/{two:.+}", handler) - url = route.url(parts={'num': '123'}) - assert '/get/123' == url + url = route.url_for(one="1", two="2") + assert "/1/2" == str(url) -def test_route_dynamic_with_regex_spec_and_trailing_slash(router): +def test_route_dynamic_quoting(router) -> None: handler = make_handler() - route = router.add_route('GET', '/get/{num:^\d+}/', handler, - name='name') + route = router.add_route("GET", r"/пре %2Fфикс/{arg}", handler) - url = route.url(parts={'num': '123'}) - assert '/get/123/' == url + url = route.url_for(arg="1 2/текст%2F") + assert url.path == "/пре /фикс/1 2/текст%2F" + assert str(url) == ( + "/%D0%BF%D1%80%D0%B5%20%2F%D1%84%D0%B8%D0%BA%D1%81" + "/1%202/%D1%82%D0%B5%D0%BA%D1%81%D1%82%252F" + ) -def test_route_dynamic_with_regex(router): +async def test_regular_match_info(router) -> None: handler = make_handler() - route = router.add_route('GET', r'/{one}/{two:.+}', handler) + router.add_route("GET", "/get/{name}", handler) - url = route.url(parts={'one': 1, 'two': 2}) - assert '/1/2' == url + req = make_mocked_request("GET", "/get/john") + match_info = await router.resolve(req) + assert {"name": "john"} == match_info + assert Matches(">") == repr(match_info) -@asyncio.coroutine -def test_regular_match_info(router): +async def test_match_info_with_plus(router) -> None: handler = make_handler() - router.add_route('GET', '/get/{name}', handler) + router.add_route("GET", "/get/{version}", handler) - req = make_request('GET', '/get/john') - match_info = yield from router.resolve(req) - assert {'name': 'john'} == match_info - assert re.match(">", - repr(match_info)) + req = make_mocked_request("GET", "/get/1.0+test") + match_info = await router.resolve(req) + assert {"version": "1.0+test"} == match_info -@asyncio.coroutine -def test_not_found_repr(router): - req = make_request('POST', '/path/to') - match_info = yield from router.resolve(req) +async def test_not_found_repr(router) -> None: + req = make_mocked_request("POST", "/path/to") + match_info = await router.resolve(req) assert "" == repr(match_info) -@asyncio.coroutine -def test_not_allowed_repr(router): +async def test_not_allowed_repr(router) -> None: handler = make_handler() - router.add_route('GET', '/path/to', handler) + router.add_route("GET", "/path/to", handler) handler2 = make_handler() - router.add_route('POST', '/path/to', handler2) + router.add_route("POST", "/path/to", handler2) - req = make_request('PUT', '/path/to') - match_info = yield from router.resolve(req) + req = make_mocked_request("PUT", "/path/to") + match_info = await router.resolve(req) assert "" == repr(match_info) -def test_default_expect_handler(router): - route = router.add_route('GET', '/', make_handler()) - assert route._expect_handler is _defaultExpectHandler +def test_default_expect_handler(router) -> None: + route = router.add_route("GET", "/", make_handler()) + assert route._expect_handler is _default_expect_handler -def test_custom_expect_handler_plain(router): - - @asyncio.coroutine - def handler(request): +def test_custom_expect_handler_plain(router) -> None: + async def handler(request): pass - route = router.add_route( - 'GET', '/', make_handler(), expect_handler=handler) + route = router.add_route("GET", "/", make_handler(), expect_handler=handler) assert route._expect_handler is handler assert isinstance(route, ResourceRoute) -def test_custom_expect_handler_dynamic(router): - - @asyncio.coroutine - def handler(request): +def test_custom_expect_handler_dynamic(router) -> None: + async def handler(request): pass route = router.add_route( - 'GET', '/get/{name}', make_handler(), expect_handler=handler) + "GET", "/get/{name}", make_handler(), expect_handler=handler + ) assert route._expect_handler is handler assert isinstance(route, ResourceRoute) -def test_expect_handler_non_coroutine(router): - +def test_expect_handler_non_coroutine(router) -> None: def handler(request): pass with pytest.raises(AssertionError): - router.add_route('GET', '/', make_handler(), - expect_handler=handler) + router.add_route("GET", "/", make_handler(), expect_handler=handler) -@asyncio.coroutine -def test_dynamic_match_non_ascii(router): +async def test_dynamic_match_non_ascii(router) -> None: handler = make_handler() - router.add_route('GET', '/{var}', handler) - req = make_request( - 'GET', - '/%D1%80%D1%83%D1%81%20%D1%82%D0%B5%D0%BA%D1%81%D1%82') - match_info = yield from router.resolve(req) - assert {'var': 'рус текст'} == match_info + router.add_route("GET", "/{var}", handler) + req = make_mocked_request( + "GET", "/%D1%80%D1%83%D1%81%20%D1%82%D0%B5%D0%BA%D1%81%D1%82" + ) + match_info = await router.resolve(req) + assert {"var": "рус текст"} == match_info -@asyncio.coroutine -def test_dynamic_match_with_static_part(router): +async def test_dynamic_match_with_static_part(router) -> None: handler = make_handler() - router.add_route('GET', '/{name}.html', handler) - req = make_request('GET', '/file.html') - match_info = yield from router.resolve(req) - assert {'name': 'file'} == match_info + router.add_route("GET", "/{name}.html", handler) + req = make_mocked_request("GET", "/file.html") + match_info = await router.resolve(req) + assert {"name": "file"} == match_info -@asyncio.coroutine -def test_dynamic_match_two_part2(router): +async def test_dynamic_match_two_part2(router) -> None: handler = make_handler() - router.add_route('GET', '/{name}.{ext}', handler) - req = make_request('GET', '/file.html') - match_info = yield from router.resolve(req) - assert {'name': 'file', 'ext': 'html'} == match_info + router.add_route("GET", "/{name}.{ext}", handler) + req = make_mocked_request("GET", "/file.html") + match_info = await router.resolve(req) + assert {"name": "file", "ext": "html"} == match_info -@asyncio.coroutine -def test_dynamic_match_unquoted_path(router): +async def test_dynamic_match_unquoted_path(router) -> None: handler = make_handler() - router.add_route('GET', '/{path}/{subpath}', handler) - resource_id = 'my%2Fpath%7Cwith%21some%25strange%24characters' - req = make_request('GET', '/path/{0}'.format(resource_id)) - match_info = yield from router.resolve(req) - assert match_info == { - 'path': 'path', - 'subpath': unquote(resource_id) - } + router.add_route("GET", "/{path}/{subpath}", handler) + resource_id = "my%2Fpath%7Cwith%21some%25strange%24characters" + req = make_mocked_request("GET", f"/path/{resource_id}") + match_info = await router.resolve(req) + assert match_info == {"path": "path", "subpath": unquote(resource_id)} -def test_add_route_not_started_with_slash(router): +def test_add_route_not_started_with_slash(router) -> None: with pytest.raises(ValueError): handler = make_handler() - router.add_route('GET', 'invalid_path', handler) + router.add_route("GET", "invalid_path", handler) -def test_add_route_invalid_method(router): +def test_add_route_invalid_method(router) -> None: sample_bad_methods = { - 'BAD METHOD', - 'B@D_METHOD', - '[BAD_METHOD]', - '{BAD_METHOD}', - '(BAD_METHOD)', - 'B?D_METHOD', + "BAD METHOD", + "B@D_METHOD", + "[BAD_METHOD]", + "{BAD_METHOD}", + "(BAD_METHOD)", + "B?D_METHOD", } for bad_method in sample_bad_methods: with pytest.raises(ValueError): handler = make_handler() - router.add_route(bad_method, '/path', handler) + router.add_route(bad_method, "/path", handler) -def test_routes_view_len(router, fill_routes): +def test_routes_view_len(router, fill_routes) -> None: fill_routes() assert 4 == len(router.routes()) -def test_routes_view_iter(router, fill_routes): +def test_routes_view_iter(router, fill_routes) -> None: routes = fill_routes() assert list(routes) == list(router.routes()) -def test_routes_view_contains(router, fill_routes): +def test_routes_view_contains(router, fill_routes) -> None: routes = fill_routes() for route in routes: assert route in router.routes() -def test_routes_abc(router): +def test_routes_abc(router) -> None: assert isinstance(router.routes(), Sized) assert isinstance(router.routes(), Iterable) assert isinstance(router.routes(), Container) -def test_named_resources_abc(router): +def test_named_resources_abc(router) -> None: assert isinstance(router.named_resources(), Mapping) assert not isinstance(router.named_resources(), MutableMapping) -def test_named_resources(router): - route1 = router.add_route('GET', '/plain', make_handler(), - name='route1') - route2 = router.add_route('GET', '/variable/{name}', - make_handler(), name='route2') - route3 = router.add_static('/static', - os.path.dirname(aiohttp.__file__), - name='route3') +def test_named_resources(router) -> None: + route1 = router.add_route("GET", "/plain", make_handler(), name="route1") + route2 = router.add_route("GET", "/variable/{name}", make_handler(), name="route2") + route3 = router.add_static( + "/static", os.path.dirname(aiohttp.__file__), name="route3" + ) names = {route1.name, route2.name, route3.name} assert 3 == len(router.named_resources()) for name in names: assert name in router.named_resources() - assert isinstance(router.named_resources()[name], - AbstractResource) + assert isinstance(router.named_resources()[name], AbstractResource) + +def test_resource_iter(router) -> None: + async def handler(request): + pass -def test_resource_iter(router): - resource = router.add_resource('/path') - r1 = resource.add_route('GET', lambda req: None) - r2 = resource.add_route('POST', lambda req: None) + resource = router.add_resource("/path") + r1 = resource.add_route("GET", handler) + r2 = resource.add_route("POST", handler) assert 2 == len(resource) assert [r1, r2] == list(resource) -def test_deprecate_bare_generators(router): - resource = router.add_resource('/path') +def test_deprecate_bare_generators(router) -> None: + resource = router.add_resource("/path") def gen(request): yield with pytest.warns(DeprecationWarning): - resource.add_route('GET', gen) + resource.add_route("GET", gen) -def test_view_route(router): - resource = router.add_resource('/path') +def test_view_route(router) -> None: + resource = router.add_resource("/path") - route = resource.add_route('GET', View) + route = resource.add_route("GET", View) assert View is route.handler -def test_resource_route_match(router): - resource = router.add_resource('/path') - route = resource.add_route('GET', lambda req: None) - assert {} == route.resource._match('/path') +def test_resource_route_match(router) -> None: + async def handler(request): + pass + + resource = router.add_resource("/path") + route = resource.add_route("GET", handler) + assert {} == route.resource._match("/path") -def test_error_on_double_route_adding(router): - resource = router.add_resource('/path') +def test_error_on_double_route_adding(router) -> None: + async def handler(request): + pass - resource.add_route('GET', lambda: None) + resource = router.add_resource("/path") + + resource.add_route("GET", handler) with pytest.raises(RuntimeError): - resource.add_route('GET', lambda: None) + resource.add_route("GET", handler) + +def test_error_on_adding_route_after_wildcard(router) -> None: + async def handler(request): + pass -def test_error_on_adding_route_after_wildcard(router): - resource = router.add_resource('/path') + resource = router.add_resource("/path") - resource.add_route('*', lambda: None) + resource.add_route("*", handler) with pytest.raises(RuntimeError): - resource.add_route('GET', lambda: None) + resource.add_route("GET", handler) -@asyncio.coroutine -def test_http_exception_is_none_when_resolved(router): +async def test_http_exception_is_none_when_resolved(router) -> None: handler = make_handler() - router.add_route('GET', '/', handler) - req = make_request('GET', '/') - info = yield from router.resolve(req) + router.add_route("GET", "/", handler) + req = make_mocked_request("GET", "/") + info = await router.resolve(req) assert info.http_exception is None -@asyncio.coroutine -def test_http_exception_is_not_none_when_not_resolved(router): +async def test_http_exception_is_not_none_when_not_resolved(router) -> None: handler = make_handler() - router.add_route('GET', '/', handler) - req = make_request('GET', '/abc') - info = yield from router.resolve(req) + router.add_route("GET", "/", handler) + req = make_mocked_request("GET", "/abc") + info = await router.resolve(req) assert info.http_exception.status == 404 -@asyncio.coroutine -def test_match_info_get_info_plain(router): +async def test_match_info_get_info_plain(router) -> None: handler = make_handler() - router.add_route('GET', '/', handler) - req = make_request('GET', '/') - info = yield from router.resolve(req) - assert info.get_info() == {'path': '/'} + router.add_route("GET", "/", handler) + req = make_mocked_request("GET", "/") + info = await router.resolve(req) + assert info.get_info() == {"path": "/"} -@asyncio.coroutine -def test_match_info_get_info_dynamic(router): +async def test_match_info_get_info_dynamic(router) -> None: handler = make_handler() - router.add_route('GET', '/{a}', handler) - req = make_request('GET', '/value') - info = yield from router.resolve(req) + router.add_route("GET", "/{a}", handler) + req = make_mocked_request("GET", "/value") + info = await router.resolve(req) assert info.get_info() == { - 'pattern': re.compile('\\/(?P[^{}/]+)'), - 'formatter': '/{a}'} + "pattern": re.compile(PATH_SEP + "(?P[^{}/]+)"), + "formatter": "/{a}", + } -@asyncio.coroutine -def test_match_info_get_info_dynamic2(router): +async def test_match_info_get_info_dynamic2(router) -> None: handler = make_handler() - router.add_route('GET', '/{a}/{b}', handler) - req = make_request('GET', '/path/to') - info = yield from router.resolve(req) + router.add_route("GET", "/{a}/{b}", handler) + req = make_mocked_request("GET", "/path/to") + info = await router.resolve(req) assert info.get_info() == { - 'pattern': re.compile('\\/(?P[^{}/]+)\\/(?P[^{}/]+)'), - 'formatter': '/{a}/{b}'} + "pattern": re.compile( + PATH_SEP + "(?P[^{}/]+)" + PATH_SEP + "(?P[^{}/]+)" + ), + "formatter": "/{a}/{b}", + } -def test_static_resource_get_info(router): - directory = pathlib.Path(aiohttp.__file__).parent - resource = router.add_static('/st', directory) - assert resource.get_info() == {'directory': directory, - 'prefix': '/st'} +def test_static_resource_get_info(router) -> None: + directory = pathlib.Path(aiohttp.__file__).parent.resolve() + resource = router.add_static("/st", directory) + info = resource.get_info() + assert len(info) == 3 + assert info["directory"] == directory + assert info["prefix"] == "/st" + assert all([type(r) is ResourceRoute for r in info["routes"].values()]) -@asyncio.coroutine -def test_system_route_get_info(router): +async def test_system_route_get_info(router) -> None: handler = make_handler() - router.add_route('GET', '/', handler) - req = make_request('GET', '/abc') - info = yield from router.resolve(req) - assert info.get_info()['http_exception'].status == 404 + router.add_route("GET", "/", handler) + req = make_mocked_request("GET", "/abc") + info = await router.resolve(req) + assert info.get_info()["http_exception"].status == 404 -def test_resources_view_len(router): - router.add_resource('/plain') - router.add_resource('/variable/{name}') +def test_resources_view_len(router) -> None: + router.add_resource("/plain") + router.add_resource("/variable/{name}") assert 2 == len(router.resources()) -def test_resources_view_iter(router): - resource1 = router.add_resource('/plain') - resource2 = router.add_resource('/variable/{name}') +def test_resources_view_iter(router) -> None: + resource1 = router.add_resource("/plain") + resource2 = router.add_resource("/variable/{name}") resources = [resource1, resource2] assert list(resources) == list(router.resources()) -def test_resources_view_contains(router): - resource1 = router.add_resource('/plain') - resource2 = router.add_resource('/variable/{name}') +def test_resources_view_contains(router) -> None: + resource1 = router.add_resource("/plain") + resource2 = router.add_resource("/variable/{name}") resources = [resource1, resource2] for resource in resources: assert resource in router.resources() -def test_resources_abc(router): +def test_resources_abc(router) -> None: assert isinstance(router.resources(), Sized) assert isinstance(router.resources(), Iterable) assert isinstance(router.resources(), Container) -def test_static_route_user_home(router): +def test_static_route_user_home(router) -> None: here = pathlib.Path(aiohttp.__file__).parent - home = pathlib.Path(os.path.expanduser('~')) + home = pathlib.Path(os.path.expanduser("~")) if not str(here).startswith(str(home)): # pragma: no cover pytest.skip("aiohttp folder is not placed in user's HOME") - static_dir = '~/' + str(here.relative_to(home)) - route = router.add_static('/st', static_dir) - assert here == route.get_info()['directory'] + static_dir = "~/" + str(here.relative_to(home)) + route = router.add_static("/st", static_dir) + assert here == route.get_info()["directory"] -def test_static_route_points_to_file(router): - here = pathlib.Path(aiohttp.__file__).parent / '__init__.py' +def test_static_route_points_to_file(router) -> None: + here = pathlib.Path(aiohttp.__file__).parent / "__init__.py" with pytest.raises(ValueError): - router.add_static('/st', here) + router.add_static("/st", here) -@asyncio.coroutine -def test_404_for_static_resource(router): - resource = router.add_static('/st', - os.path.dirname(aiohttp.__file__)) - ret = yield from resource.resolve( - make_mocked_request('GET', '/unknown/path')) +async def test_404_for_static_resource(router) -> None: + resource = router.add_static("/st", os.path.dirname(aiohttp.__file__)) + ret = await resource.resolve(make_mocked_request("GET", "/unknown/path")) assert (None, set()) == ret -@asyncio.coroutine -def test_405_for_resource_adapter(router): - resource = router.add_static('/st', - os.path.dirname(aiohttp.__file__)) - ret = yield from resource.resolve( - make_mocked_request('POST', '/st/abc.py')) - assert (None, {'HEAD', 'GET'}) == ret +async def test_405_for_resource_adapter(router) -> None: + resource = router.add_static("/st", os.path.dirname(aiohttp.__file__)) + ret = await resource.resolve(make_mocked_request("POST", "/st/abc.py")) + assert (None, {"HEAD", "GET"}) == ret -@asyncio.coroutine -def test_check_allowed_method_for_found_resource(router): +async def test_check_allowed_method_for_found_resource(router) -> None: handler = make_handler() - resource = router.add_resource('/') - resource.add_route('GET', handler) - ret = yield from resource.resolve(make_mocked_request('GET', '/')) + resource = router.add_resource("/") + resource.add_route("GET", handler) + ret = await resource.resolve(make_mocked_request("GET", "/")) assert ret[0] is not None - assert {'GET'} == ret[1] + assert {"GET"} == ret[1] -def test_url_for_in_static_resource(router): - resource = router.add_static('/static', - os.path.dirname(aiohttp.__file__)) - assert URL('/static/file.txt') == resource.url_for(filename='file.txt') +def test_url_for_in_static_resource(router) -> None: + resource = router.add_static("/static", os.path.dirname(aiohttp.__file__)) + assert URL("/static/file.txt") == resource.url_for(filename="file.txt") -def test_url_for_in_static_resource_pathlib(router): - resource = router.add_static('/static', - os.path.dirname(aiohttp.__file__)) - assert URL('/static/file.txt') == resource.url_for( - filename=pathlib.Path('file.txt')) +def test_url_for_in_static_resource_pathlib(router) -> None: + resource = router.add_static("/static", os.path.dirname(aiohttp.__file__)) + assert URL("/static/file.txt") == resource.url_for( + filename=pathlib.Path("file.txt") + ) -def test_url_for_in_resource_route(router): - route = router.add_route('GET', '/get/{name}', make_handler(), - name='name') - assert URL('/get/John') == route.url_for(name='John') +def test_url_for_in_resource_route(router) -> None: + route = router.add_route("GET", "/get/{name}", make_handler(), name="name") + assert URL("/get/John") == route.url_for(name="John") -def test_subapp_get_info(app, loop): +def test_subapp_get_info(app) -> None: subapp = web.Application() - resource = subapp.add_subapp('/pre', subapp) - assert resource.get_info() == {'prefix': '/pre', 'app': subapp} - - -def test_subapp_url(app, loop): + resource = subapp.add_subapp("/pre", subapp) + assert resource.get_info() == {"prefix": "/pre", "app": subapp} + + +@pytest.mark.parametrize( + "domain,error", + [ + (None, TypeError), + ("", ValueError), + ("http://dom", ValueError), + ("*.example.com", ValueError), + ("example$com", ValueError), + ], +) +def test_domain_validation_error(domain, error): + with pytest.raises(error): + Domain(domain) + + +def test_domain_valid(): + assert Domain("example.com:81").canonical == "example.com:81" + assert MaskDomain("*.example.com").canonical == r".*\.example\.com" + assert Domain("пуни.код").canonical == "xn--h1ajfq.xn--d1alm" + + +@pytest.mark.parametrize( + "a,b,result", + [ + ("example.com", "example.com", True), + ("example.com:81", "example.com:81", True), + ("example.com:81", "example.com", False), + ("пуникод", "xn--d1ahgkhc2a", True), + ("*.example.com", "jpg.example.com", True), + ("*.example.com", "a.example.com", True), + ("*.example.com", "example.com", False), + ], +) +def test_match_domain(a, b, result): + if "*" in a: + rule = MaskDomain(a) + else: + rule = Domain(a) + assert rule.match_domain(b) is result + + +def test_add_subapp_errors(app): + with pytest.raises(TypeError): + app.add_subapp(1, web.Application()) + + +def test_subapp_rule_resource(app): subapp = web.Application() - resource = app.add_subapp('/pre', subapp) + subapp.router.add_get("/", make_handler()) + rule = Domain("example.com") + assert rule.get_info() == {"domain": "example.com"} + resource = app.add_domain("example.com", subapp) + assert resource.canonical == "example.com" + assert resource.get_info() == {"rule": resource._rule, "app": subapp} + resource.add_prefix("/a") + resource.raw_match("/b") + assert len(resource) + assert list(resource) + assert repr(resource).startswith(" None: subapp = web.Application() - resource = app.add_subapp('/pre', subapp) + resource = app.add_subapp("/pre", subapp) with pytest.raises(RuntimeError): resource.url_for() -def test_subapp_repr(app, loop): +def test_subapp_repr(app) -> None: subapp = web.Application() - resource = app.add_subapp('/pre', subapp) - assert repr(resource).startswith( - ' None: subapp = web.Application() - subapp.router.add_get('/', make_handler(), allow_head=False) - subapp.router.add_post('/', make_handler()) - resource = app.add_subapp('/pre', subapp) + subapp.router.add_get("/", make_handler(), allow_head=False) + subapp.router.add_post("/", make_handler()) + resource = app.add_subapp("/pre", subapp) assert len(resource) == 2 -def test_subapp_iter(app, loop): +def test_subapp_iter(app) -> None: subapp = web.Application() - r1 = subapp.router.add_get('/', make_handler(), allow_head=False) - r2 = subapp.router.add_post('/', make_handler()) - resource = app.add_subapp('/pre', subapp) + r1 = subapp.router.add_get("/", make_handler(), allow_head=False) + r2 = subapp.router.add_post("/", make_handler()) + resource = app.add_subapp("/pre", subapp) assert list(resource) == [r1, r2] -def test_invalid_route_name(router): +def test_invalid_route_name(router) -> None: + with pytest.raises(ValueError): + router.add_get("/", make_handler(), name="invalid name") + + +def test_invalid_route_name(router) -> None: with pytest.raises(ValueError): - router.add_get('/', make_handler(), name='invalid name') + router.add_get("/", make_handler(), name="class") # identifier -def test_frozen_router(router): +def test_frozen_router(router) -> None: router.freeze() with pytest.raises(RuntimeError): - router.add_get('/', make_handler()) + router.add_get("/", make_handler()) -def test_frozen_router_subapp(app, loop): +def test_frozen_router_subapp(app) -> None: subapp = web.Application() subapp.freeze() with pytest.raises(RuntimeError): - app.add_subapp('/', subapp) + app.add_subapp("/pre", subapp) -def test_frozen_app_on_subapp(app, loop): +def test_frozen_app_on_subapp(app) -> None: app.freeze() subapp = web.Application() with pytest.raises(RuntimeError): - app.add_subapp('/', subapp) + app.add_subapp("/pre", subapp) -def test_set_options_route(router): - resource = router.add_static('/static', - os.path.dirname(aiohttp.__file__)) +def test_set_options_route(router) -> None: + resource = router.add_static("/static", os.path.dirname(aiohttp.__file__)) options = None for route in resource: - if route.method == 'OPTIONS': + if route.method == "OPTIONS": options = route assert options is None resource.set_options_route(make_handler()) for route in resource: - if route.method == 'OPTIONS': + if route.method == "OPTIONS": options = route assert options is not None @@ -995,28 +1179,110 @@ def test_set_options_route(router): resource.set_options_route(make_handler()) -def test_dynamic_url_with_name_started_from_undescore(router): - route = router.add_route('GET', '/get/{_name}', make_handler()) - assert URL('/get/John') == route.url_for(_name='John') +def test_dynamic_url_with_name_started_from_underscore(router) -> None: + route = router.add_route("GET", "/get/{_name}", make_handler()) + assert URL("/get/John") == route.url_for(_name="John") -def test_cannot_add_subapp_with_empty_prefix(app, loop): +def test_cannot_add_subapp_with_empty_prefix(app) -> None: subapp = web.Application() with pytest.raises(ValueError): - app.add_subapp('', subapp) + app.add_subapp("", subapp) -def test_cannot_add_subapp_with_slash_prefix(app, loop): +def test_cannot_add_subapp_with_slash_prefix(app) -> None: subapp = web.Application() with pytest.raises(ValueError): - app.add_subapp('/', subapp) + app.add_subapp("/", subapp) -@asyncio.coroutine -def test_convert_empty_path_to_slash_on_freezing(router): +async def test_convert_empty_path_to_slash_on_freezing(router) -> None: handler = make_handler() - route = router.add_get('', handler) + route = router.add_get("", handler) resource = route.resource - assert resource.get_info() == {'path': ''} + assert resource.get_info() == {"path": ""} router.freeze() - assert resource.get_info() == {'path': '/'} + assert resource.get_info() == {"path": "/"} + + +def test_deprecate_non_coroutine(router) -> None: + def handler(request): + pass + + with pytest.warns(DeprecationWarning): + router.add_route("GET", "/handler", handler) + + +def test_plain_resource_canonical() -> None: + canonical = "/plain/path" + res = PlainResource(path=canonical) + assert res.canonical == canonical + + +def test_dynamic_resource_canonical() -> None: + canonicals = { + "/get/{name}": "/get/{name}", + r"/get/{num:^\d+}": "/get/{num}", + r"/handler/{to:\d+}": r"/handler/{to}", + r"/{one}/{two:.+}": r"/{one}/{two}", + } + for pattern, canonical in canonicals.items(): + res = DynamicResource(path=pattern) + assert res.canonical == canonical + + +def test_static_resource_canonical() -> None: + prefix = "/prefix" + directory = str(os.path.dirname(aiohttp.__file__)) + canonical = prefix + res = StaticResource(prefix=prefix, directory=directory) + assert res.canonical == canonical + + +def test_prefixed_subapp_resource_canonical(app) -> None: + canonical = "/prefix" + subapp = web.Application() + res = subapp.add_subapp(canonical, subapp) + assert res.canonical == canonical + + +async def test_prefixed_subapp_overlap(app) -> None: + # Subapp should not overshadow other subapps with overlapping prefixes + subapp1 = web.Application() + handler1 = make_handler() + subapp1.router.add_get("/a", handler1) + app.add_subapp("/s", subapp1) + + subapp2 = web.Application() + handler2 = make_handler() + subapp2.router.add_get("/b", handler2) + app.add_subapp("/ss", subapp2) + + match_info = await app.router.resolve(make_mocked_request("GET", "/s/a")) + assert match_info.route.handler is handler1 + match_info = await app.router.resolve(make_mocked_request("GET", "/ss/b")) + assert match_info.route.handler is handler2 + + +async def test_prefixed_subapp_empty_route(app) -> None: + subapp = web.Application() + handler = make_handler() + subapp.router.add_get("", handler) + app.add_subapp("/s", subapp) + + match_info = await app.router.resolve(make_mocked_request("GET", "/s")) + assert match_info.route.handler is handler + match_info = await app.router.resolve(make_mocked_request("GET", "/s/")) + assert "" == repr(match_info) + + +async def test_prefixed_subapp_root_route(app) -> None: + subapp = web.Application() + handler = make_handler() + subapp.router.add_get("/", handler) + app.add_subapp("/s", subapp) + + match_info = await app.router.resolve(make_mocked_request("GET", "/s/")) + assert match_info.route.handler is handler + match_info = await app.router.resolve(make_mocked_request("GET", "/s")) + assert "" == repr(match_info) diff --git a/tests/test_web_app.py b/tests/test_web_app.py new file mode 100644 index 00000000000..f48e54bb861 --- /dev/null +++ b/tests/test_web_app.py @@ -0,0 +1,576 @@ +import asyncio +from unittest import mock + +import pytest +from async_generator import async_generator, yield_ + +from aiohttp import log, web +from aiohttp.abc import AbstractAccessLogger, AbstractRouter +from aiohttp.helpers import DEBUG, PY_36 +from aiohttp.test_utils import make_mocked_coro + + +async def test_app_ctor() -> None: + loop = asyncio.get_event_loop() + with pytest.warns(DeprecationWarning): + app = web.Application(loop=loop) + with pytest.warns(DeprecationWarning): + assert loop is app.loop + assert app.logger is log.web_logger + + +def test_app_call() -> None: + app = web.Application() + assert app is app() + + +def test_app_default_loop() -> None: + app = web.Application() + with pytest.warns(DeprecationWarning): + assert app.loop is None + + +async def test_set_loop() -> None: + loop = asyncio.get_event_loop() + app = web.Application() + app._set_loop(loop) + with pytest.warns(DeprecationWarning): + assert app.loop is loop + + +def test_set_loop_default_loop() -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + app = web.Application() + app._set_loop(None) + with pytest.warns(DeprecationWarning): + assert app.loop is loop + asyncio.set_event_loop(None) + + +def test_set_loop_with_different_loops() -> None: + loop = asyncio.new_event_loop() + app = web.Application() + app._set_loop(loop) + with pytest.warns(DeprecationWarning): + assert app.loop is loop + + with pytest.raises(RuntimeError): + app._set_loop(loop=object()) + + +@pytest.mark.parametrize("debug", [True, False]) +async def test_app_make_handler_debug_exc(mocker, debug) -> None: + with pytest.warns(DeprecationWarning): + app = web.Application(debug=debug) + srv = mocker.patch("aiohttp.web_app.Server") + + with pytest.warns(DeprecationWarning): + assert app.debug == debug + + app._make_handler() + srv.assert_called_with( + app._handle, + request_factory=app._make_request, + access_log_class=mock.ANY, + loop=asyncio.get_event_loop(), + debug=debug, + ) + + +async def test_app_make_handler_args(mocker) -> None: + app = web.Application(handler_args={"test": True}) + srv = mocker.patch("aiohttp.web_app.Server") + + app._make_handler() + srv.assert_called_with( + app._handle, + request_factory=app._make_request, + access_log_class=mock.ANY, + loop=asyncio.get_event_loop(), + debug=mock.ANY, + test=True, + ) + + +async def test_app_make_handler_access_log_class(mocker) -> None: + class Logger: + pass + + app = web.Application() + + with pytest.raises(TypeError): + app._make_handler(access_log_class=Logger) + + class Logger(AbstractAccessLogger): + def log(self, request, response, time): + self.logger.info("msg") + + srv = mocker.patch("aiohttp.web_app.Server") + + app._make_handler(access_log_class=Logger) + srv.assert_called_with( + app._handle, + access_log_class=Logger, + request_factory=app._make_request, + loop=asyncio.get_event_loop(), + debug=mock.ANY, + ) + + app = web.Application(handler_args={"access_log_class": Logger}) + app._make_handler(access_log_class=Logger) + srv.assert_called_with( + app._handle, + access_log_class=Logger, + request_factory=app._make_request, + loop=asyncio.get_event_loop(), + debug=mock.ANY, + ) + + +async def test_app_make_handler_raises_deprecation_warning() -> None: + app = web.Application() + + with pytest.warns(DeprecationWarning): + app.make_handler() + + +async def test_app_register_on_finish() -> None: + app = web.Application() + cb1 = make_mocked_coro(None) + cb2 = make_mocked_coro(None) + app.on_cleanup.append(cb1) + app.on_cleanup.append(cb2) + app.freeze() + await app.cleanup() + cb1.assert_called_once_with(app) + cb2.assert_called_once_with(app) + + +async def test_app_register_coro() -> None: + app = web.Application() + fut = asyncio.get_event_loop().create_future() + + async def cb(app): + await asyncio.sleep(0.001) + fut.set_result(123) + + app.on_cleanup.append(cb) + app.freeze() + await app.cleanup() + assert fut.done() + assert 123 == fut.result() + + +def test_non_default_router() -> None: + router = mock.Mock(spec=AbstractRouter) + with pytest.warns(DeprecationWarning): + app = web.Application(router=router) + assert router is app.router + + +def test_logging() -> None: + logger = mock.Mock() + app = web.Application() + app.logger = logger + assert app.logger is logger + + +async def test_on_shutdown() -> None: + app = web.Application() + called = False + + async def on_shutdown(app_param): + nonlocal called + assert app is app_param + called = True + + app.on_shutdown.append(on_shutdown) + app.freeze() + await app.shutdown() + assert called + + +async def test_on_startup() -> None: + app = web.Application() + + long_running1_called = False + long_running2_called = False + all_long_running_called = False + + async def long_running1(app_param): + nonlocal long_running1_called + assert app is app_param + long_running1_called = True + + async def long_running2(app_param): + nonlocal long_running2_called + assert app is app_param + long_running2_called = True + + async def on_startup_all_long_running(app_param): + nonlocal all_long_running_called + assert app is app_param + all_long_running_called = True + return await asyncio.gather(long_running1(app_param), long_running2(app_param)) + + app.on_startup.append(on_startup_all_long_running) + app.freeze() + + await app.startup() + assert long_running1_called + assert long_running2_called + assert all_long_running_called + + +def test_app_delitem() -> None: + app = web.Application() + app["key"] = "value" + assert len(app) == 1 + del app["key"] + assert len(app) == 0 + + +def test_app_freeze() -> None: + app = web.Application() + subapp = mock.Mock() + subapp._middlewares = () + app._subapps.append(subapp) + + app.freeze() + assert subapp.freeze.called + + app.freeze() + assert len(subapp.freeze.call_args_list) == 1 + + +def test_equality() -> None: + app1 = web.Application() + app2 = web.Application() + + assert app1 == app1 + assert app1 != app2 + + +def test_app_run_middlewares() -> None: + + root = web.Application() + sub = web.Application() + root.add_subapp("/sub", sub) + root.freeze() + assert root._run_middlewares is False + + @web.middleware + async def middleware(request, handler): + return await handler(request) + + root = web.Application(middlewares=[middleware]) + sub = web.Application() + root.add_subapp("/sub", sub) + root.freeze() + assert root._run_middlewares is True + + root = web.Application() + sub = web.Application(middlewares=[middleware]) + root.add_subapp("/sub", sub) + root.freeze() + assert root._run_middlewares is True + + +def test_subapp_pre_frozen_after_adding() -> None: + app = web.Application() + subapp = web.Application() + + app.add_subapp("/prefix", subapp) + assert subapp.pre_frozen + assert not subapp.frozen + + +@pytest.mark.skipif(not PY_36, reason="Python 3.6+ required") +def test_app_inheritance() -> None: + with pytest.warns(DeprecationWarning): + + class A(web.Application): + pass + + +@pytest.mark.skipif(not DEBUG, reason="The check is applied in DEBUG mode only") +def test_app_custom_attr() -> None: + app = web.Application() + with pytest.warns(DeprecationWarning): + app.custom = None + + +async def test_cleanup_ctx() -> None: + app = web.Application() + out = [] + + def f(num): + @async_generator + async def inner(app): + out.append("pre_" + str(num)) + await yield_(None) + out.append("post_" + str(num)) + + return inner + + app.cleanup_ctx.append(f(1)) + app.cleanup_ctx.append(f(2)) + app.freeze() + await app.startup() + assert out == ["pre_1", "pre_2"] + await app.cleanup() + assert out == ["pre_1", "pre_2", "post_2", "post_1"] + + +async def test_cleanup_ctx_exception_on_startup() -> None: + app = web.Application() + out = [] + + exc = Exception("fail") + + def f(num, fail=False): + @async_generator + async def inner(app): + out.append("pre_" + str(num)) + if fail: + raise exc + await yield_(None) + out.append("post_" + str(num)) + + return inner + + app.cleanup_ctx.append(f(1)) + app.cleanup_ctx.append(f(2, True)) + app.cleanup_ctx.append(f(3)) + app.freeze() + with pytest.raises(Exception) as ctx: + await app.startup() + assert ctx.value is exc + assert out == ["pre_1", "pre_2"] + await app.cleanup() + assert out == ["pre_1", "pre_2", "post_1"] + + +async def test_cleanup_ctx_exception_on_cleanup() -> None: + app = web.Application() + out = [] + + exc = Exception("fail") + + def f(num, fail=False): + @async_generator + async def inner(app): + out.append("pre_" + str(num)) + await yield_(None) + out.append("post_" + str(num)) + if fail: + raise exc + + return inner + + app.cleanup_ctx.append(f(1)) + app.cleanup_ctx.append(f(2, True)) + app.cleanup_ctx.append(f(3)) + app.freeze() + await app.startup() + assert out == ["pre_1", "pre_2", "pre_3"] + with pytest.raises(Exception) as ctx: + await app.cleanup() + assert ctx.value is exc + assert out == ["pre_1", "pre_2", "pre_3", "post_3", "post_2", "post_1"] + + +async def test_cleanup_ctx_exception_on_cleanup_multiple() -> None: + app = web.Application() + out = [] + + def f(num, fail=False): + @async_generator + async def inner(app): + out.append("pre_" + str(num)) + await yield_(None) + out.append("post_" + str(num)) + if fail: + raise Exception("fail_" + str(num)) + + return inner + + app.cleanup_ctx.append(f(1)) + app.cleanup_ctx.append(f(2, True)) + app.cleanup_ctx.append(f(3, True)) + app.freeze() + await app.startup() + assert out == ["pre_1", "pre_2", "pre_3"] + with pytest.raises(web.CleanupError) as ctx: + await app.cleanup() + exc = ctx.value + assert len(exc.exceptions) == 2 + assert str(exc.exceptions[0]) == "fail_3" + assert str(exc.exceptions[1]) == "fail_2" + assert out == ["pre_1", "pre_2", "pre_3", "post_3", "post_2", "post_1"] + + +async def test_cleanup_ctx_multiple_yields() -> None: + app = web.Application() + out = [] + + def f(num): + @async_generator + async def inner(app): + out.append("pre_" + str(num)) + await yield_(None) + out.append("post_" + str(num)) + await yield_(None) + + return inner + + app.cleanup_ctx.append(f(1)) + app.freeze() + await app.startup() + assert out == ["pre_1"] + with pytest.raises(RuntimeError) as ctx: + await app.cleanup() + assert "has more than one 'yield'" in str(ctx.value) + assert out == ["pre_1", "post_1"] + + +async def test_subapp_chained_config_dict_visibility(aiohttp_client) -> None: + async def main_handler(request): + assert request.config_dict["key1"] == "val1" + assert "key2" not in request.config_dict + return web.Response(status=200) + + root = web.Application() + root["key1"] = "val1" + root.add_routes([web.get("/", main_handler)]) + + async def sub_handler(request): + assert request.config_dict["key1"] == "val1" + assert request.config_dict["key2"] == "val2" + return web.Response(status=201) + + sub = web.Application() + sub["key2"] = "val2" + sub.add_routes([web.get("/", sub_handler)]) + root.add_subapp("/sub", sub) + + client = await aiohttp_client(root) + + resp = await client.get("/") + assert resp.status == 200 + resp = await client.get("/sub/") + assert resp.status == 201 + + +async def test_subapp_chained_config_dict_overriding(aiohttp_client) -> None: + async def main_handler(request): + assert request.config_dict["key"] == "val1" + return web.Response(status=200) + + root = web.Application() + root["key"] = "val1" + root.add_routes([web.get("/", main_handler)]) + + async def sub_handler(request): + assert request.config_dict["key"] == "val2" + return web.Response(status=201) + + sub = web.Application() + sub["key"] = "val2" + sub.add_routes([web.get("/", sub_handler)]) + root.add_subapp("/sub", sub) + + client = await aiohttp_client(root) + + resp = await client.get("/") + assert resp.status == 200 + resp = await client.get("/sub/") + assert resp.status == 201 + + +async def test_subapp_on_startup(aiohttp_client) -> None: + + subapp = web.Application() + + startup_called = False + + async def on_startup(app): + nonlocal startup_called + startup_called = True + app["startup"] = True + + subapp.on_startup.append(on_startup) + + ctx_pre_called = False + ctx_post_called = False + + @async_generator + async def cleanup_ctx(app): + nonlocal ctx_pre_called, ctx_post_called + ctx_pre_called = True + app["cleanup"] = True + await yield_(None) + ctx_post_called = True + + subapp.cleanup_ctx.append(cleanup_ctx) + + shutdown_called = False + + async def on_shutdown(app): + nonlocal shutdown_called + shutdown_called = True + + subapp.on_shutdown.append(on_shutdown) + + cleanup_called = False + + async def on_cleanup(app): + nonlocal cleanup_called + cleanup_called = True + + subapp.on_cleanup.append(on_cleanup) + + app = web.Application() + + app.add_subapp("/subapp", subapp) + + assert not startup_called + assert not ctx_pre_called + assert not ctx_post_called + assert not shutdown_called + assert not cleanup_called + + assert subapp.on_startup.frozen + assert subapp.cleanup_ctx.frozen + assert subapp.on_shutdown.frozen + assert subapp.on_cleanup.frozen + assert subapp.router.frozen + + client = await aiohttp_client(app) + + assert startup_called + assert ctx_pre_called + assert not ctx_post_called + assert not shutdown_called + assert not cleanup_called + + await client.close() + + assert startup_called + assert ctx_pre_called + assert ctx_post_called + assert shutdown_called + assert cleanup_called + + +def test_app_iter(): + app = web.Application() + app["a"] = "1" + app["b"] = "2" + assert sorted(list(app)) == ["a", "b"] + + +def test_app_boolean() -> None: + app = web.Application() + assert app diff --git a/tests/test_web_application.py b/tests/test_web_application.py deleted file mode 100644 index da80ad5efb0..00000000000 --- a/tests/test_web_application.py +++ /dev/null @@ -1,192 +0,0 @@ -import asyncio -from unittest import mock - -import pytest - -from aiohttp import helpers, log, web -from aiohttp.abc import AbstractRouter - - -def test_app_ctor(loop): - app = web.Application(loop=loop) - assert loop is app.loop - assert app.logger is log.web_logger - - -def test_app_call(): - app = web.Application() - assert app is app() - - -def test_app_default_loop(): - app = web.Application() - assert app.loop is None - - -def test_set_loop(loop): - app = web.Application() - app._set_loop(loop) - assert app.loop is loop - - -def test_set_loop_default_loop(loop): - asyncio.set_event_loop(loop) - app = web.Application() - app._set_loop(None) - assert app.loop is loop - - -def test_set_loop_with_different_loops(loop): - app = web.Application() - app._set_loop(loop) - assert app.loop is loop - - with pytest.raises(RuntimeError): - app._set_loop(loop=object()) - - -def test_on_loop_available(loop): - app = web.Application() - - cb = mock.Mock() - app.on_loop_available.append(cb) - - app._set_loop(loop) - cb.assert_called_with(app) - - -@pytest.mark.parametrize('debug', [True, False]) -def test_app_make_handler_debug_exc(loop, mocker, debug): - app = web.Application(debug=debug) - srv = mocker.patch('aiohttp.web.Server') - - app.make_handler(loop=loop) - srv.assert_called_with(app._handle, - request_factory=app._make_request, - loop=loop, - debug=debug) - - -@asyncio.coroutine -def test_app_register_on_finish(): - app = web.Application() - cb1 = mock.Mock() - cb2 = mock.Mock() - app.on_cleanup.append(cb1) - app.on_cleanup.append(cb2) - yield from app.cleanup() - cb1.assert_called_once_with(app) - cb2.assert_called_once_with(app) - - -@asyncio.coroutine -def test_app_register_coro(loop): - app = web.Application() - fut = helpers.create_future(loop) - - @asyncio.coroutine - def cb(app): - yield from asyncio.sleep(0.001, loop=loop) - fut.set_result(123) - - app.on_cleanup.append(cb) - yield from app.cleanup() - assert fut.done() - assert 123 == fut.result() - - -def test_non_default_router(): - router = mock.Mock(spec=AbstractRouter) - app = web.Application(router=router) - assert router is app.router - - -def test_logging(): - logger = mock.Mock() - app = web.Application() - app.logger = logger - assert app.logger is logger - - -@asyncio.coroutine -def test_on_shutdown(): - app = web.Application() - called = False - - @asyncio.coroutine - def on_shutdown(app_param): - nonlocal called - assert app is app_param - called = True - - app.on_shutdown.append(on_shutdown) - - yield from app.shutdown() - assert called - - -@asyncio.coroutine -def test_on_startup(loop): - app = web.Application() - app._set_loop(loop) - - blocking_called = False - long_running1_called = False - long_running2_called = False - all_long_running_called = False - - def on_startup_blocking(app_param): - nonlocal blocking_called - assert app is app_param - blocking_called = True - - @asyncio.coroutine - def long_running1(app_param): - nonlocal long_running1_called - assert app is app_param - long_running1_called = True - - @asyncio.coroutine - def long_running2(app_param): - nonlocal long_running2_called - assert app is app_param - long_running2_called = True - - @asyncio.coroutine - def on_startup_all_long_running(app_param): - nonlocal all_long_running_called - assert app is app_param - all_long_running_called = True - return (yield from asyncio.gather(long_running1(app_param), - long_running2(app_param), - loop=app_param.loop)) - - app.on_startup.append(on_startup_blocking) - app.on_startup.append(on_startup_all_long_running) - - yield from app.startup() - assert blocking_called - assert long_running1_called - assert long_running2_called - assert all_long_running_called - - -def test_app_delitem(): - app = web.Application() - app['key'] = 'value' - assert len(app) == 1 - del app['key'] - assert len(app) == 0 - - -def test_secure_proxy_ssl_header_default(): - app = web.Application() - assert app._secure_proxy_ssl_header is None - - -@asyncio.coroutine -def test_secure_proxy_ssl_header_non_default(loop): - app = web.Application() - hdr = ('X-Forwarded-Proto', 'https') - app.make_handler(secure_proxy_ssl_header=hdr, loop=loop) - assert app._secure_proxy_ssl_header is hdr diff --git a/tests/test_web_cli.py b/tests/test_web_cli.py index 5d954f89a4c..12a01dff577 100644 --- a/tests/test_web_cli.py +++ b/tests/test_web_cli.py @@ -3,88 +3,72 @@ from aiohttp import web -def test_entry_func_empty(mocker): - error = mocker.patch("aiohttp.web.ArgumentParser.error", - side_effect=SystemExit) +def test_entry_func_empty(mocker) -> None: + error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) argv = [""] with pytest.raises(SystemExit): web.main(argv) - error.assert_called_with( - "'entry-func' not in 'module:function' syntax" - ) + error.assert_called_with("'entry-func' not in 'module:function' syntax") -def test_entry_func_only_module(mocker): +def test_entry_func_only_module(mocker) -> None: argv = ["test"] - error = mocker.patch("aiohttp.web.ArgumentParser.error", - side_effect=SystemExit) + error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) - error.assert_called_with( - "'entry-func' not in 'module:function' syntax" - ) + error.assert_called_with("'entry-func' not in 'module:function' syntax") -def test_entry_func_only_function(mocker): +def test_entry_func_only_function(mocker) -> None: argv = [":test"] - error = mocker.patch("aiohttp.web.ArgumentParser.error", - side_effect=SystemExit) + error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) - error.assert_called_with( - "'entry-func' not in 'module:function' syntax" - ) + error.assert_called_with("'entry-func' not in 'module:function' syntax") -def test_entry_func_only_seperator(mocker): +def test_entry_func_only_separator(mocker) -> None: argv = [":"] - error = mocker.patch("aiohttp.web.ArgumentParser.error", - side_effect=SystemExit) + error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) - error.assert_called_with( - "'entry-func' not in 'module:function' syntax" - ) + error.assert_called_with("'entry-func' not in 'module:function' syntax") -def test_entry_func_relative_module(mocker): +def test_entry_func_relative_module(mocker) -> None: argv = [".a.b:c"] - error = mocker.patch("aiohttp.web.ArgumentParser.error", - side_effect=SystemExit) + error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) error.assert_called_with("relative module names not supported") -def test_entry_func_non_existent_module(mocker): +def test_entry_func_non_existent_module(mocker) -> None: argv = ["alpha.beta:func"] - mocker.patch("aiohttp.web.import_module", - side_effect=ImportError("Test Error")) - error = mocker.patch("aiohttp.web.ArgumentParser.error", - side_effect=SystemExit) + mocker.patch("aiohttp.web.import_module", side_effect=ImportError("Test Error")) + error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) - error.assert_called_with('unable to import alpha.beta: Test Error') + error.assert_called_with("unable to import alpha.beta: Test Error") -def test_entry_func_non_existent_attribute(mocker): +def test_entry_func_non_existent_attribute(mocker) -> None: argv = ["alpha.beta:func"] import_module = mocker.patch("aiohttp.web.import_module") - error = mocker.patch("aiohttp.web.ArgumentParser.error", - side_effect=SystemExit) + error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) module = import_module("alpha.beta") del module.func @@ -92,47 +76,49 @@ def test_entry_func_non_existent_attribute(mocker): web.main(argv) error.assert_called_with( - "module %r has no attribute %r" % ("alpha.beta", "func") + "module {!r} has no attribute {!r}".format("alpha.beta", "func") ) -def test_path_when_unsupported(mocker, monkeypatch): +def test_path_when_unsupported(mocker, monkeypatch) -> None: argv = "--path=test_path.sock alpha.beta:func".split() mocker.patch("aiohttp.web.import_module") monkeypatch.delattr("socket.AF_UNIX", raising=False) - error = mocker.patch("aiohttp.web.ArgumentParser.error", - side_effect=SystemExit) + error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) - error.assert_called_with("file system paths not supported by your" - " operating environment") + error.assert_called_with( + "file system paths not supported by your" " operating environment" + ) -def test_entry_func_call(mocker): +def test_entry_func_call(mocker) -> None: mocker.patch("aiohttp.web.run_app") import_module = mocker.patch("aiohttp.web.import_module") - argv = ("-H testhost -P 6666 --extra-optional-eins alpha.beta:func " - "--extra-optional-zwei extra positional args").split() + argv = ( + "-H testhost -P 6666 --extra-optional-eins alpha.beta:func " + "--extra-optional-zwei extra positional args" + ).split() module = import_module("alpha.beta") with pytest.raises(SystemExit): web.main(argv) module.func.assert_called_with( - ("--extra-optional-eins --extra-optional-zwei extra positional " - "args").split() + ("--extra-optional-eins --extra-optional-zwei extra positional " "args").split() ) -def test_running_application(mocker): +def test_running_application(mocker) -> None: run_app = mocker.patch("aiohttp.web.run_app") import_module = mocker.patch("aiohttp.web.import_module") - exit = mocker.patch("aiohttp.web.ArgumentParser.exit", - side_effect=SystemExit) - argv = ("-H testhost -P 6666 --extra-optional-eins alpha.beta:func " - "--extra-optional-zwei extra positional args").split() + exit = mocker.patch("aiohttp.web.ArgumentParser.exit", side_effect=SystemExit) + argv = ( + "-H testhost -P 6666 --extra-optional-eins alpha.beta:func " + "--extra-optional-zwei extra positional args" + ).split() module = import_module("alpha.beta") app = module.func() diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index 1b19a69f2b3..43e5029803f 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -1,6 +1,6 @@ -import asyncio import collections import re +from traceback import format_exception from unittest import mock import pytest @@ -15,20 +15,23 @@ def buf(): @pytest.fixture -def request(buf): - method = 'GET' - path = '/' +def http_request(buf): + method = "GET" + path = "/" writer = mock.Mock() writer.drain.return_value = () - def append(data=b''): + def append(data=b""): buf.extend(data) return helpers.noop() - def write_headers(status_line, headers): - headers = status_line + ''.join( - [k + ': ' + v + '\r\n' for k, v in headers.items()]) - headers = headers.encode('utf-8') + b'\r\n' + async def write_headers(status_line, headers): + headers = ( + status_line + + "\r\n" + + "".join([k + ": " + v + "\r\n" for k, v in headers.items()]) + ) + headers = headers.encode("utf-8") + b"\r\n" buf.extend(headers) writer.buffer_data.side_effect = append @@ -39,35 +42,40 @@ def write_headers(status_line, headers): app = mock.Mock() app._debug = False app.on_response_prepare = signals.Signal(app) - req = make_mocked_request(method, path, app=app, payload_writer=writer) + app.on_response_prepare.freeze() + req = make_mocked_request(method, path, app=app, writer=writer) return req -def test_all_http_exceptions_exported(): - assert 'HTTPException' in web.__all__ +def test_all_http_exceptions_exported() -> None: + assert "HTTPException" in web.__all__ for name in dir(web): - if name.startswith('_'): + if name.startswith("_"): continue obj = getattr(web, name) if isinstance(obj, type) and issubclass(obj, web.HTTPException): assert name in web.__all__ -@asyncio.coroutine -def test_HTTPOk(buf, request): +async def test_HTTPOk(buf, http_request) -> None: resp = web.HTTPOk() - yield from resp.prepare(request) - yield from resp.write_eof() - txt = buf.decode('utf8') - assert re.match(('HTTP/1.1 200 OK\r\n' - 'Content-Type: text/plain; charset=utf-8\r\n' - 'Content-Length: 7\r\n' - 'Date: .+\r\n' - 'Server: .+\r\n\r\n' - '200: OK'), txt) - - -def test_terminal_classes_has_status_code(): + await resp.prepare(http_request) + await resp.write_eof() + txt = buf.decode("utf8") + assert re.match( + ( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/plain; charset=utf-8\r\n" + "Content-Length: 7\r\n" + "Date: .+\r\n" + "Server: .+\r\n\r\n" + "200: OK" + ), + txt, + ) + + +def test_terminal_classes_has_status_code() -> None: terminals = set() for name in dir(web): obj = getattr(web, name) @@ -87,91 +95,111 @@ def test_terminal_classes_has_status_code(): assert 1 == codes.most_common(1)[0][1] -@asyncio.coroutine -def test_HTTPFound(buf, request): - resp = web.HTTPFound(location='/redirect') - assert '/redirect' == resp.location - assert '/redirect' == resp.headers['location'] - yield from resp.prepare(request) - yield from resp.write_eof() - txt = buf.decode('utf8') - assert re.match('HTTP/1.1 302 Found\r\n' - 'Content-Type: text/plain; charset=utf-8\r\n' - 'Location: /redirect\r\n' - 'Content-Length: 10\r\n' - 'Date: .+\r\n' - 'Server: .+\r\n\r\n' - '302: Found', txt) - - -def test_HTTPFound_empty_location(): +async def test_HTTPFound(buf, http_request) -> None: + resp = web.HTTPFound(location="/redirect") + assert "/redirect" == resp.location + assert "/redirect" == resp.headers["location"] + await resp.prepare(http_request) + await resp.write_eof() + txt = buf.decode("utf8") + assert re.match( + "HTTP/1.1 302 Found\r\n" + "Content-Type: text/plain; charset=utf-8\r\n" + "Location: /redirect\r\n" + "Content-Length: 10\r\n" + "Date: .+\r\n" + "Server: .+\r\n\r\n" + "302: Found", + txt, + ) + + +def test_HTTPFound_empty_location() -> None: with pytest.raises(ValueError): - web.HTTPFound(location='') + web.HTTPFound(location="") with pytest.raises(ValueError): web.HTTPFound(location=None) -@asyncio.coroutine -def test_HTTPMethodNotAllowed(buf, request): - resp = web.HTTPMethodNotAllowed('get', ['POST', 'PUT']) - assert 'GET' == resp.method - assert ['POST', 'PUT'] == resp.allowed_methods - assert 'POST,PUT' == resp.headers['allow'] - yield from resp.prepare(request) - yield from resp.write_eof() - txt = buf.decode('utf8') - assert re.match('HTTP/1.1 405 Method Not Allowed\r\n' - 'Content-Type: text/plain; charset=utf-8\r\n' - 'Allow: POST,PUT\r\n' - 'Content-Length: 23\r\n' - 'Date: .+\r\n' - 'Server: .+\r\n\r\n' - '405: Method Not Allowed', txt) - - -def test_override_body_with_text(): +def test_HTTPFound_location_CRLF() -> None: + exc = web.HTTPFound(location="/redirect\r\n") + assert "\r\n" not in exc.headers["Location"] + + +async def test_HTTPMethodNotAllowed(buf, http_request) -> None: + resp = web.HTTPMethodNotAllowed("get", ["POST", "PUT"]) + assert "GET" == resp.method + assert {"POST", "PUT"} == resp.allowed_methods + assert "POST,PUT" == resp.headers["allow"] + await resp.prepare(http_request) + await resp.write_eof() + txt = buf.decode("utf8") + assert re.match( + "HTTP/1.1 405 Method Not Allowed\r\n" + "Content-Type: text/plain; charset=utf-8\r\n" + "Allow: POST,PUT\r\n" + "Content-Length: 23\r\n" + "Date: .+\r\n" + "Server: .+\r\n\r\n" + "405: Method Not Allowed", + txt, + ) + + +def test_override_body_with_text() -> None: resp = web.HTTPNotFound(text="Page not found") assert 404 == resp.status - assert "Page not found".encode('utf-8') == resp.body + assert b"Page not found" == resp.body assert "Page not found" == resp.text assert "text/plain" == resp.content_type assert "utf-8" == resp.charset -def test_override_body_with_binary(): +def test_override_body_with_binary() -> None: txt = "Page not found" - resp = web.HTTPNotFound(body=txt.encode('utf-8'), - content_type="text/html") + with pytest.warns(DeprecationWarning): + resp = web.HTTPNotFound(body=txt.encode("utf-8"), content_type="text/html") assert 404 == resp.status - assert txt.encode('utf-8') == resp.body + assert txt.encode("utf-8") == resp.body assert txt == resp.text assert "text/html" == resp.content_type assert resp.charset is None -def test_default_body(): +def test_default_body() -> None: resp = web.HTTPOk() - assert b'200: OK' == resp.body + assert b"200: OK" == resp.body -def test_empty_body_204(): +def test_empty_body_204() -> None: resp = web.HTTPNoContent() assert resp.body is None -def test_empty_body_205(): +def test_empty_body_205() -> None: resp = web.HTTPNoContent() assert resp.body is None -def test_empty_body_304(): +def test_empty_body_304() -> None: resp = web.HTTPNoContent() resp.body is None -def test_link_header_451(buf, request): - resp = web.HTTPUnavailableForLegalReasons(link='http://warning.or.kr/') +def test_link_header_451(buf) -> None: + resp = web.HTTPUnavailableForLegalReasons(link="http://warning.or.kr/") - assert 'http://warning.or.kr/' == resp.link - assert '; rel="blocked-by"' == resp.headers['Link'] + assert "http://warning.or.kr/" == resp.link + assert '; rel="blocked-by"' == resp.headers["Link"] + + +def test_HTTPException_retains_cause() -> None: + with pytest.raises(web.HTTPException) as ei: + try: + raise Exception("CustomException") + except Exception as exc: + raise web.HTTPException() from exc + tb = "".join(format_exception(ei.type, ei.value, ei.tb)) + assert "CustomException" in tb + assert "direct cause" in tb diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 053a68c0f14..a28fcd4f56b 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -2,20 +2,24 @@ import io import json import pathlib +import socket import zlib from unittest import mock import pytest -from multidict import MultiDict +from async_generator import async_generator, yield_ +from multidict import CIMultiDictProxy, MultiDict from yarl import URL import aiohttp -from aiohttp import FormData, HttpVersion10, HttpVersion11, multipart, web +from aiohttp import FormData, HttpVersion10, HttpVersion11, TraceConfig, multipart, web +from aiohttp.hdrs import CONTENT_LENGTH, TRANSFER_ENCODING +from aiohttp.test_utils import make_mocked_coro try: import ssl -except: - ssl = False +except ImportError: + ssl = None # type: ignore @pytest.fixture @@ -25,1290 +29,1289 @@ def here(): @pytest.fixture def fname(here): - return here / 'sample.key' + return here / "conftest.py" -@asyncio.coroutine -def test_simple_get(loop, test_client): +def new_dummy_form(): + form = FormData() + form.add_field("name", b"123", content_transfer_encoding="base64") + return form + - @asyncio.coroutine - def handler(request): - body = yield from request.read() - assert b'' == body - return web.Response(body=b'OK') +async def test_simple_get(aiohttp_client) -> None: + async def handler(request): + body = await request.read() + assert b"" == body + return web.Response(body=b"OK") app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - txt = yield from resp.text() - assert 'OK' == txt + txt = await resp.text() + assert "OK" == txt -@asyncio.coroutine -def test_simple_get_with_text(loop, test_client): - - @asyncio.coroutine - def handler(request): - body = yield from request.read() - assert b'' == body - return web.Response(text='OK', headers={'content-type': 'text/plain'}) +async def test_simple_get_with_text(aiohttp_client) -> None: + async def handler(request): + body = await request.read() + assert b"" == body + return web.Response(text="OK", headers={"content-type": "text/plain"}) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - txt = yield from resp.text() - assert 'OK' == txt + txt = await resp.text() + assert "OK" == txt -@asyncio.coroutine -def test_handler_returns_not_response(loop, test_server, test_client): +async def test_handler_returns_not_response(aiohttp_server, aiohttp_client) -> None: + asyncio.get_event_loop().set_debug(True) logger = mock.Mock() - @asyncio.coroutine - def handler(request): - return 'abc' + async def handler(request): + return "abc" app = web.Application() - app.router.add_get('/', handler) - server = yield from test_server(app, logger=logger) - client = yield from test_client(server) + app.router.add_get("/", handler) + server = await aiohttp_server(app, logger=logger) + client = await aiohttp_client(server) - resp = yield from client.get('/') - assert 500 == resp.status + with pytest.raises(aiohttp.ServerDisconnectedError): + await client.get("/") - assert logger.exception.called + logger.exception.assert_called_with( + "Unhandled runtime exception", exc_info=mock.ANY + ) -@asyncio.coroutine -def test_head_returns_empty_body(loop, test_client): +async def test_handler_returns_none(aiohttp_server, aiohttp_client) -> None: + asyncio.get_event_loop().set_debug(True) + logger = mock.Mock() - @asyncio.coroutine - def handler(request): - return web.Response(body=b'test') + async def handler(request): + return None app = web.Application() - app.router.add_head('/', handler) - client = yield from test_client(app, version=HttpVersion11) + app.router.add_get("/", handler) + server = await aiohttp_server(app, logger=logger) + client = await aiohttp_client(server) - resp = yield from client.head('/') - assert 200 == resp.status - txt = yield from resp.text() - assert '' == txt + with pytest.raises(aiohttp.ServerDisconnectedError): + await client.get("/") + # Actual error text is placed in exc_info + logger.exception.assert_called_with( + "Unhandled runtime exception", exc_info=mock.ANY + ) -@asyncio.coroutine -def test_response_before_complete(loop, test_client): - @asyncio.coroutine - def handler(request): - return web.Response(body=b'OK') +async def test_head_returns_empty_body(aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"test") app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) - - data = b'0' * 1024 * 1024 + app.router.add_head("/", handler) + client = await aiohttp_client(app, version=HttpVersion11) - resp = yield from client.post('/', data=data) + resp = await client.head("/") assert 200 == resp.status - text = yield from resp.text() - assert 'OK' == text + txt = await resp.text() + assert "" == txt -@asyncio.coroutine -def test_post_form(loop, test_client): - - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert {'a': '1', 'b': '2', 'c': ''} == data - return web.Response(body=b'OK') +async def test_response_before_complete(aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"OK") app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + data = b"0" * 1024 * 1024 - resp = yield from client.post('/', data={'a': 1, 'b': 2, 'c': ''}) + resp = await client.post("/", data=data) assert 200 == resp.status - txt = yield from resp.text() - assert 'OK' == txt + text = await resp.text() + assert "OK" == text -@asyncio.coroutine -def test_post_text(loop, test_client): +async def test_post_form(aiohttp_client) -> None: + async def handler(request): + data = await request.post() + assert {"a": "1", "b": "2", "c": ""} == data + return web.Response(body=b"OK") - @asyncio.coroutine - def handler(request): - data = yield from request.text() - assert 'русский' == data - data2 = yield from request.text() + app = web.Application() + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + resp = await client.post("/", data={"a": 1, "b": 2, "c": ""}) + assert 200 == resp.status + txt = await resp.text() + assert "OK" == txt + + +async def test_post_text(aiohttp_client) -> None: + async def handler(request): + data = await request.text() + assert "русский" == data + data2 = await request.text() assert data == data2 return web.Response(text=data) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data='русский') + resp = await client.post("/", data="русский") assert 200 == resp.status - txt = yield from resp.text() - assert 'русский' == txt + txt = await resp.text() + assert "русский" == txt -@asyncio.coroutine -def test_post_json(loop, test_client): +async def test_post_json(aiohttp_client) -> None: - dct = {'key': 'текст'} + dct = {"key": "текст"} - @asyncio.coroutine - def handler(request): - data = yield from request.json() + async def handler(request): + data = await request.json() assert dct == data - data2 = yield from request.json(loads=json.loads) + data2 = await request.json(loads=json.loads) assert data == data2 resp = web.Response() - resp.content_type = 'application/json' - resp.body = json.dumps(data).encode('utf8') + resp.content_type = "application/json" + resp.body = json.dumps(data).encode("utf8") return resp app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - headers = {'Content-Type': 'application/json'} - resp = yield from client.post('/', data=json.dumps(dct), headers=headers) + headers = {"Content-Type": "application/json"} + resp = await client.post("/", data=json.dumps(dct), headers=headers) assert 200 == resp.status - data = yield from resp.json() + data = await resp.json() assert dct == data -@asyncio.coroutine -def test_multipart(loop, test_client): +async def test_multipart(aiohttp_client) -> None: with multipart.MultipartWriter() as writer: - writer.append('test') - writer.append_json({'passed': True}) + writer.append("test") + writer.append_json({"passed": True}) - @asyncio.coroutine - def handler(request): - reader = yield from request.multipart() + async def handler(request): + reader = await request.multipart() assert isinstance(reader, multipart.MultipartReader) - part = yield from reader.next() + part = await reader.next() assert isinstance(part, multipart.BodyPartReader) - thing = yield from part.text() - assert thing == 'test' + thing = await part.text() + assert thing == "test" - part = yield from reader.next() + part = await reader.next() assert isinstance(part, multipart.BodyPartReader) - assert part.headers['Content-Type'] == 'application/json' - thing = yield from part.json() - assert thing == {'passed': True} + assert part.headers["Content-Type"] == "application/json" + thing = await part.json() + assert thing == {"passed": True} resp = web.Response() - resp.content_type = 'application/json' - resp.body = b'' + resp.content_type = "application/json" + resp.body = b"" return resp app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data=writer, headers=writer.headers) + resp = await client.post("/", data=writer) assert 200 == resp.status - yield from resp.release() + await resp.release() -@asyncio.coroutine -def test_multipart_content_transfer_encoding(loop, test_client): - """For issue #1168""" +async def test_multipart_empty(aiohttp_client) -> None: with multipart.MultipartWriter() as writer: - writer.append(b'\x00' * 10, - headers={'Content-Transfer-Encoding': 'binary'}) + pass - @asyncio.coroutine - def handler(request): - reader = yield from request.multipart() + async def handler(request): + reader = await request.multipart() assert isinstance(reader, multipart.MultipartReader) + async for part in reader: + assert False, f"Unexpected part found in reader: {part!r}" + return web.Response() - part = yield from reader.next() + app = web.Application() + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + resp = await client.post("/", data=writer) + assert 200 == resp.status + await resp.release() + + +async def test_multipart_content_transfer_encoding(aiohttp_client) -> None: + # For issue #1168 + with multipart.MultipartWriter() as writer: + writer.append(b"\x00" * 10, headers={"Content-Transfer-Encoding": "binary"}) + + async def handler(request): + reader = await request.multipart() + assert isinstance(reader, multipart.MultipartReader) + + part = await reader.next() assert isinstance(part, multipart.BodyPartReader) - assert part.headers['Content-Transfer-Encoding'] == 'binary' - thing = yield from part.read() - assert thing == b'\x00' * 10 + assert part.headers["Content-Transfer-Encoding"] == "binary" + thing = await part.read() + assert thing == b"\x00" * 10 resp = web.Response() - resp.content_type = 'application/json' - resp.body = b'' + resp.content_type = "application/json" + resp.body = b"" return resp app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data=writer, headers=writer.headers) + resp = await client.post("/", data=writer) assert 200 == resp.status - yield from resp.release() - + await resp.release() -@asyncio.coroutine -def test_render_redirect(loop, test_client): - @asyncio.coroutine - def handler(request): - raise web.HTTPMovedPermanently(location='/path') +async def test_render_redirect(aiohttp_client) -> None: + async def handler(request): + raise web.HTTPMovedPermanently(location="/path") app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/', allow_redirects=False) + resp = await client.get("/", allow_redirects=False) assert 301 == resp.status - txt = yield from resp.text() - assert '301: Moved Permanently' == txt - assert '/path' == resp.headers['location'] + txt = await resp.text() + assert "301: Moved Permanently" == txt + assert "/path" == resp.headers["location"] -@asyncio.coroutine -def test_post_single_file(loop, test_client): +async def test_post_single_file(aiohttp_client) -> None: here = pathlib.Path(__file__).parent def check_file(fs): fullname = here / fs.filename - with fullname.open() as f: - test_data = f.read().encode() + with fullname.open("rb") as f: + test_data = f.read() data = fs.file.read() assert test_data == data - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert ['sample.crt'] == list(data.keys()) + async def handler(request): + data = await request.post() + assert ["data.unknown_mime_type"] == list(data.keys()) for fs in data.values(): check_file(fs) fs.file.close() - resp = web.Response(body=b'OK') + resp = web.Response(body=b"OK") return resp app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - fname = here / 'sample.crt' + fname = here / "data.unknown_mime_type" - resp = yield from client.post('/', data=[fname.open()]) + resp = await client.post("/", data=[fname.open("rb")]) assert 200 == resp.status -@asyncio.coroutine -def test_files_upload_with_same_key(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - files = data.getall('file') +async def test_files_upload_with_same_key(aiohttp_client) -> None: + async def handler(request): + data = await request.post() + files = data.getall("file") file_names = set() for _file in files: assert not _file.file.closed - if _file.filename == 'test1.jpeg': - assert _file.file.read() == b'binary data 1' - if _file.filename == 'test2.jpeg': - assert _file.file.read() == b'binary data 2' + if _file.filename == "test1.jpeg": + assert _file.file.read() == b"binary data 1" + if _file.filename == "test2.jpeg": + assert _file.file.read() == b"binary data 2" file_names.add(_file.filename) assert len(files) == 2 - assert file_names == {'test1.jpeg', 'test2.jpeg'} - resp = web.Response(body=b'OK') + assert file_names == {"test1.jpeg", "test2.jpeg"} + resp = web.Response(body=b"OK") return resp app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) data = FormData() - data.add_field('file', b'binary data 1', - content_type='image/jpeg', - filename='test1.jpeg') - data.add_field('file', b'binary data 2', - content_type='image/jpeg', - filename='test2.jpeg') - resp = yield from client.post('/', data=data) + data.add_field( + "file", b"binary data 1", content_type="image/jpeg", filename="test1.jpeg" + ) + data.add_field( + "file", b"binary data 2", content_type="image/jpeg", filename="test2.jpeg" + ) + resp = await client.post("/", data=data) assert 200 == resp.status -@asyncio.coroutine -def test_post_files(loop, test_client): +async def test_post_files(aiohttp_client) -> None: here = pathlib.Path(__file__).parent def check_file(fs): fullname = here / fs.filename - with fullname.open() as f: - test_data = f.read().encode() + with fullname.open("rb") as f: + test_data = f.read() data = fs.file.read() assert test_data == data - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert ['sample.crt', 'sample.key'] == list(data.keys()) + async def handler(request): + data = await request.post() + assert ["data.unknown_mime_type", "conftest.py"] == list(data.keys()) for fs in data.values(): check_file(fs) fs.file.close() - resp = web.Response(body=b'OK') + resp = web.Response(body=b"OK") return resp app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with (here / 'sample.crt').open() as f1: - with (here / 'sample.key').open() as f2: - resp = yield from client.post('/', data=[f1, f2]) + with (here / "data.unknown_mime_type").open("rb") as f1: + with (here / "conftest.py").open("rb") as f2: + resp = await client.post("/", data=[f1, f2]) assert 200 == resp.status -@asyncio.coroutine -def test_release_post_data(loop, test_client): - - @asyncio.coroutine - def handler(request): - yield from request.release() - chunk = yield from request.content.readany() - assert chunk == b'' +async def test_release_post_data(aiohttp_client) -> None: + async def handler(request): + await request.release() + chunk = await request.content.readany() + assert chunk == b"" return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data='post text') + resp = await client.post("/", data="post text") assert 200 == resp.status -@asyncio.coroutine -def test_POST_DATA_with_content_transfer_encoding(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert b'123' == data['name'] +async def test_POST_DATA_with_content_transfer_encoding(aiohttp_client) -> None: + async def handler(request): + data = await request.post() + assert b"123" == data["name"] return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) form = FormData() - form.add_field('name', b'123', - content_transfer_encoding='base64') + form.add_field("name", b"123", content_transfer_encoding="base64") - resp = yield from client.post('/', data=form) + resp = await client.post("/", data=form) assert 200 == resp.status -@asyncio.coroutine -def test_post_form_with_duplicate_keys(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() +async def test_post_form_with_duplicate_keys(aiohttp_client) -> None: + async def handler(request): + data = await request.post() lst = list(data.items()) - assert [('a', '1'), ('a', '2')] == lst + assert [("a", "1"), ("a", "2")] == lst return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data=MultiDict([('a', 1), ('a', 2)])) + resp = await client.post("/", data=MultiDict([("a", 1), ("a", 2)])) assert 200 == resp.status -def test_repr_for_application(loop): +def test_repr_for_application() -> None: app = web.Application() assert "".format(id(app)) == repr(app) -@asyncio.coroutine -def test_expect_default_handler_unknown(loop, test_client): - """Test default Expect handler for unknown Expect value. +async def test_expect_default_handler_unknown(aiohttp_client) -> None: + # Test default Expect handler for unknown Expect value. - A server that does not understand or is unable to comply with any of - the expectation values in the Expect field of a request MUST respond - with appropriate error status. The server MUST respond with a 417 - (Expectation Failed) status if any of the expectations cannot be met - or, if there are other problems with the request, some other 4xx - status. + # A server that does not understand or is unable to comply with any of + # the expectation values in the Expect field of a request MUST respond + # with appropriate error status. The server MUST respond with a 417 + # (Expectation Failed) status if any of the expectations cannot be met + # or, if there are other problems with the request, some other 4xx + # status. - http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.20 - """ - @asyncio.coroutine - def handler(request): - yield from request.post() - pytest.xfail('Handler should not proceed to this point in case of ' - 'unknown Expect header') + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.20 + async def handler(request): + await request.post() + pytest.xfail( + "Handler should not proceed to this point in case of " + "unknown Expect header" + ) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', headers={'Expect': 'SPAM'}) + resp = await client.post("/", headers={"Expect": "SPAM"}) assert 417 == resp.status -@asyncio.coroutine -def test_100_continue(loop, test_client): - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert b'123' == data['name'] +async def test_100_continue(aiohttp_client) -> None: + async def handler(request): + data = await request.post() + assert b"123" == data["name"] return web.Response() form = FormData() - form.add_field('name', b'123', - content_transfer_encoding='base64') + form.add_field("name", b"123", content_transfer_encoding="base64") app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data=form, expect100=True) + resp = await client.post("/", data=form, expect100=True) assert 200 == resp.status -@asyncio.coroutine -def test_100_continue_custom(loop, test_client): +async def test_100_continue_custom(aiohttp_client) -> None: expect_received = False - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert b'123' == data['name'] + async def handler(request): + data = await request.post() + assert b"123" == data["name"] return web.Response() - @asyncio.coroutine - def expect_handler(request): + async def expect_handler(request): nonlocal expect_received expect_received = True if request.version == HttpVersion11: - request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") - - form = FormData() - form.add_field('name', b'123', - content_transfer_encoding='base64') + await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") app = web.Application() - app.router.add_post('/', handler, expect_handler=expect_handler) - client = yield from test_client(app) + app.router.add_post("/", handler, expect_handler=expect_handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data=form, expect100=True) + resp = await client.post("/", data=new_dummy_form(), expect100=True) assert 200 == resp.status assert expect_received -@asyncio.coroutine -def test_100_continue_custom_response(loop, test_client): - - @asyncio.coroutine - def handler(request): - data = yield from request.post() - assert b'123', data['name'] +async def test_100_continue_custom_response(aiohttp_client) -> None: + async def handler(request): + data = await request.post() + assert b"123", data["name"] return web.Response() - @asyncio.coroutine - def expect_handler(request): + async def expect_handler(request): if request.version == HttpVersion11: if auth_err: - return web.HTTPForbidden() + raise web.HTTPForbidden() - request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") - - form = FormData() - form.add_field('name', b'123', - content_transfer_encoding='base64') + await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") app = web.Application() - app.router.add_post('/', handler, expect_handler=expect_handler) - client = yield from test_client(app) + app.router.add_post("/", handler, expect_handler=expect_handler) + client = await aiohttp_client(app) auth_err = False - resp = yield from client.post('/', data=form, expect100=True) + resp = await client.post("/", data=new_dummy_form(), expect100=True) assert 200 == resp.status auth_err = True - resp = yield from client.post('/', data=form, expect100=True) + resp = await client.post("/", data=new_dummy_form(), expect100=True) assert 403 == resp.status -@asyncio.coroutine -def test_100_continue_for_not_found(loop, test_client): +async def test_100_continue_for_not_found(aiohttp_client) -> None: app = web.Application() - client = yield from test_client(app) + client = await aiohttp_client(app) - resp = yield from client.post('/not_found', data='data', expect100=True) + resp = await client.post("/not_found", data="data", expect100=True) assert 404 == resp.status -@asyncio.coroutine -def test_100_continue_for_not_allowed(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_100_continue_for_not_allowed(aiohttp_client) -> None: + async def handler(request): return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/', expect100=True) + resp = await client.get("/", expect100=True) assert 405 == resp.status -@asyncio.coroutine -def test_http11_keep_alive_default(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_http11_keep_alive_default(aiohttp_client) -> None: + async def handler(request): return web.Response() app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app, version=HttpVersion11) + app.router.add_get("/", handler) + client = await aiohttp_client(app, version=HttpVersion11) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status assert resp.version == HttpVersion11 - assert 'Connection' not in resp.headers + assert "Connection" not in resp.headers @pytest.mark.xfail -@asyncio.coroutine -def test_http10_keep_alive_default(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_http10_keep_alive_default(aiohttp_client) -> None: + async def handler(request): return web.Response() app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app, version=HttpVersion10) + app.router.add_get("/", handler) + client = await aiohttp_client(app, version=HttpVersion10) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status assert resp.version == HttpVersion10 - assert resp.headers['Connection'] == 'keep-alive' + assert resp.headers["Connection"] == "keep-alive" -@asyncio.coroutine -def test_http10_keep_alive_with_headers_close(loop, test_client): - - @asyncio.coroutine - def handler(request): - yield from request.read() - return web.Response(body=b'OK') +async def test_http10_keep_alive_with_headers_close(aiohttp_client) -> None: + async def handler(request): + await request.read() + return web.Response(body=b"OK") app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app, version=HttpVersion10) + app.router.add_get("/", handler) + client = await aiohttp_client(app, version=HttpVersion10) - headers = {'Connection': 'close'} - resp = yield from client.get('/', headers=headers) + headers = {"Connection": "close"} + resp = await client.get("/", headers=headers) assert 200 == resp.status assert resp.version == HttpVersion10 - assert 'Connection' not in resp.headers - + assert "Connection" not in resp.headers -@asyncio.coroutine -def test_http10_keep_alive_with_headers(loop, test_client): - @asyncio.coroutine - def handler(request): - yield from request.read() - return web.Response(body=b'OK') +async def test_http10_keep_alive_with_headers(aiohttp_client) -> None: + async def handler(request): + await request.read() + return web.Response(body=b"OK") app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app, version=HttpVersion10) + app.router.add_get("/", handler) + client = await aiohttp_client(app, version=HttpVersion10) - headers = {'Connection': 'keep-alive'} - resp = yield from client.get('/', headers=headers) + headers = {"Connection": "keep-alive"} + resp = await client.get("/", headers=headers) assert 200 == resp.status assert resp.version == HttpVersion10 - assert resp.headers['Connection'] == 'keep-alive' + assert resp.headers["Connection"] == "keep-alive" -@asyncio.coroutine -def test_upload_file(loop, test_client): +async def test_upload_file(aiohttp_client) -> None: here = pathlib.Path(__file__).parent - fname = here / 'aiohttp.png' - with fname.open('rb') as f: + fname = here / "aiohttp.png" + with fname.open("rb") as f: data = f.read() - @asyncio.coroutine - def handler(request): - form = yield from request.post() - raw_data = form['file'].file.read() + async def handler(request): + form = await request.post() + raw_data = form["file"].file.read() assert data == raw_data return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data={'file': data}) + resp = await client.post("/", data={"file": data}) assert 200 == resp.status -@asyncio.coroutine -def test_upload_file_object(loop, test_client): +async def test_upload_file_object(aiohttp_client) -> None: here = pathlib.Path(__file__).parent - fname = here / 'aiohttp.png' - with fname.open('rb') as f: + fname = here / "aiohttp.png" + with fname.open("rb") as f: data = f.read() - @asyncio.coroutine - def handler(request): - form = yield from request.post() - raw_data = form['file'].file.read() + async def handler(request): + form = await request.post() + raw_data = form["file"].file.read() assert data == raw_data return web.Response() app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - with fname.open('rb') as f: - resp = yield from client.post('/', data={'file': f}) + with fname.open("rb") as f: + resp = await client.post("/", data={"file": f}) assert 200 == resp.status -@asyncio.coroutine -def test_empty_content_for_query_without_body(loop, test_client): - - @asyncio.coroutine - def handler(request): - assert not request.has_body +@pytest.mark.parametrize( + "method", ["get", "post", "options", "post", "put", "patch", "delete"] +) +async def test_empty_content_for_query_without_body(method, aiohttp_client) -> None: + async def handler(request): + assert not request.body_exists + assert not request.can_read_body + with pytest.warns(DeprecationWarning): + assert not request.has_body return web.Response() app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_route(method, "/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.request(method, "/") assert 200 == resp.status -@asyncio.coroutine -def test_empty_content_for_query_with_body(loop, test_client): - - @asyncio.coroutine - def handler(request): - assert request.has_body - body = yield from request.read() +async def test_empty_content_for_query_with_body(aiohttp_client) -> None: + async def handler(request): + assert request.body_exists + assert request.can_read_body + with pytest.warns(DeprecationWarning): + assert request.has_body + body = await request.read() return web.Response(body=body) app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.post('/', data=b'data') + resp = await client.post("/", data=b"data") assert 200 == resp.status -@asyncio.coroutine -def test_get_with_empty_arg(loop, test_client): - - @asyncio.coroutine - def handler(request): - assert 'arg' in request.query - assert '' == request.query['arg'] +async def test_get_with_empty_arg(aiohttp_client) -> None: + async def handler(request): + assert "arg" in request.query + assert "" == request.query["arg"] return web.Response() app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/?arg') + resp = await client.get("/?arg") assert 200 == resp.status -@asyncio.coroutine -def test_large_header(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_large_header(aiohttp_client) -> None: + async def handler(request): return web.Response() app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - headers = {'Long-Header': 'ab' * 8129} - resp = yield from client.get('/', headers=headers) + headers = {"Long-Header": "ab" * 8129} + resp = await client.get("/", headers=headers) assert 400 == resp.status -@asyncio.coroutine -def test_large_header_allowed(loop, test_client, test_server): - - @asyncio.coroutine - def handler(request): +async def test_large_header_allowed(aiohttp_client, aiohttp_server) -> None: + async def handler(request): return web.Response() app = web.Application() - app.router.add_get('/', handler) - server = yield from test_server(app, max_field_size=81920) - client = yield from test_client(server) + app.router.add_post("/", handler) + server = await aiohttp_server(app, max_field_size=81920) + client = await aiohttp_client(server) - headers = {'Long-Header': 'ab' * 8129} - resp = yield from client.get('/', headers=headers) + headers = {"Long-Header": "ab" * 8129} + resp = await client.post("/", headers=headers) assert 200 == resp.status -@asyncio.coroutine -def test_get_with_empty_arg_with_equal(loop, test_client): - @asyncio.coroutine - def handler(request): - assert 'arg' in request.query - assert '' == request.query['arg'] +async def test_get_with_empty_arg_with_equal(aiohttp_client) -> None: + async def handler(request): + assert "arg" in request.query + assert "" == request.query["arg"] return web.Response() app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/?arg=') + resp = await client.get("/?arg=") assert 200 == resp.status -@asyncio.coroutine -def test_response_with_streamer(loop, test_client, fname): +async def test_response_with_async_gen(aiohttp_client, fname) -> None: - with fname.open('rb') as f: + with fname.open("rb") as f: data = f.read() data_size = len(data) - @aiohttp.streamer - def stream(writer, f_name): - with f_name.open('rb') as f: + @async_generator + async def stream(f_name): + with f_name.open("rb") as f: data = f.read(100) while data: - yield from writer.write(data) + await yield_(data) + data = f.read(100) + + async def handler(request): + headers = {"Content-Length": str(data_size)} + return web.Response(body=stream(fname), headers=headers) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert 200 == resp.status + resp_data = await resp.read() + assert resp_data == data + assert resp.headers.get("Content-Length") == str(len(resp_data)) + + +async def test_response_with_streamer(aiohttp_client, fname) -> None: + + with fname.open("rb") as f: + data = f.read() + + data_size = len(data) + + with pytest.warns(DeprecationWarning): + + @aiohttp.streamer + async def stream(writer, f_name): + with f_name.open("rb") as f: data = f.read(100) + while data: + await writer.write(data) + data = f.read(100) - @asyncio.coroutine - def handler(request): - headers = {'Content-Length': str(data_size)} + async def handler(request): + headers = {"Content-Length": str(data_size)} return web.Response(body=stream(fname), headers=headers) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - resp_data = yield from resp.read() + resp_data = await resp.read() assert resp_data == data - assert resp.headers.get('Content-Length') == str(len(resp_data)) + assert resp.headers.get("Content-Length") == str(len(resp_data)) -@asyncio.coroutine -def test_response_with_streamer_no_params(loop, test_client, fname): +async def test_response_with_async_gen_no_params(aiohttp_client, fname) -> None: - with fname.open('rb') as f: + with fname.open("rb") as f: data = f.read() data_size = len(data) - @aiohttp.streamer - def stream(writer): - with fname.open('rb') as f: + @async_generator + async def stream(): + with fname.open("rb") as f: data = f.read(100) while data: - yield from writer.write(data) + await yield_(data) data = f.read(100) - @asyncio.coroutine - def handler(request): - headers = {'Content-Length': str(data_size)} - return web.Response(body=stream, headers=headers) + async def handler(request): + headers = {"Content-Length": str(data_size)} + return web.Response(body=stream(), headers=headers) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - resp_data = yield from resp.read() + resp_data = await resp.read() assert resp_data == data - assert resp.headers.get('Content-Length') == str(len(resp_data)) + assert resp.headers.get("Content-Length") == str(len(resp_data)) -@asyncio.coroutine -def test_response_with_file(loop, test_client, fname): +async def test_response_with_streamer_no_params(aiohttp_client, fname) -> None: - with fname.open('rb') as f: + with fname.open("rb") as f: data = f.read() - @asyncio.coroutine - def handler(request): - return web.Response(body=fname.open('rb')) + data_size = len(data) + + with pytest.warns(DeprecationWarning): + + @aiohttp.streamer + async def stream(writer): + with fname.open("rb") as f: + data = f.read(100) + while data: + await writer.write(data) + data = f.read(100) + + async def handler(request): + headers = {"Content-Length": str(data_size)} + return web.Response(body=stream, headers=headers) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - resp_data = yield from resp.read() + resp_data = await resp.read() assert resp_data == data - assert resp.headers.get('Content-Type') in ( - 'application/octet-stream', 'application/pgp-keys') - assert resp.headers.get('Content-Length') == str(len(resp_data)) - assert (resp.headers.get('Content-Disposition') == - 'attachment; filename="sample.key"; filename*=utf-8\'\'sample.key') + assert resp.headers.get("Content-Length") == str(len(resp_data)) -@asyncio.coroutine -def test_response_with_file_ctype(loop, test_client, fname): +async def test_response_with_file(aiohttp_client, fname) -> None: - with fname.open('rb') as f: + with fname.open("rb") as f: data = f.read() - @asyncio.coroutine - def handler(request): - return web.Response( - body=fname.open('rb'), headers={'content-type': 'text/binary'}) + async def handler(request): + return web.Response(body=fname.open("rb")) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - resp_data = yield from resp.read() + resp_data = await resp.read() + expected_content_disposition = ( + "attachment; filename=\"conftest.py\"; filename*=utf-8''conftest.py" + ) assert resp_data == data - assert resp.headers.get('Content-Type') == 'text/binary' - assert resp.headers.get('Content-Length') == str(len(resp_data)) - assert (resp.headers.get('Content-Disposition') == - 'attachment; filename="sample.key"; filename*=utf-8\'\'sample.key') + assert resp.headers.get("Content-Type") in ( + "application/octet-stream", + "text/x-python", + "text/plain", + ) + assert resp.headers.get("Content-Length") == str(len(resp_data)) + assert resp.headers.get("Content-Disposition") == expected_content_disposition -@asyncio.coroutine -def test_response_with_payload_disp(loop, test_client, fname): +async def test_response_with_file_ctype(aiohttp_client, fname) -> None: - with fname.open('rb') as f: + with fname.open("rb") as f: data = f.read() - @asyncio.coroutine - def handler(request): - pl = aiohttp.get_payload(fname.open('rb')) - pl.set_content_disposition('inline', filename='test.txt') + async def handler(request): return web.Response( - body=pl, headers={'content-type': 'text/binary'}) + body=fname.open("rb"), headers={"content-type": "text/binary"} + ) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - resp_data = yield from resp.read() + resp_data = await resp.read() + expected_content_disposition = ( + "attachment; filename=\"conftest.py\"; filename*=utf-8''conftest.py" + ) assert resp_data == data - assert resp.headers.get('Content-Type') == 'text/binary' - assert resp.headers.get('Content-Length') == str(len(resp_data)) - assert (resp.headers.get('Content-Disposition') == - 'inline; filename="test.txt"; filename*=utf-8\'\'test.txt') + assert resp.headers.get("Content-Type") == "text/binary" + assert resp.headers.get("Content-Length") == str(len(resp_data)) + assert resp.headers.get("Content-Disposition") == expected_content_disposition -@asyncio.coroutine -def test_response_with_payload_stringio(loop, test_client, fname): +async def test_response_with_payload_disp(aiohttp_client, fname) -> None: - @asyncio.coroutine - def handler(request): - return web.Response(body=io.StringIO('test')) + with fname.open("rb") as f: + data = f.read() + + async def handler(request): + pl = aiohttp.get_payload(fname.open("rb")) + pl.set_content_disposition("inline", filename="test.txt") + return web.Response(body=pl, headers={"content-type": "text/binary"}) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - resp_data = yield from resp.read() - assert resp_data == b'test' + resp_data = await resp.read() + assert resp_data == data + assert resp.headers.get("Content-Type") == "text/binary" + assert resp.headers.get("Content-Length") == str(len(resp_data)) + assert ( + resp.headers.get("Content-Disposition") + == "inline; filename=\"test.txt\"; filename*=utf-8''test.txt" + ) + + +async def test_response_with_payload_stringio(aiohttp_client, fname) -> None: + async def handler(request): + return web.Response(body=io.StringIO("test")) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") + assert 200 == resp.status + resp_data = await resp.read() + assert resp_data == b"test" -@asyncio.coroutine -def test_response_with_precompressed_body_gzip(loop, test_client): - @asyncio.coroutine - def handler(request): - headers = {'Content-Encoding': 'gzip'} +async def test_response_with_precompressed_body_gzip(aiohttp_client) -> None: + async def handler(request): + headers = {"Content-Encoding": "gzip"} zcomp = zlib.compressobj(wbits=16 + zlib.MAX_WBITS) - data = zcomp.compress(b'mydata') + zcomp.flush() + data = zcomp.compress(b"mydata") + zcomp.flush() return web.Response(body=data, headers=headers) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - data = yield from resp.read() - assert b'mydata' == data - assert resp.headers.get('Content-Encoding') == 'gzip' + data = await resp.read() + assert b"mydata" == data + assert resp.headers.get("Content-Encoding") == "gzip" + + +async def test_response_with_precompressed_body_deflate(aiohttp_client) -> None: + async def handler(request): + headers = {"Content-Encoding": "deflate"} + zcomp = zlib.compressobj(wbits=zlib.MAX_WBITS) + data = zcomp.compress(b"mydata") + zcomp.flush() + return web.Response(body=data, headers=headers) + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert 200 == resp.status + data = await resp.read() + assert b"mydata" == data + assert resp.headers.get("Content-Encoding") == "deflate" -@asyncio.coroutine -def test_response_with_precompressed_body_deflate(loop, test_client): - @asyncio.coroutine - def handler(request): - headers = {'Content-Encoding': 'deflate'} +async def test_response_with_precompressed_body_deflate_no_hdrs(aiohttp_client) -> None: + async def handler(request): + headers = {"Content-Encoding": "deflate"} + # Actually, wrong compression format, but + # should be supported for some legacy cases. zcomp = zlib.compressobj(wbits=-zlib.MAX_WBITS) - data = zcomp.compress(b'mydata') + zcomp.flush() + data = zcomp.compress(b"mydata") + zcomp.flush() return web.Response(body=data, headers=headers) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - data = yield from resp.read() - assert b'mydata' == data - assert resp.headers.get('Content-Encoding') == 'deflate' + data = await resp.read() + assert b"mydata" == data + assert resp.headers.get("Content-Encoding") == "deflate" -@asyncio.coroutine -def test_bad_request_payload(loop, test_client): - - @asyncio.coroutine - def handler(request): - assert request.method == 'GET' +async def test_bad_request_payload(aiohttp_client) -> None: + async def handler(request): + assert request.method == "POST" with pytest.raises(aiohttp.web.RequestPayloadError): - yield from request.content.read() + await request.content.read() return web.Response() app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get( - '/', data=b'test', headers={'content-encoding': 'gzip'}) + resp = await client.post("/", data=b"test", headers={"content-encoding": "gzip"}) assert 200 == resp.status -@asyncio.coroutine -def test_stream_response_multiple_chunks(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_stream_response_multiple_chunks(aiohttp_client) -> None: + async def handler(request): resp = web.StreamResponse() resp.enable_chunked_encoding() - yield from resp.prepare(request) - resp.write(b'x') - resp.write(b'y') - resp.write(b'z') + await resp.prepare(request) + await resp.write(b"x") + await resp.write(b"y") + await resp.write(b"z") return resp app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - data = yield from resp.read() - assert b'xyz' == data + data = await resp.read() + assert b"xyz" == data -@asyncio.coroutine -def test_start_without_routes(loop, test_client): +async def test_start_without_routes(aiohttp_client) -> None: app = web.Application() - client = yield from test_client(app) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 404 == resp.status -@asyncio.coroutine -def test_requests_count(loop, test_client): - - @asyncio.coroutine - def handler(request): +async def test_requests_count(aiohttp_client) -> None: + async def handler(request): return web.Response() app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) assert client.server.handler.requests_count == 0 - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status assert client.server.handler.requests_count == 1 - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status assert client.server.handler.requests_count == 2 - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status assert client.server.handler.requests_count == 3 -@asyncio.coroutine -def test_redirect_url(loop, test_client): +async def test_redirect_url(aiohttp_client) -> None: + async def redirector(request): + raise web.HTTPFound(location=URL("/redirected")) - @asyncio.coroutine - def redirector(request): - raise web.HTTPFound(location=URL('/redirected')) - - @asyncio.coroutine - def redirected(request): + async def redirected(request): return web.Response() app = web.Application() - app.router.add_get('/redirector', redirector) - app.router.add_get('/redirected', redirected) + app.router.add_get("/redirector", redirector) + app.router.add_get("/redirected", redirected) - client = yield from test_client(app) - resp = yield from client.get('/redirector') + client = await aiohttp_client(app) + resp = await client.get("/redirector") assert resp.status == 200 -@asyncio.coroutine -def test_simple_subapp(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_simple_subapp(aiohttp_client) -> None: + async def handler(request): return web.Response(text="OK") app = web.Application() subapp = web.Application() - subapp.router.add_get('/to', handler) - app.add_subapp('/path', subapp) + subapp.router.add_get("/to", handler) + app.add_subapp("/path", subapp) - client = yield from test_client(app) - resp = yield from client.get('/path/to') + client = await aiohttp_client(app) + resp = await client.get("/path/to") assert resp.status == 200 - txt = yield from resp.text() - assert 'OK' == txt + txt = await resp.text() + assert "OK" == txt -@asyncio.coroutine -def test_subapp_reverse_url(loop, test_client): - @asyncio.coroutine - def handler(request): - return web.HTTPMovedPermanently( - location=subapp.router['name'].url_for()) +async def test_subapp_reverse_url(aiohttp_client) -> None: + async def handler(request): + raise web.HTTPMovedPermanently(location=subapp.router["name"].url_for()) - @asyncio.coroutine - def handler2(request): + async def handler2(request): return web.Response(text="OK") app = web.Application() subapp = web.Application() - subapp.router.add_get('/to', handler) - subapp.router.add_get('/final', handler2, name='name') - app.add_subapp('/path', subapp) + subapp.router.add_get("/to", handler) + subapp.router.add_get("/final", handler2, name="name") + app.add_subapp("/path", subapp) - client = yield from test_client(app) - resp = yield from client.get('/path/to') + client = await aiohttp_client(app) + resp = await client.get("/path/to") assert resp.status == 200 - txt = yield from resp.text() - assert 'OK' == txt - assert resp.url.path == '/path/final' + txt = await resp.text() + assert "OK" == txt + assert resp.url.path == "/path/final" -@asyncio.coroutine -def test_subapp_reverse_variable_url(loop, test_client): - @asyncio.coroutine - def handler(request): - return web.HTTPMovedPermanently( - location=subapp.router['name'].url_for(part='final')) +async def test_subapp_reverse_variable_url(aiohttp_client) -> None: + async def handler(request): + raise web.HTTPMovedPermanently( + location=subapp.router["name"].url_for(part="final") + ) - @asyncio.coroutine - def handler2(request): + async def handler2(request): return web.Response(text="OK") app = web.Application() subapp = web.Application() - subapp.router.add_get('/to', handler) - subapp.router.add_get('/{part}', handler2, name='name') - app.add_subapp('/path', subapp) + subapp.router.add_get("/to", handler) + subapp.router.add_get("/{part}", handler2, name="name") + app.add_subapp("/path", subapp) - client = yield from test_client(app) - resp = yield from client.get('/path/to') + client = await aiohttp_client(app) + resp = await client.get("/path/to") assert resp.status == 200 - txt = yield from resp.text() - assert 'OK' == txt - assert resp.url.path == '/path/final' + txt = await resp.text() + assert "OK" == txt + assert resp.url.path == "/path/final" -@asyncio.coroutine -def test_subapp_reverse_static_url(loop, test_client): - fname = 'aiohttp.png' +async def test_subapp_reverse_static_url(aiohttp_client) -> None: + fname = "aiohttp.png" - @asyncio.coroutine - def handler(request): - return web.HTTPMovedPermanently( - location=subapp.router['name'].url_for(filename=fname)) + async def handler(request): + raise web.HTTPMovedPermanently( + location=subapp.router["name"].url_for(filename=fname) + ) app = web.Application() subapp = web.Application() - subapp.router.add_get('/to', handler) + subapp.router.add_get("/to", handler) here = pathlib.Path(__file__).parent - subapp.router.add_static('/static', here, name='name') - app.add_subapp('/path', subapp) + subapp.router.add_static("/static", here, name="name") + app.add_subapp("/path", subapp) - client = yield from test_client(app) - resp = yield from client.get('/path/to') - assert resp.url.path == '/path/static/' + fname + client = await aiohttp_client(app) + resp = await client.get("/path/to") + assert resp.url.path == "/path/static/" + fname assert resp.status == 200 - body = yield from resp.read() - with (here / fname).open('rb') as f: + body = await resp.read() + with (here / fname).open("rb") as f: assert body == f.read() -@asyncio.coroutine -def test_subapp_app(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_subapp_app(aiohttp_client) -> None: + async def handler(request): assert request.app is subapp - return web.HTTPOk(text='OK') + return web.Response(text="OK") app = web.Application() subapp = web.Application() - subapp.router.add_get('/to', handler) - app.add_subapp('/path/', subapp) + subapp.router.add_get("/to", handler) + app.add_subapp("/path/", subapp) - client = yield from test_client(app) - resp = yield from client.get('/path/to') + client = await aiohttp_client(app) + resp = await client.get("/path/to") assert resp.status == 200 - txt = yield from resp.text() - assert 'OK' == txt + txt = await resp.text() + assert "OK" == txt -@asyncio.coroutine -def test_subapp_not_found(loop, test_client): - @asyncio.coroutine - def handler(request): - return web.HTTPOk(text='OK') +async def test_subapp_not_found(aiohttp_client) -> None: + async def handler(request): + return web.Response(text="OK") app = web.Application() subapp = web.Application() - subapp.router.add_get('/to', handler) - app.add_subapp('/path/', subapp) + subapp.router.add_get("/to", handler) + app.add_subapp("/path/", subapp) - client = yield from test_client(app) - resp = yield from client.get('/path/other') + client = await aiohttp_client(app) + resp = await client.get("/path/other") assert resp.status == 404 -@asyncio.coroutine -def test_subapp_not_found2(loop, test_client): - @asyncio.coroutine - def handler(request): - return web.HTTPOk(text='OK') +async def test_subapp_not_found2(aiohttp_client) -> None: + async def handler(request): + return web.Response(text="OK") app = web.Application() subapp = web.Application() - subapp.router.add_get('/to', handler) - app.add_subapp('/path/', subapp) + subapp.router.add_get("/to", handler) + app.add_subapp("/path/", subapp) - client = yield from test_client(app) - resp = yield from client.get('/invalid/other') + client = await aiohttp_client(app) + resp = await client.get("/invalid/other") assert resp.status == 404 -@asyncio.coroutine -def test_subapp_not_allowed(loop, test_client): - @asyncio.coroutine - def handler(request): - return web.HTTPOk(text='OK') +async def test_subapp_not_allowed(aiohttp_client) -> None: + async def handler(request): + return web.Response(text="OK") app = web.Application() subapp = web.Application() - subapp.router.add_get('/to', handler) - app.add_subapp('/path/', subapp) + subapp.router.add_get("/to", handler) + app.add_subapp("/path/", subapp) - client = yield from test_client(app) - resp = yield from client.post('/path/to') + client = await aiohttp_client(app) + resp = await client.post("/path/to") assert resp.status == 405 - assert resp.headers['Allow'] == 'GET,HEAD' + assert resp.headers["Allow"] == "GET,HEAD" -@asyncio.coroutine -def test_subapp_cannot_add_app_in_handler(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_subapp_cannot_add_app_in_handler(aiohttp_client) -> None: + async def handler(request): request.match_info.add_app(app) - return web.HTTPOk(text='OK') + return web.Response(text="OK") app = web.Application() subapp = web.Application() - subapp.router.add_get('/to', handler) - app.add_subapp('/path/', subapp) + subapp.router.add_get("/to", handler) + app.add_subapp("/path/", subapp) - client = yield from test_client(app) - resp = yield from client.get('/path/to') + client = await aiohttp_client(app) + resp = await client.get("/path/to") assert resp.status == 500 -@asyncio.coroutine -def test_subapp_middlewares(loop, test_client): +async def test_subapp_middlewares(aiohttp_client) -> None: order = [] - @asyncio.coroutine - def handler(request): - return web.HTTPOk(text='OK') - - @asyncio.coroutine - def middleware_factory(app, handler): + async def handler(request): + return web.Response(text="OK") - @asyncio.coroutine - def middleware(request): + async def middleware_factory(app, handler): + async def middleware(request): order.append((1, app)) - resp = yield from handler(request) + resp = await handler(request) assert 200 == resp.status order.append((2, app)) return resp + return middleware app = web.Application(middlewares=[middleware_factory]) subapp1 = web.Application(middlewares=[middleware_factory]) subapp2 = web.Application(middlewares=[middleware_factory]) - subapp2.router.add_get('/to', handler) - subapp1.add_subapp('/b/', subapp2) - app.add_subapp('/a/', subapp1) + subapp2.router.add_get("/to", handler) + with pytest.warns(DeprecationWarning): + subapp1.add_subapp("/b/", subapp2) + app.add_subapp("/a/", subapp1) + client = await aiohttp_client(app) - client = yield from test_client(app) - resp = yield from client.get('/a/b/to') + resp = await client.get("/a/b/to") assert resp.status == 200 - assert [(1, app), (1, subapp1), (1, subapp2), - (2, subapp2), (2, subapp1), (2, app)] == order + assert [ + (1, app), + (1, subapp1), + (1, subapp2), + (2, subapp2), + (2, subapp1), + (2, app), + ] == order -@asyncio.coroutine -def test_subapp_on_response_prepare(loop, test_client): +async def test_subapp_on_response_prepare(aiohttp_client) -> None: order = [] - @asyncio.coroutine - def handler(request): - return web.HTTPOk(text='OK') + async def handler(request): + return web.Response(text="OK") def make_signal(app): - - @asyncio.coroutine - def on_response(request, response): + async def on_response(request, response): order.append(app) return on_response @@ -1319,22 +1322,20 @@ def on_response(request, response): subapp1.on_response_prepare.append(make_signal(subapp1)) subapp2 = web.Application() subapp2.on_response_prepare.append(make_signal(subapp2)) - subapp2.router.add_get('/to', handler) - subapp1.add_subapp('/b/', subapp2) - app.add_subapp('/a/', subapp1) + subapp2.router.add_get("/to", handler) + subapp1.add_subapp("/b/", subapp2) + app.add_subapp("/a/", subapp1) - client = yield from test_client(app) - resp = yield from client.get('/a/b/to') + client = await aiohttp_client(app) + resp = await client.get("/a/b/to") assert resp.status == 200 assert [app, subapp1, subapp2] == order -@asyncio.coroutine -def test_subapp_on_startup(loop, test_server): +async def test_subapp_on_startup(aiohttp_server) -> None: order = [] - @asyncio.coroutine - def on_signal(app): + async def on_signal(app): order.append(app) app = web.Application() @@ -1343,19 +1344,18 @@ def on_signal(app): subapp1.on_startup.append(on_signal) subapp2 = web.Application() subapp2.on_startup.append(on_signal) - subapp1.add_subapp('/b/', subapp2) - app.add_subapp('/a/', subapp1) + subapp1.add_subapp("/b/", subapp2) + app.add_subapp("/a/", subapp1) - yield from test_server(app) + await aiohttp_server(app) assert [app, subapp1, subapp2] == order -@asyncio.coroutine -def test_subapp_on_shutdown(loop, test_server): +async def test_subapp_on_shutdown(aiohttp_server) -> None: order = [] - def on_signal(app): + async def on_signal(app): order.append(app) app = web.Application() @@ -1364,21 +1364,19 @@ def on_signal(app): subapp1.on_shutdown.append(on_signal) subapp2 = web.Application() subapp2.on_shutdown.append(on_signal) - subapp1.add_subapp('/b/', subapp2) - app.add_subapp('/a/', subapp1) + subapp1.add_subapp("/b/", subapp2) + app.add_subapp("/a/", subapp1) - server = yield from test_server(app) - yield from server.close() + server = await aiohttp_server(app) + await server.close() assert [app, subapp1, subapp2] == order -@asyncio.coroutine -def test_subapp_on_cleanup(loop, test_server): +async def test_subapp_on_cleanup(aiohttp_server) -> None: order = [] - @asyncio.coroutine - def on_signal(app): + async def on_signal(app): order.append(app) app = web.Application() @@ -1387,181 +1385,585 @@ def on_signal(app): subapp1.on_cleanup.append(on_signal) subapp2 = web.Application() subapp2.on_cleanup.append(on_signal) - subapp1.add_subapp('/b/', subapp2) - app.add_subapp('/a/', subapp1) + subapp1.add_subapp("/b/", subapp2) + app.add_subapp("/a/", subapp1) - server = yield from test_server(app) - yield from server.close() + server = await aiohttp_server(app) + await server.close() assert [app, subapp1, subapp2] == order -@asyncio.coroutine -def test_custom_date_header(loop, test_client): +@pytest.mark.parametrize( + "route,expected,middlewares", + [ + ("/sub/", ["A: root", "C: sub", "D: sub"], "AC"), + ("/", ["A: root", "B: root"], "AC"), + ("/sub/", ["A: root", "D: sub"], "A"), + ("/", ["A: root", "B: root"], "A"), + ("/sub/", ["C: sub", "D: sub"], "C"), + ("/", ["B: root"], "C"), + ("/sub/", ["D: sub"], ""), + ("/", ["B: root"], ""), + ], +) +async def test_subapp_middleware_context(aiohttp_client, route, expected, middlewares): + values = [] + + def show_app_context(appname): + @web.middleware + async def middleware(request, handler): + values.append("{}: {}".format(appname, request.app["my_value"])) + return await handler(request) + + return middleware + + def make_handler(appname): + async def handler(request): + values.append("{}: {}".format(appname, request.app["my_value"])) + return web.Response(text="Ok") - @asyncio.coroutine - def handler(request): - return web.Response(headers={'Date': 'Sun, 30 Oct 2016 03:13:52 GMT'}) + return handler app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app["my_value"] = "root" + if "A" in middlewares: + app.middlewares.append(show_app_context("A")) + app.router.add_get("/", make_handler("B")) - resp = yield from client.get('/') + subapp = web.Application() + subapp["my_value"] = "sub" + if "C" in middlewares: + subapp.middlewares.append(show_app_context("C")) + subapp.router.add_get("/", make_handler("D")) + app.add_subapp("/sub/", subapp) + + client = await aiohttp_client(app) + resp = await client.get(route) assert 200 == resp.status - assert resp.headers['Date'] == 'Sun, 30 Oct 2016 03:13:52 GMT' + assert "Ok" == await resp.text() + assert expected == values + + +async def test_custom_date_header(aiohttp_client) -> None: + async def handler(request): + return web.Response(headers={"Date": "Sun, 30 Oct 2016 03:13:52 GMT"}) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") + assert 200 == resp.status + assert resp.headers["Date"] == "Sun, 30 Oct 2016 03:13:52 GMT" -@asyncio.coroutine -def test_response_prepared_with_clone(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_response_prepared_with_clone(aiohttp_client) -> None: + async def handler(request): cloned = request.clone() resp = web.StreamResponse() - yield from resp.prepare(cloned) + await resp.prepare(cloned) return resp app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status -@asyncio.coroutine -def test_app_max_client_size(loop, test_client): - - @asyncio.coroutine - def handler(request): - yield from request.post() - return web.Response(body=b'ok') +async def test_app_max_client_size(aiohttp_client) -> None: + async def handler(request): + await request.post() + return web.Response(body=b"ok") - max_size = 1024**2 + max_size = 1024 ** 2 app = web.Application() - app.router.add_post('/', handler) - client = yield from test_client(app) - data = {"long_string": max_size * 'x' + 'xxx'} - resp = yield from client.post('/', data=data) + app.router.add_post("/", handler) + client = await aiohttp_client(app) + data = {"long_string": max_size * "x" + "xxx"} + with pytest.warns(ResourceWarning): + resp = await client.post("/", data=data) assert 413 == resp.status - resp_text = yield from resp.text() - assert 'Request Entity Too Large' in resp_text + resp_text = await resp.text() + assert ( + "Maximum request body size 1048576 exceeded, " "actual body size" in resp_text + ) + # Maximum request body size X exceeded, actual body size X + body_size = int(resp_text.split()[-1]) + assert body_size >= max_size -@asyncio.coroutine -def test_app_max_client_size_adjusted(loop, test_client): +async def test_app_max_client_size_adjusted(aiohttp_client) -> None: + async def handler(request): + await request.post() + return web.Response(body=b"ok") - @asyncio.coroutine - def handler(request): - yield from request.post() - return web.Response(body=b'ok') - - default_max_size = 1024**2 + default_max_size = 1024 ** 2 custom_max_size = default_max_size * 2 app = web.Application(client_max_size=custom_max_size) - app.router.add_post('/', handler) - client = yield from test_client(app) - data = {'long_string': default_max_size * 'x' + 'xxx'} - resp = yield from client.post('/', data=data) + app.router.add_post("/", handler) + client = await aiohttp_client(app) + data = {"long_string": default_max_size * "x" + "xxx"} + with pytest.warns(ResourceWarning): + resp = await client.post("/", data=data) assert 200 == resp.status - resp_text = yield from resp.text() - assert 'ok' == resp_text - too_large_data = {'log_string': custom_max_size * 'x' + "xxx"} - resp = yield from client.post('/', data=too_large_data) + resp_text = await resp.text() + assert "ok" == resp_text + too_large_data = {"log_string": custom_max_size * "x" + "xxx"} + with pytest.warns(ResourceWarning): + resp = await client.post("/", data=too_large_data) assert 413 == resp.status - resp_text = yield from resp.text() - assert 'Request Entity Too Large' in resp_text - + resp_text = await resp.text() + assert ( + "Maximum request body size 2097152 exceeded, " "actual body size" in resp_text + ) + # Maximum request body size X exceeded, actual body size X + body_size = int(resp_text.split()[-1]) + assert body_size >= custom_max_size -@asyncio.coroutine -def test_app_max_client_size_none(loop, test_client): - @asyncio.coroutine - def handler(request): - yield from request.post() - return web.Response(body=b'ok') +async def test_app_max_client_size_none(aiohttp_client) -> None: + async def handler(request): + await request.post() + return web.Response(body=b"ok") - default_max_size = 1024**2 + default_max_size = 1024 ** 2 custom_max_size = None app = web.Application(client_max_size=custom_max_size) - app.router.add_post('/', handler) - client = yield from test_client(app) - data = {'long_string': default_max_size * 'x' + 'xxx'} - resp = yield from client.post('/', data=data) + app.router.add_post("/", handler) + client = await aiohttp_client(app) + data = {"long_string": default_max_size * "x" + "xxx"} + with pytest.warns(ResourceWarning): + resp = await client.post("/", data=data) assert 200 == resp.status - resp_text = yield from resp.text() - assert 'ok' == resp_text - too_large_data = {'log_string': default_max_size * 2 * 'x'} - resp = yield from client.post('/', data=too_large_data) + resp_text = await resp.text() + assert "ok" == resp_text + too_large_data = {"log_string": default_max_size * 2 * "x"} + with pytest.warns(ResourceWarning): + resp = await client.post("/", data=too_large_data) assert 200 == resp.status - resp_text = yield from resp.text() - assert resp_text == 'ok' + resp_text = await resp.text() + assert resp_text == "ok" -@asyncio.coroutine -def test_post_max_client_size(loop, test_client): - - @asyncio.coroutine - def handler(request): - try: - yield from request.post() - except ValueError: - return web.HTTPOk() - return web.HTTPBadRequest() +async def test_post_max_client_size(aiohttp_client) -> None: + async def handler(request): + await request.post() + return web.Response() app = web.Application(client_max_size=10) - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + data = {"long_string": 1024 * "x", "file": io.BytesIO(b"test")} + resp = await client.post("/", data=data) + + assert 413 == resp.status + resp_text = await resp.text() + assert ( + "Maximum request body size 10 exceeded, " "actual body size 1024" in resp_text + ) + + +async def test_post_max_client_size_for_file(aiohttp_client) -> None: + async def handler(request): + await request.post() + return web.Response() + + app = web.Application(client_max_size=2) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - data = {"long_string": 1024 * 'x', 'file': io.BytesIO(b'test')} - resp = yield from client.post('/', data=data) + data = {"file": io.BytesIO(b"test")} + resp = await client.post("/", data=data) + + assert 413 == resp.status + + +async def test_response_with_bodypart(aiohttp_client) -> None: + async def handler(request): + reader = await request.multipart() + part = await reader.next() + return web.Response(body=part) + + app = web.Application(client_max_size=2) + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + data = {"file": io.BytesIO(b"test")} + resp = await client.post("/", data=data) assert 200 == resp.status + body = await resp.read() + assert body == b"test" + disp = multipart.parse_content_disposition(resp.headers["content-disposition"]) + assert disp == ( + "attachment", + {"name": "file", "filename": "file", "filename*": "file"}, + ) -@asyncio.coroutine -def test_post_max_client_size_for_file(loop, test_client): - @asyncio.coroutine - def handler(request): - try: - yield from request.post() - except ValueError: - return web.HTTPOk() - return web.HTTPBadRequest() +async def test_response_with_bodypart_named(aiohttp_client, tmpdir) -> None: + async def handler(request): + reader = await request.multipart() + part = await reader.next() + return web.Response(body=part) app = web.Application(client_max_size=2) - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - data = {'file': io.BytesIO(b'test')} - resp = yield from client.post('/', data=data) + f = tmpdir.join("foobar.txt") + f.write_text("test", encoding="utf8") + data = {"file": open(str(f), "rb")} + resp = await client.post("/", data=data) assert 200 == resp.status + body = await resp.read() + assert body == b"test" + disp = multipart.parse_content_disposition(resp.headers["content-disposition"]) + assert disp == ( + "attachment", + {"name": "file", "filename": "foobar.txt", "filename*": "foobar.txt"}, + ) -@asyncio.coroutine -def test_response_with_bodypart(loop, test_client): - @asyncio.coroutine - def handler(request): - reader = yield from request.multipart() - part = yield from reader.next() +async def test_response_with_bodypart_invalid_name(aiohttp_client) -> None: + async def handler(request): + reader = await request.multipart() + part = await reader.next() return web.Response(body=part) app = web.Application(client_max_size=2) - app.router.add_post('/', handler) - client = yield from test_client(app) + app.router.add_post("/", handler) + client = await aiohttp_client(app) - data = {'file': io.BytesIO(b'test')} - resp = yield from client.post('/', data=data) + with aiohttp.MultipartWriter() as mpwriter: + mpwriter.append(b"test") + resp = await client.post("/", data=mpwriter) assert 200 == resp.status - body = yield from resp.read() - assert body == b'test' + body = await resp.read() + assert body == b"test" + + assert "content-disposition" not in resp.headers + + +async def test_request_clone(aiohttp_client) -> None: + async def handler(request): + r2 = request.clone(method="POST") + assert r2.method == "POST" + assert r2.match_info is request.match_info + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert 200 == resp.status + + +async def test_await(aiohttp_server) -> None: + async def handler(request): + resp = web.StreamResponse(headers={"content-length": str(4)}) + await resp.prepare(request) + with pytest.warns(DeprecationWarning): + await resp.drain() + await asyncio.sleep(0.01) + await resp.write(b"test") + await asyncio.sleep(0.01) + await resp.write_eof() + return resp + + app = web.Application() + app.router.add_route("GET", "/", handler) + server = await aiohttp_server(app) + + async with aiohttp.ClientSession() as session: + resp = await session.get(server.make_url("/")) + assert resp.status == 200 + assert resp.connection is not None + await resp.read() + await resp.release() + assert resp.connection is None + + +async def test_response_context_manager(aiohttp_server) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_route("GET", "/", handler) + server = await aiohttp_server(app) + resp = await aiohttp.ClientSession().get(server.make_url("/")) + async with resp: + assert resp.status == 200 + assert resp.connection is None + assert resp.connection is None + + +async def test_response_context_manager_error(aiohttp_server) -> None: + async def handler(request): + return web.Response(text="some text") + + app = web.Application() + app.router.add_route("GET", "/", handler) + server = await aiohttp_server(app) + session = aiohttp.ClientSession() + cm = session.get(server.make_url("/")) + resp = await cm + with pytest.raises(RuntimeError): + async with resp: + assert resp.status == 200 + resp.content.set_exception(RuntimeError()) + await resp.read() + assert resp.closed + + assert len(session._connector._conns) == 1 + + +async def aiohttp_client_api_context_manager(aiohttp_server): + async def handler(request): + return web.Response() + + app = web.Application() + app.router.add_route("GET", "/", handler) + server = await aiohttp_server(app) + + async with aiohttp.ClientSession() as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + assert resp.connection is None + assert resp.connection is None + + +async def test_context_manager_close_on_release(aiohttp_server, mocker) -> None: + async def handler(request): + resp = web.StreamResponse() + await resp.prepare(request) + with pytest.warns(DeprecationWarning): + await resp.drain() + await asyncio.sleep(10) + return resp + + app = web.Application() + app.router.add_route("GET", "/", handler) + server = await aiohttp_server(app) + + async with aiohttp.ClientSession() as session: + resp = await session.get(server.make_url("/")) + proto = resp.connection._protocol + mocker.spy(proto, "close") + async with resp: + assert resp.status == 200 + assert resp.connection is not None + assert resp.connection is None + assert proto.close.called + + +async def test_iter_any(aiohttp_server) -> None: + + data = b"0123456789" * 1024 + + async def handler(request): + buf = [] + async for raw in request.content.iter_any(): + buf.append(raw) + assert b"".join(buf) == data + return web.Response() + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + async with aiohttp.ClientSession() as session: + async with session.post(server.make_url("/"), data=data) as resp: + assert resp.status == 200 + + +async def test_request_tracing(aiohttp_server) -> None: + + on_request_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_request_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_dns_resolvehost_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_dns_resolvehost_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_request_redirect = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_connection_create_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_connection_create_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - disp = multipart.parse_content_disposition( - resp.headers['content-disposition']) - assert disp == ('attachment', - {'name': 'file', 'filename': 'file', 'filename*': 'file'}) + async def redirector(request): + raise web.HTTPFound(location=URL("/redirected")) + + async def redirected(request): + return web.Response() + + trace_config = TraceConfig() + + trace_config.on_request_start.append(on_request_start) + trace_config.on_request_end.append(on_request_end) + trace_config.on_request_redirect.append(on_request_redirect) + trace_config.on_connection_create_start.append(on_connection_create_start) + trace_config.on_connection_create_end.append(on_connection_create_end) + trace_config.on_dns_resolvehost_start.append(on_dns_resolvehost_start) + trace_config.on_dns_resolvehost_end.append(on_dns_resolvehost_end) + + app = web.Application() + app.router.add_get("/redirector", redirector) + app.router.add_get("/redirected", redirected) + server = await aiohttp_server(app) + + class FakeResolver: + _LOCAL_HOST = {0: "127.0.0.1", socket.AF_INET: "127.0.0.1"} + + def __init__(self, fakes): + # fakes -- dns -> port dict + self._fakes = fakes + self._resolver = aiohttp.DefaultResolver() + + async def resolve(self, host, port=0, family=socket.AF_INET): + fake_port = self._fakes.get(host) + if fake_port is not None: + return [ + { + "hostname": host, + "host": self._LOCAL_HOST[family], + "port": fake_port, + "family": socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + ] + else: + return await self._resolver.resolve(host, port, family) + + resolver = FakeResolver({"example.com": server.port}) + connector = aiohttp.TCPConnector(resolver=resolver) + client = aiohttp.ClientSession(connector=connector, trace_configs=[trace_config]) + + await client.get("http://example.com/redirector", data="foo") + + assert on_request_start.called + assert on_request_end.called + assert on_dns_resolvehost_start.called + assert on_dns_resolvehost_end.called + assert on_request_redirect.called + assert on_connection_create_start.called + assert on_connection_create_end.called + await client.close() + + +async def test_return_http_exception_deprecated(aiohttp_client) -> None: + async def handler(request): + return web.HTTPForbidden() + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + with pytest.warns(DeprecationWarning): + await client.get("/") + + +async def test_request_path(aiohttp_client) -> None: + async def handler(request): + assert request.path_qs == "/path%20to?a=1" + assert request.path == "/path to" + assert request.raw_path == "/path%20to?a=1" + return web.Response(body=b"OK") + + app = web.Application() + app.router.add_get("/path to", handler) + client = await aiohttp_client(app) + + resp = await client.get("/path to", params={"a": "1"}) + assert 200 == resp.status + txt = await resp.text() + assert "OK" == txt + + +async def test_app_add_routes(aiohttp_client) -> None: + async def handler(request): + return web.Response() + + app = web.Application() + app.add_routes([web.get("/get", handler)]) + + client = await aiohttp_client(app) + resp = await client.get("/get") + assert resp.status == 200 + + +async def test_request_headers_type(aiohttp_client) -> None: + async def handler(request): + assert isinstance(request.headers, CIMultiDictProxy) + return web.Response() + + app = web.Application() + app.add_routes([web.get("/get", handler)]) + + client = await aiohttp_client(app) + resp = await client.get("/get") + assert resp.status == 200 + + +async def test_signal_on_error_handler(aiohttp_client) -> None: + async def on_prepare(request, response): + response.headers["X-Custom"] = "val" + + app = web.Application() + app.on_response_prepare.append(on_prepare) + + client = await aiohttp_client(app) + resp = await client.get("/") + assert resp.status == 404 + assert resp.headers["X-Custom"] == "val" + + +@pytest.mark.skipif( + "HttpRequestParserC" not in dir(aiohttp.http_parser), + reason="C based HTTP parser not available", +) +async def test_bad_method_for_c_http_parser_not_hangs(aiohttp_client) -> None: + app = web.Application() + timeout = aiohttp.ClientTimeout(sock_read=0.2) + client = await aiohttp_client(app, timeout=timeout) + resp = await client.request("GET1", "/") + assert 400 == resp.status + + +async def test_read_bufsize(aiohttp_client) -> None: + async def handler(request): + ret = request.content.get_read_buffer_limits() + data = await request.text() # read posted data + return web.Response(text=f"{data} {ret!r}") + + app = web.Application(handler_args={"read_bufsize": 2}) + app.router.add_post("/", handler) + + client = await aiohttp_client(app) + resp = await client.post("/", data=b"data") + assert resp.status == 200 + assert await resp.text() == "data (2, 4)" + + +@pytest.mark.parametrize( + "status", + [101, 204], +) +async def test_response_101_204_no_content_length_http11( + status, aiohttp_client +) -> None: + async def handler(_): + return web.Response(status=status) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app, version="1.1") + resp = await client.get("/") + assert CONTENT_LENGTH not in resp.headers + assert TRANSFER_ENCODING not in resp.headers diff --git a/tests/test_web_log.py b/tests/test_web_log.py new file mode 100644 index 00000000000..0a4168ae72e --- /dev/null +++ b/tests/test_web_log.py @@ -0,0 +1,197 @@ +import datetime +import platform +from unittest import mock + +import pytest + +import aiohttp +from aiohttp import web +from aiohttp.abc import AbstractAccessLogger +from aiohttp.helpers import PY_37 +from aiohttp.web_log import AccessLogger + +try: + from contextvars import ContextVar +except ImportError: + ContextVar = None + + +IS_PYPY = platform.python_implementation() == "PyPy" + + +def test_access_logger_format() -> None: + log_format = '%T "%{ETag}o" %X {X} %%P' + mock_logger = mock.Mock() + access_logger = AccessLogger(mock_logger, log_format) + expected = '%s "%s" %%X {X} %%%s' + assert expected == access_logger._log_format + + +@pytest.mark.skipif( + IS_PYPY, + reason=""" + Because of patching :py:class:`datetime.datetime`, under PyPy it + fails in :py:func:`isinstance` call in + :py:meth:`datetime.datetime.__sub__` (called from + :py:meth:`aiohttp.AccessLogger._format_t`): + + *** TypeError: isinstance() arg 2 must be a class, type, or tuple of classes and types + + (Pdb) from datetime import datetime + (Pdb) isinstance(now, datetime) + *** TypeError: isinstance() arg 2 must be a class, type, or tuple of classes and types + (Pdb) datetime.__class__ + + (Pdb) isinstance(now, datetime.__class__) + False + + Ref: https://bitbucket.org/pypy/pypy/issues/1187/call-to-isinstance-in-__sub__-self-other + Ref: https://github.com/celery/celery/issues/811 + Ref: https://stackoverflow.com/a/46102240/595220 + """, # noqa: E501 +) +def test_access_logger_atoms(mocker) -> None: + utcnow = datetime.datetime(1843, 1, 1, 0, 30) + mock_datetime = mocker.patch("datetime.datetime") + mock_getpid = mocker.patch("os.getpid") + mock_datetime.utcnow.return_value = utcnow + mock_getpid.return_value = 42 + log_format = '%a %t %P %r %s %b %T %Tf %D "%{H1}i" "%{H2}i"' + mock_logger = mock.Mock() + access_logger = AccessLogger(mock_logger, log_format) + request = mock.Mock( + headers={"H1": "a", "H2": "b"}, + method="GET", + path_qs="/path", + version=aiohttp.HttpVersion(1, 1), + remote="127.0.0.2", + ) + response = mock.Mock(headers={}, body_length=42, status=200) + access_logger.log(request, response, 3.1415926) + assert not mock_logger.exception.called + expected = ( + "127.0.0.2 [01/Jan/1843:00:29:56 +0000] <42> " + 'GET /path HTTP/1.1 200 42 3 3.141593 3141593 "a" "b"' + ) + extra = { + "first_request_line": "GET /path HTTP/1.1", + "process_id": "<42>", + "remote_address": "127.0.0.2", + "request_start_time": "[01/Jan/1843:00:29:56 +0000]", + "request_time": "3", + "request_time_frac": "3.141593", + "request_time_micro": "3141593", + "response_size": 42, + "response_status": 200, + "request_header": {"H1": "a", "H2": "b"}, + } + + mock_logger.info.assert_called_with(expected, extra=extra) + + +def test_access_logger_dicts() -> None: + log_format = "%{User-Agent}i %{Content-Length}o %{None}i" + mock_logger = mock.Mock() + access_logger = AccessLogger(mock_logger, log_format) + request = mock.Mock( + headers={"User-Agent": "Mock/1.0"}, version=(1, 1), remote="127.0.0.2" + ) + response = mock.Mock(headers={"Content-Length": 123}) + access_logger.log(request, response, 0.0) + assert not mock_logger.error.called + expected = "Mock/1.0 123 -" + extra = { + "request_header": {"User-Agent": "Mock/1.0", "None": "-"}, + "response_header": {"Content-Length": 123}, + } + + mock_logger.info.assert_called_with(expected, extra=extra) + + +def test_access_logger_unix_socket() -> None: + log_format = "|%a|" + mock_logger = mock.Mock() + access_logger = AccessLogger(mock_logger, log_format) + request = mock.Mock(headers={"User-Agent": "Mock/1.0"}, version=(1, 1), remote="") + response = mock.Mock() + access_logger.log(request, response, 0.0) + assert not mock_logger.error.called + expected = "||" + mock_logger.info.assert_called_with(expected, extra={"remote_address": ""}) + + +def test_logger_no_message() -> None: + mock_logger = mock.Mock() + access_logger = AccessLogger(mock_logger, "%r %{content-type}i") + extra_dict = { + "first_request_line": "-", + "request_header": {"content-type": "(no headers)"}, + } + + access_logger.log(None, None, 0.0) + mock_logger.info.assert_called_with("- (no headers)", extra=extra_dict) + + +def test_logger_internal_error() -> None: + mock_logger = mock.Mock() + access_logger = AccessLogger(mock_logger, "%D") + access_logger.log(None, None, "invalid") + mock_logger.exception.assert_called_with("Error in logging") + + +def test_logger_no_transport() -> None: + mock_logger = mock.Mock() + access_logger = AccessLogger(mock_logger, "%a") + access_logger.log(None, None, 0) + mock_logger.info.assert_called_with("-", extra={"remote_address": "-"}) + + +def test_logger_abc() -> None: + class Logger(AbstractAccessLogger): + def log(self, request, response, time): + 1 / 0 + + mock_logger = mock.Mock() + access_logger = Logger(mock_logger, None) + + with pytest.raises(ZeroDivisionError): + access_logger.log(None, None, None) + + class Logger(AbstractAccessLogger): + def log(self, request, response, time): + self.logger.info( + self.log_format.format(request=request, response=response, time=time) + ) + + mock_logger = mock.Mock() + access_logger = Logger(mock_logger, "{request} {response} {time}") + access_logger.log("request", "response", 1) + mock_logger.info.assert_called_with("request response 1") + + +@pytest.mark.skipif(not PY_37, reason="contextvars support is required") +async def test_contextvars_logger(aiohttp_server, aiohttp_client): + VAR = ContextVar("VAR") + + async def handler(request): + return web.Response() + + @web.middleware + async def middleware(request, handler): + VAR.set("uuid") + return await handler(request) + + msg = None + + class Logger(AbstractAccessLogger): + def log(self, request, response, time): + nonlocal msg + msg = f"contextvars: {VAR.get()}" + + app = web.Application(middlewares=[middleware]) + app.router.add_get("/", handler) + server = await aiohttp_server(app, access_log_class=Logger) + client = await aiohttp_client(server) + resp = await client.get("/") + assert 200 == resp.status + assert msg == "contextvars: uuid" diff --git a/tests/test_web_middleware.py b/tests/test_web_middleware.py index 11879702c65..1a6ea61cdd5 100644 --- a/tests/test_web_middleware.py +++ b/tests/test_web_middleware.py @@ -1,190 +1,566 @@ -import asyncio +import re +from typing import Any import pytest +from yarl import URL from aiohttp import web -@asyncio.coroutine -def test_middleware_modifies_response(loop, test_client): +async def test_middleware_modifies_response(loop, aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"OK") - @asyncio.coroutine - def handler(request): - return web.Response(body=b'OK') + @web.middleware + async def middleware(request, handler): + resp = await handler(request) + assert 200 == resp.status + resp.set_status(201) + resp.text = resp.text + "[MIDDLEWARE]" + return resp - @asyncio.coroutine - def middleware_factory(app, handler): + app = web.Application() + app.middlewares.append(middleware) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") + assert 201 == resp.status + txt = await resp.text() + assert "OK[MIDDLEWARE]" == txt - @asyncio.coroutine - def middleware(request): - resp = yield from handler(request) - assert 200 == resp.status - resp.set_status(201) - resp.text = resp.text + '[MIDDLEWARE]' - return resp - return middleware + +async def test_middleware_handles_exception(loop, aiohttp_client) -> None: + async def handler(request): + raise RuntimeError("Error text") + + @web.middleware + async def middleware(request, handler): + with pytest.raises(RuntimeError) as ctx: + await handler(request) + return web.Response(status=501, text=str(ctx.value) + "[MIDDLEWARE]") app = web.Application() - app.middlewares.append(middleware_factory) - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.get('/') - assert 201 == resp.status - txt = yield from resp.text() - assert 'OK[MIDDLEWARE]' == txt + app.middlewares.append(middleware) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") + assert 501 == resp.status + txt = await resp.text() + assert "Error text[MIDDLEWARE]" == txt + +async def test_middleware_chain(loop, aiohttp_client) -> None: + async def handler(request): + return web.Response(text="OK") -@asyncio.coroutine -def test_middleware_handles_exception(loop, test_client): + handler.annotation = "annotation_value" - @asyncio.coroutine - def handler(request): - raise RuntimeError('Error text') + async def handler2(request): + return web.Response(text="OK") - @asyncio.coroutine - def middleware_factory(app, handler): + middleware_annotation_seen_values = [] - @asyncio.coroutine - def middleware(request): - with pytest.raises(RuntimeError) as ctx: - yield from handler(request) - return web.Response(status=501, - text=str(ctx.value) + '[MIDDLEWARE]') + def make_middleware(num): + @web.middleware + async def middleware(request, handler): + middleware_annotation_seen_values.append( + getattr(handler, "annotation", None) + ) + resp = await handler(request) + resp.text = resp.text + f"[{num}]" + return resp return middleware app = web.Application() - app.middlewares.append(middleware_factory) - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.get('/') - assert 501 == resp.status - txt = yield from resp.text() - assert 'Error text[MIDDLEWARE]' == txt + app.middlewares.append(make_middleware(1)) + app.middlewares.append(make_middleware(2)) + app.router.add_route("GET", "/", handler) + app.router.add_route("GET", "/r2", handler2) + client = await aiohttp_client(app) + resp = await client.get("/") + assert 200 == resp.status + txt = await resp.text() + assert "OK[2][1]" == txt + assert middleware_annotation_seen_values == ["annotation_value", "annotation_value"] + # check that attributes from handler are not applied to handler2 + resp = await client.get("/r2") + assert 200 == resp.status + assert middleware_annotation_seen_values == [ + "annotation_value", + "annotation_value", + None, + None, + ] -@asyncio.coroutine -def test_middleware_chain(loop, test_client): - @asyncio.coroutine - def handler(request): - return web.Response(text='OK') +async def test_middleware_subapp(loop, aiohttp_client) -> None: + async def sub_handler(request): + return web.Response(text="OK") - def make_factory(num): + sub_handler.annotation = "annotation_value" - @asyncio.coroutine - def factory(app, handler): + async def handler(request): + return web.Response(text="OK") - def middleware(request): - resp = yield from handler(request) - resp.text = resp.text + '[{}]'.format(num) - return resp + middleware_annotation_seen_values = [] - return middleware - return factory + def make_middleware(num): + @web.middleware + async def middleware(request, handler): + annotation = getattr(handler, "annotation", None) + if annotation is not None: + middleware_annotation_seen_values.append(f"{annotation}/{num}") + return await handler(request) + + return middleware app = web.Application() - app.middlewares.append(make_factory(1)) - app.middlewares.append(make_factory(2)) - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - resp = yield from client.get('/') + app.middlewares.append(make_middleware(1)) + app.router.add_route("GET", "/r2", handler) + + subapp = web.Application() + subapp.middlewares.append(make_middleware(2)) + subapp.router.add_route("GET", "/", sub_handler) + app.add_subapp("/sub", subapp) + + client = await aiohttp_client(app) + resp = await client.get("/sub/") + assert 200 == resp.status + await resp.text() + assert middleware_annotation_seen_values == [ + "annotation_value/1", + "annotation_value/2", + ] + + # check that attributes from sub_handler are not applied to handler + del middleware_annotation_seen_values[:] + resp = await client.get("/r2") assert 200 == resp.status - txt = yield from resp.text() - assert 'OK[2][1]' == txt + assert middleware_annotation_seen_values == [] @pytest.fixture -def cli(loop, test_client): +def cli(loop, aiohttp_client): + async def handler(request): + return web.Response(text="OK") + def wrapper(extra_middlewares): app = web.Application() - app.router.add_route( - 'GET', '/resource1', lambda x: web.Response(text="OK")) - app.router.add_route( - 'GET', '/resource2/', lambda x: web.Response(text="OK")) - app.router.add_route( - 'GET', '/resource1/a/b', lambda x: web.Response(text="OK")) - app.router.add_route( - 'GET', '/resource2/a/b/', lambda x: web.Response(text="OK")) + app.router.add_route("GET", "/resource1", handler) + app.router.add_route("GET", "/resource2/", handler) + app.router.add_route("GET", "/resource1/a/b", handler) + app.router.add_route("GET", "/resource2/a/b/", handler) + app.router.add_route("GET", "/resource2/a/b%2Fc/", handler) app.middlewares.extend(extra_middlewares) - return test_client(app, server_kwargs={'skip_url_asserts': True}) + return aiohttp_client(app, server_kwargs={"skip_url_asserts": True}) + return wrapper class TestNormalizePathMiddleware: - - @asyncio.coroutine - @pytest.mark.parametrize("path, status", [ - ('/resource1', 200), - ('/resource1/', 404), - ('/resource2', 200), - ('/resource2/', 200) - ]) - def test_add_trailing_when_necessary( - self, path, status, cli): - extra_middlewares = [ - web.normalize_path_middleware(merge_slashes=False)] - client = yield from cli(extra_middlewares) - - resp = yield from client.get(path) + @pytest.mark.parametrize( + "path, status", + [ + ("/resource1", 200), + ("/resource1/", 404), + ("/resource2", 200), + ("/resource2/", 200), + ("/resource1?p1=1&p2=2", 200), + ("/resource1/?p1=1&p2=2", 404), + ("/resource2?p1=1&p2=2", 200), + ("/resource2/?p1=1&p2=2", 200), + ("/resource2/a/b%2Fc", 200), + ("/resource2/a/b%2Fc/", 200), + ], + ) + async def test_add_trailing_when_necessary(self, path, status, cli): + extra_middlewares = [web.normalize_path_middleware(merge_slashes=False)] + client = await cli(extra_middlewares) + + resp = await client.get(path) assert resp.status == status - - @asyncio.coroutine - @pytest.mark.parametrize("path, status", [ - ('/resource1', 200), - ('/resource1/', 404), - ('/resource2', 404), - ('/resource2/', 200) - ]) - def test_no_trailing_slash_when_disabled( - self, path, status, cli): + assert resp.url.query == URL(path).query + + @pytest.mark.parametrize( + "path, status", + [ + ("/resource1", 200), + ("/resource1/", 200), + ("/resource2", 404), + ("/resource2/", 200), + ("/resource1?p1=1&p2=2", 200), + ("/resource1/?p1=1&p2=2", 200), + ("/resource2?p1=1&p2=2", 404), + ("/resource2/?p1=1&p2=2", 200), + ("/resource2/a/b%2Fc", 404), + ("/resource2/a/b%2Fc/", 200), + ], + ) + async def test_remove_trailing_when_necessary(self, path, status, cli) -> None: extra_middlewares = [ web.normalize_path_middleware( - append_slash=False, merge_slashes=False)] - client = yield from cli(extra_middlewares) + append_slash=False, remove_slash=True, merge_slashes=False + ) + ] + client = await cli(extra_middlewares) - resp = yield from client.get(path) + resp = await client.get(path) assert resp.status == status - - @asyncio.coroutine - @pytest.mark.parametrize("path, status", [ - ('/resource1/a/b', 200), - ('//resource1//a//b', 200), - ('//resource1//a//b/', 404), - ('///resource1//a//b', 200), - ('/////resource1/a///b', 200), - ('/////resource1/a//b/', 404) - ]) - def test_merge_slash(self, path, status, cli): + assert resp.url.query == URL(path).query + + @pytest.mark.parametrize( + "path, status", + [ + ("/resource1", 200), + ("/resource1/", 404), + ("/resource2", 404), + ("/resource2/", 200), + ("/resource1?p1=1&p2=2", 200), + ("/resource1/?p1=1&p2=2", 404), + ("/resource2?p1=1&p2=2", 404), + ("/resource2/?p1=1&p2=2", 200), + ("/resource2/a/b%2Fc", 404), + ("/resource2/a/b%2Fc/", 200), + ], + ) + async def test_no_trailing_slash_when_disabled(self, path, status, cli): extra_middlewares = [ - web.normalize_path_middleware(append_slash=False)] - client = yield from cli(extra_middlewares) + web.normalize_path_middleware(append_slash=False, merge_slashes=False) + ] + client = await cli(extra_middlewares) - resp = yield from client.get(path) + resp = await client.get(path) assert resp.status == status - - @asyncio.coroutine - @pytest.mark.parametrize("path, status", [ - ('/resource1/a/b', 200), - ('/resource1/a/b/', 404), - ('//resource2//a//b', 200), - ('//resource2//a//b/', 200), - ('///resource1//a//b', 200), - ('///resource1//a//b/', 404), - ('/////resource1/a///b', 200), - ('/////resource1/a///b/', 404), - ('/resource2/a/b', 200), - ('//resource2//a//b', 200), - ('//resource2//a//b/', 200), - ('///resource2//a//b', 200), - ('///resource2//a//b/', 200), - ('/////resource2/a///b', 200), - ('/////resource2/a///b/', 200) - ]) - def test_append_and_merge_slash(self, path, status, cli): + assert resp.url.query == URL(path).query + + @pytest.mark.parametrize( + "path, status", + [ + ("/resource1/a/b", 200), + ("//resource1//a//b", 200), + ("//resource1//a//b/", 404), + ("///resource1//a//b", 200), + ("/////resource1/a///b", 200), + ("/////resource1/a//b/", 404), + ("/resource1/a/b?p=1", 200), + ("//resource1//a//b?p=1", 200), + ("//resource1//a//b/?p=1", 404), + ("///resource1//a//b?p=1", 200), + ("/////resource1/a///b?p=1", 200), + ("/////resource1/a//b/?p=1", 404), + ], + ) + async def test_merge_slash(self, path, status, cli) -> None: + extra_middlewares = [web.normalize_path_middleware(append_slash=False)] + client = await cli(extra_middlewares) + + resp = await client.get(path) + assert resp.status == status + assert resp.url.query == URL(path).query + + @pytest.mark.parametrize( + "path, status", + [ + ("/resource1/a/b", 200), + ("/resource1/a/b/", 404), + ("//resource2//a//b", 200), + ("//resource2//a//b/", 200), + ("///resource1//a//b", 200), + ("///resource1//a//b/", 404), + ("/////resource1/a///b", 200), + ("/////resource1/a///b/", 404), + ("/resource2/a/b", 200), + ("//resource2//a//b", 200), + ("//resource2//a//b/", 200), + ("///resource2//a//b", 200), + ("///resource2//a//b/", 200), + ("/////resource2/a///b", 200), + ("/////resource2/a///b/", 200), + ("/resource1/a/b?p=1", 200), + ("/resource1/a/b/?p=1", 404), + ("//resource2//a//b?p=1", 200), + ("//resource2//a//b/?p=1", 200), + ("///resource1//a//b?p=1", 200), + ("///resource1//a//b/?p=1", 404), + ("/////resource1/a///b?p=1", 200), + ("/////resource1/a///b/?p=1", 404), + ("/resource2/a/b?p=1", 200), + ("//resource2//a//b?p=1", 200), + ("//resource2//a//b/?p=1", 200), + ("///resource2//a//b?p=1", 200), + ("///resource2//a//b/?p=1", 200), + ("/////resource2/a///b?p=1", 200), + ("/////resource2/a///b/?p=1", 200), + ], + ) + async def test_append_and_merge_slash(self, path, status, cli) -> None: + extra_middlewares = [web.normalize_path_middleware()] + + client = await cli(extra_middlewares) + resp = await client.get(path) + assert resp.status == status + assert resp.url.query == URL(path).query + + @pytest.mark.parametrize( + "path, status", + [ + ("/resource1/a/b", 200), + ("/resource1/a/b/", 200), + ("//resource2//a//b", 404), + ("//resource2//a//b/", 200), + ("///resource1//a//b", 200), + ("///resource1//a//b/", 200), + ("/////resource1/a///b", 200), + ("/////resource1/a///b/", 200), + ("/////resource1/a///b///", 200), + ("/resource2/a/b", 404), + ("//resource2//a//b", 404), + ("//resource2//a//b/", 200), + ("///resource2//a//b", 404), + ("///resource2//a//b/", 200), + ("/////resource2/a///b", 404), + ("/////resource2/a///b/", 200), + ("/resource1/a/b?p=1", 200), + ("/resource1/a/b/?p=1", 200), + ("//resource2//a//b?p=1", 404), + ("//resource2//a//b/?p=1", 200), + ("///resource1//a//b?p=1", 200), + ("///resource1//a//b/?p=1", 200), + ("/////resource1/a///b?p=1", 200), + ("/////resource1/a///b/?p=1", 200), + ("/resource2/a/b?p=1", 404), + ("//resource2//a//b?p=1", 404), + ("//resource2//a//b/?p=1", 200), + ("///resource2//a//b?p=1", 404), + ("///resource2//a//b/?p=1", 200), + ("/////resource2/a///b?p=1", 404), + ("/////resource2/a///b/?p=1", 200), + ], + ) + async def test_remove_and_merge_slash(self, path, status, cli) -> None: extra_middlewares = [ - web.normalize_path_middleware()] + web.normalize_path_middleware(append_slash=False, remove_slash=True) + ] - client = yield from cli(extra_middlewares) - resp = yield from client.get(path) + client = await cli(extra_middlewares) + resp = await client.get(path) assert resp.status == status + assert resp.url.query == URL(path).query + + async def test_cannot_remove_and_add_slash(self) -> None: + with pytest.raises(AssertionError): + web.normalize_path_middleware(append_slash=True, remove_slash=True) + + @pytest.mark.parametrize( + ["append_slash", "remove_slash"], + [ + (True, False), + (False, True), + (False, False), + ], + ) + async def test_open_redirects( + self, append_slash: bool, remove_slash: bool, aiohttp_client: Any + ) -> None: + async def handle(request: web.Request) -> web.StreamResponse: + pytest.fail( + msg="Security advisory 'GHSA-v6wp-4m6f-gcjg' test handler " + "matched unexpectedly", + pytrace=False, + ) + + app = web.Application( + middlewares=[ + web.normalize_path_middleware( + append_slash=append_slash, remove_slash=remove_slash + ) + ] + ) + app.add_routes([web.get("/", handle), web.get("/google.com", handle)]) + client = await aiohttp_client(app, server_kwargs={"skip_url_asserts": True}) + resp = await client.get("//google.com", allow_redirects=False) + assert resp.status == 308 + assert resp.headers["Location"] == "/google.com" + assert resp.url.query == URL("//google.com").query + + +async def test_old_style_middleware(loop, aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"OK") + + async def middleware_factory(app, handler): + async def middleware(request): + resp = await handler(request) + assert 200 == resp.status + resp.set_status(201) + resp.text = resp.text + "[old style middleware]" + return resp + + return middleware + + with pytest.warns(DeprecationWarning) as warning_checker: + app = web.Application() + app.middlewares.append(middleware_factory) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") + assert 201 == resp.status + txt = await resp.text() + assert "OK[old style middleware]" == txt + + assert len(warning_checker) == 1 + msg = str(warning_checker.list[0].message) + assert re.match( + "^old-style middleware " + '".' + 'middleware_factory at 0x[0-9a-fA-F]+>" ' + "deprecated, see #2252$", + msg, + ) + + +async def test_mixed_middleware(loop, aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"OK") + + async def m_old1(app, handler): + async def middleware(request): + resp = await handler(request) + resp.text += "[old style 1]" + return resp + + return middleware + + @web.middleware + async def m_new1(request, handler): + resp = await handler(request) + resp.text += "[new style 1]" + return resp + + async def m_old2(app, handler): + async def middleware(request): + resp = await handler(request) + resp.text += "[old style 2]" + return resp + + return middleware + + @web.middleware + async def m_new2(request, handler): + resp = await handler(request) + resp.text += "[new style 2]" + return resp + + middlewares = m_old1, m_new1, m_old2, m_new2 + + with pytest.warns(DeprecationWarning) as w: + app = web.Application(middlewares=middlewares) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") + assert 200 == resp.status + txt = await resp.text() + assert "OK[new style 2][old style 2][new style 1][old style 1]" == txt + + assert len(w) == 2 + tmpl = ( + "^old-style middleware " + '".' + '{} at 0x[0-9a-fA-F]+>" ' + "deprecated, see #2252$" + ) + p1 = tmpl.format("m_old1") + p2 = tmpl.format("m_old2") + + assert re.match(p2, str(w.list[0].message)) + assert re.match(p1, str(w.list[1].message)) + + +async def test_old_style_middleware_class(loop, aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"OK") + + class Middleware: + async def __call__(self, app, handler): + async def middleware(request): + resp = await handler(request) + assert 200 == resp.status + resp.set_status(201) + resp.text = resp.text + "[old style middleware]" + return resp + + return middleware + + with pytest.warns(DeprecationWarning) as warning_checker: + app = web.Application() + app.middlewares.append(Middleware()) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") + assert 201 == resp.status + txt = await resp.text() + assert "OK[old style middleware]" == txt + + assert len(warning_checker) == 1 + msg = str(warning_checker.list[0].message) + assert re.match( + "^old-style middleware " + '".Middleware object " + 'at 0x[0-9a-fA-F]+>" deprecated, see #2252$', + msg, + ) + + +async def test_new_style_middleware_class(loop, aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"OK") + + @web.middleware + class Middleware: + async def __call__(self, request, handler): + resp = await handler(request) + assert 200 == resp.status + resp.set_status(201) + resp.text = resp.text + "[new style middleware]" + return resp + + with pytest.warns(None) as warning_checker: + app = web.Application() + app.middlewares.append(Middleware()) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") + assert 201 == resp.status + txt = await resp.text() + assert "OK[new style middleware]" == txt + + assert len(warning_checker) == 0 + + +async def test_new_style_middleware_method(loop, aiohttp_client) -> None: + async def handler(request): + return web.Response(body=b"OK") + + class Middleware: + @web.middleware + async def call(self, request, handler): + resp = await handler(request) + assert 200 == resp.status + resp.set_status(201) + resp.text = resp.text + "[new style middleware]" + return resp + + with pytest.warns(None) as warning_checker: + app = web.Application() + app.middlewares.append(Middleware().call) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.get("/") + assert 201 == resp.status + txt = await resp.text() + assert "OK[new style middleware]" == txt + + assert len(warning_checker) == 0 diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 1a93624321a..9795270cd59 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -1,23 +1,25 @@ -"""Tests for aiohttp/server.py""" +# Tests for aiohttp/server.py import asyncio +import platform import socket from functools import partial -from html import escape from unittest import mock import pytest from aiohttp import helpers, http, streams, web +IS_MACOS = platform.system() == "Darwin" -@pytest.yield_fixture + +@pytest.fixture def make_srv(loop, manager): srv = None def maker(*, cls=web.RequestHandler, **kwargs): nonlocal srv - m = kwargs.pop('manager', manager) + m = kwargs.pop("manager", manager) srv = cls(m, loop=loop, access_log=None, **kwargs) return srv @@ -30,7 +32,10 @@ def maker(*, cls=web.RequestHandler, **kwargs): @pytest.fixture def manager(request_handler, loop): - return web.Server(request_handler, loop=loop) + async def maker(): + return web.Server(request_handler) + + return loop.run_until_complete(maker()) @pytest.fixture @@ -38,7 +43,10 @@ def srv(make_srv, transport): srv = make_srv() srv.connection_made(transport) transport.close.side_effect = partial(srv.connection_lost, None) - return srv + with mock.patch.object( + web.RequestHandler, "_drain_helper", side_effect=helpers.noop + ): + yield srv @pytest.fixture @@ -48,9 +56,7 @@ def buf(): @pytest.fixture def request_handler(): - - @asyncio.coroutine - def handler(request): + async def handler(request): return web.Response() m = mock.Mock() @@ -61,23 +67,22 @@ def handler(request): @pytest.fixture def handle_with_error(): def wrapper(exc=ValueError): - - @asyncio.coroutine - def handle(request): + async def handle(request): raise exc h = mock.Mock() h.side_effect = handle return h + return wrapper -@pytest.yield_fixture +@pytest.fixture def writer(srv): - return http.PayloadWriter(srv.writer, srv._loop) + return http.StreamWriter(srv, srv.transport, srv._loop) -@pytest.yield_fixture +@pytest.fixture def transport(buf): transport = mock.Mock() @@ -85,135 +90,89 @@ def write(chunk): buf.extend(chunk) transport.write.side_effect = write - transport.drain.side_effect = helpers.noop + transport.is_closing.return_value = False return transport -@pytest.fixture -def ceil(mocker): - def ceil(val): - return val - - mocker.patch('aiohttp.helpers.ceil').side_effect = ceil - - -@asyncio.coroutine -def test_shutdown(srv, loop, transport): - assert transport is srv.transport - - srv._keepalive = True - srv.data_received( - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') - - request_handler = srv._request_handlers[-1] - - yield from asyncio.sleep(0.1, loop=loop) - assert len(srv._waiters) == 1 - assert len(srv._request_handlers) == 1 - - t0 = loop.time() - yield from srv.shutdown() - t1 = loop.time() - - assert t1 - t0 < 0.05, t1-t0 - - assert transport.close.called - assert srv.transport is None - - assert not srv._request_handlers - assert request_handler.done() - - -@asyncio.coroutine -def test_shutdown_multiple_handlers(srv, loop, transport): - srv.handle_request = mock.Mock() - srv.handle_request.side_effect = helpers.noop - +async def test_shutdown(srv, transport) -> None: + loop = asyncio.get_event_loop() assert transport is srv.transport srv._keepalive = True - srv.data_received( - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n' - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') + task_handler = srv._task_handler - h1, h2 = srv._request_handlers - - yield from asyncio.sleep(0.1, loop=loop) - assert len(srv._waiters) == 2 - assert len(srv._request_handlers) == 2 + assert srv._waiter is not None + assert srv._task_handler is not None t0 = loop.time() - yield from srv.shutdown() + await srv.shutdown() t1 = loop.time() - assert t1 - t0 < 0.05, t1-t0 + assert t1 - t0 < 0.05, t1 - t0 assert transport.close.called assert srv.transport is None - assert not srv._request_handlers - assert h1.done() - assert h2.done() + assert not srv._task_handler + await asyncio.sleep(0.1) + assert task_handler.done() -@asyncio.coroutine -def test_double_shutdown(srv, transport): - yield from srv.shutdown() +async def test_double_shutdown(srv, transport) -> None: + await srv.shutdown() assert transport.close.called assert srv.transport is None transport.reset_mock() - yield from srv.shutdown() + await srv.shutdown() assert not transport.close.called assert srv.transport is None -@asyncio.coroutine -def test_close_after_response(srv, loop, transport): +async def test_shutdown_wait_error_handler(srv, transport) -> None: + loop = asyncio.get_event_loop() + + async def _error_handle(): + pass + + srv._error_handler = loop.create_task(_error_handle()) + await srv.shutdown() + assert srv._error_handler.done() + + +async def test_close_after_response(srv, transport) -> None: srv.data_received( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') - h, = srv._request_handlers + b"GET / HTTP/1.0\r\n" b"Host: example.com\r\n" b"Content-Length: 0\r\n\r\n" + ) + h = srv._task_handler - yield from asyncio.sleep(0.1, loop=loop) - assert len(srv._waiters) == 0 - assert len(srv._request_handlers) == 0 + await asyncio.sleep(0.1) + assert srv._waiter is None + assert srv._task_handler is None assert transport.close.called assert srv.transport is None - assert not srv._request_handlers assert h.done() -def test_connection_made(make_srv): +def test_connection_made(make_srv) -> None: srv = make_srv() - assert not srv._request_handlers - srv.connection_made(mock.Mock()) - assert not srv._request_handlers assert not srv._force_close -def test_connection_made_with_keepaplive(make_srv, transport): +def test_connection_made_with_tcp_keepaplive(make_srv, transport) -> None: srv = make_srv() sock = mock.Mock() transport.get_extra_info.return_value = sock srv.connection_made(transport) - sock.setsockopt.assert_called_with(socket.SOL_SOCKET, - socket.SO_KEEPALIVE, 1) + sock.setsockopt.assert_called_with(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) -def test_connection_made_without_keepaplive(make_srv): +def test_connection_made_without_tcp_keepaplive(make_srv) -> None: srv = make_srv(tcp_keepalive=False) sock = mock.Mock() @@ -223,33 +182,31 @@ def test_connection_made_without_keepaplive(make_srv): assert not sock.setsockopt.called -def test_eof_received(make_srv): +def test_eof_received(make_srv) -> None: srv = make_srv() srv.connection_made(mock.Mock()) srv.eof_received() # assert srv.reader._eof -@asyncio.coroutine -def test_connection_lost(srv, loop): +async def test_connection_lost(srv) -> None: srv.data_received( - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') + b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 0\r\n\r\n" + ) srv._keepalive = True - handle = srv._request_handlers[0] - yield from asyncio.sleep(0, loop=loop) # wait for .start() starting + handle = srv._task_handler + await asyncio.sleep(0) # wait for .start() starting srv.connection_lost(None) assert srv._force_close - yield from handle + await handle - assert not srv._request_handlers + assert not srv._task_handler -def test_srv_keep_alive(srv): +def test_srv_keep_alive(srv) -> None: assert not srv._keepalive srv.keep_alive(True) @@ -259,90 +216,72 @@ def test_srv_keep_alive(srv): assert not srv._keepalive -def test_slow_request(make_srv): - with pytest.warns(DeprecationWarning): - make_srv(slow_request_timeout=0.01) +def test_srv_keep_alive_disable(srv) -> None: + handle = srv._keepalive_handle = mock.Mock() + srv.keep_alive(False) + assert not srv._keepalive + assert srv._keepalive_handle is None + handle.cancel.assert_called_with() -@asyncio.coroutine -def test_simple(srv, loop, buf): - srv.data_received( - b'GET / HTTP/1.1\r\n\r\n') - - yield from asyncio.sleep(0, loop=loop) - assert buf.startswith(b'HTTP/1.1 200 OK\r\n') +async def test_simple(srv, buf) -> None: + srv.data_received(b"GET / HTTP/1.1\r\n\r\n") -@asyncio.coroutine -def test_bad_method(srv, loop, buf): - srv.data_received( - b'!@#$ / HTTP/1.0\r\n' - b'Host: example.com\r\n\r\n') + await asyncio.sleep(0.05) + assert buf.startswith(b"HTTP/1.1 200 OK\r\n") - yield from asyncio.sleep(0, loop=loop) - assert buf.startswith(b'HTTP/1.0 400 Bad Request\r\n') +async def test_bad_method(srv, buf) -> None: + srv.data_received(b":BAD; / HTTP/1.0\r\n" b"Host: example.com\r\n\r\n") -@asyncio.coroutine -def test_internal_error(srv, loop, buf): - srv._request_parser = mock.Mock() - srv._request_parser.feed_data.side_effect = TypeError + await asyncio.sleep(0) + assert buf.startswith(b"HTTP/1.0 400 Bad Request\r\n") - srv.data_received( - b'!@#$ / HTTP/1.0\r\n' - b'Host: example.com\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) - assert buf.startswith(b'HTTP/1.0 500 Internal Server Error\r\n') +async def test_line_too_long(srv, buf) -> None: + srv.data_received(b"".join([b"a" for _ in range(10000)]) + b"\r\n\r\n") + await asyncio.sleep(0) + assert buf.startswith(b"HTTP/1.0 400 Bad Request\r\n") -@asyncio.coroutine -def test_line_too_long(srv, loop, buf): - srv.data_received(b''.join([b'a' for _ in range(10000)]) + b'\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) - assert buf.startswith(b'HTTP/1.0 400 Bad Request\r\n') - - -@asyncio.coroutine -def test_invalid_content_length(srv, loop, buf): +async def test_invalid_content_length(srv, buf) -> None: srv.data_received( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n' - b'Content-Length: sdgg\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) + b"GET / HTTP/1.0\r\n" b"Host: example.com\r\n" b"Content-Length: sdgg\r\n\r\n" + ) + await asyncio.sleep(0) - assert buf.startswith(b'HTTP/1.0 400 Bad Request\r\n') + assert buf.startswith(b"HTTP/1.0 400 Bad Request\r\n") -@asyncio.coroutine -def test_handle_error__utf(make_srv, buf, transport, loop, request_handler): - request_handler.side_effect = RuntimeError('что-то пошло не так') +async def test_unhandled_runtime_error(make_srv, transport, request_handler): + async def handle(request): + resp = web.Response() + resp.write_eof = mock.Mock() + resp.write_eof.side_effect = RuntimeError + return resp - srv = make_srv(debug=True) + srv = make_srv(lingering_time=0) + srv.debug = True srv.connection_made(transport) - srv.keep_alive(True) - srv.logger = mock.Mock() + srv.logger.exception = mock.Mock() + request_handler.side_effect = handle srv.data_received( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) - - assert b'HTTP/1.0 500 Internal Server Error' in buf - assert b'Content-Type: text/html; charset=utf-8' in buf - pattern = escape("RuntimeError: что-то пошло не так") - assert pattern.encode('utf-8') in buf - assert not srv._keepalive + b"GET / HTTP/1.0\r\n" b"Host: example.com\r\n" b"Content-Length: 0\r\n\r\n" + ) + await srv._task_handler + assert request_handler.called srv.logger.exception.assert_called_with( - "Error handling request", exc_info=mock.ANY) + "Unhandled runtime exception", exc_info=mock.ANY + ) -@asyncio.coroutine -def test_handle_uncompleted( - make_srv, loop, transport, handle_with_error, request_handler): +async def test_handle_uncompleted( + make_srv, transport, handle_with_error, request_handler +): closed = False def close(): @@ -357,20 +296,24 @@ def close(): request_handler.side_effect = handle_with_error() srv.data_received( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n' - b'Content-Length: 50000\r\n\r\n') + b"GET / HTTP/1.0\r\n" b"Host: example.com\r\n" b"Content-Length: 50000\r\n\r\n" + ) - yield from srv._request_handlers[0] + await srv._task_handler assert request_handler.called assert closed - srv.logger.exception.assert_called_with( - "Error handling request", exc_info=mock.ANY) - - -@asyncio.coroutine -def test_handle_uncompleted_pipe( - make_srv, loop, transport, request_handler, handle_with_error): + srv.logger.exception.assert_called_with("Error handling request", exc_info=mock.ANY) + + +@pytest.mark.xfail( + IS_MACOS, + raises=TypeError, + reason="Intermittently fails on macOS", + strict=False, +) +async def test_handle_uncompleted_pipe( + make_srv, transport, request_handler, handle_with_error +): closed = False normal_completed = False @@ -384,246 +327,226 @@ def close(): srv.connection_made(transport) srv.logger.exception = mock.Mock() - @asyncio.coroutine - def handle(request): + async def handle(request): nonlocal normal_completed normal_completed = True - yield from asyncio.sleep(0.05, loop=loop) + await asyncio.sleep(0.05) return web.Response() # normal request_handler.side_effect = handle srv.data_received( - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) + b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 0\r\n\r\n" + ) + await asyncio.sleep(0.01) # with exception request_handler.side_effect = handle_with_error() srv.data_received( - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 50000\r\n\r\n') + b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 50000\r\n\r\n" + ) - assert len(srv._request_handlers) == 2 + assert srv._task_handler - yield from asyncio.sleep(0, loop=loop) + await asyncio.sleep(0.01) - yield from srv._request_handlers[0] + await srv._task_handler assert normal_completed assert request_handler.called assert closed - srv.logger.exception.assert_called_with( - "Error handling request", exc_info=mock.ANY) + srv.logger.exception.assert_called_with("Error handling request", exc_info=mock.ANY) -@asyncio.coroutine -def test_lingering(srv, loop, transport): +async def test_lingering(srv, transport) -> None: assert not transport.close.called - @asyncio.coroutine - def handle(message, request, writer): + async def handle(message, request, writer): pass - srv.handle_request = handle - srv.data_received( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n' - b'Content-Length: 3\r\n\r\n') - - yield from asyncio.sleep(0.05, loop=loop) - assert not transport.close.called + with mock.patch.object( + web.RequestHandler, "handle_request", create=True, new=handle + ): + srv.data_received( + b"GET / HTTP/1.0\r\n" b"Host: example.com\r\n" b"Content-Length: 3\r\n\r\n" + ) - srv.data_received(b'123') + await asyncio.sleep(0.05) + assert not transport.close.called - yield from asyncio.sleep(0, loop=loop) - transport.close.assert_called_with() + srv.data_received(b"123") + await asyncio.sleep(0) + transport.close.assert_called_with() -@asyncio.coroutine -def test_lingering_disabled(make_srv, loop, transport, request_handler): - @asyncio.coroutine - def handle_request(request): - yield from asyncio.sleep(0, loop=loop) +async def test_lingering_disabled(make_srv, transport, request_handler) -> None: + async def handle_request(request): + await asyncio.sleep(0) srv = make_srv(lingering_time=0) srv.connection_made(transport) request_handler.side_effect = handle_request - yield from asyncio.sleep(0, loop=loop) + await asyncio.sleep(0) assert not transport.close.called srv.data_received( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n' - b'Content-Length: 50\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) + b"GET / HTTP/1.0\r\n" b"Host: example.com\r\n" b"Content-Length: 50\r\n\r\n" + ) + await asyncio.sleep(0) assert not transport.close.called - yield from asyncio.sleep(0, loop=loop) + await asyncio.sleep(0.05) transport.close.assert_called_with() -@asyncio.coroutine -def test_lingering_timeout(make_srv, loop, transport, ceil, request_handler): - - @asyncio.coroutine - def handle_request(request): - yield from asyncio.sleep(0, loop=loop) +async def test_lingering_timeout(make_srv, transport, request_handler): + async def handle_request(request): + await asyncio.sleep(0) srv = make_srv(lingering_time=1e-30) srv.connection_made(transport) request_handler.side_effect = handle_request - yield from asyncio.sleep(0, loop=loop) + await asyncio.sleep(0.05) assert not transport.close.called srv.data_received( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n' - b'Content-Length: 50\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) + b"GET / HTTP/1.0\r\n" b"Host: example.com\r\n" b"Content-Length: 50\r\n\r\n" + ) + await asyncio.sleep(0) assert not transport.close.called - yield from asyncio.sleep(0, loop=loop) + await asyncio.sleep(0.05) transport.close.assert_called_with() -def test_handle_cancel(make_srv, loop, transport): - log = mock.Mock() - - srv = make_srv(logger=log, debug=True) +async def test_handle_payload_access_error(make_srv, transport, request_handler): + srv = make_srv(lingering_time=0) srv.connection_made(transport) - - def handle_request(message, payload, writer): - yield from asyncio.sleep(10, loop=loop) - - srv.handle_request = handle_request - - @asyncio.coroutine - def cancel(): - srv._request_handlers[0].cancel() - srv.data_received( - b'GET / HTTP/1.0\r\n' - b'Content-Length: 10\r\n' - b'Host: example.com\r\n\r\n') + b"POST /test HTTP/1.1\r\n" b"Content-Length: 9\r\n\r\n" b"some data" + ) + # start request_handler task + await asyncio.sleep(0.05) - loop.run_until_complete( - asyncio.gather(srv._request_handlers[0], cancel(), loop=loop)) - assert log.debug.called + with pytest.raises(web.PayloadAccessError): + await request_handler.call_args[0][0].content.read() -def test_handle_cancelled(make_srv, loop, transport): +async def test_handle_cancel(make_srv, transport) -> None: log = mock.Mock() srv = make_srv(logger=log, debug=True) srv.connection_made(transport) - srv.handle_request = mock.Mock() - # start request_handler task - loop.run_until_complete(asyncio.sleep(0, loop=loop)) + async def handle_request(message, payload, writer): + await asyncio.sleep(10) - srv.data_received( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n\r\n') + async def cancel(): + srv._task_handler.cancel() - r_handler = srv._request_handlers[0] - assert loop.run_until_complete(r_handler) is None + with mock.patch.object( + web.RequestHandler, "handle_request", create=True, new=handle_request + ): + srv.data_received( + b"GET / HTTP/1.0\r\n" b"Content-Length: 10\r\n" b"Host: example.com\r\n\r\n" + ) + await asyncio.gather(srv._task_handler, cancel()) + assert log.debug.called -@asyncio.coroutine -def test_handle_400(srv, loop, buf, transport): - srv.data_received(b'GET / HT/asd\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) - assert b'400 Bad Request' in buf - - -def test_handle_500(srv, loop, buf, transport, request_handler): - request_handler.side_effect = ValueError - - srv.data_received( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n\r\n') - loop.run_until_complete(srv._request_handlers[0]) +async def test_handle_cancelled(make_srv, transport) -> None: + log = mock.Mock() - assert b'500 Internal Server Error' in buf + srv = make_srv(logger=log, debug=True) + srv.connection_made(transport) + # start request_handler task + await asyncio.sleep(0) -@asyncio.coroutine -def test_keep_alive(make_srv, loop, transport, ceil): - srv = make_srv(keepalive_timeout=0.05) - srv.connection_made(transport) + srv.data_received(b"GET / HTTP/1.0\r\n" b"Host: example.com\r\n\r\n") - srv.keep_alive(True) - srv.handle_request = mock.Mock() - srv.handle_request.return_value = helpers.create_future(loop) - srv.handle_request.return_value.set_result(1) + r_handler = srv._task_handler + assert (await r_handler) is None - srv.data_received( - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) - assert len(srv._waiters) == 1 - assert srv._keepalive_handle is not None - assert not transport.close.called +async def test_handle_400(srv, buf, transport) -> None: + srv.data_received(b"GET / HT/asd\r\n\r\n") - yield from asyncio.sleep(0.1, loop=loop) - assert transport.close.called - assert srv._waiters[0].cancelled + await asyncio.sleep(0) + assert b"400 Bad Request" in buf -def test_srv_process_request_without_timeout(make_srv, loop, transport): +async def test_keep_alive(make_srv, transport) -> None: + loop = asyncio.get_event_loop() + srv = make_srv(keepalive_timeout=0.05) + future = loop.create_future() + future.set_result(1) + + with mock.patch.object( + web.RequestHandler, "KEEPALIVE_RESCHEDULE_DELAY", new=0.1 + ), mock.patch.object( + web.RequestHandler, "handle_request", create=True, return_value=future + ): + srv.connection_made(transport) + srv.keep_alive(True) + srv.data_received( + b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 0\r\n\r\n" + ) + + waiter = None + while waiter is None: + await asyncio.sleep(0) + waiter = srv._waiter + assert srv._keepalive_handle is not None + assert not transport.close.called + + await asyncio.sleep(0.2) + assert transport.close.called + assert waiter.cancelled + + +async def test_srv_process_request_without_timeout(make_srv, transport) -> None: srv = make_srv() srv.connection_made(transport) - srv.data_received( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n\r\n') + srv.data_received(b"GET / HTTP/1.0\r\n" b"Host: example.com\r\n\r\n") - loop.run_until_complete(srv._request_handlers[0]) + await srv._task_handler assert transport.close.called -def test_keep_alive_timeout_default(srv): +def test_keep_alive_timeout_default(srv) -> None: assert 75 == srv.keepalive_timeout -def test_keep_alive_timeout_nondefault(make_srv): +def test_keep_alive_timeout_nondefault(make_srv) -> None: srv = make_srv(keepalive_timeout=10) assert 10 == srv.keepalive_timeout -@asyncio.coroutine -def test_supports_connect_method(srv, loop, transport, request_handler): +async def test_supports_connect_method(srv, transport, request_handler) -> None: srv.data_received( - b'CONNECT aiohttp.readthedocs.org:80 HTTP/1.0\r\n' - b'Content-Length: 0\r\n\r\n') - yield from asyncio.sleep(0.1, loop=loop) + b"CONNECT aiohttp.readthedocs.org:80 HTTP/1.0\r\n" b"Content-Length: 0\r\n\r\n" + ) + await asyncio.sleep(0.1) assert request_handler.called - assert isinstance( - request_handler.call_args[0][0].content, - streams.FlowControlStreamReader) + assert isinstance(request_handler.call_args[0][0].content, streams.StreamReader) -@asyncio.coroutine -def test_content_length_0(srv, loop, request_handler): +async def test_content_length_0(srv, request_handler) -> None: srv.data_received( - b'GET / HTTP/1.1\r\n' - b'Host: example.org\r\n' - b'Content-Length: 0\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) + b"GET / HTTP/1.1\r\n" b"Host: example.org\r\n" b"Content-Length: 0\r\n\r\n" + ) + await asyncio.sleep(0.01) assert request_handler.called assert request_handler.call_args[0][0].content == streams.EMPTY_PAYLOAD -def test_rudimentary_transport(srv, loop): +def test_rudimentary_transport(srv) -> None: transport = mock.Mock() srv.connection_made(transport) @@ -646,46 +569,12 @@ def test_rudimentary_transport(srv, loop): assert not srv._reading_paused -@asyncio.coroutine -def test_close(srv, loop, transport): +async def test_pipeline_multiple_messages(srv, transport, request_handler): transport.close.side_effect = partial(srv.connection_lost, None) - srv._max_concurrent_handlers = 2 - srv.connection_made(transport) - - srv.handle_request = mock.Mock() - srv.handle_request.side_effect = helpers.noop - - assert transport is srv.transport - - srv._keepalive = True - srv.data_received( - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n' - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') - - yield from asyncio.sleep(0, loop=loop) - assert len(srv._request_handlers) == 2 - assert len(srv._waiters) == 2 - - srv.close() - yield from asyncio.sleep(0, loop=loop) - assert len(srv._request_handlers) == 0 - assert srv.transport is None - assert transport.close.called - - -@asyncio.coroutine -def test_pipeline_multiple_messages(srv, loop, transport, request_handler): - transport.close.side_effect = partial(srv.connection_lost, None) - srv._max_concurrent_handlers = 1 processed = 0 - @asyncio.coroutine - def handle(request): + async def handle(request): nonlocal processed processed += 1 return web.Response() @@ -696,67 +585,174 @@ def handle(request): srv._keepalive = True srv.data_received( - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n' - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') - - assert len(srv._request_handlers) == 1 - assert len(srv._messages) == 1 - assert len(srv._waiters) == 0 - - yield from asyncio.sleep(0, loop=loop) - assert len(srv._request_handlers) == 1 - assert len(srv._waiters) == 1 + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 0\r\n\r\n" + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 0\r\n\r\n" + ) + + assert srv._task_handler is not None + assert len(srv._messages) == 2 + assert srv._waiter is not None + + await asyncio.sleep(0.05) + assert srv._task_handler is not None + assert srv._waiter is not None assert processed == 2 -@asyncio.coroutine -def test_pipeline_response_order(srv, loop, buf, transport, request_handler): +async def test_pipeline_response_order(srv, buf, transport, request_handler): transport.close.side_effect = partial(srv.connection_lost, None) srv._keepalive = True processed = [] - @asyncio.coroutine - def handle1(request): + async def handle1(request): nonlocal processed - yield from asyncio.sleep(0.01, loop=loop) + await asyncio.sleep(0.01) resp = web.StreamResponse() - yield from resp.prepare(request) - yield from resp.write(b'test1') - yield from resp.write_eof() + await resp.prepare(request) + await resp.write(b"test1") + await resp.write_eof() processed.append(1) return resp request_handler.side_effect = handle1 srv.data_received( - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) + b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 0\r\n\r\n" + ) + await asyncio.sleep(0.01) # second - @asyncio.coroutine - def handle2(request): + + async def handle2(request): nonlocal processed resp = web.StreamResponse() - yield from resp.prepare(request) - resp.write(b'test2') - yield from resp.write_eof() + await resp.prepare(request) + await resp.write(b"test2") + await resp.write_eof() processed.append(2) return resp request_handler.side_effect = handle2 srv.data_received( - b'GET / HTTP/1.1\r\n' - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') - yield from asyncio.sleep(0, loop=loop) + b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 0\r\n\r\n" + ) + await asyncio.sleep(0.01) - assert len(srv._request_handlers) == 2 + assert srv._task_handler is not None - yield from asyncio.sleep(0.1, loop=loop) + await asyncio.sleep(0.1) assert processed == [1, 2] + + +def test_data_received_close(srv) -> None: + srv.close() + srv.data_received( + b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 0\r\n\r\n" + ) + + assert not srv._messages + + +def test_data_received_force_close(srv) -> None: + srv.force_close() + srv.data_received( + b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 0\r\n\r\n" + ) + + assert not srv._messages + + +async def test__process_keepalive(srv) -> None: + loop = asyncio.get_event_loop() + # wait till the waiter is waiting + await asyncio.sleep(0) + + assert srv._waiter is not None + + srv._keepalive_time = 1 + srv._keepalive = True + srv._keepalive_timeout = 1 + expired_time = srv._keepalive_time + srv._keepalive_timeout + 1 + with mock.patch.object(loop, "time", return_value=expired_time): + srv._process_keepalive() + assert srv._force_close + + +async def test__process_keepalive_schedule_next(srv) -> None: + loop = asyncio.get_event_loop() + # wait till the waiter is waiting + await asyncio.sleep(0) + + srv._keepalive = True + srv._keepalive_time = 1 + srv._keepalive_timeout = 1 + expire_time = srv._keepalive_time + srv._keepalive_timeout + with mock.patch.object(loop, "time", return_value=expire_time): + with mock.patch.object(loop, "call_later") as call_later_patched: + srv._process_keepalive() + call_later_patched.assert_called_with(1, srv._process_keepalive) + + +async def test__process_keepalive_force_close(srv) -> None: + loop = asyncio.get_event_loop() + srv._force_close = True + with mock.patch.object(loop, "call_at") as call_at_patched: + srv._process_keepalive() + assert not call_at_patched.called + + +async def test_two_data_received_without_waking_up_start_task(srv) -> None: + # make a chance to srv.start() method start waiting for srv._waiter + await asyncio.sleep(0.01) + assert srv._waiter is not None + + srv.data_received( + b"GET / HTTP/1.1\r\n" b"Host: ex.com\r\n" b"Content-Length: 1\r\n\r\n" b"a" + ) + srv.data_received( + b"GET / HTTP/1.1\r\n" b"Host: ex.com\r\n" b"Content-Length: 1\r\n\r\n" b"b" + ) + + assert len(srv._messages) == 2 + assert srv._waiter.done() + await asyncio.sleep(0.01) + + +async def test_client_disconnect(aiohttp_server) -> None: + async def handler(request): + buf = b"" + with pytest.raises(ConnectionError): + while len(buf) < 10: + buf += await request.content.read(10) + # return with closed transport means premature client disconnection + return web.Response() + + logger = mock.Mock() + app = web.Application() + app._debug = True + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app, logger=logger) + + _, writer = await asyncio.open_connection("127.0.0.1", server.port) + writer.write( + """POST / HTTP/1.1\r +Connection: keep-alive\r +Content-Length: 10\r +Host: localhost:{port}\r +\r +""".format( + port=server.port + ).encode( + "ascii" + ) + ) + await writer.drain() + await asyncio.sleep(0.1) + writer.write(b"x") + writer.close() + await asyncio.sleep(0.1) + logger.debug.assert_called_with("Ignored premature client disconnection") diff --git a/tests/test_web_request.py b/tests/test_web_request.py index 34f522b59b4..f251e04f4b9 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -1,152 +1,254 @@ import asyncio -from collections import MutableMapping +import socket +from collections.abc import MutableMapping +from typing import Any from unittest import mock import pytest -from multidict import CIMultiDict, MultiDict +from multidict import CIMultiDict, CIMultiDictProxy, MultiDict from yarl import URL from aiohttp import HttpVersion +from aiohttp.helpers import DEBUG +from aiohttp.http_parser import RawRequestMessage from aiohttp.streams import StreamReader from aiohttp.test_utils import make_mocked_request -from aiohttp.web import HTTPRequestEntityTooLarge +from aiohttp.web import BaseRequest, HTTPRequestEntityTooLarge @pytest.fixture -def make_request(): - return make_mocked_request +def protocol(): + return mock.Mock(_reading_paused=False) + + +def test_base_ctor() -> None: + message = RawRequestMessage( + "GET", + "/path/to?a=1&b=2", + HttpVersion(1, 1), + CIMultiDictProxy(CIMultiDict()), + (), + False, + False, + False, + False, + URL("/path/to?a=1&b=2"), + ) + + req = BaseRequest( + message, mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock() + ) + + assert "GET" == req.method + assert HttpVersion(1, 1) == req.version + assert req.host == socket.getfqdn() + assert "/path/to?a=1&b=2" == req.path_qs + assert "/path/to" == req.path + assert "a=1&b=2" == req.query_string + assert CIMultiDict() == req.headers + assert () == req.raw_headers + + get = req.query + assert MultiDict([("a", "1"), ("b", "2")]) == get + # second call should return the same object + assert get is req.query + + assert req.keep_alive + + assert req -def test_ctor(make_request): - req = make_request('GET', '/path/to?a=1&b=2') +def test_ctor() -> None: + req = make_mocked_request("GET", "/path/to?a=1&b=2") - assert 'GET' == req.method + assert "GET" == req.method assert HttpVersion(1, 1) == req.version - assert req.host is None - assert '/path/to?a=1&b=2' == req.path_qs - assert '/path/to' == req.path - assert 'a=1&b=2' == req.query_string + # MacOS may return CamelCased host name, need .lower() + assert req.host.lower() == socket.getfqdn().lower() + assert "/path/to?a=1&b=2" == req.path_qs + assert "/path/to" == req.path + assert "a=1&b=2" == req.query_string assert CIMultiDict() == req.headers assert () == req.raw_headers - assert req.message == req._message get = req.query - assert MultiDict([('a', '1'), ('b', '2')]) == get + assert MultiDict([("a", "1"), ("b", "2")]) == get # second call should return the same object assert get is req.query assert req.keep_alive # just make sure that all lines of make_mocked_request covered - headers = CIMultiDict(FOO='bar') + headers = CIMultiDict(FOO="bar") payload = mock.Mock() protocol = mock.Mock() app = mock.Mock() - req = make_request('GET', '/path/to?a=1&b=2', headers=headers, - protocol=protocol, payload=payload, app=app) + req = make_mocked_request( + "GET", + "/path/to?a=1&b=2", + headers=headers, + protocol=protocol, + payload=payload, + app=app, + ) assert req.app is app assert req.content is payload assert req.protocol is protocol assert req.transport is protocol.transport assert req.headers == headers - assert req.raw_headers == ((b'Foo', b'bar'),) + assert req.raw_headers == ((b"FOO", b"bar"),) + assert req.task is req._task + + +def test_deprecated_message() -> None: + req = make_mocked_request("GET", "/path/to?a=1&b=2") + with pytest.warns(DeprecationWarning): + assert req.message == req._message -def test_doubleslashes(make_request): +def test_doubleslashes() -> None: # NB: //foo/bar is an absolute URL with foo netloc and /bar path - req = make_request('GET', '/bar//foo/') - assert '/bar//foo/' == req.path + req = make_mocked_request("GET", "/bar//foo/") + assert "/bar//foo/" == req.path -def test_content_type_not_specified(make_request): - req = make_request('Get', '/') - assert 'application/octet-stream' == req.content_type +def test_content_type_not_specified() -> None: + req = make_mocked_request("Get", "/") + assert "application/octet-stream" == req.content_type -def test_content_type_from_spec(make_request): - req = make_request('Get', '/', - CIMultiDict([('CONTENT-TYPE', 'application/json')])) - assert 'application/json' == req.content_type +def test_content_type_from_spec() -> None: + req = make_mocked_request( + "Get", "/", CIMultiDict([("CONTENT-TYPE", "application/json")]) + ) + assert "application/json" == req.content_type -def test_content_type_from_spec_with_charset(make_request): - req = make_request( - 'Get', '/', - CIMultiDict([('CONTENT-TYPE', 'text/html; charset=UTF-8')])) - assert 'text/html' == req.content_type - assert 'UTF-8' == req.charset +def test_content_type_from_spec_with_charset() -> None: + req = make_mocked_request( + "Get", "/", CIMultiDict([("CONTENT-TYPE", "text/html; charset=UTF-8")]) + ) + assert "text/html" == req.content_type + assert "UTF-8" == req.charset -def test_calc_content_type_on_getting_charset(make_request): - req = make_request( - 'Get', '/', - CIMultiDict([('CONTENT-TYPE', 'text/html; charset=UTF-8')])) - assert 'UTF-8' == req.charset - assert 'text/html' == req.content_type +def test_calc_content_type_on_getting_charset() -> None: + req = make_mocked_request( + "Get", "/", CIMultiDict([("CONTENT-TYPE", "text/html; charset=UTF-8")]) + ) + assert "UTF-8" == req.charset + assert "text/html" == req.content_type -def test_urlencoded_querystring(make_request): - req = make_request('GET', - '/yandsearch?text=%D1%82%D0%B5%D0%BA%D1%81%D1%82') - assert {'text': 'текст'} == req.query +def test_urlencoded_querystring() -> None: + req = make_mocked_request("GET", "/yandsearch?text=%D1%82%D0%B5%D0%BA%D1%81%D1%82") + assert {"text": "текст"} == req.query -def test_non_ascii_path(make_request): - req = make_request('GET', '/путь') - assert '/путь' == req.path +def test_non_ascii_path() -> None: + req = make_mocked_request("GET", "/путь") + assert "/путь" == req.path -def test_non_ascii_raw_path(make_request): - req = make_request('GET', '/путь') - assert '/путь' == req.raw_path +def test_non_ascii_raw_path() -> None: + req = make_mocked_request("GET", "/путь") + assert "/путь" == req.raw_path -def test_content_length(make_request): - req = make_request('Get', '/', - CIMultiDict([('CONTENT-LENGTH', '123')])) +def test_content_length() -> None: + req = make_mocked_request("Get", "/", CIMultiDict([("CONTENT-LENGTH", "123")])) assert 123 == req.content_length -def test_non_keepalive_on_http10(make_request): - req = make_request('GET', '/', version=HttpVersion(1, 0)) +def test_range_to_slice_head() -> None: + def bytes_gen(size): + for i in range(size): + yield i % 256 + + payload = bytearray(bytes_gen(10000)) + req = make_mocked_request( + "GET", "/", headers=CIMultiDict([("RANGE", "bytes=0-499")]), payload=payload + ) + assert isinstance(req.http_range, slice) + assert req.content[req.http_range] == payload[:500] + + +def test_range_to_slice_mid() -> None: + def bytes_gen(size): + for i in range(size): + yield i % 256 + + payload = bytearray(bytes_gen(10000)) + req = make_mocked_request( + "GET", "/", headers=CIMultiDict([("RANGE", "bytes=500-999")]), payload=payload + ) + assert isinstance(req.http_range, slice) + assert req.content[req.http_range] == payload[500:1000] + + +def test_range_to_slice_tail_start() -> None: + def bytes_gen(size): + for i in range(size): + yield i % 256 + + payload = bytearray(bytes_gen(10000)) + req = make_mocked_request( + "GET", "/", headers=CIMultiDict([("RANGE", "bytes=9500-")]), payload=payload + ) + assert isinstance(req.http_range, slice) + assert req.content[req.http_range] == payload[-500:] + + +def test_range_to_slice_tail_stop() -> None: + def bytes_gen(size): + for i in range(size): + yield i % 256 + + payload = bytearray(bytes_gen(10000)) + req = make_mocked_request( + "GET", "/", headers=CIMultiDict([("RANGE", "bytes=-500")]), payload=payload + ) + assert isinstance(req.http_range, slice) + assert req.content[req.http_range] == payload[-500:] + + +def test_non_keepalive_on_http10() -> None: + req = make_mocked_request("GET", "/", version=HttpVersion(1, 0)) assert not req.keep_alive -def test_non_keepalive_on_closing(make_request): - req = make_request('GET', '/', closing=True) +def test_non_keepalive_on_closing() -> None: + req = make_mocked_request("GET", "/", closing=True) assert not req.keep_alive -@asyncio.coroutine -def test_call_POST_on_GET_request(make_request): - req = make_request('GET', '/') +async def test_call_POST_on_GET_request() -> None: + req = make_mocked_request("GET", "/") - ret = yield from req.post() + ret = await req.post() assert CIMultiDict() == ret -@asyncio.coroutine -def test_call_POST_on_weird_content_type(make_request): - req = make_request( - 'POST', '/', - headers=CIMultiDict({'CONTENT-TYPE': 'something/weird'})) +async def test_call_POST_on_weird_content_type() -> None: + req = make_mocked_request( + "POST", "/", headers=CIMultiDict({"CONTENT-TYPE": "something/weird"}) + ) - ret = yield from req.post() + ret = await req.post() assert CIMultiDict() == ret -@asyncio.coroutine -def test_call_POST_twice(make_request): - req = make_request('GET', '/') +async def test_call_POST_twice() -> None: + req = make_mocked_request("GET", "/") - ret1 = yield from req.post() - ret2 = yield from req.post() + ret1 = await req.post() + ret2 = await req.post() assert ret1 is ret2 -def test_no_request_cookies(make_request): - req = make_request('GET', '/') +def test_no_request_cookies() -> None: + req = make_mocked_request("GET", "/") assert req.cookies == {} @@ -154,218 +256,489 @@ def test_no_request_cookies(make_request): assert cookies is req.cookies -def test_request_cookie(make_request): - headers = CIMultiDict(COOKIE='cookie1=value1; cookie2=value2') - req = make_request('GET', '/', headers=headers) +def test_request_cookie() -> None: + headers = CIMultiDict(COOKIE="cookie1=value1; cookie2=value2") + req = make_mocked_request("GET", "/", headers=headers) - assert req.cookies == {'cookie1': 'value1', - 'cookie2': 'value2'} + assert req.cookies == {"cookie1": "value1", "cookie2": "value2"} -def test_request_cookie__set_item(make_request): - headers = CIMultiDict(COOKIE='name=value') - req = make_request('GET', '/', headers=headers) +def test_request_cookie__set_item() -> None: + headers = CIMultiDict(COOKIE="name=value") + req = make_mocked_request("GET", "/", headers=headers) - assert req.cookies == {'name': 'value'} + assert req.cookies == {"name": "value"} with pytest.raises(TypeError): - req.cookies['my'] = 'value' + req.cookies["my"] = "value" -def test_match_info(make_request): - req = make_request('GET', '/') +def test_match_info() -> None: + req = make_mocked_request("GET", "/") assert req._match_info is req.match_info -def test_request_is_mutable_mapping(make_request): - req = make_request('GET', '/') +def test_request_is_mutable_mapping() -> None: + req = make_mocked_request("GET", "/") assert isinstance(req, MutableMapping) - req['key'] = 'value' - assert 'value' == req['key'] + req["key"] = "value" + assert "value" == req["key"] -def test_request_delitem(make_request): - req = make_request('GET', '/') - req['key'] = 'value' - assert 'value' == req['key'] - del req['key'] - assert 'key' not in req +def test_request_delitem() -> None: + req = make_mocked_request("GET", "/") + req["key"] = "value" + assert "value" == req["key"] + del req["key"] + assert "key" not in req -def test_request_len(make_request): - req = make_request('GET', '/') +def test_request_len() -> None: + req = make_mocked_request("GET", "/") assert len(req) == 0 - req['key'] = 'value' + req["key"] = "value" assert len(req) == 1 -def test_request_iter(make_request): - req = make_request('GET', '/') - req['key'] = 'value' - req['key2'] = 'value2' - assert set(req) == {'key', 'key2'} +def test_request_iter() -> None: + req = make_mocked_request("GET", "/") + req["key"] = "value" + req["key2"] = "value2" + assert set(req) == {"key", "key2"} -def test___repr__(make_request): - req = make_request('GET', '/path/to') +def test___repr__() -> None: + req = make_mocked_request("GET", "/path/to") assert "" == repr(req) -def test___repr___non_ascii_path(make_request): - req = make_request('GET', '/path/\U0001f415\U0001f308') +def test___repr___non_ascii_path() -> None: + req = make_mocked_request("GET", "/path/\U0001f415\U0001f308") assert "" == repr(req) -def test_http_scheme(make_request): - req = make_request('GET', '/') +def test_http_scheme() -> None: + req = make_mocked_request("GET", "/", headers={"Host": "example.com"}) assert "http" == req.scheme + assert req.secure is False -def test_https_scheme_by_ssl_transport(make_request): - req = make_request('GET', '/', sslcontext=True) +def test_https_scheme_by_ssl_transport() -> None: + req = make_mocked_request( + "GET", "/", headers={"Host": "example.com"}, sslcontext=True + ) assert "https" == req.scheme + assert req.secure is True + + +def test_single_forwarded_header() -> None: + header = "by=identifier;for=identifier;host=identifier;proto=identifier" + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert req.forwarded[0]["by"] == "identifier" + assert req.forwarded[0]["for"] == "identifier" + assert req.forwarded[0]["host"] == "identifier" + assert req.forwarded[0]["proto"] == "identifier" + + +@pytest.mark.parametrize( + "forward_for_in, forward_for_out", + [ + ("1.2.3.4:1234", "1.2.3.4:1234"), + ("1.2.3.4", "1.2.3.4"), + ('"[2001:db8:cafe::17]:1234"', "[2001:db8:cafe::17]:1234"), + ('"[2001:db8:cafe::17]"', "[2001:db8:cafe::17]"), + ], +) +def test_forwarded_node_identifier(forward_for_in, forward_for_out) -> None: + header = f"for={forward_for_in}" + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert req.forwarded == ({"for": forward_for_out},) + + +def test_single_forwarded_header_camelcase() -> None: + header = "bY=identifier;fOr=identifier;HOst=identifier;pRoTO=identifier" + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert req.forwarded[0]["by"] == "identifier" + assert req.forwarded[0]["for"] == "identifier" + assert req.forwarded[0]["host"] == "identifier" + assert req.forwarded[0]["proto"] == "identifier" + + +def test_single_forwarded_header_single_param() -> None: + header = "BY=identifier" + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert req.forwarded[0]["by"] == "identifier" + + +def test_single_forwarded_header_multiple_param() -> None: + header = "By=identifier1,BY=identifier2, By=identifier3 , BY=identifier4" + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert len(req.forwarded) == 4 + assert req.forwarded[0]["by"] == "identifier1" + assert req.forwarded[1]["by"] == "identifier2" + assert req.forwarded[2]["by"] == "identifier3" + assert req.forwarded[3]["by"] == "identifier4" + + +def test_single_forwarded_header_quoted_escaped() -> None: + header = r'BY=identifier;pROTO="\lala lan\d\~ 123\!&"' + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert req.forwarded[0]["by"] == "identifier" + assert req.forwarded[0]["proto"] == "lala land~ 123!&" + + +def test_single_forwarded_header_custom_param() -> None: + header = r'BY=identifier;PROTO=https;SOME="other, \"value\""' + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert len(req.forwarded) == 1 + assert req.forwarded[0]["by"] == "identifier" + assert req.forwarded[0]["proto"] == "https" + assert req.forwarded[0]["some"] == 'other, "value"' + + +def test_single_forwarded_header_empty_params() -> None: + # This is allowed by the grammar given in RFC 7239 + header = ";For=identifier;;PROTO=https;;;" + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert req.forwarded[0]["for"] == "identifier" + assert req.forwarded[0]["proto"] == "https" + + +def test_single_forwarded_header_bad_separator() -> None: + header = "BY=identifier PROTO=https" + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert "proto" not in req.forwarded[0] + + +def test_single_forwarded_header_injection1() -> None: + # We might receive a header like this if we're sitting behind a reverse + # proxy that blindly appends a forwarded-element without checking + # the syntax of existing field-values. We should be able to recover + # the appended element anyway. + header = 'for=_injected;by=", for=_real' + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert len(req.forwarded) == 2 + assert "by" not in req.forwarded[0] + assert req.forwarded[1]["for"] == "_real" + + +def test_single_forwarded_header_injection2() -> None: + header = "very bad syntax, for=_real" + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert len(req.forwarded) == 2 + assert "for" not in req.forwarded[0] + assert req.forwarded[1]["for"] == "_real" + + +def test_single_forwarded_header_long_quoted_string() -> None: + header = 'for="' + "\\\\" * 5000 + '"' + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) + assert req.forwarded[0]["for"] == "\\" * 5000 + + +def test_multiple_forwarded_headers() -> None: + headers = CIMultiDict() + headers.add("Forwarded", "By=identifier1;for=identifier2, BY=identifier3") + headers.add("Forwarded", "By=identifier4;fOr=identifier5") + req = make_mocked_request("GET", "/", headers=headers) + assert len(req.forwarded) == 3 + assert req.forwarded[0]["by"] == "identifier1" + assert req.forwarded[0]["for"] == "identifier2" + assert req.forwarded[1]["by"] == "identifier3" + assert req.forwarded[2]["by"] == "identifier4" + assert req.forwarded[2]["for"] == "identifier5" + + +def test_multiple_forwarded_headers_bad_syntax() -> None: + headers = CIMultiDict() + headers.add("Forwarded", "for=_1;by=_2") + headers.add("Forwarded", "invalid value") + headers.add("Forwarded", "") + headers.add("Forwarded", "for=_3;by=_4") + req = make_mocked_request("GET", "/", headers=headers) + assert len(req.forwarded) == 4 + assert req.forwarded[0]["for"] == "_1" + assert "for" not in req.forwarded[1] + assert "for" not in req.forwarded[2] + assert req.forwarded[3]["by"] == "_4" -def test_https_scheme_by_secure_proxy_ssl_header(make_request): - req = make_request('GET', '/', - secure_proxy_ssl_header=('X-HEADER', '1'), - headers=CIMultiDict({'X-HEADER': '1'})) - assert "https" == req.scheme - - -def test_https_scheme_by_secure_proxy_ssl_header_false_test(make_request): - req = make_request('GET', '/', - secure_proxy_ssl_header=('X-HEADER', '1'), - headers=CIMultiDict({'X-HEADER': '0'})) - assert "http" == req.scheme +def test_multiple_forwarded_headers_injection() -> None: + headers = CIMultiDict() + # This could be sent by an attacker, hoping to "shadow" the second header. + headers.add("Forwarded", 'for=_injected;by="') + # This is added by our trusted reverse proxy. + headers.add("Forwarded", "for=_real;by=_actual_proxy") + req = make_mocked_request("GET", "/", headers=headers) + assert len(req.forwarded) == 2 + assert "by" not in req.forwarded[0] + assert req.forwarded[1]["for"] == "_real" + assert req.forwarded[1]["by"] == "_actual_proxy" + +def test_host_by_host_header() -> None: + req = make_mocked_request("GET", "/", headers=CIMultiDict({"Host": "example.com"})) + assert req.host == "example.com" -def test_raw_headers(make_request): - req = make_request('GET', '/', - headers=CIMultiDict({'X-HEADER': 'aaa'})) - assert req.raw_headers == ((b'X-Header', b'aaa'),) + +def test_raw_headers() -> None: + req = make_mocked_request("GET", "/", headers=CIMultiDict({"X-HEADER": "aaa"})) + assert req.raw_headers == ((b"X-HEADER", b"aaa"),) -def test_rel_url(make_request): - req = make_request('GET', '/path') - assert URL('/path') == req.rel_url +def test_rel_url() -> None: + req = make_mocked_request("GET", "/path") + assert URL("/path") == req.rel_url + + +def test_url_url() -> None: + req = make_mocked_request("GET", "/path", headers={"HOST": "example.com"}) + assert URL("http://example.com/path") == req.url -def test_url_url(make_request): - req = make_request('GET', '/path', headers={'HOST': 'example.com'}) - assert URL('http://example.com/path') == req.url +def test_clone() -> None: + req = make_mocked_request("GET", "/path") + req2 = req.clone() + assert req2.method == "GET" + assert req2.rel_url == URL("/path") -def test_clone(): - req = make_mocked_request('GET', '/path') +def test_clone_client_max_size() -> None: + req = make_mocked_request("GET", "/path", client_max_size=1024) req2 = req.clone() - assert req2.method == 'GET' - assert req2.rel_url == URL('/path') + assert req._client_max_size == req2._client_max_size + assert req2._client_max_size == 1024 -def test_clone_method(): - req = make_mocked_request('GET', '/path') - req2 = req.clone(method='POST') - assert req2.method == 'POST' - assert req2.rel_url == URL('/path') +def test_clone_method() -> None: + req = make_mocked_request("GET", "/path") + req2 = req.clone(method="POST") + assert req2.method == "POST" + assert req2.rel_url == URL("/path") -def test_clone_rel_url(): - req = make_mocked_request('GET', '/path') - req2 = req.clone(rel_url=URL('/path2')) - assert req2.rel_url == URL('/path2') +def test_clone_rel_url() -> None: + req = make_mocked_request("GET", "/path") + req2 = req.clone(rel_url=URL("/path2")) + assert req2.rel_url == URL("/path2") -def test_clone_rel_url_str(): - req = make_mocked_request('GET', '/path') - req2 = req.clone(rel_url='/path2') - assert req2.rel_url == URL('/path2') +def test_clone_rel_url_str() -> None: + req = make_mocked_request("GET", "/path") + req2 = req.clone(rel_url="/path2") + assert req2.rel_url == URL("/path2") -def test_clone_headers(): - req = make_mocked_request('GET', '/path', headers={'A': 'B'}) - req2 = req.clone(headers=CIMultiDict({'B': 'C'})) - assert req2.headers == CIMultiDict({'B': 'C'}) - assert req2.raw_headers == ((b'B', b'C'),) +def test_clone_headers() -> None: + req = make_mocked_request("GET", "/path", headers={"A": "B"}) + req2 = req.clone(headers=CIMultiDict({"B": "C"})) + assert req2.headers == CIMultiDict({"B": "C"}) + assert req2.raw_headers == ((b"B", b"C"),) -def test_clone_headers_dict(): - req = make_mocked_request('GET', '/path', headers={'A': 'B'}) - req2 = req.clone(headers={'B': 'C'}) - assert req2.headers == CIMultiDict({'B': 'C'}) - assert req2.raw_headers == ((b'B', b'C'),) +def test_clone_headers_dict() -> None: + req = make_mocked_request("GET", "/path", headers={"A": "B"}) + req2 = req.clone(headers={"B": "C"}) + assert req2.headers == CIMultiDict({"B": "C"}) + assert req2.raw_headers == ((b"B", b"C"),) -@asyncio.coroutine -def test_cannot_clone_after_read(loop): - payload = StreamReader(loop=loop) - payload.feed_data(b'data') +async def test_cannot_clone_after_read(protocol) -> None: + payload = StreamReader(protocol, 2 ** 16, loop=asyncio.get_event_loop()) + payload.feed_data(b"data") payload.feed_eof() - req = make_mocked_request('GET', '/path', payload=payload) - yield from req.read() + req = make_mocked_request("GET", "/path", payload=payload) + await req.read() with pytest.raises(RuntimeError): req.clone() -@asyncio.coroutine -def test_make_too_big_request(loop): - payload = StreamReader(loop=loop) - large_file = 1024 ** 2 * b'x' - too_large_file = large_file + b'x' +async def test_make_too_big_request(protocol) -> None: + payload = StreamReader(protocol, 2 ** 16, loop=asyncio.get_event_loop()) + large_file = 1024 ** 2 * b"x" + too_large_file = large_file + b"x" payload.feed_data(too_large_file) payload.feed_eof() - req = make_mocked_request('POST', '/', payload=payload) + req = make_mocked_request("POST", "/", payload=payload) with pytest.raises(HTTPRequestEntityTooLarge) as err: - yield from req.read() + await req.read() assert err.value.status_code == 413 -@asyncio.coroutine -def test_make_too_big_request_adjust_limit(loop): - payload = StreamReader(loop=loop) - large_file = 1024 ** 2 * b'x' - too_large_file = large_file + b'x' +async def test_make_too_big_request_adjust_limit(protocol) -> None: + payload = StreamReader(protocol, 2 ** 16, loop=asyncio.get_event_loop()) + large_file = 1024 ** 2 * b"x" + too_large_file = large_file + b"x" payload.feed_data(too_large_file) payload.feed_eof() - max_size = 1024**2 + 2 - req = make_mocked_request('POST', '/', payload=payload, - client_max_size=max_size) - txt = yield from req.read() - assert len(txt) == 1024**2 + 1 - - -@asyncio.coroutine -def test_multipart_formdata(loop): - payload = StreamReader(loop=loop) - payload.feed_data(b"""-----------------------------326931944431359\r -Content-Disposition: form-data; name="a"\r -\r -b\r ------------------------------326931944431359\r -Content-Disposition: form-data; name="c"\r -\r -d\r ------------------------------326931944431359--\r\n""") - content_type = "multipart/form-data; boundary="\ - "---------------------------326931944431359" + max_size = 1024 ** 2 + 2 + req = make_mocked_request("POST", "/", payload=payload, client_max_size=max_size) + txt = await req.read() + assert len(txt) == 1024 ** 2 + 1 + + +async def test_multipart_formdata(protocol) -> None: + payload = StreamReader(protocol, 2 ** 16, loop=asyncio.get_event_loop()) + payload.feed_data( + b"-----------------------------326931944431359\r\n" + b'Content-Disposition: form-data; name="a"\r\n' + b"\r\n" + b"b\r\n" + b"-----------------------------326931944431359\r\n" + b'Content-Disposition: form-data; name="c"\r\n' + b"\r\n" + b"d\r\n" + b"-----------------------------326931944431359--\r\n" + ) + content_type = ( + "multipart/form-data; boundary=" "---------------------------326931944431359" + ) + payload.feed_eof() + req = make_mocked_request( + "POST", "/", headers={"CONTENT-TYPE": content_type}, payload=payload + ) + result = await req.post() + assert dict(result) == {"a": "b", "c": "d"} + + +async def test_multipart_formdata_file(protocol) -> None: + # Make sure file uploads work, even without a content type + payload = StreamReader(protocol, 2 ** 16, loop=asyncio.get_event_loop()) + payload.feed_data( + b"-----------------------------326931944431359\r\n" + b'Content-Disposition: form-data; name="a_file"; filename="binary"\r\n' + b"\r\n" + b"\ff\r\n" + b"-----------------------------326931944431359--\r\n" + ) + content_type = ( + "multipart/form-data; boundary=" "---------------------------326931944431359" + ) payload.feed_eof() - req = make_mocked_request('POST', '/', - headers={'CONTENT-TYPE': content_type}, - payload=payload) - result = yield from req.post() - assert dict(result) == {'a': 'b', 'c': 'd'} - - -@asyncio.coroutine -def test_make_too_big_request_limit_None(loop): - payload = StreamReader(loop=loop) - large_file = 1024 ** 2 * b'x' - too_large_file = large_file + b'x' + req = make_mocked_request( + "POST", "/", headers={"CONTENT-TYPE": content_type}, payload=payload + ) + result = await req.post() + assert hasattr(result["a_file"], "file") + content = result["a_file"].file.read() + assert content == b"\ff" + + +async def test_make_too_big_request_limit_None(protocol) -> None: + payload = StreamReader(protocol, 2 ** 16, loop=asyncio.get_event_loop()) + large_file = 1024 ** 2 * b"x" + too_large_file = large_file + b"x" payload.feed_data(too_large_file) payload.feed_eof() max_size = None - req = make_mocked_request('POST', '/', payload=payload, - client_max_size=max_size) - txt = yield from req.read() - assert len(txt) == 1024**2 + 1 + req = make_mocked_request("POST", "/", payload=payload, client_max_size=max_size) + txt = await req.read() + assert len(txt) == 1024 ** 2 + 1 + + +def test_remote_peername_tcp() -> None: + transp = mock.Mock() + transp.get_extra_info.return_value = ("10.10.10.10", 1234) + req = make_mocked_request("GET", "/", transport=transp) + assert req.remote == "10.10.10.10" + + +def test_remote_peername_unix() -> None: + transp = mock.Mock() + transp.get_extra_info.return_value = "/path/to/sock" + req = make_mocked_request("GET", "/", transport=transp) + assert req.remote == "/path/to/sock" + + +def test_save_state_on_clone() -> None: + req = make_mocked_request("GET", "/") + req["key"] = "val" + req2 = req.clone() + req2["key"] = "val2" + assert req["key"] == "val" + assert req2["key"] == "val2" + + +def test_clone_scheme() -> None: + req = make_mocked_request("GET", "/") + req2 = req.clone(scheme="https") + assert req2.scheme == "https" + + +def test_clone_host() -> None: + req = make_mocked_request("GET", "/") + req2 = req.clone(host="example.com") + assert req2.host == "example.com" + + +def test_clone_remote() -> None: + req = make_mocked_request("GET", "/") + req2 = req.clone(remote="11.11.11.11") + assert req2.remote == "11.11.11.11" + + +@pytest.mark.skipif(not DEBUG, reason="The check is applied in DEBUG mode only") +def test_request_custom_attr() -> None: + req = make_mocked_request("GET", "/") + with pytest.warns(DeprecationWarning): + req.custom = None + + +def test_remote_with_closed_transport() -> None: + transp = mock.Mock() + transp.get_extra_info.return_value = ("10.10.10.10", 1234) + req = make_mocked_request("GET", "/", transport=transp) + req._protocol = None + assert req.remote == "10.10.10.10" + + +def test_url_http_with_closed_transport() -> None: + req = make_mocked_request("GET", "/") + req._protocol = None + assert str(req.url).startswith("http://") + + +def test_url_https_with_closed_transport() -> None: + req = make_mocked_request("GET", "/", sslcontext=True) + req._protocol = None + assert str(req.url).startswith("https://") + + +async def test_get_extra_info() -> None: + valid_key = "test" + valid_value = "existent" + default_value = "default" + + def get_extra_info(name: str, default: Any = None): + return {valid_key: valid_value}.get(name, default) + + transp = mock.Mock() + transp.get_extra_info.side_effect = get_extra_info + req = make_mocked_request("GET", "/", transport=transp) + + req_extra_info = req.get_extra_info(valid_key, default_value) + transp_extra_info = req._protocol.transport.get_extra_info(valid_key, default_value) + assert req_extra_info == transp_extra_info + + req._protocol.transport = None + extra_info = req.get_extra_info(valid_key, default_value) + assert extra_info == default_value + + req._protocol = None + extra_info = req.get_extra_info(valid_key, default_value) + assert extra_info == default_value + + +def test_eq() -> None: + req1 = make_mocked_request("GET", "/path/to?a=1&b=2") + req2 = make_mocked_request("GET", "/path/to?a=1&b=2") + assert req1 != req2 + assert req1 == req1 + + +async def test_loop_prop() -> None: + loop = asyncio.get_event_loop() + req = make_mocked_request("GET", "/path", loop=loop) + with pytest.warns(DeprecationWarning): + assert req.loop is loop diff --git a/tests/test_web_request_handler.py b/tests/test_web_request_handler.py index b2865b2903d..a4c4ae0de4f 100644 --- a/tests/test_web_request_handler.py +++ b/tests/test_web_request_handler.py @@ -1,27 +1,25 @@ -import asyncio from unittest import mock from aiohttp import web -from aiohttp.test_utils import make_mocked_coro, make_mocked_request +from aiohttp.test_utils import make_mocked_coro -def test_repr(loop): - app = web.Application() - manager = app.make_handler(loop=loop) +async def serve(request): + return web.Response() + + +async def test_repr() -> None: + manager = web.Server(serve) handler = manager() - assert '' == repr(handler) + assert "" == repr(handler) handler.transport = object() - request = make_mocked_request('GET', '/index.html') - handler._request = request - # assert '' == repr(handler) - assert '' == repr(handler) + assert "" == repr(handler) -def test_connections(loop): - app = web.Application() - manager = app.make_handler(loop=loop) +async def test_connections() -> None: + manager = web.Server(serve) assert manager.connections == [] handler = object() @@ -33,34 +31,30 @@ def test_connections(loop): assert manager.connections == [] -@asyncio.coroutine -def test_finish_connection_no_timeout(loop): - app = web.Application() - manager = app.make_handler(loop=loop) +async def test_shutdown_no_timeout() -> None: + manager = web.Server(serve) handler = mock.Mock() handler.shutdown = make_mocked_coro(mock.Mock()) transport = mock.Mock() manager.connection_made(handler, transport) - yield from manager.finish_connections() + await manager.shutdown() manager.connection_lost(handler, None) assert manager.connections == [] handler.shutdown.assert_called_with(None) -@asyncio.coroutine -def test_finish_connection_timeout(loop): - app = web.Application() - manager = app.make_handler(loop=loop) +async def test_shutdown_timeout() -> None: + manager = web.Server(serve) handler = mock.Mock() handler.shutdown = make_mocked_coro(mock.Mock()) transport = mock.Mock() manager.connection_made(handler, transport) - yield from manager.finish_connections(timeout=0.1) + await manager.shutdown(timeout=0.1) manager.connection_lost(handler, None) assert manager.connections == [] diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 91e5140d78c..f8473431010 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -1,34 +1,46 @@ -import asyncio +import collections.abc import datetime +import gzip import json -import re +from concurrent.futures import ThreadPoolExecutor from unittest import mock import pytest -from multidict import CIMultiDict +from multidict import CIMultiDict, CIMultiDictProxy +from re_assert import Matches from aiohttp import HttpVersion, HttpVersion10, HttpVersion11, hdrs, signals -from aiohttp.test_utils import make_mocked_request +from aiohttp.payload import BytesPayload +from aiohttp.test_utils import make_mocked_coro, make_mocked_request from aiohttp.web import ContentCoding, Response, StreamResponse, json_response -def make_request(method, path, headers=CIMultiDict(), - version=HttpVersion11, **kwargs): - app = kwargs.pop('app', None) or mock.Mock() +def make_request( + method, + path, + headers=CIMultiDict(), + version=HttpVersion11, + on_response_prepare=None, + **kwargs +): + app = kwargs.pop("app", None) or mock.Mock() app._debug = False - app.on_response_prepare = signals.Signal(app) - protocol = kwargs.pop('protocol', None) or mock.Mock() - return make_mocked_request(method, path, headers, - version=version, protocol=protocol, - app=app, **kwargs) + if on_response_prepare is None: + on_response_prepare = signals.Signal(app) + app.on_response_prepare = on_response_prepare + app.on_response_prepare.freeze() + protocol = kwargs.pop("protocol", None) or mock.Mock() + return make_mocked_request( + method, path, headers, version=version, protocol=protocol, app=app, **kwargs + ) -@pytest.yield_fixture +@pytest.fixture def buf(): return bytearray() -@pytest.yield_fixture +@pytest.fixture def writer(buf): writer = mock.Mock() @@ -41,14 +53,16 @@ def buffer_data(chunk): def write(chunk): buf.extend(chunk) - def write_headers(status_line, headers): - headers = status_line + ''.join( - [k + ': ' + v + '\r\n' for k, v in headers.items()]) - headers = headers.encode('utf-8') + b'\r\n' + async def write_headers(status_line, headers): + headers = ( + status_line + + "\r\n" + + "".join([k + ": " + v + "\r\n" for k, v in headers.items()]) + ) + headers = headers.encode("utf-8") + b"\r\n" buf.extend(headers) - @asyncio.coroutine - def write_eof(chunk=b''): + async def write_eof(chunk=b""): buf.extend(chunk) writer.acquire.side_effect = acquire @@ -62,101 +76,156 @@ def write_eof(chunk=b''): return writer -def test_stream_response_ctor(): +def test_stream_response_ctor() -> None: resp = StreamResponse() assert 200 == resp.status assert resp.keep_alive is None + assert resp.task is None + + req = mock.Mock() + resp._req = req + assert resp.task is req.task + + +def test_stream_response_hashable() -> None: + # should not raise exception + hash(StreamResponse()) + + +def test_stream_response_eq() -> None: + resp1 = StreamResponse() + resp2 = StreamResponse() + + assert resp1 == resp1 + assert not resp1 == resp2 -def test_content_length(): + +def test_stream_response_is_mutable_mapping() -> None: + resp = StreamResponse() + assert isinstance(resp, collections.abc.MutableMapping) + resp["key"] = "value" + assert "value" == resp["key"] + + +def test_stream_response_delitem() -> None: + resp = StreamResponse() + resp["key"] = "value" + del resp["key"] + assert "key" not in resp + + +def test_stream_response_len() -> None: + resp = StreamResponse() + assert len(resp) == 0 + resp["key"] = "value" + assert len(resp) == 1 + + +def test_request_iter() -> None: + resp = StreamResponse() + resp["key"] = "value" + resp["key2"] = "value2" + assert set(resp) == {"key", "key2"} + + +def test_content_length() -> None: resp = StreamResponse() assert resp.content_length is None -def test_content_length_setter(): +def test_content_length_setter() -> None: resp = StreamResponse() resp.content_length = 234 assert 234 == resp.content_length -def test_drop_content_length_header_on_setting_len_to_None(): +def test_content_length_setter_with_enable_chunked_encoding() -> None: + resp = StreamResponse() + + resp.enable_chunked_encoding() + with pytest.raises(RuntimeError): + resp.content_length = 234 + + +def test_drop_content_length_header_on_setting_len_to_None() -> None: resp = StreamResponse() resp.content_length = 1 - assert "1" == resp.headers['Content-Length'] + assert "1" == resp.headers["Content-Length"] resp.content_length = None - assert 'Content-Length' not in resp.headers + assert "Content-Length" not in resp.headers -def test_set_content_length_to_None_on_non_set(): +def test_set_content_length_to_None_on_non_set() -> None: resp = StreamResponse() resp.content_length = None - assert 'Content-Length' not in resp.headers + assert "Content-Length" not in resp.headers resp.content_length = None - assert 'Content-Length' not in resp.headers + assert "Content-Length" not in resp.headers -def test_setting_content_type(): +def test_setting_content_type() -> None: resp = StreamResponse() - resp.content_type = 'text/html' - assert 'text/html' == resp.headers['content-type'] + resp.content_type = "text/html" + assert "text/html" == resp.headers["content-type"] -def test_setting_charset(): +def test_setting_charset() -> None: resp = StreamResponse() - resp.content_type = 'text/html' - resp.charset = 'koi8-r' - assert 'text/html; charset=koi8-r' == resp.headers['content-type'] + resp.content_type = "text/html" + resp.charset = "koi8-r" + assert "text/html; charset=koi8-r" == resp.headers["content-type"] -def test_default_charset(): +def test_default_charset() -> None: resp = StreamResponse() assert resp.charset is None -def test_reset_charset(): +def test_reset_charset() -> None: resp = StreamResponse() - resp.content_type = 'text/html' + resp.content_type = "text/html" resp.charset = None assert resp.charset is None -def test_reset_charset_after_setting(): +def test_reset_charset_after_setting() -> None: resp = StreamResponse() - resp.content_type = 'text/html' - resp.charset = 'koi8-r' + resp.content_type = "text/html" + resp.charset = "koi8-r" resp.charset = None assert resp.charset is None -def test_charset_without_content_type(): +def test_charset_without_content_type() -> None: resp = StreamResponse() with pytest.raises(RuntimeError): - resp.charset = 'koi8-r' + resp.charset = "koi8-r" -def test_last_modified_initial(): +def test_last_modified_initial() -> None: resp = StreamResponse() assert resp.last_modified is None -def test_last_modified_string(): +def test_last_modified_string() -> None: resp = StreamResponse() dt = datetime.datetime(1990, 1, 2, 3, 4, 5, 0, datetime.timezone.utc) - resp.last_modified = 'Mon, 2 Jan 1990 03:04:05 GMT' + resp.last_modified = "Mon, 2 Jan 1990 03:04:05 GMT" assert resp.last_modified == dt -def test_last_modified_timestamp(): +def test_last_modified_timestamp() -> None: resp = StreamResponse() dt = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, datetime.timezone.utc) @@ -168,7 +237,7 @@ def test_last_modified_timestamp(): assert resp.last_modified == dt -def test_last_modified_datetime(): +def test_last_modified_datetime() -> None: resp = StreamResponse() dt = datetime.datetime(2001, 2, 3, 4, 5, 6, 0, datetime.timezone.utc) @@ -176,7 +245,7 @@ def test_last_modified_datetime(): assert resp.last_modified == dt -def test_last_modified_reset(): +def test_last_modified_reset() -> None: resp = StreamResponse() resp.last_modified = 0 @@ -184,69 +253,72 @@ def test_last_modified_reset(): assert resp.last_modified is None -@asyncio.coroutine -def test_start(): - req = make_request('GET', '/', payload_writer=mock.Mock()) +async def test_start() -> None: + req = make_request("GET", "/") resp = StreamResponse() assert resp.keep_alive is None - msg = yield from resp.prepare(req) + msg = await resp.prepare(req) assert msg.write_headers.called - msg2 = yield from resp.prepare(req) + msg2 = await resp.prepare(req) assert msg is msg2 assert resp.keep_alive - req2 = make_request('GET', '/') + req2 = make_request("GET", "/") # with pytest.raises(RuntimeError): - msg3 = yield from resp.prepare(req2) + msg3 = await resp.prepare(req2) assert msg is msg3 -@asyncio.coroutine -def test_chunked_encoding(): - req = make_request('GET', '/') +async def test_chunked_encoding() -> None: + req = make_request("GET", "/") resp = StreamResponse() assert not resp.chunked resp.enable_chunked_encoding() assert resp.chunked - msg = yield from resp.prepare(req) + msg = await resp.prepare(req) assert msg.chunked -@asyncio.coroutine -def test_chunk_size(): - req = make_request('GET', '/', payload_writer=mock.Mock()) +def test_enable_chunked_encoding_with_content_length() -> None: + resp = StreamResponse() + + resp.content_length = 234 + with pytest.raises(RuntimeError): + resp.enable_chunked_encoding() + + +async def test_chunk_size() -> None: + req = make_request("GET", "/") resp = StreamResponse() assert not resp.chunked - resp.enable_chunked_encoding(chunk_size=8192) + with pytest.warns(DeprecationWarning): + resp.enable_chunked_encoding(chunk_size=8192) assert resp.chunked - msg = yield from resp.prepare(req) + msg = await resp.prepare(req) assert msg.chunked assert msg.enable_chunking.called assert msg.filter is not None -@asyncio.coroutine -def test_chunked_encoding_forbidden_for_http_10(): - req = make_request('GET', '/', version=HttpVersion10) +async def test_chunked_encoding_forbidden_for_http_10() -> None: + req = make_request("GET", "/", version=HttpVersion10) resp = StreamResponse() resp.enable_chunked_encoding() with pytest.raises(RuntimeError) as ctx: - yield from resp.prepare(req) - assert re.match("Using chunked encoding is forbidden for HTTP/1.0", - str(ctx.value)) + await resp.prepare(req) + assert Matches("Using chunked encoding is forbidden for HTTP/1.0") == str(ctx.value) -@asyncio.coroutine -def test_compression_no_accept(): - req = make_request('GET', '/', payload_writer=mock.Mock()) +async def test_compression_no_accept() -> None: + req = make_request("GET", "/") resp = StreamResponse() assert not resp.chunked @@ -254,43 +326,42 @@ def test_compression_no_accept(): resp.enable_compression() assert resp.compression - msg = yield from resp.prepare(req) + msg = await resp.prepare(req) assert not msg.enable_compression.called -@asyncio.coroutine -def test_force_compression_no_accept_backwards_compat(): - req = make_request('GET', '/', payload_writer=mock.Mock()) +async def test_force_compression_no_accept_backwards_compat() -> None: + req = make_request("GET", "/") resp = StreamResponse() assert not resp.chunked assert not resp.compression - resp.enable_compression(force=True) + with pytest.warns(DeprecationWarning): + resp.enable_compression(force=True) assert resp.compression - msg = yield from resp.prepare(req) + msg = await resp.prepare(req) assert msg.enable_compression.called assert msg.filter is not None -@asyncio.coroutine -def test_force_compression_false_backwards_compat(): - req = make_request('GET', '/', payload_writer=mock.Mock()) +async def test_force_compression_false_backwards_compat() -> None: + req = make_request("GET", "/") resp = StreamResponse() assert not resp.compression - resp.enable_compression(force=False) + with pytest.warns(DeprecationWarning): + resp.enable_compression(force=False) assert resp.compression - msg = yield from resp.prepare(req) + msg = await resp.prepare(req) assert not msg.enable_compression.called -@asyncio.coroutine -def test_compression_default_coding(): +async def test_compression_default_coding() -> None: req = make_request( - 'GET', '/', - headers=CIMultiDict({hdrs.ACCEPT_ENCODING: 'gzip, deflate'})) + "GET", "/", headers=CIMultiDict({hdrs.ACCEPT_ENCODING: "gzip, deflate"}) + ) resp = StreamResponse() assert not resp.chunked @@ -298,173 +369,290 @@ def test_compression_default_coding(): resp.enable_compression() assert resp.compression - msg = yield from resp.prepare(req) + msg = await resp.prepare(req) - msg.enable_compression.assert_called_with('deflate') - assert 'deflate' == resp.headers.get(hdrs.CONTENT_ENCODING) + msg.enable_compression.assert_called_with("deflate") + assert "deflate" == resp.headers.get(hdrs.CONTENT_ENCODING) assert msg.filter is not None -@asyncio.coroutine -def test_force_compression_deflate(): +async def test_force_compression_deflate() -> None: req = make_request( - 'GET', '/', - headers=CIMultiDict({hdrs.ACCEPT_ENCODING: 'gzip, deflate'})) + "GET", "/", headers=CIMultiDict({hdrs.ACCEPT_ENCODING: "gzip, deflate"}) + ) resp = StreamResponse() resp.enable_compression(ContentCoding.deflate) assert resp.compression - msg = yield from resp.prepare(req) - msg.enable_compression.assert_called_with('deflate') - assert 'deflate' == resp.headers.get(hdrs.CONTENT_ENCODING) + msg = await resp.prepare(req) + msg.enable_compression.assert_called_with("deflate") + assert "deflate" == resp.headers.get(hdrs.CONTENT_ENCODING) -@asyncio.coroutine -def test_force_compression_no_accept_deflate(): - req = make_request('GET', '/') +async def test_force_compression_no_accept_deflate() -> None: + req = make_request("GET", "/") resp = StreamResponse() resp.enable_compression(ContentCoding.deflate) assert resp.compression - msg = yield from resp.prepare(req) - msg.enable_compression.assert_called_with('deflate') - assert 'deflate' == resp.headers.get(hdrs.CONTENT_ENCODING) + msg = await resp.prepare(req) + msg.enable_compression.assert_called_with("deflate") + assert "deflate" == resp.headers.get(hdrs.CONTENT_ENCODING) -@asyncio.coroutine -def test_force_compression_gzip(): +async def test_force_compression_gzip() -> None: req = make_request( - 'GET', '/', - headers=CIMultiDict({hdrs.ACCEPT_ENCODING: 'gzip, deflate'})) + "GET", "/", headers=CIMultiDict({hdrs.ACCEPT_ENCODING: "gzip, deflate"}) + ) resp = StreamResponse() resp.enable_compression(ContentCoding.gzip) assert resp.compression - msg = yield from resp.prepare(req) - msg.enable_compression.assert_called_with('gzip') - assert 'gzip' == resp.headers.get(hdrs.CONTENT_ENCODING) + msg = await resp.prepare(req) + msg.enable_compression.assert_called_with("gzip") + assert "gzip" == resp.headers.get(hdrs.CONTENT_ENCODING) -@asyncio.coroutine -def test_force_compression_no_accept_gzip(): - req = make_request('GET', '/') +async def test_force_compression_no_accept_gzip() -> None: + req = make_request("GET", "/") resp = StreamResponse() resp.enable_compression(ContentCoding.gzip) assert resp.compression - msg = yield from resp.prepare(req) - msg.enable_compression.assert_called_with('gzip') - assert 'gzip' == resp.headers.get(hdrs.CONTENT_ENCODING) + msg = await resp.prepare(req) + msg.enable_compression.assert_called_with("gzip") + assert "gzip" == resp.headers.get(hdrs.CONTENT_ENCODING) + + +async def test_change_content_threaded_compression_enabled() -> None: + req = make_request("GET", "/") + body_thread_size = 1024 + body = b"answer" * body_thread_size + resp = Response(body=body, zlib_executor_size=body_thread_size) + resp.enable_compression(ContentCoding.gzip) + + await resp.prepare(req) + assert gzip.decompress(resp._compressed_body) == body + + +async def test_change_content_threaded_compression_enabled_explicit() -> None: + req = make_request("GET", "/") + body_thread_size = 1024 + body = b"answer" * body_thread_size + with ThreadPoolExecutor(1) as executor: + resp = Response( + body=body, zlib_executor_size=body_thread_size, zlib_executor=executor + ) + resp.enable_compression(ContentCoding.gzip) + + await resp.prepare(req) + assert gzip.decompress(resp._compressed_body) == body + + +async def test_change_content_length_if_compression_enabled() -> None: + req = make_request("GET", "/") + resp = Response(body=b"answer") + resp.enable_compression(ContentCoding.gzip) + + await resp.prepare(req) + assert resp.content_length is not None and resp.content_length != len(b"answer") -@asyncio.coroutine -def test_delete_content_length_if_compression_enabled(): - req = make_request('GET', '/') - resp = Response(body=b'answer') +async def test_set_content_length_if_compression_enabled() -> None: + writer = mock.Mock() + + async def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH in headers + assert headers[hdrs.CONTENT_LENGTH] == "26" + assert hdrs.TRANSFER_ENCODING not in headers + + writer.write_headers.side_effect = write_headers + req = make_request("GET", "/", writer=writer) + resp = Response(body=b"answer") resp.enable_compression(ContentCoding.gzip) - yield from resp.prepare(req) + await resp.prepare(req) + assert resp.content_length == 26 + del resp.headers[hdrs.CONTENT_LENGTH] + assert resp.content_length == 26 + + +async def test_remove_content_length_if_compression_enabled_http11() -> None: + writer = mock.Mock() + + async def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH not in headers + assert headers.get(hdrs.TRANSFER_ENCODING, "") == "chunked" + + writer.write_headers.side_effect = write_headers + req = make_request("GET", "/", writer=writer) + resp = StreamResponse() + resp.content_length = 123 + resp.enable_compression(ContentCoding.gzip) + await resp.prepare(req) assert resp.content_length is None -@asyncio.coroutine -def test_write_non_byteish(): +async def test_remove_content_length_if_compression_enabled_http10() -> None: + writer = mock.Mock() + + async def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH not in headers + assert hdrs.TRANSFER_ENCODING not in headers + + writer.write_headers.side_effect = write_headers + req = make_request("GET", "/", version=HttpVersion10, writer=writer) resp = StreamResponse() - yield from resp.prepare(make_request('GET', '/')) + resp.content_length = 123 + resp.enable_compression(ContentCoding.gzip) + await resp.prepare(req) + assert resp.content_length is None + + +async def test_force_compression_identity() -> None: + writer = mock.Mock() + + async def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH in headers + assert hdrs.TRANSFER_ENCODING not in headers + + writer.write_headers.side_effect = write_headers + req = make_request("GET", "/", writer=writer) + resp = StreamResponse() + resp.content_length = 123 + resp.enable_compression(ContentCoding.identity) + await resp.prepare(req) + assert resp.content_length == 123 + + +async def test_force_compression_identity_response() -> None: + writer = mock.Mock() + + async def write_headers(status_line, headers): + assert headers[hdrs.CONTENT_LENGTH] == "6" + assert hdrs.TRANSFER_ENCODING not in headers + + writer.write_headers.side_effect = write_headers + req = make_request("GET", "/", writer=writer) + resp = Response(body=b"answer") + resp.enable_compression(ContentCoding.identity) + await resp.prepare(req) + assert resp.content_length == 6 + + +async def test_rm_content_length_if_compression_http11() -> None: + writer = mock.Mock() + + async def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH not in headers + assert headers.get(hdrs.TRANSFER_ENCODING, "") == "chunked" + + writer.write_headers.side_effect = write_headers + req = make_request("GET", "/", writer=writer) + payload = BytesPayload(b"answer", headers={"X-Test-Header": "test"}) + resp = Response(body=payload) + assert resp.content_length == 6 + resp.body = payload + resp.enable_compression(ContentCoding.gzip) + await resp.prepare(req) + assert resp.content_length is None + + +async def test_rm_content_length_if_compression_http10() -> None: + writer = mock.Mock() + + async def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH not in headers + assert hdrs.TRANSFER_ENCODING not in headers + + writer.write_headers.side_effect = write_headers + req = make_request("GET", "/", version=HttpVersion10, writer=writer) + resp = Response(body=BytesPayload(b"answer")) + resp.enable_compression(ContentCoding.gzip) + await resp.prepare(req) + assert resp.content_length is None + + +async def test_content_length_on_chunked() -> None: + req = make_request("GET", "/") + resp = Response(body=b"answer") + assert resp.content_length == 6 + resp.enable_chunked_encoding() + assert resp.content_length is None + await resp.prepare(req) + + +async def test_write_non_byteish() -> None: + resp = StreamResponse() + await resp.prepare(make_request("GET", "/")) with pytest.raises(AssertionError): - resp.write(123) + await resp.write(123) -def test_write_before_start(): +async def test_write_before_start() -> None: resp = StreamResponse() with pytest.raises(RuntimeError): - resp.write(b'data') + await resp.write(b"data") -@asyncio.coroutine -def test_cannot_write_after_eof(): +async def test_cannot_write_after_eof() -> None: resp = StreamResponse() - writer = mock.Mock() - resp_impl = yield from resp.prepare( - make_request('GET', '/', writer=writer)) - resp_impl.write_eof = mock.Mock() - resp_impl.write_eof.return_value = () + req = make_request("GET", "/") + await resp.prepare(req) - resp.write(b'data') - yield from resp.write_eof() - writer.write.reset_mock() + await resp.write(b"data") + await resp.write_eof() + req.writer.write.reset_mock() with pytest.raises(RuntimeError): - resp.write(b'next data') - assert not writer.write.called + await resp.write(b"next data") + assert not req.writer.write.called -@asyncio.coroutine -def test___repr___after_eof(): +async def test___repr___after_eof() -> None: resp = StreamResponse() - yield from resp.prepare(make_request('GET', '/')) + await resp.prepare(make_request("GET", "/")) assert resp.prepared - resp.write(b'data') - yield from resp.write_eof() + await resp.write(b"data") + await resp.write_eof() assert not resp.prepared resp_repr = repr(resp) - assert resp_repr == '' + assert resp_repr == "" -@asyncio.coroutine -def test_cannot_write_eof_before_headers(): +async def test_cannot_write_eof_before_headers() -> None: resp = StreamResponse() with pytest.raises(AssertionError): - yield from resp.write_eof() + await resp.write_eof() -@asyncio.coroutine -def test_cannot_write_eof_twice(): +async def test_cannot_write_eof_twice() -> None: resp = StreamResponse() writer = mock.Mock() - resp_impl = yield from resp.prepare(make_request('GET', '/')) - resp_impl.write = mock.Mock() - resp_impl.write_eof = mock.Mock() - resp_impl.write_eof.return_value = () + resp_impl = await resp.prepare(make_request("GET", "/")) + resp_impl.write = make_mocked_coro(None) + resp_impl.write_eof = make_mocked_coro(None) - resp.write(b'data') + await resp.write(b"data") assert resp_impl.write.called - yield from resp.write_eof() + await resp.write_eof() resp_impl.write.reset_mock() - yield from resp.write_eof() + await resp.write_eof() assert not writer.write.called -@asyncio.coroutine -def _test_write_returns_drain(): - resp = StreamResponse() - yield from resp.prepare(make_request('GET', '/')) - - with mock.patch('aiohttp.http_writer.noop') as noop: - assert noop == resp.write(b'data') - - -@asyncio.coroutine -def _test_write_returns_empty_tuple_on_empty_data(): - resp = StreamResponse() - yield from resp.prepare(make_request('GET', '/')) - - with mock.patch('aiohttp.http_writer.noop') as noop: - assert noop.return_value == resp.write(b'') - - -def test_force_close(): +def test_force_close() -> None: resp = StreamResponse() assert resp.keep_alive is None @@ -472,86 +660,101 @@ def test_force_close(): assert resp.keep_alive is False -@asyncio.coroutine -def test_response_output_length(): +async def test_response_output_length() -> None: resp = StreamResponse() - yield from resp.prepare(make_request('GET', '/')) + await resp.prepare(make_request("GET", "/")) with pytest.warns(DeprecationWarning): assert resp.output_length -def test_response_cookies(): +def test_response_cookies() -> None: resp = StreamResponse() assert resp.cookies == {} - assert str(resp.cookies) == '' - - resp.set_cookie('name', 'value') - assert str(resp.cookies) == 'Set-Cookie: name=value; Path=/' - resp.set_cookie('name', 'other_value') - assert str(resp.cookies) == 'Set-Cookie: name=other_value; Path=/' - - resp.cookies['name'] = 'another_other_value' - resp.cookies['name']['max-age'] = 10 - assert (str(resp.cookies) == - 'Set-Cookie: name=another_other_value; Max-Age=10; Path=/') - - resp.del_cookie('name') - expected = ('Set-Cookie: name=("")?; ' - 'expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; Path=/') - assert re.match(expected, str(resp.cookies)) - - resp.set_cookie('name', 'value', domain='local.host') - expected = 'Set-Cookie: name=value; Domain=local.host; Path=/' + assert str(resp.cookies) == "" + + resp.set_cookie("name", "value") + assert str(resp.cookies) == "Set-Cookie: name=value; Path=/" + resp.set_cookie("name", "other_value") + assert str(resp.cookies) == "Set-Cookie: name=other_value; Path=/" + + resp.cookies["name"] = "another_other_value" + resp.cookies["name"]["max-age"] = 10 + assert ( + str(resp.cookies) == "Set-Cookie: name=another_other_value; Max-Age=10; Path=/" + ) + + resp.del_cookie("name") + expected = ( + 'Set-Cookie: name=("")?; ' + "expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; Path=/" + ) + assert Matches(expected) == str(resp.cookies) + + resp.set_cookie("name", "value", domain="local.host") + expected = "Set-Cookie: name=value; Domain=local.host; Path=/" assert str(resp.cookies) == expected -def test_response_cookie_path(): +def test_response_cookie_path() -> None: resp = StreamResponse() assert resp.cookies == {} - resp.set_cookie('name', 'value', path='/some/path') - assert str(resp.cookies) == 'Set-Cookie: name=value; Path=/some/path' - resp.set_cookie('name', 'value', expires='123') - assert (str(resp.cookies) == - 'Set-Cookie: name=value; expires=123; Path=/') - resp.set_cookie('name', 'value', domain='example.com', - path='/home', expires='123', max_age='10', - secure=True, httponly=True, version='2.0') - assert (str(resp.cookies).lower() == 'set-cookie: name=value; ' - 'domain=example.com; ' - 'expires=123; ' - 'httponly; ' - 'max-age=10; ' - 'path=/home; ' - 'secure; ' - 'version=2.0') - - -def test_response_cookie__issue_del_cookie(): + resp.set_cookie("name", "value", path="/some/path") + assert str(resp.cookies) == "Set-Cookie: name=value; Path=/some/path" + resp.set_cookie("name", "value", expires="123") + assert str(resp.cookies) == "Set-Cookie: name=value; expires=123; Path=/" + resp.set_cookie( + "name", + "value", + domain="example.com", + path="/home", + expires="123", + max_age="10", + secure=True, + httponly=True, + version="2.0", + samesite="lax", + ) + assert ( + str(resp.cookies).lower() == "set-cookie: name=value; " + "domain=example.com; " + "expires=123; " + "httponly; " + "max-age=10; " + "path=/home; " + "samesite=lax; " + "secure; " + "version=2.0" + ) + + +def test_response_cookie__issue_del_cookie() -> None: resp = StreamResponse() assert resp.cookies == {} - assert str(resp.cookies) == '' + assert str(resp.cookies) == "" - resp.del_cookie('name') - expected = ('Set-Cookie: name=("")?; ' - 'expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; Path=/') - assert re.match(expected, str(resp.cookies)) + resp.del_cookie("name") + expected = ( + 'Set-Cookie: name=("")?; ' + "expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; Path=/" + ) + assert Matches(expected) == str(resp.cookies) -def test_cookie_set_after_del(): +def test_cookie_set_after_del() -> None: resp = StreamResponse() - resp.del_cookie('name') - resp.set_cookie('name', 'val') + resp.del_cookie("name") + resp.set_cookie("name", "val") # check for Max-Age dropped - expected = 'Set-Cookie: name=val; Path=/' + expected = "Set-Cookie: name=val; Path=/" assert str(resp.cookies) == expected -def test_set_status_with_reason(): +def test_set_status_with_reason() -> None: resp = StreamResponse() resp.set_status(200, "Everithing is fine!") @@ -559,325 +762,295 @@ def test_set_status_with_reason(): assert "Everithing is fine!" == resp.reason -@asyncio.coroutine -def test_start_force_close(): - req = make_request('GET', '/') +async def test_start_force_close() -> None: + req = make_request("GET", "/") resp = StreamResponse() resp.force_close() assert not resp.keep_alive - yield from resp.prepare(req) + await resp.prepare(req) assert not resp.keep_alive -@asyncio.coroutine -def test___repr__(): - req = make_request('GET', '/path/to') +async def test___repr__() -> None: + req = make_request("GET", "/path/to") resp = StreamResponse(reason=301) - yield from resp.prepare(req) + await resp.prepare(req) assert "" == repr(resp) -def test___repr___not_prepared(): +def test___repr___not_prepared() -> None: resp = StreamResponse(reason=301) assert "" == repr(resp) -@asyncio.coroutine -def test_keep_alive_http10_default(): - req = make_request('GET', '/', version=HttpVersion10) +async def test_keep_alive_http10_default() -> None: + req = make_request("GET", "/", version=HttpVersion10) resp = StreamResponse() - yield from resp.prepare(req) + await resp.prepare(req) assert not resp.keep_alive -@asyncio.coroutine -def test_keep_alive_http10_switched_on(): - headers = CIMultiDict(Connection='keep-alive') - req = make_request('GET', '/', version=HttpVersion10, headers=headers) +async def test_keep_alive_http10_switched_on() -> None: + headers = CIMultiDict(Connection="keep-alive") + req = make_request("GET", "/", version=HttpVersion10, headers=headers) req._message = req._message._replace(should_close=False) resp = StreamResponse() - yield from resp.prepare(req) + await resp.prepare(req) assert resp.keep_alive -@asyncio.coroutine -def test_keep_alive_http09(): - headers = CIMultiDict(Connection='keep-alive') - req = make_request('GET', '/', version=HttpVersion(0, 9), headers=headers) +async def test_keep_alive_http09() -> None: + headers = CIMultiDict(Connection="keep-alive") + req = make_request("GET", "/", version=HttpVersion(0, 9), headers=headers) resp = StreamResponse() - yield from resp.prepare(req) + await resp.prepare(req) assert not resp.keep_alive -def test_prepare_twice(): - req = make_request('GET', '/') +async def test_prepare_twice() -> None: + req = make_request("GET", "/") resp = StreamResponse() - impl1 = yield from resp.prepare(req) - impl2 = yield from resp.prepare(req) + impl1 = await resp.prepare(req) + impl2 = await resp.prepare(req) assert impl1 is impl2 -@asyncio.coroutine -def test_prepare_calls_signal(): +async def test_prepare_calls_signal() -> None: app = mock.Mock() - req = make_request('GET', '/', app=app) + sig = make_mocked_coro() + on_response_prepare = signals.Signal(app) + on_response_prepare.append(sig) + req = make_request("GET", "/", app=app, on_response_prepare=on_response_prepare) resp = StreamResponse() - sig = mock.Mock() - app.on_response_prepare.append(sig) - yield from resp.prepare(req) + await resp.prepare(req) sig.assert_called_with(req, resp) -def test_get_nodelay_unprepared(): - resp = StreamResponse() - with pytest.raises(AssertionError): - resp.tcp_nodelay - - -def test_set_nodelay_unprepared(): - resp = StreamResponse() - with pytest.raises(AssertionError): - resp.set_tcp_nodelay(True) - - -@asyncio.coroutine -def test_get_nodelay_prepared(): - resp = StreamResponse() - writer = mock.Mock() - writer.tcp_nodelay = False - req = make_request('GET', '/', payload_writer=writer) - - yield from resp.prepare(req) - assert not resp.tcp_nodelay - - -def test_set_nodelay_prepared(): - resp = StreamResponse() - writer = mock.Mock() - req = make_request('GET', '/', payload_writer=writer) - - yield from resp.prepare(req) - resp.set_tcp_nodelay(True) - writer.set_tcp_nodelay.assert_called_with(True) - - -def test_get_cork_unprepared(): - resp = StreamResponse() - with pytest.raises(AssertionError): - resp.tcp_cork - - -def test_set_cork_unprepared(): - resp = StreamResponse() - with pytest.raises(AssertionError): - resp.set_tcp_cork(True) - - -@asyncio.coroutine -def test_get_cork_prepared(): - resp = StreamResponse() - writer = mock.Mock() - writer.tcp_cork = False - req = make_request('GET', '/', payload_writer=writer) - - yield from resp.prepare(req) - assert not resp.tcp_cork - - -def test_set_cork_prepared(): - resp = StreamResponse() - writer = mock.Mock() - req = make_request('GET', '/', payload_writer=writer) - - yield from resp.prepare(req) - resp.set_tcp_cork(True) - writer.set_tcp_cork.assert_called_with(True) - - # Response class -def test_response_ctor(): +def test_response_ctor() -> None: resp = Response() assert 200 == resp.status - assert 'OK' == resp.reason + assert "OK" == resp.reason assert resp.body is None assert resp.content_length == 0 - assert 'CONTENT-LENGTH' not in resp.headers + assert "CONTENT-LENGTH" not in resp.headers -def test_ctor_with_headers_and_status(): - resp = Response(body=b'body', status=201, - headers={'Age': '12', 'DATE': 'date'}) +async def test_ctor_with_headers_and_status() -> None: + resp = Response(body=b"body", status=201, headers={"Age": "12", "DATE": "date"}) assert 201 == resp.status - assert b'body' == resp.body - assert resp.headers['AGE'] == '12' + assert b"body" == resp.body + assert resp.headers["AGE"] == "12" - resp._start(mock.Mock(version=HttpVersion11)) + req = make_mocked_request("GET", "/") + await resp._start(req) assert 4 == resp.content_length - assert resp.headers['CONTENT-LENGTH'] == '4' + assert resp.headers["CONTENT-LENGTH"] == "4" -def test_ctor_content_type(): - resp = Response(content_type='application/json') +def test_ctor_content_type() -> None: + resp = Response(content_type="application/json") assert 200 == resp.status - assert 'OK' == resp.reason + assert "OK" == resp.reason assert 0 == resp.content_length - assert (CIMultiDict([('CONTENT-TYPE', 'application/json')]) == - resp.headers) + assert CIMultiDict([("CONTENT-TYPE", "application/json")]) == resp.headers -def test_ctor_text_body_combined(): +def test_ctor_text_body_combined() -> None: with pytest.raises(ValueError): - Response(body=b'123', text='test text') + Response(body=b"123", text="test text") -def test_ctor_text(): - resp = Response(text='test text') +async def test_ctor_text() -> None: + resp = Response(text="test text") assert 200 == resp.status - assert 'OK' == resp.reason + assert "OK" == resp.reason assert 9 == resp.content_length - assert (CIMultiDict( - [('CONTENT-TYPE', 'text/plain; charset=utf-8')]) == resp.headers) + assert CIMultiDict([("CONTENT-TYPE", "text/plain; charset=utf-8")]) == resp.headers - assert resp.body == b'test text' - assert resp.text == 'test text' + assert resp.body == b"test text" + assert resp.text == "test text" - resp.headers['DATE'] = 'date' - resp._start(mock.Mock(version=HttpVersion11)) - assert resp.headers['CONTENT-LENGTH'] == '9' + resp.headers["DATE"] = "date" + req = make_mocked_request("GET", "/", version=HttpVersion11) + await resp._start(req) + assert resp.headers["CONTENT-LENGTH"] == "9" -def test_ctor_charset(): - resp = Response(text='текст', charset='koi8-r') +def test_ctor_charset() -> None: + resp = Response(text="текст", charset="koi8-r") - assert 'текст'.encode('koi8-r') == resp.body - assert 'koi8-r' == resp.charset + assert "текст".encode("koi8-r") == resp.body + assert "koi8-r" == resp.charset -def test_ctor_charset_default_utf8(): - resp = Response(text='test test', charset=None) +def test_ctor_charset_default_utf8() -> None: + resp = Response(text="test test", charset=None) - assert 'utf-8' == resp.charset + assert "utf-8" == resp.charset -def test_ctor_charset_in_content_type(): +def test_ctor_charset_in_content_type() -> None: with pytest.raises(ValueError): - Response(text='test test', content_type='text/plain; charset=utf-8') + Response(text="test test", content_type="text/plain; charset=utf-8") + + +def test_ctor_charset_without_text() -> None: + resp = Response(content_type="text/plain", charset="koi8-r") + + assert "koi8-r" == resp.charset -def test_ctor_charset_without_text(): - resp = Response(content_type='text/plain', charset='koi8-r') +def test_ctor_content_type_with_extra() -> None: + resp = Response(text="test test", content_type="text/plain; version=0.0.4") - assert 'koi8-r' == resp.charset + assert resp.content_type == "text/plain" + assert resp.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8" -def test_ctor_both_content_type_param_and_header_with_text(): +def test_ctor_both_content_type_param_and_header_with_text() -> None: with pytest.raises(ValueError): - Response(headers={'Content-Type': 'application/json'}, - content_type='text/html', text='text') + Response( + headers={"Content-Type": "application/json"}, + content_type="text/html", + text="text", + ) -def test_ctor_both_charset_param_and_header_with_text(): +def test_ctor_both_charset_param_and_header_with_text() -> None: with pytest.raises(ValueError): - Response(headers={'Content-Type': 'application/json'}, - charset='koi8-r', text='text') + Response( + headers={"Content-Type": "application/json"}, charset="koi8-r", text="text" + ) -def test_ctor_both_content_type_param_and_header(): +def test_ctor_both_content_type_param_and_header() -> None: with pytest.raises(ValueError): - Response(headers={'Content-Type': 'application/json'}, - content_type='text/html') + Response(headers={"Content-Type": "application/json"}, content_type="text/html") -def test_ctor_both_charset_param_and_header(): +def test_ctor_both_charset_param_and_header() -> None: with pytest.raises(ValueError): - Response(headers={'Content-Type': 'application/json'}, - charset='koi8-r') + Response(headers={"Content-Type": "application/json"}, charset="koi8-r") -def test_assign_nonbyteish_body(): - resp = Response(body=b'data') +async def test_assign_nonbyteish_body() -> None: + resp = Response(body=b"data") with pytest.raises(ValueError): resp.body = 123 - assert b'data' == resp.body + assert b"data" == resp.body assert 4 == resp.content_length - resp.headers['DATE'] = 'date' - resp._start(mock.Mock(version=HttpVersion11)) - assert resp.headers['CONTENT-LENGTH'] == '4' + resp.headers["DATE"] = "date" + req = make_mocked_request("GET", "/", version=HttpVersion11) + await resp._start(req) + assert resp.headers["CONTENT-LENGTH"] == "4" assert 4 == resp.content_length -def test_assign_nonstr_text(): - resp = Response(text='test') +def test_assign_nonstr_text() -> None: + resp = Response(text="test") with pytest.raises(AssertionError): - resp.text = b'123' - assert b'test' == resp.body + resp.text = b"123" + assert b"test" == resp.body assert 4 == resp.content_length -@asyncio.coroutine -def test_send_headers_for_empty_body(buf, writer): - req = make_request('GET', '/', payload_writer=writer) +def test_response_set_content_length() -> None: resp = Response() + with pytest.raises(RuntimeError): + resp.content_length = 1 - yield from resp.prepare(req) - yield from resp.write_eof() - txt = buf.decode('utf8') - assert re.match('HTTP/1.1 200 OK\r\n' - 'Content-Length: 0\r\n' - 'Content-Type: application/octet-stream\r\n' - 'Date: .+\r\n' - 'Server: .+\r\n\r\n', txt) - - -@asyncio.coroutine -def test_render_with_body(buf, writer): - req = make_request('GET', '/', payload_writer=writer) - resp = Response(body=b'data') - yield from resp.prepare(req) - yield from resp.write_eof() +async def test_send_headers_for_empty_body(buf, writer) -> None: + req = make_request("GET", "/", writer=writer) + resp = Response() - txt = buf.decode('utf8') - assert re.match('HTTP/1.1 200 OK\r\n' - 'Content-Length: 4\r\n' - 'Content-Type: application/octet-stream\r\n' - 'Date: .+\r\n' - 'Server: .+\r\n\r\n' - 'data', txt) + await resp.prepare(req) + await resp.write_eof() + txt = buf.decode("utf8") + assert ( + Matches( + "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "Content-Type: application/octet-stream\r\n" + "Date: .+\r\n" + "Server: .+\r\n\r\n" + ) + == txt + ) + + +async def test_render_with_body(buf, writer) -> None: + req = make_request("GET", "/", writer=writer) + resp = Response(body=b"data") + + await resp.prepare(req) + await resp.write_eof() + + txt = buf.decode("utf8") + assert ( + Matches( + "HTTP/1.1 200 OK\r\n" + "Content-Length: 4\r\n" + "Content-Type: application/octet-stream\r\n" + "Date: .+\r\n" + "Server: .+\r\n\r\n" + "data" + ) + == txt + ) -@asyncio.coroutine -def test_send_set_cookie_header(buf, writer): +async def test_send_set_cookie_header(buf, writer) -> None: resp = Response() - resp.cookies['name'] = 'value' - req = make_request('GET', '/', payload_writer=writer) + resp.cookies["name"] = "value" + req = make_request("GET", "/", writer=writer) + + await resp.prepare(req) + await resp.write_eof() + + txt = buf.decode("utf8") + assert ( + Matches( + "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "Set-Cookie: name=value\r\n" + "Content-Type: application/octet-stream\r\n" + "Date: .+\r\n" + "Server: .+\r\n\r\n" + ) + == txt + ) + - yield from resp.prepare(req) - yield from resp.write_eof() +async def test_consecutive_write_eof() -> None: + writer = mock.Mock() + writer.write_eof = make_mocked_coro() + writer.write_headers = make_mocked_coro() + req = make_request("GET", "/", writer=writer) + data = b"data" + resp = Response(body=data) - txt = buf.decode('utf8') - assert re.match('HTTP/1.1 200 OK\r\n' - 'Content-Length: 0\r\n' - 'Set-Cookie: name=value\r\n' - 'Content-Type: application/octet-stream\r\n' - 'Date: .+\r\n' - 'Server: .+\r\n\r\n', txt) + await resp.prepare(req) + await resp.write_eof() + await resp.write_eof() + writer.write_eof.assert_called_once_with(data) -def test_set_text_with_content_type(): +def test_set_text_with_content_type() -> None: resp = Response() resp.content_type = "text/html" resp.text = "text" @@ -887,143 +1060,165 @@ def test_set_text_with_content_type(): assert "text/html" == resp.content_type -def test_set_text_with_charset(): +def test_set_text_with_charset() -> None: resp = Response() - resp.content_type = 'text/plain' + resp.content_type = "text/plain" resp.charset = "KOI8-R" resp.text = "текст" assert "текст" == resp.text - assert "текст".encode('koi8-r') == resp.body + assert "текст".encode("koi8-r") == resp.body assert "koi8-r" == resp.charset -def test_default_content_type_in_stream_response(): +def test_default_content_type_in_stream_response() -> None: resp = StreamResponse() - assert resp.content_type == 'application/octet-stream' + assert resp.content_type == "application/octet-stream" -def test_default_content_type_in_response(): +def test_default_content_type_in_response() -> None: resp = Response() - assert resp.content_type == 'application/octet-stream' + assert resp.content_type == "application/octet-stream" -def test_content_type_with_set_text(): - resp = Response(text='text') - assert resp.content_type == 'text/plain' +def test_content_type_with_set_text() -> None: + resp = Response(text="text") + assert resp.content_type == "text/plain" -def test_content_type_with_set_body(): - resp = Response(body=b'body') - assert resp.content_type == 'application/octet-stream' +def test_content_type_with_set_body() -> None: + resp = Response(body=b"body") + assert resp.content_type == "application/octet-stream" -def test_started_when_not_started(): +def test_started_when_not_started() -> None: resp = StreamResponse() assert not resp.prepared -@asyncio.coroutine -def test_started_when_started(): +async def test_started_when_started() -> None: resp = StreamResponse() - yield from resp.prepare(make_request('GET', '/')) + await resp.prepare(make_request("GET", "/")) assert resp.prepared -@asyncio.coroutine -def test_drain_before_start(): +async def test_drain_before_start() -> None: resp = StreamResponse() with pytest.raises(AssertionError): - yield from resp.drain() + await resp.drain() -@asyncio.coroutine -def test_changing_status_after_prepare_raises(): +async def test_changing_status_after_prepare_raises() -> None: resp = StreamResponse() - yield from resp.prepare(make_request('GET', '/')) + await resp.prepare(make_request("GET", "/")) with pytest.raises(AssertionError): resp.set_status(400) -def test_nonstr_text_in_ctor(): +def test_nonstr_text_in_ctor() -> None: with pytest.raises(TypeError): - Response(text=b'data') + Response(text=b"data") -def test_text_in_ctor_with_content_type(): - resp = Response(text='data', content_type='text/html') - assert 'data' == resp.text - assert 'text/html' == resp.content_type +def test_text_in_ctor_with_content_type() -> None: + resp = Response(text="data", content_type="text/html") + assert "data" == resp.text + assert "text/html" == resp.content_type -def test_text_in_ctor_with_content_type_header(): - resp = Response(text='текст', - headers={'Content-Type': 'text/html; charset=koi8-r'}) - assert 'текст'.encode('koi8-r') == resp.body - assert 'text/html' == resp.content_type - assert 'koi8-r' == resp.charset +def test_text_in_ctor_with_content_type_header() -> None: + resp = Response(text="текст", headers={"Content-Type": "text/html; charset=koi8-r"}) + assert "текст".encode("koi8-r") == resp.body + assert "text/html" == resp.content_type + assert "koi8-r" == resp.charset -def test_text_in_ctor_with_content_type_header_multidict(): - headers = CIMultiDict({'Content-Type': 'text/html; charset=koi8-r'}) - resp = Response(text='текст', - headers=headers) - assert 'текст'.encode('koi8-r') == resp.body - assert 'text/html' == resp.content_type - assert 'koi8-r' == resp.charset +def test_text_in_ctor_with_content_type_header_multidict() -> None: + headers = CIMultiDict({"Content-Type": "text/html; charset=koi8-r"}) + resp = Response(text="текст", headers=headers) + assert "текст".encode("koi8-r") == resp.body + assert "text/html" == resp.content_type + assert "koi8-r" == resp.charset -def test_body_in_ctor_with_content_type_header_multidict(): - headers = CIMultiDict({'Content-Type': 'text/html; charset=koi8-r'}) - resp = Response(body='текст'.encode('koi8-r'), - headers=headers) - assert 'текст'.encode('koi8-r') == resp.body - assert 'text/html' == resp.content_type - assert 'koi8-r' == resp.charset +def test_body_in_ctor_with_content_type_header_multidict() -> None: + headers = CIMultiDict({"Content-Type": "text/html; charset=koi8-r"}) + resp = Response(body="текст".encode("koi8-r"), headers=headers) + assert "текст".encode("koi8-r") == resp.body + assert "text/html" == resp.content_type + assert "koi8-r" == resp.charset -def test_text_with_empty_payload(): +def test_text_with_empty_payload() -> None: resp = Response(status=200) assert resp.body is None assert resp.text is None -def test_response_with_content_length_header_without_body(): - resp = Response(headers={'Content-Length': 123}) +def test_response_with_content_length_header_without_body() -> None: + resp = Response(headers={"Content-Length": 123}) assert resp.content_length == 123 -class TestJSONResponse: +def test_response_with_immutable_headers() -> None: + resp = Response( + text="text", headers=CIMultiDictProxy(CIMultiDict({"Header": "Value"})) + ) + assert resp.headers == { + "Header": "Value", + "Content-Type": "text/plain; charset=utf-8", + } + + +async def test_response_prepared_after_header_preparation() -> None: + req = make_request("GET", "/") + resp = StreamResponse() + await resp.prepare(req) - def test_content_type_is_application_json_by_default(self): - resp = json_response('') - assert 'application/json' == resp.content_type + assert type(resp.headers["Server"]) is str - def test_passing_text_only(self): - resp = json_response(text=json.dumps('jaysawn')) - assert resp.text == json.dumps('jaysawn') + async def _strip_server(req, res): + assert "Server" in res.headers - def test_data_and_text_raises_value_error(self): + if "Server" in res.headers: + del res.headers["Server"] + + app = mock.Mock() + sig = signals.Signal(app) + sig.append(_strip_server) + + req = make_request("GET", "/", on_response_prepare=sig, app=app) + resp = StreamResponse() + await resp.prepare(req) + + assert "Server" not in resp.headers + + +class TestJSONResponse: + def test_content_type_is_application_json_by_default(self) -> None: + resp = json_response("") + assert "application/json" == resp.content_type + + def test_passing_text_only(self) -> None: + resp = json_response(text=json.dumps("jaysawn")) + assert resp.text == json.dumps("jaysawn") + + def test_data_and_text_raises_value_error(self) -> None: with pytest.raises(ValueError) as excinfo: - json_response(data='foo', text='bar') - expected_message = ( - 'only one of data, text, or body should be specified' - ) + json_response(data="foo", text="bar") + expected_message = "only one of data, text, or body should be specified" assert expected_message == excinfo.value.args[0] - def test_data_and_body_raises_value_error(self): + def test_data_and_body_raises_value_error(self) -> None: with pytest.raises(ValueError) as excinfo: - json_response(data='foo', body=b'bar') - expected_message = ( - 'only one of data, text, or body should be specified' - ) + json_response(data="foo", body=b"bar") + expected_message = "only one of data, text, or body should be specified" assert expected_message == excinfo.value.args[0] - def test_text_is_json_encoded(self): - resp = json_response({'foo': 42}) - assert json.dumps({'foo': 42}) == resp.text + def test_text_is_json_encoded(self) -> None: + resp = json_response({"foo": 42}) + assert json.dumps({"foo": 42}) == resp.text - def test_content_type_is_overrideable(self): - resp = json_response({'foo': 42}, - content_type='application/vnd.json+api') - assert 'application/vnd.json+api' == resp.content_type + def test_content_type_is_overrideable(self) -> None: + resp = json_response({"foo": 42}, content_type="application/vnd.json+api") + assert "application/vnd.json+api" == resp.content_type diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py new file mode 100644 index 00000000000..af6df1aa8e0 --- /dev/null +++ b/tests/test_web_runner.py @@ -0,0 +1,164 @@ +import asyncio +import platform +import signal +from unittest.mock import patch + +import pytest + +from aiohttp import web +from aiohttp.test_utils import get_unused_port_socket + + +@pytest.fixture +def app(): + return web.Application() + + +@pytest.fixture +def make_runner(loop, app): + asyncio.set_event_loop(loop) + runners = [] + + def go(**kwargs): + runner = web.AppRunner(app, **kwargs) + runners.append(runner) + return runner + + yield go + for runner in runners: + loop.run_until_complete(runner.cleanup()) + + +async def test_site_for_nonfrozen_app(make_runner) -> None: + runner = make_runner() + with pytest.raises(RuntimeError): + web.TCPSite(runner) + assert len(runner.sites) == 0 + + +@pytest.mark.skipif( + platform.system() == "Windows", reason="the test is not valid for Windows" +) +async def test_runner_setup_handle_signals(make_runner) -> None: + runner = make_runner(handle_signals=True) + await runner.setup() + assert signal.getsignal(signal.SIGTERM) is not signal.SIG_DFL + await runner.cleanup() + assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL + + +@pytest.mark.skipif( + platform.system() == "Windows", reason="the test is not valid for Windows" +) +async def test_runner_setup_without_signal_handling(make_runner) -> None: + runner = make_runner(handle_signals=False) + await runner.setup() + assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL + await runner.cleanup() + assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL + + +async def test_site_double_added(make_runner) -> None: + _sock = get_unused_port_socket("127.0.0.1") + runner = make_runner() + await runner.setup() + site = web.SockSite(runner, _sock) + await site.start() + with pytest.raises(RuntimeError): + await site.start() + + assert len(runner.sites) == 1 + + +async def test_site_stop_not_started(make_runner) -> None: + runner = make_runner() + await runner.setup() + site = web.TCPSite(runner) + with pytest.raises(RuntimeError): + await site.stop() + + assert len(runner.sites) == 0 + + +async def test_custom_log_format(make_runner) -> None: + runner = make_runner(access_log_format="abc") + await runner.setup() + assert runner.server._kwargs["access_log_format"] == "abc" + + +async def test_unreg_site(make_runner) -> None: + runner = make_runner() + await runner.setup() + site = web.TCPSite(runner) + with pytest.raises(RuntimeError): + runner._unreg_site(site) + + +async def test_app_property(make_runner, app) -> None: + runner = make_runner() + assert runner.app is app + + +def test_non_app() -> None: + with pytest.raises(TypeError): + web.AppRunner(object()) + + +@pytest.mark.skipif( + platform.system() == "Windows", reason="Unix socket support is required" +) +async def test_addresses(make_runner, shorttmpdir) -> None: + _sock = get_unused_port_socket("127.0.0.1") + runner = make_runner() + await runner.setup() + tcp = web.SockSite(runner, _sock) + await tcp.start() + path = str(shorttmpdir / "tmp.sock") + unix = web.UnixSite(runner, path) + await unix.start() + actual_addrs = runner.addresses + expected_host, expected_post = _sock.getsockname()[:2] + assert actual_addrs == [(expected_host, expected_post), path] + + +@pytest.mark.skipif( + platform.system() != "Windows", reason="Proactor Event loop present only in Windows" +) +async def test_named_pipe_runner_wrong_loop(app, selector_loop, pipe_name) -> None: + runner = web.AppRunner(app) + await runner.setup() + with pytest.raises(RuntimeError): + web.NamedPipeSite(runner, pipe_name) + + +@pytest.mark.skipif( + platform.system() != "Windows", reason="Proactor Event loop present only in Windows" +) +async def test_named_pipe_runner_proactor_loop(proactor_loop, app, pipe_name) -> None: + runner = web.AppRunner(app) + await runner.setup() + pipe = web.NamedPipeSite(runner, pipe_name) + await pipe.start() + await runner.cleanup() + + +async def test_tcpsite_default_host(make_runner): + runner = make_runner() + await runner.setup() + site = web.TCPSite(runner) + assert site.name == "http://0.0.0.0:8080" + + calls = [] + + async def mock_create_server(*args, **kwargs): + calls.append((args, kwargs)) + + with patch("asyncio.get_event_loop") as mock_get_loop: + mock_get_loop.return_value.create_server = mock_create_server + await site.start() + + assert len(calls) == 1 + server, host, port = calls[0][0] + assert server is runner.server + assert host is None + assert port == 8080 diff --git a/tests/test_web_sendfile.py b/tests/test_web_sendfile.py index b3e43456f1c..48353547abe 100644 --- a/tests/test_web_sendfile.py +++ b/tests/test_web_sendfile.py @@ -1,81 +1,13 @@ from unittest import mock -from aiohttp import hdrs, helpers +from aiohttp import hdrs from aiohttp.test_utils import make_mocked_coro, make_mocked_request -from aiohttp.web_fileresponse import FileResponse, SendfilePayloadWriter - - -def test_static_handle_eof(loop): - fake_loop = mock.Mock() - with mock.patch('aiohttp.web_fileresponse.os') as m_os: - out_fd = 30 - in_fd = 31 - fut = helpers.create_future(loop) - m_os.sendfile.return_value = 0 - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) - writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) - m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) - assert fut.done() - assert fut.result() is None - assert not fake_loop.add_writer.called - assert not fake_loop.remove_writer.called - - -def test_static_handle_again(loop): - fake_loop = mock.Mock() - with mock.patch('aiohttp.web_fileresponse.os') as m_os: - out_fd = 30 - in_fd = 31 - fut = helpers.create_future(loop) - m_os.sendfile.side_effect = BlockingIOError() - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) - writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) - m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) - assert not fut.done() - fake_loop.add_writer.assert_called_with(out_fd, - writer._sendfile_cb, - fut, out_fd, in_fd, 0, 100, - fake_loop, True) - assert not fake_loop.remove_writer.called - - -def test_static_handle_exception(loop): - fake_loop = mock.Mock() - with mock.patch('aiohttp.web_fileresponse.os') as m_os: - out_fd = 30 - in_fd = 31 - fut = helpers.create_future(loop) - exc = OSError() - m_os.sendfile.side_effect = exc - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) - writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) - m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) - assert fut.done() - assert exc is fut.exception() - assert not fake_loop.add_writer.called - assert not fake_loop.remove_writer.called - - -def test__sendfile_cb_return_on_cancelling(loop): - fake_loop = mock.Mock() - with mock.patch('aiohttp.web_fileresponse.os') as m_os: - out_fd = 30 - in_fd = 31 - fut = helpers.create_future(loop) - fut.cancel() - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) - writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) - assert fut.done() - assert not fake_loop.add_writer.called - assert not fake_loop.remove_writer.called - assert not m_os.sendfile.called - - -def test_using_gzip_if_header_present_and_file_available(loop): +from aiohttp.web_fileresponse import FileResponse + + +def test_using_gzip_if_header_present_and_file_available(loop) -> None: request = make_mocked_request( - 'GET', 'http://python.org/logo.png', headers={ - hdrs.ACCEPT_ENCODING: 'gzip' - } + "GET", "http://python.org/logo.png", headers={hdrs.ACCEPT_ENCODING: "gzip"} ) gz_filepath = mock.Mock() @@ -85,7 +17,7 @@ def test_using_gzip_if_header_present_and_file_available(loop): gz_filepath.stat.st_size = 1024 filepath = mock.Mock() - filepath.name = 'logo.png' + filepath.name = "logo.png" filepath.open = mock.mock_open() filepath.with_name.return_value = gz_filepath @@ -98,18 +30,15 @@ def test_using_gzip_if_header_present_and_file_available(loop): assert gz_filepath.open.called -def test_gzip_if_header_not_present_and_file_available(loop): - request = make_mocked_request( - 'GET', 'http://python.org/logo.png', headers={ - } - ) +def test_gzip_if_header_not_present_and_file_available(loop) -> None: + request = make_mocked_request("GET", "http://python.org/logo.png", headers={}) gz_filepath = mock.Mock() gz_filepath.open = mock.mock_open() gz_filepath.is_file.return_value = True filepath = mock.Mock() - filepath.name = 'logo.png' + filepath.name = "logo.png" filepath.open = mock.mock_open() filepath.with_name.return_value = gz_filepath filepath.stat.return_value = mock.MagicMock() @@ -124,18 +53,15 @@ def test_gzip_if_header_not_present_and_file_available(loop): assert not gz_filepath.open.called -def test_gzip_if_header_not_present_and_file_not_available(loop): - request = make_mocked_request( - 'GET', 'http://python.org/logo.png', headers={ - } - ) +def test_gzip_if_header_not_present_and_file_not_available(loop) -> None: + request = make_mocked_request("GET", "http://python.org/logo.png", headers={}) gz_filepath = mock.Mock() gz_filepath.open = mock.mock_open() gz_filepath.is_file.return_value = False filepath = mock.Mock() - filepath.name = 'logo.png' + filepath.name = "logo.png" filepath.open = mock.mock_open() filepath.with_name.return_value = gz_filepath filepath.stat.return_value = mock.MagicMock() @@ -150,11 +76,9 @@ def test_gzip_if_header_not_present_and_file_not_available(loop): assert not gz_filepath.open.called -def test_gzip_if_header_present_and_file_not_available(loop): +def test_gzip_if_header_present_and_file_not_available(loop) -> None: request = make_mocked_request( - 'GET', 'http://python.org/logo.png', headers={ - hdrs.ACCEPT_ENCODING: 'gzip' - } + "GET", "http://python.org/logo.png", headers={hdrs.ACCEPT_ENCODING: "gzip"} ) gz_filepath = mock.Mock() @@ -162,7 +86,7 @@ def test_gzip_if_header_present_and_file_not_available(loop): gz_filepath.is_file.return_value = False filepath = mock.Mock() - filepath.name = 'logo.png' + filepath.name = "logo.png" filepath.open = mock.mock_open() filepath.with_name.return_value = gz_filepath filepath.stat.return_value = mock.MagicMock() @@ -175,3 +99,20 @@ def test_gzip_if_header_present_and_file_not_available(loop): assert filepath.open.called assert not gz_filepath.open.called + + +def test_status_controlled_by_user(loop) -> None: + request = make_mocked_request("GET", "http://python.org/logo.png", headers={}) + + filepath = mock.Mock() + filepath.name = "logo.png" + filepath.open = mock.mock_open() + filepath.stat.return_value = mock.MagicMock() + filepath.stat.st_size = 1024 + + file_sender = FileResponse(filepath, status=203) + file_sender._sendfile = make_mocked_coro(None) + + loop.run_until_complete(file_sender.prepare(request)) + + assert file_sender._status == 203 diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index fdea28bde64..60a542b83cb 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -1,6 +1,8 @@ import asyncio import os import pathlib +import socket +import zlib import pytest @@ -9,274 +11,331 @@ try: import ssl -except: - ssl = False +except ImportError: + ssl = None # type: ignore -@pytest.fixture(params=['sendfile', 'fallback'], ids=['sendfile', 'fallback']) -def sender(request): +@pytest.fixture +def loop_without_sendfile(loop): + def sendfile(*args, **kwargs): + raise NotImplementedError + + loop.sendfile = sendfile + return loop + + +@pytest.fixture(params=["sendfile", "no_sendfile"], ids=["sendfile", "no_sendfile"]) +def sender(request, loop_without_sendfile): def maker(*args, **kwargs): ret = web.FileResponse(*args, **kwargs) - if request.param == 'fallback': - ret._sendfile = ret._sendfile_fallback + if request.param == "no_sendfile": + asyncio.set_event_loop(loop_without_sendfile) return ret + return maker -@asyncio.coroutine -def test_static_file_ok(loop, test_client, sender): - filepath = pathlib.Path(__file__).parent / 'data.unknown_mime_type' +async def test_static_file_ok(aiohttp_client, sender) -> None: + filepath = pathlib.Path(__file__).parent / "data.unknown_mime_type" + + async def handler(request): + return sender(filepath) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert resp.status == 200 + txt = await resp.text() + assert "file content" == txt.rstrip() + assert "application/octet-stream" == resp.headers["Content-Type"] + assert resp.headers.get("Content-Encoding") is None + await resp.release() + + +async def test_zero_bytes_file_ok(aiohttp_client, sender) -> None: + filepath = pathlib.Path(__file__).parent / "data.zero_bytes" - @asyncio.coroutine - def handler(request): + async def handler(request): return sender(filepath) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert resp.status == 200 - txt = yield from resp.text() - assert 'file content' == txt.rstrip() - assert 'application/octet-stream' == resp.headers['Content-Type'] - assert resp.headers.get('Content-Encoding') is None - yield from resp.release() + txt = await resp.text() + assert "" == txt.rstrip() + assert "application/octet-stream" == resp.headers["Content-Type"] + assert resp.headers.get("Content-Encoding") is None + await resp.release() -@asyncio.coroutine -def test_static_file_ok_string_path(loop, test_client, sender): - filepath = pathlib.Path(__file__).parent / 'data.unknown_mime_type' +async def test_static_file_ok_string_path(aiohttp_client, sender) -> None: + filepath = pathlib.Path(__file__).parent / "data.unknown_mime_type" - @asyncio.coroutine - def handler(request): + async def handler(request): return sender(str(filepath)) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert resp.status == 200 - txt = yield from resp.text() - assert 'file content' == txt.rstrip() - assert 'application/octet-stream' == resp.headers['Content-Type'] - assert resp.headers.get('Content-Encoding') is None - yield from resp.release() + txt = await resp.text() + assert "file content" == txt.rstrip() + assert "application/octet-stream" == resp.headers["Content-Type"] + assert resp.headers.get("Content-Encoding") is None + await resp.release() -@asyncio.coroutine -def test_static_file_not_exists(loop, test_client): +async def test_static_file_not_exists(aiohttp_client) -> None: app = web.Application() - client = yield from test_client(lambda loop: app) + client = await aiohttp_client(app) - resp = yield from client.get('/fake') + resp = await client.get("/fake") assert resp.status == 404 - yield from resp.release() + await resp.release() -@asyncio.coroutine -def test_static_file_name_too_long(loop, test_client): +async def test_static_file_name_too_long(aiohttp_client) -> None: app = web.Application() - client = yield from test_client(lambda loop: app) + client = await aiohttp_client(app) - resp = yield from client.get('/x*500') + resp = await client.get("/x*500") assert resp.status == 404 - yield from resp.release() + await resp.release() -@asyncio.coroutine -def test_static_file_upper_directory(loop, test_client): +async def test_static_file_upper_directory(aiohttp_client) -> None: app = web.Application() - client = yield from test_client(lambda loop: app) + client = await aiohttp_client(app) - resp = yield from client.get('/../../') + resp = await client.get("/../../") assert resp.status == 404 - yield from resp.release() + await resp.release() -@asyncio.coroutine -def test_static_file_with_content_type(loop, test_client, sender): - filepath = (pathlib.Path(__file__).parent / 'aiohttp.jpg') +async def test_static_file_with_content_type(aiohttp_client, sender) -> None: + filepath = pathlib.Path(__file__).parent / "aiohttp.jpg" - @asyncio.coroutine - def handler(request): + async def handler(request): return sender(filepath, chunk_size=16) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert resp.status == 200 + body = await resp.read() + with filepath.open("rb") as f: + content = f.read() + assert content == body + assert resp.headers["Content-Type"] == "image/jpeg" + assert resp.headers.get("Content-Encoding") is None + resp.close() + + +async def test_static_file_custom_content_type(aiohttp_client, sender) -> None: + filepath = pathlib.Path(__file__).parent / "hello.txt.gz" + + async def handler(request): + resp = sender(filepath, chunk_size=16) + resp.content_type = "application/pdf" + return resp + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert resp.status == 200 - body = yield from resp.read() - with filepath.open('rb') as f: + body = await resp.read() + with filepath.open("rb") as f: content = f.read() assert content == body - assert resp.headers['Content-Type'] == 'image/jpeg' - assert resp.headers.get('Content-Encoding') is None + assert resp.headers["Content-Type"] == "application/pdf" + assert resp.headers.get("Content-Encoding") is None resp.close() -@asyncio.coroutine -def test_static_file_with_content_encoding(loop, test_client, sender): - filepath = pathlib.Path(__file__).parent / 'hello.txt.gz' +async def test_static_file_custom_content_type_compress(aiohttp_client, sender): + filepath = pathlib.Path(__file__).parent / "hello.txt" - @asyncio.coroutine - def handler(request): + async def handler(request): + resp = sender(filepath, chunk_size=16) + resp.content_type = "application/pdf" + return resp + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert resp.status == 200 + body = await resp.read() + assert b"hello aiohttp\n" == body + assert resp.headers["Content-Type"] == "application/pdf" + assert resp.headers.get("Content-Encoding") == "gzip" + resp.close() + + +async def test_static_file_with_content_encoding(aiohttp_client, sender) -> None: + filepath = pathlib.Path(__file__).parent / "hello.txt.gz" + + async def handler(request): return sender(filepath) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - body = yield from resp.read() - assert b'hello aiohttp\n' == body - ct = resp.headers['CONTENT-TYPE'] - assert 'text/plain' == ct - encoding = resp.headers['CONTENT-ENCODING'] - assert 'gzip' == encoding + body = await resp.read() + assert b"hello aiohttp\n" == body + ct = resp.headers["CONTENT-TYPE"] + assert "text/plain" == ct + encoding = resp.headers["CONTENT-ENCODING"] + assert "gzip" == encoding resp.close() -@asyncio.coroutine -def test_static_file_if_modified_since(loop, test_client, sender): - filename = 'data.unknown_mime_type' +async def test_static_file_if_modified_since(aiohttp_client, sender) -> None: + filename = "data.unknown_mime_type" filepath = pathlib.Path(__file__).parent / filename - @asyncio.coroutine - def handler(request): + async def handler(request): return sender(filepath) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert 200 == resp.status - lastmod = resp.headers.get('Last-Modified') + lastmod = resp.headers.get("Last-Modified") assert lastmod is not None resp.close() - resp = yield from client.get('/', headers={'If-Modified-Since': lastmod}) + resp = await client.get("/", headers={"If-Modified-Since": lastmod}) + body = await resp.read() assert 304 == resp.status + assert resp.headers.get("Content-Length") is None + assert b"" == body resp.close() -@asyncio.coroutine -def test_static_file_if_modified_since_past_date(loop, test_client, sender): - filename = 'data.unknown_mime_type' +async def test_static_file_if_modified_since_past_date(aiohttp_client, sender) -> None: + filename = "data.unknown_mime_type" filepath = pathlib.Path(__file__).parent / filename - @asyncio.coroutine - def handler(request): + async def handler(request): return sender(filepath) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - lastmod = 'Mon, 1 Jan 1990 01:01:01 GMT' + lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" - resp = yield from client.get('/', headers={'If-Modified-Since': lastmod}) + resp = await client.get("/", headers={"If-Modified-Since": lastmod}) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_static_file_if_modified_since_invalid_date(loop, test_client, sender): - filename = 'data.unknown_mime_type' +async def test_static_file_if_modified_since_invalid_date(aiohttp_client, sender): + filename = "data.unknown_mime_type" filepath = pathlib.Path(__file__).parent / filename - @asyncio.coroutine - def handler(request): + async def handler(request): return sender(filepath) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - lastmod = 'not a valid HTTP-date' + lastmod = "not a valid HTTP-date" - resp = yield from client.get('/', headers={'If-Modified-Since': lastmod}) + resp = await client.get("/", headers={"If-Modified-Since": lastmod}) assert 200 == resp.status resp.close() -@asyncio.coroutine -def test_static_file_if_modified_since_future_date(loop, test_client, sender): - filename = 'data.unknown_mime_type' +async def test_static_file_if_modified_since_future_date(aiohttp_client, sender): + filename = "data.unknown_mime_type" filepath = pathlib.Path(__file__).parent / filename - @asyncio.coroutine - def handler(request): + async def handler(request): return sender(filepath) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - lastmod = 'Fri, 31 Dec 9999 23:59:59 GMT' + lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" - resp = yield from client.get('/', headers={'If-Modified-Since': lastmod}) + resp = await client.get("/", headers={"If-Modified-Since": lastmod}) + body = await resp.read() assert 304 == resp.status + assert resp.headers.get("Content-Length") is None + assert b"" == body resp.close() @pytest.mark.skipif(not ssl, reason="ssl not supported") -@asyncio.coroutine -def test_static_file_ssl(loop, test_server, test_client): +async def test_static_file_ssl( + aiohttp_server, + ssl_ctx, + aiohttp_client, + client_ssl_ctx, +) -> None: dirname = os.path.dirname(__file__) - filename = 'data.unknown_mime_type' - ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ssl_ctx.load_cert_chain( - os.path.join(dirname, 'sample.crt'), - os.path.join(dirname, 'sample.key') - ) + filename = "data.unknown_mime_type" app = web.Application() - app.router.add_static('/static', dirname) - server = yield from test_server(app, ssl=ssl_ctx) - conn = aiohttp.TCPConnector(verify_ssl=False, loop=loop) - client = yield from test_client(server, connector=conn) + app.router.add_static("/static", dirname) + server = await aiohttp_server(app, ssl=ssl_ctx) + conn = aiohttp.TCPConnector(ssl=client_ssl_ctx) + client = await aiohttp_client(server, connector=conn) - resp = yield from client.get('/static/'+filename) + resp = await client.get("/static/" + filename) assert 200 == resp.status - txt = yield from resp.text() - assert 'file content' == txt.rstrip() - ct = resp.headers['CONTENT-TYPE'] - assert 'application/octet-stream' == ct - assert resp.headers.get('CONTENT-ENCODING') is None + txt = await resp.text() + assert "file content" == txt.rstrip() + ct = resp.headers["CONTENT-TYPE"] + assert "application/octet-stream" == ct + assert resp.headers.get("CONTENT-ENCODING") is None -@asyncio.coroutine -def test_static_file_directory_traversal_attack(loop, test_client): +async def test_static_file_directory_traversal_attack(aiohttp_client) -> None: dirname = os.path.dirname(__file__) - relpath = '../README.rst' + relpath = "../README.rst" assert os.path.isfile(os.path.join(dirname, relpath)) app = web.Application() - app.router.add_static('/static', dirname) - client = yield from test_client(app) + app.router.add_static("/static", dirname) + client = await aiohttp_client(app) - resp = yield from client.get('/static/'+relpath) + resp = await client.get("/static/" + relpath) assert 404 == resp.status - url_relpath2 = '/static/dir/../' + relpath - resp = yield from client.get(url_relpath2) + url_relpath2 = "/static/dir/../" + relpath + resp = await client.get(url_relpath2) assert 404 == resp.status - url_abspath = \ - '/static/' + os.path.abspath(os.path.join(dirname, relpath)) - resp = yield from client.get(url_abspath) - assert 404 == resp.status + url_abspath = "/static/" + os.path.abspath(os.path.join(dirname, relpath)) + resp = await client.get(url_abspath) + assert 403 == resp.status -def test_static_route_path_existence_check(): +def test_static_route_path_existence_check() -> None: directory = os.path.dirname(__file__) web.StaticResource("/", directory) @@ -285,33 +344,32 @@ def test_static_route_path_existence_check(): web.StaticResource("/", nodirectory) -@asyncio.coroutine -def test_static_file_huge(loop, test_client, tmpdir): - filename = 'huge_data.unknown_mime_type' +async def test_static_file_huge(aiohttp_client, tmpdir) -> None: + filename = "huge_data.unknown_mime_type" - # fill 100MB file - with tmpdir.join(filename).open('w') as f: - for i in range(1024*20): - f.write(chr(i % 64 + 0x20) * 1024) + # fill 20MB file + with tmpdir.join(filename).open("wb") as f: + for i in range(1024 * 20): + f.write((chr(i % 64 + 0x20) * 1024).encode()) file_st = os.stat(str(tmpdir.join(filename))) app = web.Application() - app.router.add_static('/static', str(tmpdir)) - client = yield from test_client(app) + app.router.add_static("/static", str(tmpdir)) + client = await aiohttp_client(app) - resp = yield from client.get('/static/'+filename) + resp = await client.get("/static/" + filename) assert 200 == resp.status - ct = resp.headers['CONTENT-TYPE'] - assert 'application/octet-stream' == ct - assert resp.headers.get('CONTENT-ENCODING') is None - assert int(resp.headers.get('CONTENT-LENGTH')) == file_st.st_size + ct = resp.headers["CONTENT-TYPE"] + assert "application/octet-stream" == ct + assert resp.headers.get("CONTENT-ENCODING") is None + assert int(resp.headers.get("CONTENT-LENGTH")) == file_st.st_size - f = tmpdir.join(filename).open('rb') + f = tmpdir.join(filename).open("rb") off = 0 cnt = 0 while off < file_st.st_size: - chunk = yield from resp.content.readany() + chunk = await resp.content.readany() expected = f.read(len(chunk)) assert chunk == expected off += len(chunk) @@ -319,45 +377,53 @@ def test_static_file_huge(loop, test_client, tmpdir): f.close() -@asyncio.coroutine -def test_static_file_range(loop, test_client, sender): - filepath = (pathlib.Path(__file__).parent.parent / 'LICENSE.txt') +async def test_static_file_range(aiohttp_client, sender) -> None: + filepath = pathlib.Path(__file__).parent.parent / "LICENSE.txt" - @asyncio.coroutine - def handler(request): + filesize = filepath.stat().st_size + + async def handler(request): return sender(filepath, chunk_size=16) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - with filepath.open('rb') as f: + with filepath.open("rb") as f: content = f.read() # Ensure the whole file requested in parts is correct - responses = yield from asyncio.gather( - client.get('/', headers={'Range': 'bytes=0-999'}), - client.get('/', headers={'Range': 'bytes=1000-1999'}), - client.get('/', headers={'Range': 'bytes=2000-'}), - loop=loop + responses = await asyncio.gather( + client.get("/", headers={"Range": "bytes=0-999"}), + client.get("/", headers={"Range": "bytes=1000-1999"}), + client.get("/", headers={"Range": "bytes=2000-"}), ) assert len(responses) == 3 - assert responses[0].status == 206, \ - "failed 'bytes=0-999': %s" % responses[0].reason - assert responses[1].status == 206, \ + assert responses[0].status == 206, "failed 'bytes=0-999': %s" % responses[0].reason + assert responses[0].headers["Content-Range"] == "bytes 0-999/{}".format( + filesize + ), "failed: Content-Range Error" + assert responses[1].status == 206, ( "failed 'bytes=1000-1999': %s" % responses[1].reason - assert responses[2].status == 206, \ - "failed 'bytes=2000-': %s" % responses[2].reason - - body = yield from asyncio.gather( + ) + assert responses[1].headers["Content-Range"] == "bytes 1000-1999/{}".format( + filesize + ), "failed: Content-Range Error" + assert responses[2].status == 206, "failed 'bytes=2000-': %s" % responses[2].reason + assert responses[2].headers["Content-Range"] == "bytes 2000-{}/{}".format( + filesize - 1, filesize + ), "failed: Content-Range Error" + + body = await asyncio.gather( *(resp.read() for resp in responses), - loop=loop ) - assert len(body[0]) == 1000, \ - "failed 'bytes=0-999', received %d bytes" % len(body[0]) - assert len(body[1]) == 1000, \ - "failed 'bytes=1000-1999', received %d bytes" % len(body[1]) + assert len(body[0]) == 1000, "failed 'bytes=0-999', received %d bytes" % len( + body[0] + ) + assert len(body[1]) == 1000, "failed 'bytes=1000-1999', received %d bytes" % len( + body[1] + ) responses[0].close() responses[1].close() responses[2].close() @@ -365,117 +431,407 @@ def handler(request): assert content == b"".join(body) -@asyncio.coroutine -def test_static_file_range_end_bigger_than_size(loop, test_client, sender): - filepath = (pathlib.Path(__file__).parent / 'aiohttp.png') +async def test_static_file_range_end_bigger_than_size(aiohttp_client, sender): + filepath = pathlib.Path(__file__).parent / "aiohttp.png" - @asyncio.coroutine - def handler(request): + async def handler(request): return sender(filepath, chunk_size=16) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - with filepath.open('rb') as f: + with filepath.open("rb") as f: content = f.read() # Ensure the whole file requested in parts is correct - response = yield from client.get( - '/', headers={'Range': 'bytes=61000-62000'}) + response = await client.get("/", headers={"Range": "bytes=54000-55000"}) - assert response.status == 206, \ - "failed 'bytes=61000-62000': %s" % response.reason + assert response.status == 206, ( + "failed 'bytes=54000-55000': %s" % response.reason + ) + assert ( + response.headers["Content-Range"] == "bytes 54000-54996/54997" + ), "failed: Content-Range Error" - body = yield from response.read() - assert len(body) == 108, \ - "failed 'bytes=0-999', received %d bytes" % len(body[0]) + body = await response.read() + assert len(body) == 997, "failed 'bytes=54000-55000', received %d bytes" % len( + body + ) - assert content[61000:] == body + assert content[54000:] == body -@asyncio.coroutine -def test_static_file_range_beyond_eof(loop, test_client, sender): - filepath = (pathlib.Path(__file__).parent / 'aiohttp.png') +async def test_static_file_range_beyond_eof(aiohttp_client, sender) -> None: + filepath = pathlib.Path(__file__).parent / "aiohttp.png" - @asyncio.coroutine - def handler(request): + async def handler(request): return sender(filepath, chunk_size=16) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) # Ensure the whole file requested in parts is correct - response = yield from client.get( - '/', headers={'Range': 'bytes=1000000-1200000'}) + response = await client.get("/", headers={"Range": "bytes=1000000-1200000"}) - assert response.status == 206, \ + assert response.status == 416, ( "failed 'bytes=1000000-1200000': %s" % response.reason - assert response.headers['content-length'] == '0' + ) -@asyncio.coroutine -def test_static_file_range_tail(loop, test_client, sender): - filepath = (pathlib.Path(__file__).parent / 'aiohttp.png') +async def test_static_file_range_tail(aiohttp_client, sender) -> None: + filepath = pathlib.Path(__file__).parent / "aiohttp.png" - @asyncio.coroutine - def handler(request): + async def handler(request): return sender(filepath, chunk_size=16) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - with filepath.open('rb') as f: + with filepath.open("rb") as f: content = f.read() # Ensure the tail of the file is correct - resp = yield from client.get('/', headers={'Range': 'bytes=-500'}) + resp = await client.get("/", headers={"Range": "bytes=-500"}) assert resp.status == 206, resp.reason - body4 = yield from resp.read() + assert ( + resp.headers["Content-Range"] == "bytes 54497-54996/54997" + ), "failed: Content-Range Error" + body4 = await resp.read() resp.close() assert content[-500:] == body4 + # Ensure out-of-range tails could be handled + resp2 = await client.get("/", headers={"Range": "bytes=-99999999999999"}) + assert resp2.status == 206, resp.reason + assert ( + resp2.headers["Content-Range"] == "bytes 0-54996/54997" + ), "failed: Content-Range Error" -@asyncio.coroutine -def test_static_file_invalid_range(loop, test_client, sender): - filepath = (pathlib.Path(__file__).parent / 'aiohttp.png') - @asyncio.coroutine - def handler(request): +async def test_static_file_invalid_range(aiohttp_client, sender) -> None: + filepath = pathlib.Path(__file__).parent / "aiohttp.png" + + async def handler(request): return sender(filepath, chunk_size=16) app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(lambda loop: app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) # range must be in bytes - resp = yield from client.get('/', headers={'Range': 'blocks=0-10'}) - assert resp.status == 416, 'Range must be in bytes' + resp = await client.get("/", headers={"Range": "blocks=0-10"}) + assert resp.status == 416, "Range must be in bytes" resp.close() # start > end - resp = yield from client.get('/', headers={'Range': 'bytes=100-0'}) + resp = await client.get("/", headers={"Range": "bytes=100-0"}) assert resp.status == 416, "Range start can't be greater than end" resp.close() # start > end - resp = yield from client.get('/', headers={'Range': 'bytes=10-9'}) + resp = await client.get("/", headers={"Range": "bytes=10-9"}) assert resp.status == 416, "Range start can't be greater than end" resp.close() # non-number range - resp = yield from client.get('/', headers={'Range': 'bytes=a-f'}) - assert resp.status == 416, 'Range must be integers' + resp = await client.get("/", headers={"Range": "bytes=a-f"}) + assert resp.status == 416, "Range must be integers" resp.close() # double dash range - resp = yield from client.get('/', headers={'Range': 'bytes=0--10'}) - assert resp.status == 416, 'double dash in range' + resp = await client.get("/", headers={"Range": "bytes=0--10"}) + assert resp.status == 416, "double dash in range" resp.close() # no range - resp = yield from client.get('/', headers={'Range': 'bytes=-'}) - assert resp.status == 416, 'no range given' + resp = await client.get("/", headers={"Range": "bytes=-"}) + assert resp.status == 416, "no range given" + resp.close() + + +async def test_static_file_if_unmodified_since_past_with_range(aiohttp_client, sender): + filename = "data.unknown_mime_type" + filepath = pathlib.Path(__file__).parent / filename + + async def handler(request): + return sender(filepath) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" + + resp = await client.get( + "/", headers={"If-Unmodified-Since": lastmod, "Range": "bytes=2-"} + ) + assert 412 == resp.status + resp.close() + + +async def test_static_file_if_unmodified_since_future_with_range( + aiohttp_client, sender +): + filename = "data.unknown_mime_type" + filepath = pathlib.Path(__file__).parent / filename + + async def handler(request): + return sender(filepath) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" + + resp = await client.get( + "/", headers={"If-Unmodified-Since": lastmod, "Range": "bytes=2-"} + ) + assert 206 == resp.status + assert resp.headers["Content-Range"] == "bytes 2-12/13" + assert resp.headers["Content-Length"] == "11" + resp.close() + + +async def test_static_file_if_range_past_with_range(aiohttp_client, sender): + filename = "data.unknown_mime_type" + filepath = pathlib.Path(__file__).parent / filename + + async def handler(request): + return sender(filepath) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" + + resp = await client.get("/", headers={"If-Range": lastmod, "Range": "bytes=2-"}) + assert 200 == resp.status + assert resp.headers["Content-Length"] == "13" + resp.close() + + +async def test_static_file_if_range_future_with_range(aiohttp_client, sender): + filename = "data.unknown_mime_type" + filepath = pathlib.Path(__file__).parent / filename + + async def handler(request): + return sender(filepath) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" + + resp = await client.get("/", headers={"If-Range": lastmod, "Range": "bytes=2-"}) + assert 206 == resp.status + assert resp.headers["Content-Range"] == "bytes 2-12/13" + assert resp.headers["Content-Length"] == "11" + resp.close() + + +async def test_static_file_if_unmodified_since_past_without_range( + aiohttp_client, sender +): + filename = "data.unknown_mime_type" + filepath = pathlib.Path(__file__).parent / filename + + async def handler(request): + return sender(filepath) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" + + resp = await client.get("/", headers={"If-Unmodified-Since": lastmod}) + assert 412 == resp.status + resp.close() + + +async def test_static_file_if_unmodified_since_future_without_range( + aiohttp_client, sender +): + filename = "data.unknown_mime_type" + filepath = pathlib.Path(__file__).parent / filename + + async def handler(request): + return sender(filepath) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" + + resp = await client.get("/", headers={"If-Unmodified-Since": lastmod}) + assert 200 == resp.status + assert resp.headers["Content-Length"] == "13" + resp.close() + + +async def test_static_file_if_range_past_without_range(aiohttp_client, sender): + filename = "data.unknown_mime_type" + filepath = pathlib.Path(__file__).parent / filename + + async def handler(request): + return sender(filepath) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" + + resp = await client.get("/", headers={"If-Range": lastmod}) + assert 200 == resp.status + assert resp.headers["Content-Length"] == "13" + resp.close() + + +async def test_static_file_if_range_future_without_range(aiohttp_client, sender): + filename = "data.unknown_mime_type" + filepath = pathlib.Path(__file__).parent / filename + + async def handler(request): + return sender(filepath) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" + + resp = await client.get("/", headers={"If-Range": lastmod}) + assert 200 == resp.status + assert resp.headers["Content-Length"] == "13" + resp.close() + + +async def test_static_file_if_unmodified_since_invalid_date(aiohttp_client, sender): + filename = "data.unknown_mime_type" + filepath = pathlib.Path(__file__).parent / filename + + async def handler(request): + return sender(filepath) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + lastmod = "not a valid HTTP-date" + + resp = await client.get("/", headers={"If-Unmodified-Since": lastmod}) + assert 200 == resp.status + resp.close() + + +async def test_static_file_if_range_invalid_date(aiohttp_client, sender): + filename = "data.unknown_mime_type" + filepath = pathlib.Path(__file__).parent / filename + + async def handler(request): + return sender(filepath) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + lastmod = "not a valid HTTP-date" + + resp = await client.get("/", headers={"If-Range": lastmod}) + assert 200 == resp.status + resp.close() + + +async def test_static_file_compression(aiohttp_client, sender) -> None: + filepath = pathlib.Path(__file__).parent / "data.unknown_mime_type" + + async def handler(request): + ret = sender(filepath) + ret.enable_compression() + return ret + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app, auto_decompress=False) + + resp = await client.get("/") + assert resp.status == 200 + zcomp = zlib.compressobj(wbits=zlib.MAX_WBITS) + expected_body = zcomp.compress(b"file content\n") + zcomp.flush() + assert expected_body == await resp.read() + assert "application/octet-stream" == resp.headers["Content-Type"] + assert resp.headers.get("Content-Encoding") == "deflate" + await resp.release() + + +async def test_static_file_huge_cancel(aiohttp_client, tmpdir) -> None: + filename = "huge_data.unknown_mime_type" + + # fill 100MB file + with tmpdir.join(filename).open("wb") as f: + for i in range(1024 * 20): + f.write((chr(i % 64 + 0x20) * 1024).encode()) + + task = None + + async def handler(request): + nonlocal task + task = request.task + # reduce send buffer size + tr = request.transport + sock = tr.get_extra_info("socket") + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + ret = web.FileResponse(pathlib.Path(str(tmpdir.join(filename)))) + return ret + + app = web.Application() + + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert resp.status == 200 + task.cancel() + await asyncio.sleep(0) + data = b"" + while True: + try: + data += await resp.content.read(1024) + except aiohttp.ClientPayloadError: + break + assert len(data) < 1024 * 1024 * 20 + + +async def test_static_file_huge_error(aiohttp_client, tmpdir) -> None: + filename = "huge_data.unknown_mime_type" + + # fill 20MB file + with tmpdir.join(filename).open("wb") as f: + f.seek(20 * 1024 * 1024) + f.write(b"1") + + async def handler(request): + # reduce send buffer size + tr = request.transport + sock = tr.get_extra_info("socket") + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + ret = web.FileResponse(pathlib.Path(str(tmpdir.join(filename)))) + return ret + + app = web.Application() + + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + assert resp.status == 200 + # raise an exception on server side resp.close() diff --git a/tests/test_web_server.py b/tests/test_web_server.py index 710d2b05162..b02787b88c2 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -6,110 +6,147 @@ from aiohttp import client, web -@asyncio.coroutine -def test_simple_server(raw_test_server, test_client): - @asyncio.coroutine - def handler(request): +async def test_simple_server(aiohttp_raw_server, aiohttp_client) -> None: + async def handler(request): return web.Response(text=str(request.rel_url)) - server = yield from raw_test_server(handler) - cli = yield from test_client(server) - resp = yield from cli.get('/path/to') + server = await aiohttp_raw_server(handler) + cli = await aiohttp_client(server) + resp = await cli.get("/path/to") assert resp.status == 200 - txt = yield from resp.text() - assert txt == '/path/to' + txt = await resp.text() + assert txt == "/path/to" -@asyncio.coroutine -def test_raw_server_not_http_exception(raw_test_server, test_client): +async def test_raw_server_not_http_exception(aiohttp_raw_server, aiohttp_client): exc = RuntimeError("custom runtime error") - @asyncio.coroutine - def handler(request): + async def handler(request): raise exc logger = mock.Mock() - server = yield from raw_test_server(handler, logger=logger) - cli = yield from test_client(server) - resp = yield from cli.get('/path/to') + server = await aiohttp_raw_server(handler, logger=logger, debug=False) + cli = await aiohttp_client(server) + resp = await cli.get("/path/to") assert resp.status == 500 + assert resp.headers["Content-Type"].startswith("text/plain") - txt = yield from resp.text() - assert "

    500 Internal Server Error

    " in txt + txt = await resp.text() + assert txt.startswith("500 Internal Server Error") + assert "Traceback" not in txt - logger.exception.assert_called_with( - "Error handling request", - exc_info=exc) + logger.exception.assert_called_with("Error handling request", exc_info=exc) -@asyncio.coroutine -def test_raw_server_handler_timeout(raw_test_server, test_client): +async def test_raw_server_handler_timeout(aiohttp_raw_server, aiohttp_client) -> None: exc = asyncio.TimeoutError("error") - @asyncio.coroutine - def handler(request): + async def handler(request): raise exc logger = mock.Mock() - server = yield from raw_test_server(handler, logger=logger) - cli = yield from test_client(server) - resp = yield from cli.get('/path/to') + server = await aiohttp_raw_server(handler, logger=logger) + cli = await aiohttp_client(server) + resp = await cli.get("/path/to") assert resp.status == 504 - yield from resp.text() - logger.debug.assert_called_with("Request handler timed out.") + await resp.text() + logger.debug.assert_called_with("Request handler timed out.", exc_info=exc) -@asyncio.coroutine -def test_raw_server_do_not_swallow_exceptions(raw_test_server, test_client): - exc = None +async def test_raw_server_do_not_swallow_exceptions(aiohttp_raw_server, aiohttp_client): + async def handler(request): + raise asyncio.CancelledError() - @asyncio.coroutine - def handler(request): + logger = mock.Mock() + server = await aiohttp_raw_server(handler, logger=logger) + cli = await aiohttp_client(server) + + with pytest.raises(client.ServerDisconnectedError): + await cli.get("/path/to") + + logger.debug.assert_called_with("Ignored premature client disconnection") + + +async def test_raw_server_cancelled_in_write_eof(aiohttp_raw_server, aiohttp_client): + async def handler(request): + resp = web.Response(text=str(request.rel_url)) + resp.write_eof = mock.Mock(side_effect=asyncio.CancelledError("error")) + return resp + + logger = mock.Mock() + server = await aiohttp_raw_server(handler, logger=logger) + cli = await aiohttp_client(server) + + resp = await cli.get("/path/to") + with pytest.raises(client.ClientPayloadError): + await resp.read() + + logger.debug.assert_called_with("Ignored premature client disconnection") + + +async def test_raw_server_not_http_exception_debug(aiohttp_raw_server, aiohttp_client): + exc = RuntimeError("custom runtime error") + + async def handler(request): raise exc logger = mock.Mock() - server = yield from raw_test_server(handler, logger=logger) - cli = yield from test_client(server) + server = await aiohttp_raw_server(handler, logger=logger, debug=True) + cli = await aiohttp_client(server) + resp = await cli.get("/path/to") + assert resp.status == 500 + assert resp.headers["Content-Type"].startswith("text/plain") - for _exc, msg in ( - (asyncio.CancelledError("error"), - 'Ignored premature client disconnection'),): - exc = _exc - with pytest.raises(client.ServerDisconnectedError): - yield from cli.get('/path/to') + txt = await resp.text() + assert "Traceback (most recent call last):\n" in txt - logger.debug.assert_called_with(msg) + logger.exception.assert_called_with("Error handling request", exc_info=exc) -@asyncio.coroutine -def test_raw_server_not_http_exception_debug(raw_test_server, test_client): +async def test_raw_server_html_exception(aiohttp_raw_server, aiohttp_client): exc = RuntimeError("custom runtime error") - @asyncio.coroutine - def handler(request): + async def handler(request): raise exc logger = mock.Mock() - server = yield from raw_test_server(handler, logger=logger, debug=True) - cli = yield from test_client(server) - resp = yield from cli.get('/path/to') + server = await aiohttp_raw_server(handler, logger=logger, debug=False) + cli = await aiohttp_client(server) + resp = await cli.get("/path/to", headers={"Accept": "text/html"}) assert resp.status == 500 + assert resp.headers["Content-Type"].startswith("text/html") + + txt = await resp.text() + assert txt == ( + "500 Internal Server Error \n" + "

    500 Internal Server Error

    \n" + "Server got itself in trouble\n" + "\n" + ) - txt = yield from resp.text() - assert "

    Traceback:

    " in txt + logger.exception.assert_called_with("Error handling request", exc_info=exc) - logger.exception.assert_called_with( - "Error handling request", - exc_info=exc) +async def test_raw_server_html_exception_debug(aiohttp_raw_server, aiohttp_client): + exc = RuntimeError("custom runtime error") + + async def handler(request): + raise exc -def test_create_web_server_with_implicit_loop(loop): - asyncio.set_event_loop(loop) + logger = mock.Mock() + server = await aiohttp_raw_server(handler, logger=logger, debug=True) + cli = await aiohttp_client(server) + resp = await cli.get("/path/to", headers={"Accept": "text/html"}) + assert resp.status == 500 + assert resp.headers["Content-Type"].startswith("text/html") - @asyncio.coroutine - def handler(request): - return web.Response() # pragma: no cover + txt = await resp.text() + assert txt.startswith( + "500 Internal Server Error \n" + "

    500 Internal Server Error

    \n" + "

    Traceback:

    \n" + "
    Traceback (most recent call last):\n"
    +    )
     
    -    srv = web.Server(handler)
    -    assert srv._loop is loop
    +    logger.exception.assert_called_with("Error handling request", exc_info=exc)
    diff --git a/tests/test_web_urldispatcher.py b/tests/test_web_urldispatcher.py
    index f41a7d74575..0ba2e7c2034 100644
    --- a/tests/test_web_urldispatcher.py
    +++ b/tests/test_web_urldispatcher.py
    @@ -1,7 +1,9 @@
     import asyncio
     import functools
     import os
    +import pathlib
     import shutil
    +import sys
     import tempfile
     from unittest import mock
     from unittest.mock import MagicMock
    @@ -12,7 +14,7 @@
     from aiohttp.web_urldispatcher import SystemRoute
     
     
    -@pytest.fixture(scope='function')
    +@pytest.fixture(scope="function")
     def tmp_dir_path(request):
         """
         Give a path for a temporary directory
    @@ -29,89 +31,106 @@ def teardown():
         return tmp_dir
     
     
    -@pytest.mark.parametrize("show_index,status,data",
    -                         [(False, 403, None),
    -                          (True, 200,
    -                           b'\n\nIndex of /\n'
    -                           b'        
             
     \n\n

    Index of /

    \n
    \n\n')]) -@asyncio.coroutine -def test_access_root_of_static_handler(tmp_dir_path, loop, test_client, - show_index, status, data): - """ - Tests the operation of static file server. - Try to access the root of static file server, and make - sure that correct HTTP statuses are returned depending if we directory - index should be shown or not. - """ +@pytest.mark.parametrize( + "show_index,status,prefix,data", + [ + pytest.param(False, 403, "/", None, id="index_forbidden"), + pytest.param( + True, + 200, + "/", + b"\n\nIndex of /.\n" + b" \n\n

    Index of /.

    \n\n\n", + id="index_root", + ), + pytest.param( + True, + 200, + "/static", + b"\n\nIndex of /.\n" + b" \n\n

    Index of /.

    \n\n\n", + id="index_static", + ), + ], +) +async def test_access_root_of_static_handler( + tmp_dir_path, aiohttp_client, show_index, status, prefix, data +) -> None: + # Tests the operation of static file server. + # Try to access the root of static file server, and make + # sure that correct HTTP statuses are returned depending if we directory + # index should be shown or not. # Put a file inside tmp_dir_path: - my_file_path = os.path.join(tmp_dir_path, 'my_file') - with open(my_file_path, 'w') as fw: - fw.write('hello') + my_file_path = os.path.join(tmp_dir_path, "my_file") + with open(my_file_path, "w") as fw: + fw.write("hello") - my_dir_path = os.path.join(tmp_dir_path, 'my_dir') + my_dir_path = os.path.join(tmp_dir_path, "my_dir") os.mkdir(my_dir_path) - my_file_path = os.path.join(my_dir_path, 'my_file_in_dir') - with open(my_file_path, 'w') as fw: - fw.write('world') + my_file_path = os.path.join(my_dir_path, "my_file_in_dir") + with open(my_file_path, "w") as fw: + fw.write("world") app = web.Application() # Register global static route: - app.router.add_static('/', tmp_dir_path, show_index=show_index) - client = yield from test_client(app) + app.router.add_static(prefix, tmp_dir_path, show_index=show_index) + client = await aiohttp_client(app) # Request the root of the static directory. - r = yield from client.get('/') + r = await client.get(prefix) assert r.status == status if data: - assert r.headers['Content-Type'] == "text/html; charset=utf-8" - read_ = (yield from r.read()) + assert r.headers["Content-Type"] == "text/html; charset=utf-8" + read_ = await r.read() assert read_ == data -@pytest.mark.parametrize('data', ['hello world']) -@asyncio.coroutine -def test_follow_symlink(tmp_dir_path, loop, test_client, data): - """ - Tests the access to a symlink, in static folder - """ - my_dir_path = os.path.join(tmp_dir_path, 'my_dir') +async def test_follow_symlink(tmp_dir_path, aiohttp_client) -> None: + # Tests the access to a symlink, in static folder + data = "hello world" + + my_dir_path = os.path.join(tmp_dir_path, "my_dir") os.mkdir(my_dir_path) - my_file_path = os.path.join(my_dir_path, 'my_file_in_dir') - with open(my_file_path, 'w') as fw: + my_file_path = os.path.join(my_dir_path, "my_file_in_dir") + with open(my_file_path, "w") as fw: fw.write(data) - my_symlink_path = os.path.join(tmp_dir_path, 'my_symlink') + my_symlink_path = os.path.join(tmp_dir_path, "my_symlink") os.symlink(my_dir_path, my_symlink_path) app = web.Application() # Register global static route: - app.router.add_static('/', tmp_dir_path, follow_symlinks=True) - client = yield from test_client(app) + app.router.add_static("/", tmp_dir_path, follow_symlinks=True) + client = await aiohttp_client(app) # Request the root of the static directory. - r = yield from client.get('/my_symlink/my_file_in_dir') + r = await client.get("/my_symlink/my_file_in_dir") assert r.status == 200 - assert (yield from r.text()) == data + assert (await r.text()) == data -@pytest.mark.parametrize('dir_name,filename,data', [ - ('', 'test file.txt', 'test text'), - ('test dir name', 'test dir file .txt', 'test text file folder') -]) -@asyncio.coroutine -def test_access_to_the_file_with_spaces(tmp_dir_path, loop, test_client, - dir_name, filename, data): - """ - Checks operation of static files with spaces - """ +@pytest.mark.parametrize( + "dir_name,filename,data", + [ + ("", "test file.txt", "test text"), + ("test dir name", "test dir file .txt", "test text file folder"), + ], +) +async def test_access_to_the_file_with_spaces( + tmp_dir_path, aiohttp_client, dir_name, filename, data +): + # Checks operation of static files with spaces my_dir_path = os.path.join(tmp_dir_path, dir_name) @@ -120,73 +139,92 @@ def test_access_to_the_file_with_spaces(tmp_dir_path, loop, test_client, my_file_path = os.path.join(my_dir_path, filename) - with open(my_file_path, 'w') as fw: + with open(my_file_path, "w") as fw: fw.write(data) app = web.Application() - url = os.path.join('/', dir_name, filename) + url = os.path.join("/", dir_name, filename) - app.router.add_static('/', tmp_dir_path) - client = yield from test_client(app) + app.router.add_static("/", tmp_dir_path) + client = await aiohttp_client(app) - r = yield from client.get(url) + r = await client.get(url) assert r.status == 200 - assert (yield from r.text()) == data + assert (await r.text()) == data -@asyncio.coroutine -def test_access_non_existing_resource(tmp_dir_path, loop, test_client): - """ - Tests accessing non-existing resource - Try to access a non-exiting resource and make sure that 404 HTTP status - returned. - """ +async def test_access_non_existing_resource(tmp_dir_path, aiohttp_client) -> None: + # Tests accessing non-existing resource + # Try to access a non-exiting resource and make sure that 404 HTTP status + # returned. app = web.Application() # Register global static route: - app.router.add_static('/', tmp_dir_path, show_index=True) - client = yield from test_client(app) + app.router.add_static("/", tmp_dir_path, show_index=True) + client = await aiohttp_client(app) # Request the root of the static directory. - r = yield from client.get('/non_existing_resource') + r = await client.get("/non_existing_resource") assert r.status == 404 -@pytest.mark.parametrize('registered_path,request_url', [ - ('/a:b', '/a:b'), - ('/a@b', '/a@b'), - ('/a:b', '/a%3Ab'), -]) -@asyncio.coroutine -def test_url_escaping(loop, test_client, registered_path, request_url): - """ - Tests accessing a resource with - """ +@pytest.mark.parametrize( + "registered_path,request_url", + [ + ("/a:b", "/a:b"), + ("/a@b", "/a@b"), + ("/a:b", "/a%3Ab"), + ], +) +async def test_url_escaping(aiohttp_client, registered_path, request_url) -> None: + # Tests accessing a resource with app = web.Application() - def handler(_): + async def handler(request): return web.Response() + app.router.add_get(registered_path, handler) - client = yield from test_client(app) + client = await aiohttp_client(app) - r = yield from client.get(request_url) + r = await client.get(request_url) assert r.status == 200 -@asyncio.coroutine -def test_unauthorized_folder_access(tmp_dir_path, loop, test_client): +async def test_handler_metadata_persistence() -> None: """ - Tests the unauthorized access to a folder of static file server. - Try to list a folder content of static file server when server does not - have permissions to do so for the folder. + Tests accessing metadata of a handler after registering it on the app + router. """ - my_dir_path = os.path.join(tmp_dir_path, 'my_dir') + app = web.Application() + + async def async_handler(request): + """Doc""" + return web.Response() + + def sync_handler(request): + """Doc""" + return web.Response() + + app.router.add_get("/async", async_handler) + with pytest.warns(DeprecationWarning): + app.router.add_get("/sync", sync_handler) + + for resource in app.router.resources(): + for route in resource: + assert route.handler.__doc__ == "Doc" + + +async def test_unauthorized_folder_access(tmp_dir_path, aiohttp_client) -> None: + # Tests the unauthorized access to a folder of static file server. + # Try to list a folder content of static file server when server does not + # have permissions to do so for the folder. + my_dir_path = os.path.join(tmp_dir_path, "my_dir") os.mkdir(my_dir_path) app = web.Application() - with mock.patch('pathlib.Path.__new__') as path_constructor: + with mock.patch("pathlib.Path.__new__") as path_constructor: path = MagicMock() path.joinpath.return_value = path path.resolve.return_value = path @@ -194,136 +232,286 @@ def test_unauthorized_folder_access(tmp_dir_path, loop, test_client): path_constructor.return_value = path # Register global static route: - app.router.add_static('/', tmp_dir_path, show_index=True) - client = yield from test_client(app) + app.router.add_static("/", tmp_dir_path, show_index=True) + client = await aiohttp_client(app) # Request the root of the static directory. - r = yield from client.get('/my_dir') + r = await client.get("/my_dir") assert r.status == 403 -@asyncio.coroutine -def test_access_symlink_loop(tmp_dir_path, loop, test_client): - """ - Tests the access to a looped symlink, which could not be resolved. - """ - my_dir_path = os.path.join(tmp_dir_path, 'my_symlink') +async def test_access_symlink_loop(tmp_dir_path, aiohttp_client) -> None: + # Tests the access to a looped symlink, which could not be resolved. + my_dir_path = os.path.join(tmp_dir_path, "my_symlink") os.symlink(my_dir_path, my_dir_path) app = web.Application() # Register global static route: - app.router.add_static('/', tmp_dir_path, show_index=True) - client = yield from test_client(app) + app.router.add_static("/", tmp_dir_path, show_index=True) + client = await aiohttp_client(app) # Request the root of the static directory. - r = yield from client.get('/my_symlink') + r = await client.get("/my_symlink") assert r.status == 404 -@asyncio.coroutine -def test_access_special_resource(tmp_dir_path, loop, test_client): - """ - Tests the access to a resource that is neither a file nor a directory. - Checks that if a special resource is accessed (f.e. named pipe or UNIX - domain socket) then 404 HTTP status returned. - """ +async def test_access_special_resource(tmp_dir_path, aiohttp_client) -> None: + # Tests the access to a resource that is neither a file nor a directory. + # Checks that if a special resource is accessed (f.e. named pipe or UNIX + # domain socket) then 404 HTTP status returned. app = web.Application() - with mock.patch('pathlib.Path.__new__') as path_constructor: + with mock.patch("pathlib.Path.__new__") as path_constructor: special = MagicMock() special.is_dir.return_value = False special.is_file.return_value = False path = MagicMock() - path.joinpath.side_effect = lambda p: (special if p == 'special' - else path) + path.joinpath.side_effect = lambda p: (special if p == "special" else path) path.resolve.return_value = path special.resolve.return_value = special path_constructor.return_value = path # Register global static route: - app.router.add_static('/', tmp_dir_path, show_index=True) - client = yield from test_client(app) + app.router.add_static("/", tmp_dir_path, show_index=True) + client = await aiohttp_client(app) # Request the root of the static directory. - r = yield from client.get('/special') - assert r.status == 404 + r = await client.get("/special") + assert r.status == 403 -@asyncio.coroutine -def test_partialy_applied_handler(loop, test_client): +async def test_partially_applied_handler(aiohttp_client) -> None: app = web.Application() - @asyncio.coroutine - def handler(data, request): + async def handler(data, request): return web.Response(body=data) - app.router.add_route('GET', '/', functools.partial(handler, b'hello')) - client = yield from test_client(app) + if sys.version_info >= (3, 8): + app.router.add_route("GET", "/", functools.partial(handler, b"hello")) + else: + with pytest.warns(DeprecationWarning): + app.router.add_route("GET", "/", functools.partial(handler, b"hello")) - r = yield from client.get('/') - data = (yield from r.read()) - assert data == b'hello' + client = await aiohttp_client(app) + r = await client.get("/") + data = await r.read() + assert data == b"hello" -def test_system_route(): - route = SystemRoute(web.HTTPCreated(reason='test')) - with pytest.raises(RuntimeError): - route.url() + +async def test_static_head(tmp_path, aiohttp_client) -> None: + # Test HEAD on static route + my_file_path = tmp_path / "test.txt" + with my_file_path.open("wb") as fw: + fw.write(b"should_not_see_this\n") + + app = web.Application() + app.router.add_static("/", str(tmp_path)) + client = await aiohttp_client(app) + + r = await client.head("/test.txt") + assert r.status == 200 + + # Check that there is no content sent (see #4809). This can't easily be + # done with aiohttp_client because the buffering can consume the content. + reader, writer = await asyncio.open_connection(client.host, client.port) + writer.write(b"HEAD /test.txt HTTP/1.1\r\n") + writer.write(b"Host: localhost\r\n") + writer.write(b"Connection: close\r\n") + writer.write(b"\r\n") + while await reader.readline() != b"\r\n": + pass + content = await reader.read() + writer.close() + assert content == b"" + + +def test_system_route() -> None: + route = SystemRoute(web.HTTPCreated(reason="test")) with pytest.raises(RuntimeError): route.url_for() assert route.name is None assert route.resource is None assert "" == repr(route) assert 201 == route.status - assert 'test' == route.reason + assert "test" == route.reason -@asyncio.coroutine -def test_412_is_returned(loop, test_client): - +async def test_412_is_returned(aiohttp_client) -> None: class MyRouter(abc.AbstractRouter): - - @asyncio.coroutine - def resolve(self, request): + async def resolve(self, request): raise web.HTTPPreconditionFailed() - app = web.Application(router=MyRouter(), loop=loop) + with pytest.warns(DeprecationWarning): + app = web.Application(router=MyRouter()) - client = yield from test_client(app) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert resp.status == 412 -@asyncio.coroutine -def test_allow_head(loop, test_client): - """ - Test allow_head on routes. - """ +async def test_allow_head(aiohttp_client) -> None: + # Test allow_head on routes. app = web.Application() - def handler(_): + async def handler(_): return web.Response() - app.router.add_get('/a', handler, name='a') - app.router.add_get('/b', handler, allow_head=False, name='b') - client = yield from test_client(app) - r = yield from client.get('/a') + app.router.add_get("/a", handler, name="a") + app.router.add_get("/b", handler, allow_head=False, name="b") + client = await aiohttp_client(app) + + r = await client.get("/a") assert r.status == 200 - yield from r.release() + await r.release() - r = yield from client.head('/a') + r = await client.head("/a") assert r.status == 200 - yield from r.release() + await r.release() - r = yield from client.get('/b') + r = await client.get("/b") assert r.status == 200 - yield from r.release() + await r.release() - r = yield from client.head('/b') + r = await client.head("/b") assert r.status == 405 - yield from r.release() + await r.release() + + +@pytest.mark.parametrize( + "path", + [ + "/a", + "/{a}", + ], +) +def test_reuse_last_added_resource(path) -> None: + # Test that adding a route with the same name and path of the last added + # resource doesn't create a new resource. + app = web.Application() + + async def handler(request): + return web.Response() + + app.router.add_get(path, handler, name="a") + app.router.add_post(path, handler, name="a") + + assert len(app.router.resources()) == 1 + + +def test_resource_raw_match() -> None: + app = web.Application() + + async def handler(request): + return web.Response() + + route = app.router.add_get("/a", handler, name="a") + assert route.resource.raw_match("/a") + + route = app.router.add_get("/{b}", handler, name="b") + assert route.resource.raw_match("/{b}") + + resource = app.router.add_static("/static", ".") + assert not resource.raw_match("/static") + + +async def test_add_view(aiohttp_client) -> None: + app = web.Application() + + class MyView(web.View): + async def get(self): + return web.Response() + + async def post(self): + return web.Response() + + app.router.add_view("/a", MyView) + + client = await aiohttp_client(app) + + r = await client.get("/a") + assert r.status == 200 + await r.release() + + r = await client.post("/a") + assert r.status == 200 + await r.release() + + r = await client.put("/a") + assert r.status == 405 + await r.release() + + +async def test_decorate_view(aiohttp_client) -> None: + routes = web.RouteTableDef() + + @routes.view("/a") + class MyView(web.View): + async def get(self): + return web.Response() + + async def post(self): + return web.Response() + + app = web.Application() + app.router.add_routes(routes) + + client = await aiohttp_client(app) + + r = await client.get("/a") + assert r.status == 200 + await r.release() + + r = await client.post("/a") + assert r.status == 200 + await r.release() + + r = await client.put("/a") + assert r.status == 405 + await r.release() + + +async def test_web_view(aiohttp_client) -> None: + app = web.Application() + + class MyView(web.View): + async def get(self): + return web.Response() + + async def post(self): + return web.Response() + + app.router.add_routes([web.view("/a", MyView)]) + + client = await aiohttp_client(app) + + r = await client.get("/a") + assert r.status == 200 + await r.release() + + r = await client.post("/a") + assert r.status == 200 + await r.release() + + r = await client.put("/a") + assert r.status == 405 + await r.release() + + +async def test_static_absolute_url(aiohttp_client, tmpdir) -> None: + # requested url is an absolute name like + # /static/\\machine_name\c$ or /static/D:\path + # where the static dir is totally different + app = web.Application() + fname = tmpdir / "file.txt" + fname.write_text("sample text", "ascii") + here = pathlib.Path(__file__).parent + app.router.add_static("/static", here) + client = await aiohttp_client(app) + resp = await client.get("/static/" + str(fname)) + assert resp.status == 403 diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 0f290588463..0a79113537e 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -4,10 +4,10 @@ import pytest from multidict import CIMultiDict -from aiohttp import WSMessage, WSMsgType, helpers, signals -from aiohttp.log import ws_logger +from aiohttp import WSMessage, WSMsgType, signals +from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro, make_mocked_request -from aiohttp.web import HTTPBadRequest, HTTPMethodNotAllowed, WebSocketResponse +from aiohttp.web import HTTPBadRequest, WebSocketResponse from aiohttp.web_ws import WS_CLOSED_MESSAGE, WebSocketReady @@ -17,17 +17,10 @@ def app(loop): ret.loop = loop ret._debug = False ret.on_response_prepare = signals.Signal(ret) + ret.on_response_prepare.freeze() return ret -@pytest.fixture -def writer(): - writer = mock.Mock() - writer.drain.return_value = () - writer.write_eof.return_value = () - return writer - - @pytest.fixture def protocol(): ret = mock.Mock() @@ -36,474 +29,412 @@ def protocol(): @pytest.fixture -def make_request(app, protocol, writer): +def make_request(app, protocol): def maker(method, path, headers=None, protocols=False): if headers is None: headers = CIMultiDict( - {'HOST': 'server.example.com', - 'UPGRADE': 'websocket', - 'CONNECTION': 'Upgrade', - 'SEC-WEBSOCKET-KEY': 'dGhlIHNhbXBsZSBub25jZQ==', - 'ORIGIN': 'http://example.com', - 'SEC-WEBSOCKET-VERSION': '13'}) + { + "HOST": "server.example.com", + "UPGRADE": "websocket", + "CONNECTION": "Upgrade", + "SEC-WEBSOCKET-KEY": "dGhlIHNhbXBsZSBub25jZQ==", + "ORIGIN": "http://example.com", + "SEC-WEBSOCKET-VERSION": "13", + } + ) if protocols: - headers['SEC-WEBSOCKET-PROTOCOL'] = 'chat, superchat' + headers["SEC-WEBSOCKET-PROTOCOL"] = "chat, superchat" return make_mocked_request( - method, path, headers, - app=app, protocol=protocol, payload_writer=writer) + method, path, headers, app=app, protocol=protocol, loop=app.loop + ) return maker -def test_nonstarted_ping(): +async def test_nonstarted_ping() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - ws.ping() + await ws.ping() -def test_nonstarted_pong(): +async def test_nonstarted_pong() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - ws.pong() + await ws.pong() -def test_nonstarted_send_str(): +async def test_nonstarted_send_str() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - ws.send_str('string') + await ws.send_str("string") -def test_nonstarted_send_bytes(): +async def test_nonstarted_send_bytes() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - ws.send_bytes(b'bytes') + await ws.send_bytes(b"bytes") -def test_nonstarted_send_json(): +async def test_nonstarted_send_json() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - ws.send_json({'type': 'json'}) + await ws.send_json({"type": "json"}) -@asyncio.coroutine -def test_nonstarted_close(): +async def test_nonstarted_close() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - yield from ws.close() + await ws.close() -@asyncio.coroutine -def test_nonstarted_receive_str(): - +async def test_nonstarted_receive_str() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - yield from ws.receive_str() - + await ws.receive_str() -@asyncio.coroutine -def test_nonstarted_receive_bytes(): +async def test_nonstarted_receive_bytes() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - yield from ws.receive_bytes() + await ws.receive_bytes() -@asyncio.coroutine -def test_nonstarted_receive_json(): +async def test_nonstarted_receive_json() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - yield from ws.receive_json() - + await ws.receive_json() -@asyncio.coroutine -def test_receive_str_nonstring(make_request): - req = make_request('GET', '/') +async def test_receive_str_nonstring(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) - @asyncio.coroutine - def receive(): - return WSMessage(WSMsgType.BINARY, b'data', b'') + async def receive(): + return WSMessage(WSMsgType.BINARY, b"data", b"") ws.receive = receive with pytest.raises(TypeError): - yield from ws.receive_str() + await ws.receive_str() -@asyncio.coroutine -def test_receive_bytes_nonsbytes(make_request): - req = make_request('GET', '/') +async def test_receive_bytes_nonsbytes(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) - @asyncio.coroutine - def receive(): - return WSMessage(WSMsgType.TEXT, 'data', b'') + async def receive(): + return WSMessage(WSMsgType.TEXT, "data", b"") ws.receive = receive with pytest.raises(TypeError): - yield from ws.receive_bytes() + await ws.receive_bytes() -@asyncio.coroutine -def test_send_str_nonstring(make_request): - req = make_request('GET', '/') +async def test_send_str_nonstring(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) with pytest.raises(TypeError): - ws.send_str(b'bytes') + await ws.send_str(b"bytes") -@asyncio.coroutine -def test_send_bytes_nonbytes(make_request): - req = make_request('GET', '/') +async def test_send_bytes_nonbytes(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) with pytest.raises(TypeError): - ws.send_bytes('string') + await ws.send_bytes("string") -@asyncio.coroutine -def test_send_json_nonjson(make_request): - req = make_request('GET', '/') +async def test_send_json_nonjson(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) with pytest.raises(TypeError): - ws.send_json(set()) + await ws.send_json(set()) -def test_write_non_prepared(): +async def test_write_non_prepared() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - ws.write(b'data') + await ws.write(b"data") -def test_websocket_ready(): - websocket_ready = WebSocketReady(True, 'chat') +def test_websocket_ready() -> None: + websocket_ready = WebSocketReady(True, "chat") assert websocket_ready.ok is True - assert websocket_ready.protocol == 'chat' + assert websocket_ready.protocol == "chat" -def test_websocket_not_ready(): +def test_websocket_not_ready() -> None: websocket_ready = WebSocketReady(False, None) assert websocket_ready.ok is False assert websocket_ready.protocol is None -def test_websocket_ready_unknown_protocol(): +def test_websocket_ready_unknown_protocol() -> None: websocket_ready = WebSocketReady(True, None) assert websocket_ready.ok is True assert websocket_ready.protocol is None -def test_bool_websocket_ready(): +def test_bool_websocket_ready() -> None: websocket_ready = WebSocketReady(True, None) assert bool(websocket_ready) is True -def test_bool_websocket_not_ready(): +def test_bool_websocket_not_ready() -> None: websocket_ready = WebSocketReady(False, None) assert bool(websocket_ready) is False -def test_can_prepare_ok(make_request): - req = make_request('GET', '/', protocols=True) - ws = WebSocketResponse(protocols=('chat',)) - assert(True, 'chat') == ws.can_prepare(req) +def test_can_prepare_ok(make_request) -> None: + req = make_request("GET", "/", protocols=True) + ws = WebSocketResponse(protocols=("chat",)) + assert WebSocketReady(True, "chat") == ws.can_prepare(req) -def test_can_prepare_unknown_protocol(make_request): - req = make_request('GET', '/') +def test_can_prepare_unknown_protocol(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - assert (True, None) == ws.can_prepare(req) + assert WebSocketReady(True, None) == ws.can_prepare(req) -def test_can_prepare_invalid_method(make_request): - req = make_request('POST', '/') +def test_can_prepare_without_upgrade(make_request) -> None: + req = make_request("GET", "/", headers=CIMultiDict({})) ws = WebSocketResponse() - assert (False, None) == ws.can_prepare(req) + assert WebSocketReady(False, None) == ws.can_prepare(req) -def test_can_prepare_without_upgrade(make_request): - req = make_request('GET', '/', - headers=CIMultiDict({})) +async def test_can_prepare_started(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - assert (False, None) == ws.can_prepare(req) - - -@asyncio.coroutine -def test_can_prepare_started(make_request): - req = make_request('GET', '/') - ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) with pytest.raises(RuntimeError) as ctx: ws.can_prepare(req) - assert 'Already started' in str(ctx.value) + assert "Already started" in str(ctx.value) -def test_closed_after_ctor(): +def test_closed_after_ctor() -> None: ws = WebSocketResponse() assert not ws.closed assert ws.close_code is None -@asyncio.coroutine -def test_send_str_closed(make_request, mocker): - req = make_request('GET', '/') +async def test_send_str_closed(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) - yield from ws.close() + await ws.close() - mocker.spy(ws_logger, 'warning') - ws.send_str('string') - assert ws_logger.warning.called + with pytest.raises(ConnectionError): + await ws.send_str("string") -@asyncio.coroutine -def test_send_bytes_closed(make_request, mocker): - req = make_request('GET', '/') +async def test_send_bytes_closed(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) - yield from ws.close() + await ws.close() - mocker.spy(ws_logger, 'warning') - ws.send_bytes(b'bytes') - assert ws_logger.warning.called + with pytest.raises(ConnectionError): + await ws.send_bytes(b"bytes") -@asyncio.coroutine -def test_send_json_closed(make_request, mocker): - req = make_request('GET', '/') +async def test_send_json_closed(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) - yield from ws.close() + await ws.close() - mocker.spy(ws_logger, 'warning') - ws.send_json({'type': 'json'}) - assert ws_logger.warning.called + with pytest.raises(ConnectionError): + await ws.send_json({"type": "json"}) -@asyncio.coroutine -def test_ping_closed(make_request, mocker): - req = make_request('GET', '/') +async def test_ping_closed(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) - yield from ws.close() + await ws.close() - mocker.spy(ws_logger, 'warning') - ws.ping() - assert ws_logger.warning.called + with pytest.raises(ConnectionError): + await ws.ping() -@asyncio.coroutine -def test_pong_closed(make_request, mocker): - req = make_request('GET', '/') +async def test_pong_closed(make_request, mocker) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) - yield from ws.close() + await ws.close() - mocker.spy(ws_logger, 'warning') - ws.pong() - assert ws_logger.warning.called + with pytest.raises(ConnectionError): + await ws.pong() -@asyncio.coroutine -def test_close_idempotent(make_request, writer): - req = make_request('GET', '/') +async def test_close_idempotent(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) - assert (yield from ws.close(code=1, message='message1')) + assert await ws.close(code=1, message="message1") assert ws.closed - assert not (yield from ws.close(code=2, message='message2')) + assert not (await ws.close(code=2, message="message2")) -@asyncio.coroutine -def test_prepare_invalid_method(make_request): - req = make_request('POST', '/') +async def test_prepare_post_method_ok(make_request) -> None: + req = make_request("POST", "/") ws = WebSocketResponse() - with pytest.raises(HTTPMethodNotAllowed): - yield from ws.prepare(req) + await ws.prepare(req) + assert ws.prepared -@asyncio.coroutine -def test_prepare_without_upgrade(make_request): - req = make_request('GET', '/', - headers=CIMultiDict({})) +async def test_prepare_without_upgrade(make_request) -> None: + req = make_request("GET", "/", headers=CIMultiDict({})) ws = WebSocketResponse() with pytest.raises(HTTPBadRequest): - yield from ws.prepare(req) + await ws.prepare(req) -@asyncio.coroutine -def test_wait_closed_before_start(): +async def test_wait_closed_before_start() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - yield from ws.close() + await ws.close() -@asyncio.coroutine -def test_write_eof_not_started(): +async def test_write_eof_not_started() -> None: ws = WebSocketResponse() with pytest.raises(RuntimeError): - yield from ws.write_eof() + await ws.write_eof() -@asyncio.coroutine -def test_write_eof_idempotent(make_request): - req = make_request('GET', '/') +async def test_write_eof_idempotent(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) - yield from ws.close() + await ws.close() - yield from ws.write_eof() - yield from ws.write_eof() - yield from ws.write_eof() + await ws.write_eof() + await ws.write_eof() + await ws.write_eof() -@asyncio.coroutine -def test_receive_exc_in_reader(make_request, loop): - req = make_request('GET', '/') +async def test_receive_eofstream_in_reader(make_request, loop) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) ws._reader = mock.Mock() - exc = ValueError() - res = helpers.create_future(loop) + exc = EofStream() + res = loop.create_future() res.set_exception(exc) ws._reader.read = make_mocked_coro(res) ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = helpers.create_future(loop) + ws._payload_writer.drain.return_value = loop.create_future() ws._payload_writer.drain.return_value.set_result(True) - msg = yield from ws.receive() - assert msg.type == WSMsgType.ERROR - assert msg.type is msg.tp - assert msg.data is exc - assert ws.exception() is exc - - -@asyncio.coroutine -def test_receive_cancelled(make_request, loop): - req = make_request('GET', '/') - ws = WebSocketResponse() - yield from ws.prepare(req) - - ws._reader = mock.Mock() - res = helpers.create_future(loop) - res.set_exception(asyncio.CancelledError()) - ws._reader.read = make_mocked_coro(res) - - with pytest.raises(asyncio.CancelledError): - yield from ws.receive() + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSED + assert ws.closed -@asyncio.coroutine -def test_receive_timeouterror(make_request, loop): - req = make_request('GET', '/') +async def test_receive_timeouterror(make_request, loop) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) ws._reader = mock.Mock() - res = helpers.create_future(loop) + res = loop.create_future() res.set_exception(asyncio.TimeoutError()) ws._reader.read = make_mocked_coro(res) with pytest.raises(asyncio.TimeoutError): - yield from ws.receive() + await ws.receive() -@asyncio.coroutine -def test_multiple_receive_on_close_connection(make_request): - req = make_request('GET', '/') +async def test_multiple_receive_on_close_connection(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) - yield from ws.close() + await ws.close() - yield from ws.receive() - yield from ws.receive() - yield from ws.receive() - yield from ws.receive() + await ws.receive() + await ws.receive() + await ws.receive() + await ws.receive() with pytest.raises(RuntimeError): - yield from ws.receive() + await ws.receive() -@asyncio.coroutine -def test_concurrent_receive(make_request): - req = make_request('GET', '/') +async def test_concurrent_receive(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) ws._waiting = True with pytest.raises(RuntimeError): - yield from ws.receive() + await ws.receive() -@asyncio.coroutine -def test_close_exc(make_request, loop, mocker): - req = make_request('GET', '/') +async def test_close_exc(make_request) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) - ws._reader = mock.Mock() exc = ValueError() - ws._reader.read.return_value = helpers.create_future(loop) - ws._reader.read.return_value.set_exception(exc) - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = helpers.create_future(loop) - ws._payload_writer.drain.return_value.set_result(True) - - yield from ws.close() + ws._writer = mock.Mock() + ws._writer.close.side_effect = exc + await ws.close() assert ws.closed assert ws.exception() is exc ws._closed = False - ws._reader.read.return_value = helpers.create_future(loop) - ws._reader.read.return_value.set_exception(asyncio.CancelledError()) + ws._writer.close.side_effect = asyncio.CancelledError() with pytest.raises(asyncio.CancelledError): - yield from ws.close() - assert ws.close_code == 1006 + await ws.close() + +async def test_prepare_twice_idempotent(make_request) -> None: + req = make_request("GET", "/") + ws = WebSocketResponse() + + impl1 = await ws.prepare(req) + impl2 = await ws.prepare(req) + assert impl1 is impl2 -@asyncio.coroutine -def test_close_exc2(make_request): - req = make_request('GET', '/') +async def test_send_with_per_message_deflate(make_request, mocker) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() - yield from ws.prepare(req) + await ws.prepare(req) + writer_send = ws._writer.send = make_mocked_coro() - exc = ValueError() - ws._writer = mock.Mock() - ws._writer.close.side_effect = exc - yield from ws.close() - assert ws.closed - assert ws.exception() is exc + await ws.send_str("string", compress=15) + writer_send.assert_called_with("string", binary=False, compress=15) - ws._closed = False - ws._writer.close.side_effect = asyncio.CancelledError() - with pytest.raises(asyncio.CancelledError): - yield from ws.close() + await ws.send_bytes(b"bytes", compress=0) + writer_send.assert_called_with(b"bytes", binary=True, compress=0) + await ws.send_json("[{}]", compress=9) + writer_send.assert_called_with('"[{}]"', binary=False, compress=9) -@asyncio.coroutine -def test_prepare_twice_idempotent(make_request): - req = make_request('GET', '/') + +async def test_no_transfer_encoding_header(make_request, mocker) -> None: + req = make_request("GET", "/") ws = WebSocketResponse() + await ws._start(req) - impl1 = yield from ws.prepare(req) - impl2 = yield from ws.prepare(req) - assert impl1 is impl2 + assert "Transfer-Encoding" not in ws.headers diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index da281ce79e5..e5ea2a5539d 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1,762 +1,770 @@ -"""HTTP websocket server functional tests""" +# HTTP websocket server functional tests import asyncio import pytest import aiohttp -from aiohttp import helpers, web +from aiohttp import web from aiohttp.http import WSMsgType -@pytest.fixture -def ceil(mocker): - def ceil(val): - return val - - mocker.patch('aiohttp.helpers.ceil').side_effect = ceil - - -@asyncio.coroutine -def test_websocket_can_prepare(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_websocket_can_prepare(loop, aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() if not ws.can_prepare(request): - return web.HTTPUpgradeRequired() + raise web.HTTPUpgradeRequired() - return web.HTTPOk() + return web.Response() app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - resp = yield from client.get('/') + resp = await client.get("/") assert resp.status == 426 -@asyncio.coroutine -def test_websocket_json(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_websocket_json(loop, aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() if not ws.can_prepare(request): return web.HTTPUpgradeRequired() - yield from ws.prepare(request) - msg = yield from ws.receive() + await ws.prepare(request) + msg = await ws.receive() msg_json = msg.json() - answer = msg_json['test'] - ws.send_str(answer) + answer = msg_json["test"] + await ws.send_str(answer) - yield from ws.close() + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/') - expected_value = 'value' + ws = await client.ws_connect("/") + expected_value = "value" payload = '{"test": "%s"}' % expected_value - ws.send_str(payload) + await ws.send_str(payload) - resp = yield from ws.receive() + resp = await ws.receive() assert resp.data == expected_value -@asyncio.coroutine -def test_websocket_json_invalid_message(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_websocket_json_invalid_message(loop, aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) try: - yield from ws.receive_json() + await ws.receive_json() except ValueError: - ws.send_str('ValueError was raised') + await ws.send_str("ValueError was raised") else: - raise Exception('No Exception') + raise Exception("No Exception") finally: - yield from ws.close() + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/') - payload = 'NOT A VALID JSON STRING' - ws.send_str(payload) + ws = await client.ws_connect("/") + payload = "NOT A VALID JSON STRING" + await ws.send_str(payload) - data = yield from ws.receive_str() - assert 'ValueError was raised' in data + data = await ws.receive_str() + assert "ValueError was raised" in data -@asyncio.coroutine -def test_websocket_send_json(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_websocket_send_json(loop, aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) - - data = yield from ws.receive_json() - ws.send_json(data) - - yield from ws.close() - return ws - - app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) - - ws = yield from client.ws_connect('/') - expected_value = 'value' - ws.send_json({'test': expected_value}) - - data = yield from ws.receive_json() - assert data['test'] == expected_value - - -@asyncio.coroutine -def test_websocket_send_drain(loop, test_client): - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse() - yield from ws.prepare(request) - - ws._writer._limit = 1 + await ws.prepare(request) - data = yield from ws.receive_json() - drain = ws.send_json(data) - assert drain + data = await ws.receive_json() + await ws.send_json(data) - yield from drain - yield from ws.close() + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/') - expected_value = 'value' - ws.send_json({'test': expected_value}) + ws = await client.ws_connect("/") + expected_value = "value" + await ws.send_json({"test": expected_value}) - data = yield from ws.receive_json() - assert data['test'] == expected_value + data = await ws.receive_json() + assert data["test"] == expected_value -@asyncio.coroutine -def test_websocket_receive_json(loop, test_client): - @asyncio.coroutine - def handler(request): +async def test_websocket_receive_json(loop, aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - data = yield from ws.receive_json() - answer = data['test'] - ws.send_str(answer) + data = await ws.receive_json() + answer = data["test"] + await ws.send_str(answer) - yield from ws.close() + await ws.close() return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/') - expected_value = 'value' + ws = await client.ws_connect("/") + expected_value = "value" payload = '{"test": "%s"}' % expected_value - ws.send_str(payload) + await ws.send_str(payload) - resp = yield from ws.receive() + resp = await ws.receive() assert resp.data == expected_value -@asyncio.coroutine -def test_send_recv_text(loop, test_client): +async def test_send_recv_text(loop, aiohttp_client) -> None: - closed = helpers.create_future(loop) + closed = loop.create_future() - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) - msg = yield from ws.receive_str() - ws.send_str(msg+'/answer') - yield from ws.close() + await ws.prepare(request) + msg = await ws.receive_str() + await ws.send_str(msg + "/answer") + await ws.close() closed.set_result(1) return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/') - ws.send_str('ask') - msg = yield from ws.receive() + ws = await client.ws_connect("/") + await ws.send_str("ask") + msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.TEXT - assert 'ask/answer' == msg.data + assert "ask/answer" == msg.data - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.CLOSE assert msg.data == 1000 - assert msg.extra == '' + assert msg.extra == "" assert ws.closed assert ws.close_code == 1000 - yield from closed + await closed -@asyncio.coroutine -def test_send_recv_bytes(loop, test_client): +async def test_send_recv_bytes(loop, aiohttp_client) -> None: - closed = helpers.create_future(loop) + closed = loop.create_future() - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - msg = yield from ws.receive_bytes() - ws.send_bytes(msg+b'/answer') - yield from ws.close() + msg = await ws.receive_bytes() + await ws.send_bytes(msg + b"/answer") + await ws.close() closed.set_result(1) return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/') - ws.send_bytes(b'ask') - msg = yield from ws.receive() + ws = await client.ws_connect("/") + await ws.send_bytes(b"ask") + msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.BINARY - assert b'ask/answer' == msg.data + assert b"ask/answer" == msg.data - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.CLOSE assert msg.data == 1000 - assert msg.extra == '' + assert msg.extra == "" assert ws.closed assert ws.close_code == 1000 - yield from closed + await closed -@asyncio.coroutine -def test_send_recv_json(loop, test_client): - closed = helpers.create_future(loop) +async def test_send_recv_json(loop, aiohttp_client) -> None: + closed = loop.create_future() - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) - data = yield from ws.receive_json() - ws.send_json({'response': data['request']}) - yield from ws.close() + await ws.prepare(request) + data = await ws.receive_json() + await ws.send_json({"response": data["request"]}) + await ws.close() closed.set_result(1) return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/') + ws = await client.ws_connect("/") - ws.send_str('{"request": "test"}') - msg = yield from ws.receive() + await ws.send_str('{"request": "test"}') + msg = await ws.receive() data = msg.json() assert msg.type == aiohttp.WSMsgType.TEXT - assert data['response'] == 'test' + assert data["response"] == "test" - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.CLOSE assert msg.data == 1000 - assert msg.extra == '' + assert msg.extra == "" - yield from ws.close() + await ws.close() - yield from closed + await closed -@asyncio.coroutine -def test_close_timeout(loop, test_client): - aborted = helpers.create_future(loop) +async def test_close_timeout(loop, aiohttp_client) -> None: + aborted = loop.create_future() + elapsed = 1e10 # something big - @asyncio.coroutine - def handler(request): + async def handler(request): + nonlocal elapsed ws = web.WebSocketResponse(timeout=0.1) - yield from ws.prepare(request) - assert 'request' == (yield from ws.receive_str()) - ws.send_str('reply') + await ws.prepare(request) + assert "request" == (await ws.receive_str()) + await ws.send_str("reply") begin = ws._loop.time() - assert (yield from ws.close()) + assert await ws.close() elapsed = ws._loop.time() - begin - assert elapsed < 0.201, \ - 'close() should have returned before ' \ - 'at most 2x timeout.' assert ws.close_code == 1006 assert isinstance(ws.exception(), asyncio.TimeoutError) aborted.set_result(1) return ws app = web.Application() - app.router.add_route('GET', '/', handler) - client = yield from test_client(app) + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/') - ws.send_str('request') - assert 'reply' == (yield from ws.receive_str()) + ws = await client.ws_connect("/") + await ws.send_str("request") + assert "reply" == (await ws.receive_str()) # The server closes here. Then the client sends bogus messages with an # internval shorter than server-side close timeout, to make the server # hanging indefinitely. - yield from asyncio.sleep(0.08, loop=loop) - msg = yield from ws._reader.read() + await asyncio.sleep(0.08) + msg = await ws._reader.read() assert msg.type == WSMsgType.CLOSE - ws.send_str('hang') - # i am not sure what do we test here - # under uvloop this code raises RuntimeError - try: - yield from asyncio.sleep(0.08, loop=loop) - ws.send_str('hang') - yield from asyncio.sleep(0.08, loop=loop) - ws.send_str('hang') - yield from asyncio.sleep(0.08, loop=loop) - ws.send_str('hang') - except RuntimeError: - pass + await asyncio.sleep(0.08) + assert await aborted - yield from asyncio.sleep(0.08, loop=loop) - assert (yield from aborted) + assert elapsed < 0.25, "close() should have returned before " "at most 2x timeout." - yield from ws.close() + await ws.close() -@asyncio.coroutine -def test_concurrent_close(loop, test_client): +async def test_concurrent_close(loop, aiohttp_client) -> None: srv_ws = None - @asyncio.coroutine - def handler(request): + async def handler(request): nonlocal srv_ws - ws = srv_ws = web.WebSocketResponse( - autoclose=False, protocols=('foo', 'bar')) - yield from ws.prepare(request) + ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) + await ws.prepare(request) - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.CLOSING - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.CLOSING - yield from asyncio.sleep(0, loop=loop) + await asyncio.sleep(0) - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.CLOSED return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/', autoclose=False, - protocols=('eggs', 'bar')) + ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) - yield from srv_ws.close(code=1007) + await srv_ws.close(code=1007) - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.CLOSE - yield from asyncio.sleep(0, loop=loop) - msg = yield from ws.receive() + await asyncio.sleep(0) + msg = await ws.receive() assert msg.type == WSMsgType.CLOSED -@asyncio.coroutine -def test_auto_pong_with_closing_by_peer(loop, test_client): +async def test_auto_pong_with_closing_by_peer(loop, aiohttp_client) -> None: - closed = helpers.create_future(loop) + closed = loop.create_future() - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) - yield from ws.receive() + await ws.prepare(request) + await ws.receive() - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.CLOSE assert msg.data == 1000 - assert msg.extra == 'exit message' + assert msg.extra == "exit message" closed.set_result(None) return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/', autoclose=False, autoping=False) - ws.ping() - ws.send_str('ask') + ws = await client.ws_connect("/", autoclose=False, autoping=False) + await ws.ping() + await ws.send_str("ask") - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.PONG - yield from ws.close(code=1000, message='exit message') - yield from closed + await ws.close(code=1000, message="exit message") + await closed -@asyncio.coroutine -def test_ping(loop, test_client): +async def test_ping(loop, aiohttp_client) -> None: - closed = helpers.create_future(loop) + closed = loop.create_future() - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - ws.ping('data') - yield from ws.receive() + await ws.ping("data") + await ws.receive() closed.set_result(None) return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/', autoping=False) + ws = await client.ws_connect("/", autoping=False) - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.PING - assert msg.data == b'data' - ws.pong() - yield from ws.close() - yield from closed + assert msg.data == b"data" + await ws.pong() + await ws.close() + await closed -@asyncio.coroutine -def test_client_ping(loop, test_client): +async def aiohttp_client_ping(loop, aiohttp_client): - closed = helpers.create_future(loop) + closed = loop.create_future() - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() - yield from ws.prepare(request) + await ws.prepare(request) - yield from ws.receive() + await ws.receive() closed.set_result(None) return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/', autoping=False) + ws = await client.ws_connect("/", autoping=False) - ws.ping('data') - msg = yield from ws.receive() + await ws.ping("data") + msg = await ws.receive() assert msg.type == WSMsgType.PONG - assert msg.data == b'data' - ws.pong() - yield from ws.close() + assert msg.data == b"data" + await ws.pong() + await ws.close() -@asyncio.coroutine -def test_pong(loop, test_client): +async def test_pong(loop, aiohttp_client) -> None: - closed = helpers.create_future(loop) + closed = loop.create_future() - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse(autoping=False) - yield from ws.prepare(request) + await ws.prepare(request) - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.PING - ws.pong('data') + await ws.pong("data") - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.CLOSE assert msg.data == 1000 - assert msg.extra == 'exit message' + assert msg.extra == "exit message" closed.set_result(None) return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/', autoping=False) + ws = await client.ws_connect("/", autoping=False) - ws.ping('data') - msg = yield from ws.receive() + await ws.ping("data") + msg = await ws.receive() assert msg.type == WSMsgType.PONG - assert msg.data == b'data' + assert msg.data == b"data" - yield from ws.close(code=1000, message='exit message') + await ws.close(code=1000, message="exit message") - yield from closed + await closed -@asyncio.coroutine -def test_change_status(loop, test_client): +async def test_change_status(loop, aiohttp_client) -> None: - closed = helpers.create_future(loop) + closed = loop.create_future() - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse() ws.set_status(200) assert 200 == ws.status - yield from ws.prepare(request) + await ws.prepare(request) assert 101 == ws.status - yield from ws.close() + await ws.close() closed.set_result(None) return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/', autoping=False) + ws = await client.ws_connect("/", autoping=False) - yield from ws.close() - yield from closed - yield from ws.close() + await ws.close() + await closed + await ws.close() -@asyncio.coroutine -def test_handle_protocol(loop, test_client): +async def test_handle_protocol(loop, aiohttp_client) -> None: - closed = helpers.create_future(loop) + closed = loop.create_future() - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse(protocols=('foo', 'bar')) - yield from ws.prepare(request) - yield from ws.close() - assert 'bar' == ws.ws_protocol + async def handler(request): + ws = web.WebSocketResponse(protocols=("foo", "bar")) + await ws.prepare(request) + await ws.close() + assert "bar" == ws.ws_protocol closed.set_result(None) return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/', protocols=('eggs', 'bar')) + ws = await client.ws_connect("/", protocols=("eggs", "bar")) - yield from ws.close() - yield from closed + await ws.close() + await closed -@asyncio.coroutine -def test_server_close_handshake(loop, test_client): +async def test_server_close_handshake(loop, aiohttp_client) -> None: - closed = helpers.create_future(loop) + closed = loop.create_future() - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse(protocols=('foo', 'bar')) - yield from ws.prepare(request) - yield from ws.close() + async def handler(request): + ws = web.WebSocketResponse(protocols=("foo", "bar")) + await ws.prepare(request) + await ws.close() closed.set_result(None) return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/', autoclose=False, - protocols=('eggs', 'bar')) + ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.CLOSE - yield from ws.close() - yield from closed + await ws.close() + await closed -@asyncio.coroutine -def test_client_close_handshake(loop, test_client, ceil): +async def aiohttp_client_close_handshake(loop, aiohttp_client): - closed = helpers.create_future(loop) + closed = loop.create_future() - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse( - autoclose=False, protocols=('foo', 'bar')) - yield from ws.prepare(request) + async def handler(request): + ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) + await ws.prepare(request) - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.CLOSE assert not ws.closed - yield from ws.close() + await ws.close() assert ws.closed assert ws.close_code == 1007 - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.CLOSED closed.set_result(None) return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/', autoclose=False, - protocols=('eggs', 'bar')) + ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) - yield from ws.close(code=1007) - msg = yield from ws.receive() + await ws.close(code=1007) + msg = await ws.receive() assert msg.type == WSMsgType.CLOSED - yield from closed + await closed -@asyncio.coroutine -def test_server_close_handshake_server_eats_client_messages(loop, test_client): +async def test_server_close_handshake_server_eats_client_messages(loop, aiohttp_client): + closed = loop.create_future() - closed = helpers.create_future(loop) - - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse(protocols=('foo', 'bar')) - yield from ws.prepare(request) - yield from ws.close() + async def handler(request): + ws = web.WebSocketResponse(protocols=("foo", "bar")) + await ws.prepare(request) + await ws.close() closed.set_result(None) return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/', autoclose=False, autoping=False, - protocols=('eggs', 'bar')) + ws = await client.ws_connect( + "/", autoclose=False, autoping=False, protocols=("eggs", "bar") + ) - msg = yield from ws.receive() + msg = await ws.receive() assert msg.type == WSMsgType.CLOSE - ws.send_str('text') - ws.send_bytes(b'bytes') - ws.ping() + await ws.send_str("text") + await ws.send_bytes(b"bytes") + await ws.ping() - yield from ws.close() - yield from closed + await ws.close() + await closed -@asyncio.coroutine -def test_receive_timeout(loop, test_client): +async def test_receive_timeout(loop, aiohttp_client) -> None: raised = False - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse(receive_timeout=0.1) - yield from ws.prepare(request) + await ws.prepare(request) try: - yield from ws.receive() + await ws.receive() except asyncio.TimeoutError: nonlocal raised raised = True - yield from ws.close() + await ws.close() return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/') - yield from ws.receive() - yield from ws.close() + ws = await client.ws_connect("/") + await ws.receive() + await ws.close() assert raised -@asyncio.coroutine -def test_custom_receive_timeout(loop, test_client): +async def test_custom_receive_timeout(loop, aiohttp_client) -> None: raised = False - @asyncio.coroutine - def handler(request): + async def handler(request): ws = web.WebSocketResponse(receive_timeout=None) - yield from ws.prepare(request) + await ws.prepare(request) try: - yield from ws.receive(0.1) + await ws.receive(0.1) except asyncio.TimeoutError: nonlocal raised raised = True - yield from ws.close() + await ws.close() return ws app = web.Application() - app.router.add_get('/', handler) - client = yield from test_client(app) + app.router.add_get("/", handler) + client = await aiohttp_client(app) - ws = yield from client.ws_connect('/') - yield from ws.receive() - yield from ws.close() + ws = await client.ws_connect("/") + await ws.receive() + await ws.close() assert raised -@asyncio.coroutine -def test_heartbeat(loop, test_client, ceil): - @asyncio.coroutine - def handler(request): +async def test_heartbeat(loop, aiohttp_client) -> None: + async def handler(request): ws = web.WebSocketResponse(heartbeat=0.05) - yield from ws.prepare(request) - yield from ws.receive() - yield from ws.close() + await ws.prepare(request) + await ws.receive() + await ws.close() return ws app = web.Application() - app.router.add_get('/', handler) + app.router.add_get("/", handler) - client = yield from test_client(app) - ws = yield from client.ws_connect('/', autoping=False) - msg = yield from ws.receive() + client = await aiohttp_client(app) + ws = await client.ws_connect("/", autoping=False) + msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.ping - yield from ws.close() + await ws.close() -@asyncio.coroutine -def test_heartbeat_no_pong(loop, test_client, ceil): - cancelled = False +async def test_heartbeat_no_pong(loop, aiohttp_client) -> None: + async def handler(request): + ws = web.WebSocketResponse(heartbeat=0.05) + await ws.prepare(request) - @asyncio.coroutine - def handler(request): - nonlocal cancelled - request._time_service._interval = 0.1 - request._time_service._on_cb() + await ws.receive() + return ws - ws = web.WebSocketResponse(heartbeat=0.05) - yield from ws.prepare(request) + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + ws = await client.ws_connect("/", autoping=False) + msg = await ws.receive() + assert msg.type == aiohttp.WSMsgType.ping + await ws.close() - try: - yield from ws.receive() - except asyncio.CancelledError: - cancelled = True +async def test_server_ws_async_for(loop, aiohttp_server) -> None: + closed = loop.create_future() + + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + async for msg in ws: + assert msg.type == aiohttp.WSMsgType.TEXT + s = msg.data + await ws.send_str(s + "/answer") + await ws.close() + closed.set_result(1) return ws app = web.Application() - app.router.add_get('/', handler) + app.router.add_route("GET", "/", handler) + server = await aiohttp_server(app) - client = yield from test_client(app) - ws = yield from client.ws_connect('/', autoping=False) - msg = yield from ws.receive() - assert msg.type == aiohttp.WSMsgType.ping - yield from ws.receive() + async with aiohttp.ClientSession() as sm: + async with sm.ws_connect(server.make_url("/")) as resp: + + items = ["q1", "q2", "q3"] + for item in items: + await resp.send_str(item) + msg = await resp.receive() + assert msg.type == aiohttp.WSMsgType.TEXT + assert item + "/answer" == msg.data + + await resp.close() + await closed + + +async def test_closed_async_for(loop, aiohttp_client) -> None: + + closed = loop.create_future() + + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + messages = [] + async for msg in ws: + messages.append(msg) + if "stop" == msg.data: + await ws.send_str("stopping") + await ws.close() + + assert 1 == len(messages) + assert messages[0].type == WSMsgType.TEXT + assert messages[0].data == "stop" + + closed.set_result(None) + return ws + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + ws = await client.ws_connect("/") + await ws.send_str("stop") + msg = await ws.receive() + assert msg.type == WSMsgType.TEXT + assert msg.data == "stopping" + + await ws.close() + await closed + + +async def test_websocket_disable_keepalive(loop, aiohttp_client) -> None: + async def handler(request): + ws = web.WebSocketResponse() + if not ws.can_prepare(request): + return web.Response(text="OK") + assert request.protocol._keepalive + await ws.prepare(request) + assert not request.protocol._keepalive + assert not request.protocol._keepalive_handle + + await ws.send_str("OK") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + resp = await client.get("/") + txt = await resp.text() + assert txt == "OK" + + ws = await client.ws_connect("/") + data = await ws.receive_str() + assert data == "OK" + + +async def test_bug3380(loop, aiohttp_client) -> None: + async def handle_null(request): + return aiohttp.web.json_response({"err": None}) + + async def ws_handler(request): + return web.Response(status=401) + + app = web.Application() + app.router.add_route("GET", "/ws", ws_handler) + app.router.add_route("GET", "/api/null", handle_null) + + client = await aiohttp_client(app) + + resp = await client.get("/api/null") + assert (await resp.json()) == {"err": None} + resp.close() + + with pytest.raises(aiohttp.WSServerHandshakeError): + await client.ws_connect("/ws") - assert cancelled + resp = await client.get("/api/null", timeout=1) + assert (await resp.json()) == {"err": None} + resp.close() diff --git a/tests/test_websocket_handshake.py b/tests/test_websocket_handshake.py index 15ec918cfe6..bbfa1d9260d 100644 --- a/tests/test_websocket_handshake.py +++ b/tests/test_websocket_handshake.py @@ -1,151 +1,296 @@ -"""Tests for http/websocket.py""" +# Tests for http/websocket.py import base64 -import hashlib import os -from unittest import mock -import multidict import pytest -from yarl import URL -from aiohttp import http, http_exceptions -from aiohttp.http import WS_KEY, do_handshake +from aiohttp import web +from aiohttp.test_utils import make_mocked_request -@pytest.fixture() -def transport(): - return mock.Mock() +def gen_ws_headers( + protocols="", + compress=0, + extension_text="", + server_notakeover=False, + client_notakeover=False, +): + key = base64.b64encode(os.urandom(16)).decode() + hdrs = [ + ("Upgrade", "websocket"), + ("Connection", "upgrade"), + ("Sec-Websocket-Version", "13"), + ("Sec-Websocket-Key", key), + ] + if protocols: + hdrs += [("Sec-Websocket-Protocol", protocols)] + if compress: + params = "permessage-deflate" + if compress < 15: + params += "; server_max_window_bits=" + str(compress) + if server_notakeover: + params += "; server_no_context_takeover" + if client_notakeover: + params += "; client_no_context_takeover" + if extension_text: + params += "; " + extension_text + hdrs += [("Sec-Websocket-Extensions", params)] + return hdrs, key -@pytest.fixture() -def message(): - headers = multidict.MultiDict() - return http.RawRequestMessage( - 'GET', '/path', (1, 0), headers, [], - True, None, True, False, URL('/path')) +async def test_no_upgrade() -> None: + ws = web.WebSocketResponse() + req = make_mocked_request("GET", "/") + with pytest.raises(web.HTTPBadRequest): + await ws.prepare(req) + + +async def test_no_connection() -> None: + ws = web.WebSocketResponse() + req = make_mocked_request( + "GET", "/", headers={"Upgrade": "websocket", "Connection": "keep-alive"} + ) + with pytest.raises(web.HTTPBadRequest): + await ws.prepare(req) + + +async def test_protocol_version_unset() -> None: + ws = web.WebSocketResponse() + req = make_mocked_request( + "GET", "/", headers={"Upgrade": "websocket", "Connection": "upgrade"} + ) + with pytest.raises(web.HTTPBadRequest): + await ws.prepare(req) + + +async def test_protocol_version_not_supported() -> None: + ws = web.WebSocketResponse() + req = make_mocked_request( + "GET", + "/", + headers={ + "Upgrade": "websocket", + "Connection": "upgrade", + "Sec-Websocket-Version": "1", + }, + ) + with pytest.raises(web.HTTPBadRequest): + await ws.prepare(req) + + +async def test_protocol_key_not_present() -> None: + ws = web.WebSocketResponse() + req = make_mocked_request( + "GET", + "/", + headers={ + "Upgrade": "websocket", + "Connection": "upgrade", + "Sec-Websocket-Version": "13", + }, + ) + with pytest.raises(web.HTTPBadRequest): + await ws.prepare(req) + + +async def test_protocol_key_invalid() -> None: + ws = web.WebSocketResponse() + req = make_mocked_request( + "GET", + "/", + headers={ + "Upgrade": "websocket", + "Connection": "upgrade", + "Sec-Websocket-Version": "13", + "Sec-Websocket-Key": "123", + }, + ) + with pytest.raises(web.HTTPBadRequest): + await ws.prepare(req) + + +async def test_protocol_key_bad_size() -> None: + ws = web.WebSocketResponse() + sec_key = base64.b64encode(os.urandom(2)) + val = sec_key.decode() + req = make_mocked_request( + "GET", + "/", + headers={ + "Upgrade": "websocket", + "Connection": "upgrade", + "Sec-Websocket-Version": "13", + "Sec-Websocket-Key": val, + }, + ) + with pytest.raises(web.HTTPBadRequest): + await ws.prepare(req) + + +async def test_handshake_ok() -> None: + hdrs, sec_key = gen_ws_headers() + ws = web.WebSocketResponse() + req = make_mocked_request("GET", "/", headers=hdrs) + await ws.prepare(req) -def gen_ws_headers(protocols=''): - key = base64.b64encode(os.urandom(16)).decode() - hdrs = [('Upgrade', 'websocket'), - ('Connection', 'upgrade'), - ('Sec-Websocket-Version', '13'), - ('Sec-Websocket-Key', key)] - if protocols: - hdrs += [('Sec-Websocket-Protocol', protocols)] - return hdrs, key + assert ws.ws_protocol is None -def test_not_get(message, transport): - with pytest.raises(http_exceptions.HttpProcessingError): - do_handshake('POST', message.headers, transport) +async def test_handshake_protocol() -> None: + # Tests if one protocol is returned by handshake + proto = "chat" + ws = web.WebSocketResponse(protocols={"chat"}) + req = make_mocked_request("GET", "/", headers=gen_ws_headers(proto)[0]) -def test_no_upgrade(message, transport): - with pytest.raises(http_exceptions.HttpBadRequest): - do_handshake(message.method, message.headers, transport) + await ws.prepare(req) + assert ws.ws_protocol == proto -def test_no_connection(message, transport): - message.headers.extend([('Upgrade', 'websocket'), - ('Connection', 'keep-alive')]) - with pytest.raises(http_exceptions.HttpBadRequest): - do_handshake(message.method, message.headers, transport) +async def test_handshake_protocol_agreement() -> None: + # Tests if the right protocol is selected given multiple + best_proto = "worse_proto" + wanted_protos = ["best", "chat", "worse_proto"] + server_protos = "worse_proto,chat" -def test_protocol_version(message, transport): - message.headers.extend([('Upgrade', 'websocket'), - ('Connection', 'upgrade')]) - with pytest.raises(http_exceptions.HttpBadRequest): - do_handshake(message.method, message.headers, transport) + ws = web.WebSocketResponse(protocols=wanted_protos) + req = make_mocked_request("GET", "/", headers=gen_ws_headers(server_protos)[0]) - message.headers.extend([('Upgrade', 'websocket'), - ('Connection', 'upgrade'), - ('Sec-Websocket-Version', '1')]) + await ws.prepare(req) - with pytest.raises(http_exceptions.HttpBadRequest): - do_handshake(message.method, message.headers, transport) + assert ws.ws_protocol == best_proto -def test_protocol_key(message, transport): - message.headers.extend([('Upgrade', 'websocket'), - ('Connection', 'upgrade'), - ('Sec-Websocket-Version', '13')]) - with pytest.raises(http_exceptions.HttpBadRequest): - do_handshake(message.method, message.headers, transport) +async def test_handshake_protocol_unsupported(caplog) -> None: + # Tests if a protocol mismatch handshake warns and returns None + proto = "chat" + req = make_mocked_request("GET", "/", headers=gen_ws_headers("test")[0]) - message.headers.extend([('Upgrade', 'websocket'), - ('Connection', 'upgrade'), - ('Sec-Websocket-Version', '13'), - ('Sec-Websocket-Key', '123')]) - with pytest.raises(http_exceptions.HttpBadRequest): - do_handshake(message.method, message.headers, transport) + ws = web.WebSocketResponse(protocols=[proto]) + await ws.prepare(req) - sec_key = base64.b64encode(os.urandom(2)) - message.headers.extend([('Upgrade', 'websocket'), - ('Connection', 'upgrade'), - ('Sec-Websocket-Version', '13'), - ('Sec-Websocket-Key', sec_key.decode())]) - with pytest.raises(http_exceptions.HttpBadRequest): - do_handshake(message.method, message.headers, transport) + assert ( + caplog.records[-1].msg + == "Client protocols %r don’t overlap server-known ones %r" + ) + assert ws.ws_protocol is None -def test_handshake(message, transport): - hdrs, sec_key = gen_ws_headers() +async def test_handshake_compress() -> None: + hdrs, sec_key = gen_ws_headers(compress=15) + + req = make_mocked_request("GET", "/", headers=hdrs) + + ws = web.WebSocketResponse() + await ws.prepare(req) + + assert ws.compress == 15 + + +def test_handshake_compress_server_notakeover() -> None: + hdrs, sec_key = gen_ws_headers(compress=15, server_notakeover=True) + + req = make_mocked_request("GET", "/", headers=hdrs) + + ws = web.WebSocketResponse() + headers, _, compress, notakeover = ws._handshake(req) + + assert compress == 15 + assert notakeover is True + assert "Sec-Websocket-Extensions" in headers + assert headers["Sec-Websocket-Extensions"] == ( + "permessage-deflate; server_no_context_takeover" + ) + + +def test_handshake_compress_client_notakeover() -> None: + hdrs, sec_key = gen_ws_headers(compress=15, client_notakeover=True) - message.headers.extend(hdrs) - status, headers, parser, writer, protocol = do_handshake( - message.method, message.headers, transport) - assert status == 101 - assert protocol is None + req = make_mocked_request("GET", "/", headers=hdrs) - key = base64.b64encode( - hashlib.sha1(sec_key.encode() + WS_KEY).digest()) - headers = dict(headers) - assert headers['Sec-Websocket-Accept'] == key.decode() + ws = web.WebSocketResponse() + headers, _, compress, notakeover = ws._handshake(req) + assert "Sec-Websocket-Extensions" in headers + assert headers["Sec-Websocket-Extensions"] == ("permessage-deflate"), hdrs -def test_handshake_protocol(message, transport): - '''Tests if one protocol is returned by do_handshake''' - proto = 'chat' + assert compress == 15 - message.headers.extend(gen_ws_headers(proto)[0]) - _, resp_headers, _, _, protocol = do_handshake( - message.method, message.headers, transport, - protocols=[proto]) - assert protocol == proto +def test_handshake_compress_wbits() -> None: + hdrs, sec_key = gen_ws_headers(compress=9) - # also test if we reply with the protocol - resp_headers = dict(resp_headers) - assert resp_headers['Sec-Websocket-Protocol'] == proto + req = make_mocked_request("GET", "/", headers=hdrs) + ws = web.WebSocketResponse() + headers, _, compress, notakeover = ws._handshake(req) -def test_handshake_protocol_agreement(message, transport): - '''Tests if the right protocol is selected given multiple''' - best_proto = 'worse_proto' - wanted_protos = ['best', 'chat', 'worse_proto'] - server_protos = 'worse_proto,chat' + assert "Sec-Websocket-Extensions" in headers + assert headers["Sec-Websocket-Extensions"] == ( + "permessage-deflate; server_max_window_bits=9" + ) + assert compress == 9 - message.headers.extend(gen_ws_headers(server_protos)[0]) - _, resp_headers, _, _, protocol = do_handshake( - message.method, message.headers, transport, - protocols=wanted_protos) - assert protocol == best_proto +def test_handshake_compress_wbits_error() -> None: + hdrs, sec_key = gen_ws_headers(compress=6) + req = make_mocked_request("GET", "/", headers=hdrs) -def test_handshake_protocol_unsupported(log, message, transport): - '''Tests if a protocol mismatch handshake warns and returns None''' - proto = 'chat' - message.headers.extend(gen_ws_headers('test')[0]) + ws = web.WebSocketResponse() + headers, _, compress, notakeover = ws._handshake(req) + + assert "Sec-Websocket-Extensions" not in headers + assert compress == 0 + + +def test_handshake_compress_bad_ext() -> None: + hdrs, sec_key = gen_ws_headers(compress=15, extension_text="bad") + + req = make_mocked_request("GET", "/", headers=hdrs) + + ws = web.WebSocketResponse() + headers, _, compress, notakeover = ws._handshake(req) + + assert "Sec-Websocket-Extensions" not in headers + assert compress == 0 + + +def test_handshake_compress_multi_ext_bad() -> None: + hdrs, sec_key = gen_ws_headers( + compress=15, extension_text="bad, permessage-deflate" + ) + + req = make_mocked_request("GET", "/", headers=hdrs) + + ws = web.WebSocketResponse() + headers, _, compress, notakeover = ws._handshake(req) + + assert "Sec-Websocket-Extensions" in headers + assert headers["Sec-Websocket-Extensions"] == "permessage-deflate" + + +def test_handshake_compress_multi_ext_wbits() -> None: + hdrs, sec_key = gen_ws_headers(compress=6, extension_text=", permessage-deflate") + + req = make_mocked_request("GET", "/", headers=hdrs) + + ws = web.WebSocketResponse() + headers, _, compress, notakeover = ws._handshake(req) + + assert "Sec-Websocket-Extensions" in headers + assert headers["Sec-Websocket-Extensions"] == "permessage-deflate" + assert compress == 15 + + +def test_handshake_no_transfer_encoding() -> None: + hdrs, sec_key = gen_ws_headers() + req = make_mocked_request("GET", "/", headers=hdrs) - with log('aiohttp.websocket') as ctx: - _, _, _, _, protocol = do_handshake( - message.method, message.headers, transport, - protocols=[proto]) + ws = web.WebSocketResponse() + headers, _, compress, notakeover = ws._handshake(req) - assert protocol is None - assert (ctx.records[-1].msg == - 'Client protocols %r don’t overlap server-known ones %r') + assert "Transfer-Encoding" not in headers diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index f9d5fe3420f..3bdd8108e35 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -1,5 +1,7 @@ +import pickle import random import struct +import zlib from unittest import mock import pytest @@ -7,33 +9,53 @@ import aiohttp from aiohttp import http_websocket from aiohttp.http import WebSocketError, WSCloseCode, WSMessage, WSMsgType -from aiohttp.http_websocket import (PACK_CLOSE_CODE, PACK_LEN1, PACK_LEN2, - PACK_LEN3, WebSocketReader, - _websocket_mask) - - -def build_frame(message, opcode, use_mask=False, noheader=False): - """Send a frame over the websocket with message as its payload.""" +from aiohttp.http_websocket import ( + _WS_DEFLATE_TRAILING, + PACK_CLOSE_CODE, + PACK_LEN1, + PACK_LEN2, + PACK_LEN3, + WebSocketReader, + _websocket_mask, +) + + +def build_frame( + message, opcode, use_mask=False, noheader=False, is_fin=True, compress=False +): + # Send a frame over the websocket with message as its payload. + if compress: + compressobj = zlib.compressobj(wbits=-9) + message = compressobj.compress(message) + message = message + compressobj.flush(zlib.Z_SYNC_FLUSH) + if message.endswith(_WS_DEFLATE_TRAILING): + message = message[:-4] msg_length = len(message) if use_mask: # pragma: no cover mask_bit = 0x80 else: mask_bit = 0 + if is_fin: + header_first_byte = 0x80 | opcode + else: + header_first_byte = opcode + + if compress: + header_first_byte |= 0x40 + if msg_length < 126: - header = PACK_LEN1( - 0x80 | opcode, msg_length | mask_bit) + header = PACK_LEN1(header_first_byte, msg_length | mask_bit) elif msg_length < (1 << 16): # pragma: no cover - header = PACK_LEN2( - 0x80 | opcode, 126 | mask_bit, msg_length) + header = PACK_LEN2(header_first_byte, 126 | mask_bit, msg_length) else: - header = PACK_LEN3( - 0x80 | opcode, 127 | mask_bit, msg_length) + header = PACK_LEN3(header_first_byte, 127 | mask_bit, msg_length) if use_mask: # pragma: no cover - mask = random.randrange(0, 0xffffffff) - mask = mask.to_bytes(4, 'big') - message = _websocket_mask(mask, bytearray(message)) + mask = random.randrange(0, 0xFFFFFFFF) + mask = mask.to_bytes(4, "big") + message = bytearray(message) + _websocket_mask(mask, message) if noheader: return message else: @@ -45,142 +67,136 @@ def build_frame(message, opcode, use_mask=False, noheader=False): return header + message -def build_close_frame(code=1000, message=b'', noheader=False): - """Close the websocket, sending the specified code and message.""" +def build_close_frame(code=1000, message=b"", noheader=False): + # Close the websocket, sending the specified code and message. if isinstance(message, str): # pragma: no cover - message = message.encode('utf-8') + message = message.encode("utf-8") return build_frame( - PACK_CLOSE_CODE(code) + message, - opcode=WSMsgType.CLOSE, noheader=noheader) + PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE, noheader=noheader + ) @pytest.fixture() def out(loop): - return aiohttp.DataQueue(loop=loop) + return aiohttp.DataQueue(loop) @pytest.fixture() def parser(out): - return WebSocketReader(out) + return WebSocketReader(out, 4 * 1024 * 1024) -def test_parse_frame(parser): - parser.parse_frame(struct.pack('!BB', 0b00000001, 0b00000001)) - res = parser.parse_frame(b'1') - fin, opcode, payload = res[0] +def test_parse_frame(parser) -> None: + parser.parse_frame(struct.pack("!BB", 0b00000001, 0b00000001)) + res = parser.parse_frame(b"1") + fin, opcode, payload, compress = res[0] - assert (0, 1, b'1') == (fin, opcode, payload) + assert (0, 1, b"1", False) == (fin, opcode, payload, not not compress) -def test_parse_frame_length0(parser): - fin, opcode, payload = parser.parse_frame( - struct.pack('!BB', 0b00000001, 0b00000000))[0] +def test_parse_frame_length0(parser) -> None: + fin, opcode, payload, compress = parser.parse_frame( + struct.pack("!BB", 0b00000001, 0b00000000) + )[0] - assert (0, 1, b'') == (fin, opcode, payload) + assert (0, 1, b"", False) == (fin, opcode, payload, not not compress) -def test_parse_frame_length2(parser): - parser.parse_frame(struct.pack('!BB', 0b00000001, 126)) - parser.parse_frame(struct.pack('!H', 4)) - res = parser.parse_frame(b'1234') - fin, opcode, payload = res[0] +def test_parse_frame_length2(parser) -> None: + parser.parse_frame(struct.pack("!BB", 0b00000001, 126)) + parser.parse_frame(struct.pack("!H", 4)) + res = parser.parse_frame(b"1234") + fin, opcode, payload, compress = res[0] - assert (0, 1, b'1234') == (fin, opcode, payload) + assert (0, 1, b"1234", False) == (fin, opcode, payload, not not compress) -def test_parse_frame_length4(parser): - parser.parse_frame(struct.pack('!BB', 0b00000001, 127)) - parser.parse_frame(struct.pack('!Q', 4)) - fin, opcode, payload = parser.parse_frame(b'1234')[0] +def test_parse_frame_length4(parser) -> None: + parser.parse_frame(struct.pack("!BB", 0b00000001, 127)) + parser.parse_frame(struct.pack("!Q", 4)) + fin, opcode, payload, compress = parser.parse_frame(b"1234")[0] - assert (0, 1, b'1234') == (fin, opcode, payload) + assert (0, 1, b"1234", False) == (fin, opcode, payload, not not compress) -def test_parse_frame_mask(parser): - parser.parse_frame(struct.pack('!BB', 0b00000001, 0b10000001)) - parser.parse_frame(b'0001') - fin, opcode, payload = parser.parse_frame(b'1')[0] +def test_parse_frame_mask(parser) -> None: + parser.parse_frame(struct.pack("!BB", 0b00000001, 0b10000001)) + parser.parse_frame(b"0001") + fin, opcode, payload, compress = parser.parse_frame(b"1")[0] - assert (0, 1, b'\x01') == (fin, opcode, payload) + assert (0, 1, b"\x01", False) == (fin, opcode, payload, not not compress) -def test_parse_frame_header_reversed_bits(out, parser): +def test_parse_frame_header_reversed_bits(out, parser) -> None: with pytest.raises(WebSocketError): - parser.parse_frame(struct.pack('!BB', 0b01100000, 0b00000000)) + parser.parse_frame(struct.pack("!BB", 0b01100000, 0b00000000)) raise out.exception() -def test_parse_frame_header_control_frame(out, parser): +def test_parse_frame_header_control_frame(out, parser) -> None: with pytest.raises(WebSocketError): - parser.parse_frame(struct.pack('!BB', 0b00001000, 0b00000000)) - raise out.exception() - - -def test_parse_frame_header_continuation(out, parser): - with pytest.raises(WebSocketError): - parser._frame_fin = True - parser.parse_frame(struct.pack('!BB', 0b00000000, 0b00000000)) + parser.parse_frame(struct.pack("!BB", 0b00001000, 0b00000000)) raise out.exception() def _test_parse_frame_header_new_data_err(out, parser): with pytest.raises(WebSocketError): - parser.parse_frame(struct.pack('!BB', 0b000000000, 0b00000000)) + parser.parse_frame(struct.pack("!BB", 0b000000000, 0b00000000)) raise out.exception() -def test_parse_frame_header_payload_size(out, parser): +def test_parse_frame_header_payload_size(out, parser) -> None: with pytest.raises(WebSocketError): - parser.parse_frame(struct.pack('!BB', 0b10001000, 0b01111110)) + parser.parse_frame(struct.pack("!BB", 0b10001000, 0b01111110)) raise out.exception() -def test_ping_frame(out, parser): +def test_ping_frame(out, parser) -> None: parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.PING, b'data')] + parser.parse_frame.return_value = [(1, WSMsgType.PING, b"data", False)] - parser.feed_data(b'') + parser.feed_data(b"") res = out._buffer[0] - assert res == ((WSMsgType.PING, b'data', ''), 4) + assert res == ((WSMsgType.PING, b"data", ""), 4) -def test_pong_frame(out, parser): +def test_pong_frame(out, parser) -> None: parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.PONG, b'data')] + parser.parse_frame.return_value = [(1, WSMsgType.PONG, b"data", False)] - parser.feed_data(b'') + parser.feed_data(b"") res = out._buffer[0] - assert res == ((WSMsgType.PONG, b'data', ''), 4) + assert res == ((WSMsgType.PONG, b"data", ""), 4) -def test_close_frame(out, parser): +def test_close_frame(out, parser) -> None: parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'')] + parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b"", False)] - parser.feed_data(b'') + parser.feed_data(b"") res = out._buffer[0] - assert res == ((WSMsgType.CLOSE, 0, ''), 0) + assert res == ((WSMsgType.CLOSE, 0, ""), 0) -def test_close_frame_info(out, parser): +def test_close_frame_info(out, parser) -> None: parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'0112345')] + parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b"0112345", False)] - parser.feed_data(b'') + parser.feed_data(b"") res = out._buffer[0] - assert res == (WSMessage(WSMsgType.CLOSE, 12337, '12345'), 0) + assert res == (WSMessage(WSMsgType.CLOSE, 12337, "12345"), 0) -def test_close_frame_invalid(out, parser): +def test_close_frame_invalid(out, parser) -> None: parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'1')] - parser.feed_data(b'') + parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b"1", False)] + parser.feed_data(b"") assert isinstance(out.exception(), WebSocketError) assert out.exception().code == WSCloseCode.PROTOCOL_ERROR -def test_close_frame_invalid_2(out, parser): +def test_close_frame_invalid_2(out, parser) -> None: data = build_close_frame(code=1) with pytest.raises(WebSocketError) as ctx: @@ -189,9 +205,8 @@ def test_close_frame_invalid_2(out, parser): assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR -def test_close_frame_unicode_err(parser): - data = build_close_frame( - code=1000, message=b'\xf4\x90\x80\x80') +def test_close_frame_unicode_err(parser) -> None: + data = build_close_frame(code=1000, message=b"\xf4\x90\x80\x80") with pytest.raises(WebSocketError) as ctx: parser._feed_data(data) @@ -199,24 +214,24 @@ def test_close_frame_unicode_err(parser): assert ctx.value.code == WSCloseCode.INVALID_TEXT -def test_unknown_frame(out, parser): +def test_unknown_frame(out, parser) -> None: parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.CONTINUATION, b'')] + parser.parse_frame.return_value = [(1, WSMsgType.CONTINUATION, b"", False)] with pytest.raises(WebSocketError): - parser.feed_data(b'') + parser.feed_data(b"") raise out.exception() -def test_simple_text(out, parser): - data = build_frame(b'text', WSMsgType.TEXT) +def test_simple_text(out, parser) -> None: + data = build_frame(b"text", WSMsgType.TEXT) parser._feed_data(data) res = out._buffer[0] - assert res == ((WSMsgType.TEXT, 'text', ''), 4) + assert res == ((WSMsgType.TEXT, "text", ""), 4) -def test_simple_text_unicode_err(parser): - data = build_frame(b'\xf4\x90\x80\x80', WSMsgType.TEXT) +def test_simple_text_unicode_err(parser) -> None: + data = build_frame(b"\xf4\x90\x80\x80", WSMsgType.TEXT) with pytest.raises(WebSocketError) as ctx: parser._feed_data(data) @@ -224,160 +239,184 @@ def test_simple_text_unicode_err(parser): assert ctx.value.code == WSCloseCode.INVALID_TEXT -def test_simple_binary(out, parser): +def test_simple_binary(out, parser) -> None: parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.BINARY, b'binary')] + parser.parse_frame.return_value = [(1, WSMsgType.BINARY, b"binary", False)] - parser.feed_data(b'') + parser.feed_data(b"") res = out._buffer[0] - assert res == ((WSMsgType.BINARY, b'binary', ''), 6) + assert res == ((WSMsgType.BINARY, b"binary", ""), 6) -def test_continuation(out, parser): - parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (1, WSMsgType.CONTINUATION, b'line2')] +def test_fragmentation_header(out, parser) -> None: + data = build_frame(b"a", WSMsgType.TEXT) + parser._feed_data(data[:1]) + parser._feed_data(data[1:]) + + res = out._buffer[0] + assert res == (WSMessage(WSMsgType.TEXT, "a", ""), 1) - parser._feed_data(b'') + +def test_continuation(out, parser) -> None: + data1 = build_frame(b"line1", WSMsgType.TEXT, is_fin=False) + parser._feed_data(data1) + + data2 = build_frame(b"line2", WSMsgType.CONTINUATION) + parser._feed_data(data2) res = out._buffer[0] - assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10) + assert res == (WSMessage(WSMsgType.TEXT, "line1line2", ""), 10) -def test_continuation_with_ping(out, parser): +def test_continuation_with_ping(out, parser) -> None: parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (0, WSMsgType.PING, b''), - (1, WSMsgType.CONTINUATION, b'line2'), + (0, WSMsgType.TEXT, b"line1", False), + (0, WSMsgType.PING, b"", False), + (1, WSMsgType.CONTINUATION, b"line2", False), ] - parser.feed_data(b'') + data1 = build_frame(b"line1", WSMsgType.TEXT, is_fin=False) + parser._feed_data(data1) + + data2 = build_frame(b"", WSMsgType.PING) + parser._feed_data(data2) + + data3 = build_frame(b"line2", WSMsgType.CONTINUATION) + parser._feed_data(data3) + res = out._buffer[0] - assert res == (WSMessage(WSMsgType.PING, b'', ''), 0) + assert res == (WSMessage(WSMsgType.PING, b"", ""), 0) res = out._buffer[1] - assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10) + assert res == (WSMessage(WSMsgType.TEXT, "line1line2", ""), 10) -def test_continuation_err(out, parser): +def test_continuation_err(out, parser) -> None: parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (1, WSMsgType.TEXT, b'line2')] + (0, WSMsgType.TEXT, b"line1", False), + (1, WSMsgType.TEXT, b"line2", False), + ] with pytest.raises(WebSocketError): - parser._feed_data(b'') + parser._feed_data(b"") -def test_continuation_with_close(out, parser): +def test_continuation_with_close(out, parser) -> None: parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (0, WSMsgType.CLOSE, - build_close_frame(1002, b'test', noheader=True)), - (1, WSMsgType.CONTINUATION, b'line2'), + (0, WSMsgType.TEXT, b"line1", False), + (0, WSMsgType.CLOSE, build_close_frame(1002, b"test", noheader=True), False), + (1, WSMsgType.CONTINUATION, b"line2", False), ] - parser.feed_data(b'') + parser.feed_data(b"") res = out._buffer[0] - assert res, (WSMessage(WSMsgType.CLOSE, 1002, 'test'), 0) + assert res, (WSMessage(WSMsgType.CLOSE, 1002, "test"), 0) res = out._buffer[1] - assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10) + assert res == (WSMessage(WSMsgType.TEXT, "line1line2", ""), 10) -def test_continuation_with_close_unicode_err(out, parser): +def test_continuation_with_close_unicode_err(out, parser) -> None: parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (0, WSMsgType.CLOSE, - build_close_frame(1000, b'\xf4\x90\x80\x80', noheader=True)), - (1, WSMsgType.CONTINUATION, b'line2')] + (0, WSMsgType.TEXT, b"line1", False), + ( + 0, + WSMsgType.CLOSE, + build_close_frame(1000, b"\xf4\x90\x80\x80", noheader=True), + False, + ), + (1, WSMsgType.CONTINUATION, b"line2", False), + ] with pytest.raises(WebSocketError) as ctx: - parser._feed_data(b'') + parser._feed_data(b"") assert ctx.value.code == WSCloseCode.INVALID_TEXT -def test_continuation_with_close_bad_code(out, parser): +def test_continuation_with_close_bad_code(out, parser) -> None: parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (0, WSMsgType.CLOSE, - build_close_frame(1, b'test', noheader=True)), - (1, WSMsgType.CONTINUATION, b'line2')] + (0, WSMsgType.TEXT, b"line1", False), + (0, WSMsgType.CLOSE, build_close_frame(1, b"test", noheader=True), False), + (1, WSMsgType.CONTINUATION, b"line2", False), + ] with pytest.raises(WebSocketError) as ctx: - parser._feed_data(b'') + parser._feed_data(b"") assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR -def test_continuation_with_close_bad_payload(out, parser): +def test_continuation_with_close_bad_payload(out, parser) -> None: parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (0, WSMsgType.CLOSE, b'1'), - (1, WSMsgType.CONTINUATION, b'line2')] + (0, WSMsgType.TEXT, b"line1", False), + (0, WSMsgType.CLOSE, b"1", False), + (1, WSMsgType.CONTINUATION, b"line2", False), + ] with pytest.raises(WebSocketError) as ctx: - parser._feed_data(b'') + parser._feed_data(b"") assert ctx.value.code, WSCloseCode.PROTOCOL_ERROR -def test_continuation_with_close_empty(out, parser): +def test_continuation_with_close_empty(out, parser) -> None: parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (0, WSMsgType.CLOSE, b''), - (1, WSMsgType.CONTINUATION, b'line2'), + (0, WSMsgType.TEXT, b"line1", False), + (0, WSMsgType.CLOSE, b"", False), + (1, WSMsgType.CONTINUATION, b"line2", False), ] - parser.feed_data(b'') + parser.feed_data(b"") res = out._buffer[0] - assert res, (WSMessage(WSMsgType.CLOSE, 0, ''), 0) + assert res, (WSMessage(WSMsgType.CLOSE, 0, ""), 0) res = out._buffer[1] - assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10) + assert res == (WSMessage(WSMsgType.TEXT, "line1line2", ""), 10) -websocket_mask_data = bytearray( - b'some very long data for masking by websocket') -websocket_mask_mask = b'1234' -websocket_mask_masked = (b'B]^Q\x11DVFH\x12_[_U\x13PPFR\x14W]A\x14\\S@_X' - b'\\T\x14SK\x13CTP@[RYV@') +websocket_mask_data = b"some very long data for masking by websocket" +websocket_mask_mask = b"1234" +websocket_mask_masked = ( + b"B]^Q\x11DVFH\x12_[_U\x13PPFR\x14W]A\x14\\S@_X" b"\\T\x14SK\x13CTP@[RYV@" +) -def test_websocket_mask_python(): - ret = http_websocket._websocket_mask_python( - websocket_mask_mask, websocket_mask_data) - assert ret == websocket_mask_masked +def test_websocket_mask_python() -> None: + message = bytearray(websocket_mask_data) + http_websocket._websocket_mask_python(websocket_mask_mask, message) + assert message == websocket_mask_masked -@pytest.mark.skipif(not hasattr(http_websocket, '_websocket_mask_cython'), - reason='Requires Cython') -def test_websocket_mask_cython(): - ret = http_websocket._websocket_mask_cython( - websocket_mask_mask, websocket_mask_data) - assert ret == websocket_mask_masked +@pytest.mark.skipif( + not hasattr(http_websocket, "_websocket_mask_cython"), reason="Requires Cython" +) +def test_websocket_mask_cython() -> None: + message = bytearray(websocket_mask_data) + http_websocket._websocket_mask_cython(websocket_mask_mask, message) + assert message == websocket_mask_masked -def test_websocket_mask_python_empty(): - ret = http_websocket._websocket_mask_python( - websocket_mask_mask, bytearray()) - assert ret == bytearray() +def test_websocket_mask_python_empty() -> None: + message = bytearray() + http_websocket._websocket_mask_python(websocket_mask_mask, message) + assert message == bytearray() -@pytest.mark.skipif(not hasattr(http_websocket, '_websocket_mask_cython'), - reason='Requires Cython') -def test_websocket_mask_cython_empty(): - ret = http_websocket._websocket_mask_cython( - websocket_mask_mask, bytearray()) - assert ret == bytearray() +@pytest.mark.skipif( + not hasattr(http_websocket, "_websocket_mask_cython"), reason="Requires Cython" +) +def test_websocket_mask_cython_empty() -> None: + message = bytearray() + http_websocket._websocket_mask_cython(websocket_mask_mask, message) + assert message == bytearray() -def test_msgtype_aliases(): +def test_msgtype_aliases() -> None: assert aiohttp.WSMsgType.TEXT == aiohttp.WSMsgType.text assert aiohttp.WSMsgType.BINARY == aiohttp.WSMsgType.binary assert aiohttp.WSMsgType.PING == aiohttp.WSMsgType.ping @@ -385,3 +424,92 @@ def test_msgtype_aliases(): assert aiohttp.WSMsgType.CLOSE == aiohttp.WSMsgType.close assert aiohttp.WSMsgType.CLOSED == aiohttp.WSMsgType.closed assert aiohttp.WSMsgType.ERROR == aiohttp.WSMsgType.error + + +def test_parse_compress_frame_single(parser) -> None: + parser.parse_frame(struct.pack("!BB", 0b11000001, 0b00000001)) + res = parser.parse_frame(b"1") + fin, opcode, payload, compress = res[0] + + assert (1, 1, b"1", True) == (fin, opcode, payload, not not compress) + + +def test_parse_compress_frame_multi(parser) -> None: + parser.parse_frame(struct.pack("!BB", 0b01000001, 126)) + parser.parse_frame(struct.pack("!H", 4)) + res = parser.parse_frame(b"1234") + fin, opcode, payload, compress = res[0] + assert (0, 1, b"1234", True) == (fin, opcode, payload, not not compress) + + parser.parse_frame(struct.pack("!BB", 0b10000001, 126)) + parser.parse_frame(struct.pack("!H", 4)) + res = parser.parse_frame(b"1234") + fin, opcode, payload, compress = res[0] + assert (1, 1, b"1234", True) == (fin, opcode, payload, not not compress) + + parser.parse_frame(struct.pack("!BB", 0b10000001, 126)) + parser.parse_frame(struct.pack("!H", 4)) + res = parser.parse_frame(b"1234") + fin, opcode, payload, compress = res[0] + assert (1, 1, b"1234", False) == (fin, opcode, payload, not not compress) + + +def test_parse_compress_error_frame(parser) -> None: + parser.parse_frame(struct.pack("!BB", 0b01000001, 0b00000001)) + parser.parse_frame(b"1") + + with pytest.raises(WebSocketError) as ctx: + parser.parse_frame(struct.pack("!BB", 0b11000001, 0b00000001)) + parser.parse_frame(b"1") + + assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR + + +def test_parse_no_compress_frame_single() -> None: + parser_no_compress = WebSocketReader(out, 0, compress=False) + with pytest.raises(WebSocketError) as ctx: + parser_no_compress.parse_frame(struct.pack("!BB", 0b11000001, 0b00000001)) + parser_no_compress.parse_frame(b"1") + + assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR + + +def test_msg_too_large(out) -> None: + parser = WebSocketReader(out, 256, compress=False) + data = build_frame(b"text" * 256, WSMsgType.TEXT) + with pytest.raises(WebSocketError) as ctx: + parser._feed_data(data) + assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG + + +def test_msg_too_large_not_fin(out) -> None: + parser = WebSocketReader(out, 256, compress=False) + data = build_frame(b"text" * 256, WSMsgType.TEXT, is_fin=False) + with pytest.raises(WebSocketError) as ctx: + parser._feed_data(data) + assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG + + +def test_compressed_msg_too_large(out) -> None: + parser = WebSocketReader(out, 256, compress=True) + data = build_frame(b"aaa" * 256, WSMsgType.TEXT, compress=True) + with pytest.raises(WebSocketError) as ctx: + parser._feed_data(data) + assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG + + +class TestWebSocketError: + def test_ctor(self) -> None: + err = WebSocketError(WSCloseCode.PROTOCOL_ERROR, "Something invalid") + assert err.code == WSCloseCode.PROTOCOL_ERROR + assert str(err) == "Something invalid" + + def test_pickle(self) -> None: + err = WebSocketError(WSCloseCode.PROTOCOL_ERROR, "Something invalid") + err.foo = "bar" + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(err, proto) + err2 = pickle.loads(pickled) + assert err2.code == WSCloseCode.PROTOCOL_ERROR + assert str(err2) == "Something invalid" + assert err2.foo == "bar" diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 59bc9734fbf..fce3c330d27 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -4,65 +4,103 @@ import pytest from aiohttp.http import WebSocketWriter +from aiohttp.test_utils import make_mocked_coro @pytest.fixture -def stream(): - return mock.Mock() +def protocol(): + ret = mock.Mock() + ret._drain_helper = make_mocked_coro() + return ret @pytest.fixture -def writer(stream): - return WebSocketWriter(stream, use_mask=False) +def transport(): + ret = mock.Mock() + ret.is_closing.return_value = False + return ret -def test_pong(stream, writer): - writer.pong() - stream.transport.write.assert_called_with(b'\x8a\x00') +@pytest.fixture +def writer(protocol, transport): + return WebSocketWriter(protocol, transport, use_mask=False) -def test_ping(stream, writer): - writer.ping() - stream.transport.write.assert_called_with(b'\x89\x00') +async def test_pong(writer) -> None: + await writer.pong() + writer.transport.write.assert_called_with(b"\x8a\x00") -def test_send_text(stream, writer): - writer.send(b'text') - stream.transport.write.assert_called_with(b'\x81\x04text') +async def test_ping(writer) -> None: + await writer.ping() + writer.transport.write.assert_called_with(b"\x89\x00") -def test_send_binary(stream, writer): - writer.send('binary', True) - stream.transport.write.assert_called_with(b'\x82\x06binary') +async def test_send_text(writer) -> None: + await writer.send(b"text") + writer.transport.write.assert_called_with(b"\x81\x04text") -def test_send_binary_long(stream, writer): - writer.send(b'b' * 127, True) - assert stream.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb') +async def test_send_binary(writer) -> None: + await writer.send("binary", True) + writer.transport.write.assert_called_with(b"\x82\x06binary") -def test_send_binary_very_long(stream, writer): - writer.send(b'b' * 65537, True) - assert (stream.transport.write.call_args_list[0][0][0] == - b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01') - assert stream.transport.write.call_args_list[1][0][0] == b'b' * 65537 +async def test_send_binary_long(writer) -> None: + await writer.send(b"b" * 127, True) + assert writer.transport.write.call_args[0][0].startswith(b"\x82~\x00\x7fb") -def test_close(stream, writer): - writer.close(1001, 'msg') - stream.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') +async def test_send_binary_very_long(writer) -> None: + await writer.send(b"b" * 65537, True) + assert ( + writer.transport.write.call_args_list[0][0][0] + == b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01" + ) + assert writer.transport.write.call_args_list[1][0][0] == b"b" * 65537 - writer.close(1001, b'msg') - stream.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') - # Test that Service Restart close code is also supported - writer.close(1012, b'msg') - stream.transport.write.assert_called_with(b'\x88\x05\x03\xf4msg') +async def test_close(writer) -> None: + await writer.close(1001, "msg") + writer.transport.write.assert_called_with(b"\x88\x05\x03\xe9msg") + await writer.close(1001, b"msg") + writer.transport.write.assert_called_with(b"\x88\x05\x03\xe9msg") -def test_send_text_masked(stream, writer): - writer = WebSocketWriter(stream, - use_mask=True, - random=random.Random(123)) - writer.send(b'text') - stream.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12') + # Test that Service Restart close code is also supported + await writer.close(1012, b"msg") + writer.transport.write.assert_called_with(b"\x88\x05\x03\xf4msg") + + +async def test_send_text_masked(protocol, transport) -> None: + writer = WebSocketWriter( + protocol, transport, use_mask=True, random=random.Random(123) + ) + await writer.send(b"text") + writer.transport.write.assert_called_with(b"\x81\x84\rg\xb3fy\x02\xcb\x12") + + +async def test_send_compress_text(protocol, transport) -> None: + writer = WebSocketWriter(protocol, transport, compress=15) + await writer.send(b"text") + writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") + await writer.send(b"text") + writer.transport.write.assert_called_with(b"\xc1\x05*\x01b\x00\x00") + + +async def test_send_compress_text_notakeover(protocol, transport) -> None: + writer = WebSocketWriter(protocol, transport, compress=15, notakeover=True) + await writer.send(b"text") + writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") + await writer.send(b"text") + writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") + + +async def test_send_compress_text_per_message(protocol, transport) -> None: + writer = WebSocketWriter(protocol, transport) + await writer.send(b"text", compress=15) + writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") + await writer.send(b"text") + writer.transport.write.assert_called_with(b"\x81\x04text") + await writer.send(b"text", compress=15) + writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") diff --git a/tests/test_worker.py b/tests/test_worker.py index 3b89a933385..64cff82e643 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,16 +1,15 @@ -"""Tests for aiohttp/worker.py""" +# Tests for aiohttp/worker.py import asyncio -import pathlib +import os import socket import ssl from unittest import mock import pytest -from aiohttp import helpers -from aiohttp.test_utils import make_mocked_coro +from aiohttp import web -base_worker = pytest.importorskip('aiohttp.worker') +base_worker = pytest.importorskip("aiohttp.worker") try: @@ -23,41 +22,48 @@ ACCEPTABLE_LOG_FORMAT = '%a "%{Referrer}i" %s' -class BaseTestWorker: +# tokio event loop does not allow to override attributes +def skip_if_no_dict(loop): + if not hasattr(loop, "__dict__"): + pytest.skip("can not override loop attributes") + +class BaseTestWorker: def __init__(self): self.servers = {} self.exit_code = 0 + self._notify_waiter = None self.cfg = mock.Mock() self.cfg.graceful_timeout = 100 - - try: - self.pid = 'pid' - except: - pass + self.pid = "pid" + self.wsgi = web.Application() -class AsyncioWorker(BaseTestWorker, base_worker.GunicornWebWorker): +class AsyncioWorker(BaseTestWorker, base_worker.GunicornWebWorker): # type: ignore pass PARAMS = [AsyncioWorker] if uvloop is not None: - class UvloopWorker(BaseTestWorker, base_worker.GunicornUVLoopWebWorker): + + class UvloopWorker( + BaseTestWorker, base_worker.GunicornUVLoopWebWorker # type: ignore + ): pass PARAMS.append(UvloopWorker) @pytest.fixture(params=PARAMS) -def worker(request): +def worker(request, loop): + asyncio.set_event_loop(loop) ret = request.param() ret.notify = mock.Mock() return ret -def test_init_process(worker): - with mock.patch('aiohttp.worker.asyncio') as m_asyncio: +def test_init_process(worker) -> None: + with mock.patch("aiohttp.worker.asyncio") as m_asyncio: try: worker.init_process() except TypeError: @@ -68,61 +74,81 @@ def test_init_process(worker): assert m_asyncio.set_event_loop.called -def test_run(worker, loop): - worker.wsgi = mock.Mock() +def test_run(worker, loop) -> None: + worker.log = mock.Mock() + worker.cfg = mock.Mock() + worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT + worker.cfg.is_ssl = False + worker.sockets = [] worker.loop = loop - worker._run = mock.Mock( - wraps=asyncio.coroutine(lambda: None)) - worker.wsgi.startup = make_mocked_coro(None) with pytest.raises(SystemExit): worker.run() - assert worker._run.called - worker.wsgi.startup.assert_called_once_with() + worker.log.exception.assert_not_called() assert loop.is_closed() -def test_run_wsgi(worker, loop): - worker.wsgi = lambda env, start_resp: start_resp() +def test_run_async_factory(worker, loop) -> None: + worker.log = mock.Mock() + worker.cfg = mock.Mock() + worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT + worker.cfg.is_ssl = False + worker.sockets = [] + app = worker.wsgi + + async def make_app(): + return app + + worker.wsgi = make_app worker.loop = loop - worker._run = mock.Mock( - wraps=asyncio.coroutine(lambda: None)) + worker.alive = False with pytest.raises(SystemExit): worker.run() - assert worker._run.called + worker.log.exception.assert_not_called() assert loop.is_closed() -def test_handle_quit(worker): - with mock.patch('aiohttp.worker.ensure_future') as m_ensure_future: - worker.loop = mock.Mock() - worker.handle_quit(object(), object()) - assert not worker.alive - assert worker.exit_code == 0 - assert m_ensure_future.called - worker.loop.call_later.asset_called_with( - 0.1, worker._notify_waiter_done) +def test_run_not_app(worker, loop) -> None: + worker.log = mock.Mock() + worker.cfg = mock.Mock() + worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT + + worker.loop = loop + worker.wsgi = "not-app" + worker.alive = False + with pytest.raises(SystemExit): + worker.run() + worker.log.exception.assert_called_with("Exception in gunicorn worker") + assert loop.is_closed() + + +def test_handle_quit(worker, loop) -> None: + worker.loop = mock.Mock() + worker.handle_quit(object(), object()) + assert not worker.alive + assert worker.exit_code == 0 + worker.loop.call_later.asset_called_with(0.1, worker._notify_waiter_done) -def test_handle_abort(worker): - with mock.patch('aiohttp.worker.sys') as m_sys: +def test_handle_abort(worker) -> None: + with mock.patch("aiohttp.worker.sys") as m_sys: worker.handle_abort(object(), object()) assert not worker.alive assert worker.exit_code == 1 m_sys.exit.assert_called_with(1) -def test__wait_next_notify(worker): +def test__wait_next_notify(worker) -> None: worker.loop = mock.Mock() worker._notify_waiter_done = mock.Mock() fut = worker._wait_next_notify() assert worker._notify_waiter == fut - worker.loop.call_later.assert_called_with(1.0, worker._notify_waiter_done) + worker.loop.call_later.assert_called_with(1.0, worker._notify_waiter_done, fut) -def test__notify_waiter_done(worker): +def test__notify_waiter_done(worker) -> None: worker._notify_waiter = None worker._notify_waiter_done() assert worker._notify_waiter is None @@ -135,283 +161,131 @@ def test__notify_waiter_done(worker): waiter.set_result.assert_called_with(True) -def test_init_signals(worker): - worker.loop = mock.Mock() - worker.init_signals() - assert worker.loop.add_signal_handler.called - +def test__notify_waiter_done_explicit_waiter(worker) -> None: + worker._notify_waiter = None + assert worker._notify_waiter is None -def test_make_handler(worker, mocker): - worker.wsgi = mock.Mock() - worker.loop = mock.Mock() - worker.log = mock.Mock() - worker.cfg = mock.Mock() - worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT - mocker.spy(worker, '_get_valid_log_format') + waiter = worker._notify_waiter = mock.Mock() + waiter.done.return_value = False + waiter2 = worker._notify_waiter = mock.Mock() + worker._notify_waiter_done(waiter) - f = worker.make_handler(worker.wsgi) - assert f is worker.wsgi.make_handler.return_value - assert worker._get_valid_log_format.called + assert worker._notify_waiter is waiter2 + waiter.set_result.assert_called_with(True) + assert not waiter2.set_result.called -def test_make_handler_wsgi(worker, mocker): - worker.wsgi = lambda env, start_resp: start_resp() +def test_init_signals(worker) -> None: worker.loop = mock.Mock() - worker.loop.time.return_value = 1477797232 - worker.log = mock.Mock() - worker.cfg = mock.Mock() - worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT - mocker.spy(worker, '_get_valid_log_format') - - with pytest.raises(RuntimeError): - worker.make_handler(worker.wsgi) + worker.init_signals() + assert worker.loop.add_signal_handler.called -@pytest.mark.parametrize('source,result', [ - (ACCEPTABLE_LOG_FORMAT, ACCEPTABLE_LOG_FORMAT), - (AsyncioWorker.DEFAULT_GUNICORN_LOG_FORMAT, - AsyncioWorker.DEFAULT_AIOHTTP_LOG_FORMAT), -]) -def test__get_valid_log_format_ok(worker, source, result): +@pytest.mark.parametrize( + "source,result", + [ + (ACCEPTABLE_LOG_FORMAT, ACCEPTABLE_LOG_FORMAT), + ( + AsyncioWorker.DEFAULT_GUNICORN_LOG_FORMAT, + AsyncioWorker.DEFAULT_AIOHTTP_LOG_FORMAT, + ), + ], +) +def test__get_valid_log_format_ok(worker, source, result) -> None: assert result == worker._get_valid_log_format(source) -def test__get_valid_log_format_exc(worker): +def test__get_valid_log_format_exc(worker) -> None: with pytest.raises(ValueError) as exc: worker._get_valid_log_format(WRONG_LOG_FORMAT) - assert '%(name)s' in str(exc) - - -@asyncio.coroutine -def test__run_ok(worker, loop): - worker.ppid = 1 - worker.alive = True - worker.servers = {} - sock = mock.Mock() - sock.cfg_addr = ('localhost', 8080) - worker.sockets = [sock] - worker.wsgi = mock.Mock() - worker.close = make_mocked_coro(None) - worker.log = mock.Mock() - worker.loop = loop - loop.create_server = make_mocked_coro(sock) - worker.wsgi.make_handler.return_value.requests_count = 1 - worker.cfg.max_requests = 100 - worker.cfg.is_ssl = True - worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT - - ssl_context = mock.Mock() - with mock.patch('ssl.SSLContext', return_value=ssl_context): - with mock.patch('aiohttp.worker.asyncio') as m_asyncio: - m_asyncio.sleep = mock.Mock( - wraps=asyncio.coroutine(lambda *a, **kw: None)) - yield from worker._run() - - worker.notify.assert_called_with() - worker.log.info.assert_called_with("Parent changed, shutting down: %s", - worker) + assert "%(name)s" in str(exc.value) - args, kwargs = loop.create_server.call_args - assert 'ssl' in kwargs - ctx = kwargs['ssl'] - assert ctx is ssl_context +async def test__run_ok_parent_changed(worker, loop, aiohttp_unused_port) -> None: + skip_if_no_dict(loop) -@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), - reason="UNIX sockets are not supported") -@asyncio.coroutine -def test__run_ok_unix_socket(worker, loop): - worker.ppid = 1 + worker.ppid = 0 worker.alive = True - worker.servers = {} - sock = mock.Mock() - sock.cfg_addr = ('/path/to') - sock.family = socket.AF_UNIX + sock = socket.socket() + addr = ("localhost", aiohttp_unused_port()) + sock.bind(addr) worker.sockets = [sock] - worker.wsgi = mock.Mock() - worker.close = make_mocked_coro(None) worker.log = mock.Mock() worker.loop = loop - loop.create_unix_server = make_mocked_coro(sock) - worker.wsgi.make_handler.return_value.requests_count = 1 - worker.cfg.max_requests = 100 - worker.cfg.is_ssl = True worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT + worker.cfg.max_requests = 0 + worker.cfg.is_ssl = False - ssl_context = mock.Mock() - with mock.patch('ssl.SSLContext', return_value=ssl_context): - with mock.patch('aiohttp.worker.asyncio') as m_asyncio: - m_asyncio.sleep = mock.Mock( - wraps=asyncio.coroutine(lambda *a, **kw: None)) - yield from worker._run() + await worker._run() worker.notify.assert_called_with() - worker.log.info.assert_called_with("Parent changed, shutting down: %s", - worker) - - args, kwargs = loop.create_unix_server.call_args - assert 'ssl' in kwargs - ctx = kwargs['ssl'] - assert ctx is ssl_context - - -@asyncio.coroutine -def test__run_exc(worker, loop): - with mock.patch('aiohttp.worker.os') as m_os: - m_os.getpid.return_value = 1 - m_os.getppid.return_value = 1 - - handler = mock.Mock() - handler.requests_count = 0 - worker.servers = {mock.Mock(): handler} - worker._wait_next_notify = mock.Mock() - worker.ppid = 1 - worker.alive = True - worker.sockets = [] - worker.log = mock.Mock() - worker.loop = loop - worker.cfg.is_ssl = False - worker.cfg.max_redirects = 0 - worker.cfg.max_requests = 100 - - with mock.patch('aiohttp.worker.asyncio.sleep') as m_sleep: - slp = helpers.create_future(loop) - slp.set_exception(KeyboardInterrupt) - m_sleep.return_value = slp - - worker.close = make_mocked_coro(None) - - yield from worker._run() - - assert worker._wait_next_notify.called - worker.close.assert_called_with() - - -@asyncio.coroutine -def test_close(worker, loop): - srv = mock.Mock() - srv.wait_closed = make_mocked_coro(None) - handler = mock.Mock() - worker.servers = {srv: handler} - worker.log = mock.Mock() - worker.loop = loop - app = worker.wsgi = mock.Mock() - app.cleanup = make_mocked_coro(None) - handler.connections = [object()] - handler.shutdown.return_value = helpers.create_future(loop) - handler.shutdown.return_value.set_result(1) - - app.shutdown.return_value = helpers.create_future(loop) - app.shutdown.return_value.set_result(None) - - yield from worker.close() - app.shutdown.assert_called_with() - app.cleanup.assert_called_with() - handler.shutdown.assert_called_with(timeout=95.0) - srv.close.assert_called_with() - assert worker.servers is None - - yield from worker.close() - - -@asyncio.coroutine -def test_close_wsgi(worker, loop): - srv = mock.Mock() - srv.wait_closed = make_mocked_coro(None) - handler = mock.Mock() - worker.servers = {srv: handler} - worker.log = mock.Mock() - worker.loop = loop - worker.wsgi = lambda env, start_resp: start_resp() - handler.connections = [object()] - handler.shutdown.return_value = helpers.create_future(loop) - handler.shutdown.return_value.set_result(1) - - yield from worker.close() - handler.shutdown.assert_called_with(timeout=95.0) - srv.close.assert_called_with() - assert worker.servers is None + worker.log.info.assert_called_with("Parent changed, shutting down: %s", worker) - yield from worker.close() +async def test__run_exc(worker, loop, aiohttp_unused_port) -> None: + skip_if_no_dict(loop) -@asyncio.coroutine -def test__run_ok_no_max_requests(worker, loop): - worker.ppid = 1 + worker.ppid = os.getppid() worker.alive = True - worker.servers = {} - sock = mock.Mock() - sock.cfg_addr = ('localhost', 8080) + sock = socket.socket() + addr = ("localhost", aiohttp_unused_port()) + sock.bind(addr) worker.sockets = [sock] - worker.wsgi = mock.Mock() - worker.close = make_mocked_coro(None) worker.log = mock.Mock() worker.loop = loop - loop.create_server = make_mocked_coro(sock) - worker.wsgi.make_handler.return_value.requests_count = 1 worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT worker.cfg.max_requests = 0 - worker.cfg.is_ssl = True + worker.cfg.is_ssl = False - ssl_context = mock.Mock() - with mock.patch('ssl.SSLContext', return_value=ssl_context): - with mock.patch('aiohttp.worker.asyncio') as m_asyncio: - m_asyncio.sleep = mock.Mock( - wraps=asyncio.coroutine(lambda *a, **kw: None)) - yield from worker._run() + def raiser(): + waiter = worker._notify_waiter + worker.alive = False + waiter.set_exception(RuntimeError()) - worker.notify.assert_called_with() - worker.log.info.assert_called_with("Parent changed, shutting down: %s", - worker) - - args, kwargs = loop.create_server.call_args - assert 'ssl' in kwargs - ctx = kwargs['ssl'] - assert ctx is ssl_context + loop.call_later(0.1, raiser) + await worker._run() + worker.notify.assert_called_with() -@asyncio.coroutine -def test__run_ok_max_requests_exceeded(worker, loop): - worker.ppid = 1 - worker.alive = True - worker.servers = {} - sock = mock.Mock() - sock.cfg_addr = ('localhost', 8080) - worker.sockets = [sock] - worker.wsgi = mock.Mock() - worker.close = make_mocked_coro(None) - worker.log = mock.Mock() - worker.loop = loop - loop.create_server = make_mocked_coro(sock) - worker.wsgi.make_handler.return_value.requests_count = 15 - worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT - worker.cfg.max_requests = 10 - worker.cfg.is_ssl = True - ssl_context = mock.Mock() - with mock.patch('ssl.SSLContext', return_value=ssl_context): - with mock.patch('aiohttp.worker.asyncio') as m_asyncio: - m_asyncio.sleep = mock.Mock( - wraps=asyncio.coroutine(lambda *a, **kw: None)) - yield from worker._run() +def test__create_ssl_context_without_certs_and_ciphers( + worker, + tls_certificate_pem_path, +) -> None: + worker.cfg.ssl_version = ssl.PROTOCOL_SSLv23 + worker.cfg.cert_reqs = ssl.CERT_OPTIONAL + worker.cfg.certfile = tls_certificate_pem_path + worker.cfg.keyfile = tls_certificate_pem_path + worker.cfg.ca_certs = None + worker.cfg.ciphers = None + ctx = worker._create_ssl_context(worker.cfg) + assert isinstance(ctx, ssl.SSLContext) - worker.notify.assert_called_with() - worker.log.info.assert_called_with("Max requests, shutting down: %s", - worker) - args, kwargs = loop.create_server.call_args - assert 'ssl' in kwargs - ctx = kwargs['ssl'] - assert ctx is ssl_context +def test__create_ssl_context_with_ciphers( + worker, + tls_certificate_pem_path, +) -> None: + worker.cfg.ssl_version = ssl.PROTOCOL_SSLv23 + worker.cfg.cert_reqs = ssl.CERT_OPTIONAL + worker.cfg.certfile = tls_certificate_pem_path + worker.cfg.keyfile = tls_certificate_pem_path + worker.cfg.ca_certs = None + worker.cfg.ciphers = "3DES PSK" + ctx = worker._create_ssl_context(worker.cfg) + assert isinstance(ctx, ssl.SSLContext) -def test__create_ssl_context_without_certs_and_ciphers(worker): - here = pathlib.Path(__file__).parent +def test__create_ssl_context_with_ca_certs( + worker, + tls_ca_certificate_pem_path, + tls_certificate_pem_path, +) -> None: worker.cfg.ssl_version = ssl.PROTOCOL_SSLv23 worker.cfg.cert_reqs = ssl.CERT_OPTIONAL - worker.cfg.certfile = str(here / 'sample.crt') - worker.cfg.keyfile = str(here / 'sample.key') - worker.cfg.ca_certs = None + worker.cfg.certfile = tls_certificate_pem_path + worker.cfg.keyfile = tls_certificate_pem_path + worker.cfg.ca_certs = tls_ca_certificate_pem_path worker.cfg.ciphers = None - crt = worker._create_ssl_context(worker.cfg) - assert isinstance(crt, ssl.SSLContext) + ctx = worker._create_ssl_context(worker.cfg) + assert isinstance(ctx, ssl.SSLContext) diff --git a/tools/check_changes.py b/tools/check_changes.py new file mode 100755 index 00000000000..4ee3fc1b2de --- /dev/null +++ b/tools/check_changes.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +ALLOWED_SUFFIXES = [".feature", ".bugfix", ".doc", ".removal", ".misc"] + + +def get_root(script_path): + folder = script_path.absolute().parent + while not (folder / ".git").exists(): + folder = folder.parent + if folder == folder.anchor: + raise RuntimeError("git repo not found") + return folder + + +def main(argv): + print('Check "CHANGES" folder... ', end="", flush=True) + here = Path(argv[0]) + root = get_root(here) + changes = root / "CHANGES" + failed = False + for fname in changes.iterdir(): + if fname.name in (".gitignore", ".TEMPLATE.rst"): + continue + if fname.suffix not in ALLOWED_SUFFIXES: + if not failed: + print("") + print(fname, "has illegal suffix", file=sys.stderr) + failed = True + + if failed: + print("", file=sys.stderr) + print("Allowed suffixes are:", ALLOWED_SUFFIXES, file=sys.stderr) + print("", file=sys.stderr) + else: + print("OK") + + return int(failed) + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/tools/check_sum.py b/tools/check_sum.py new file mode 100755 index 00000000000..50dec4d2be5 --- /dev/null +++ b/tools/check_sum.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python + +import argparse +import hashlib +import pathlib +import sys + +PARSER = argparse.ArgumentParser( + description="Helper for check file hashes in Makefile instead of bare timestamps" +) +PARSER.add_argument("dst", metavar="DST", type=pathlib.Path) +PARSER.add_argument("-d", "--debug", action="store_true", default=False) + + +def main(argv): + args = PARSER.parse_args(argv) + dst = args.dst + assert dst.suffix == ".hash" + dirname = dst.parent + if dirname.name != ".hash": + if args.debug: + print(f"Invalid name {dst} -> dirname {dirname}", file=sys.stderr) + return 0 + dirname.mkdir(exist_ok=True) + src_dir = dirname.parent + src_name = dst.stem # drop .hash + full_src = src_dir / src_name + hasher = hashlib.sha256() + try: + hasher.update(full_src.read_bytes()) + except OSError: + if args.debug: + print(f"Cannot open {full_src}", file=sys.stderr) + return 0 + src_hash = hasher.hexdigest() + if dst.exists(): + dst_hash = dst.read_text() + else: + dst_hash = "" + if src_hash != dst_hash: + dst.write_text(src_hash) + print(f"re-hash {src_hash}") + else: + if args.debug: + print(f"Skip {src_hash} checksum, up-to-date") + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/tools/drop_merged_branches.sh b/tools/drop_merged_branches.sh new file mode 100755 index 00000000000..d4f315a8987 --- /dev/null +++ b/tools/drop_merged_branches.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +git remote prune origin diff --git a/tools/gen.py b/tools/gen.py new file mode 100755 index 00000000000..ab2b39a2df0 --- /dev/null +++ b/tools/gen.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python + +import io +import pathlib +from collections import defaultdict + +import multidict + +ROOT = pathlib.Path.cwd() +while ROOT.parent != ROOT and not (ROOT / ".git").exists(): + ROOT = ROOT.parent + + +def calc_headers(root): + hdrs_file = root / "aiohttp/hdrs.py" + code = compile(hdrs_file.read_text(), str(hdrs_file), "exec") + globs = {} + exec(code, globs) + headers = [val for val in globs.values() if isinstance(val, multidict.istr)] + return sorted(headers) + + +headers = calc_headers(ROOT) + + +def factory(): + return defaultdict(factory) + + +TERMINAL = object() + + +def build(headers): + dct = defaultdict(factory) + for hdr in headers: + d = dct + for ch in hdr: + d = d[ch] + d[TERMINAL] = hdr + return dct + + +dct = build(headers) + + +HEADER = """\ +/* The file is autogenerated from aiohttp/hdrs.py +Run ./tools/gen.py to update it after the origin changing. */ + +#include "_find_header.h" + +#define NEXT_CHAR() \\ +{ \\ + count++; \\ + if (count == size) { \\ + /* end of search */ \\ + return -1; \\ + } \\ + pchar++; \\ + ch = *pchar; \\ + last = (count == size -1); \\ +} while(0); + +int +find_header(const char *str, int size) +{ + char *pchar = str; + int last; + char ch; + int count = -1; + pchar--; +""" + +BLOCK = """ +{label} + NEXT_CHAR(); + switch (ch) {{ +{cases} + default: + return -1; + }} +""" + +CASE = """\ + case '{char}': + if (last) {{ + return {index}; + }} + goto {next};""" + +FOOTER = """ +{missing} +missing: + /* nothing found */ + return -1; +}} +""" + + +def gen_prefix(prefix, k): + if k == "-": + return prefix + "_" + else: + return prefix + k.upper() + + +def gen_block(dct, prefix, used_blocks, missing, out): + cases = {} + for k, v in dct.items(): + if k is TERMINAL: + continue + next_prefix = gen_prefix(prefix, k) + term = v.get(TERMINAL) + if term is not None: + index = headers.index(term) + else: + index = -1 + hi = k.upper() + lo = k.lower() + case = CASE.format(char=hi, index=index, next=next_prefix) + cases[hi] = case + if lo != hi: + case = CASE.format(char=lo, index=index, next=next_prefix) + cases[lo] = case + label = prefix + ":" if prefix else "" + if cases: + block = BLOCK.format(label=label, cases="\n".join(cases.values())) + out.write(block) + else: + missing.add(label) + for k, v in dct.items(): + if not isinstance(v, defaultdict): + continue + block_name = gen_prefix(prefix, k) + if block_name in used_blocks: + continue + used_blocks.add(block_name) + gen_block(v, block_name, used_blocks, missing, out) + + +def gen(dct): + out = io.StringIO() + out.write(HEADER) + missing = set() + gen_block(dct, "", set(), missing, out) + missing_labels = "\n".join(m for m in sorted(missing)) + out.write(FOOTER.format(missing=missing_labels)) + return out + + +def gen_headers(headers): + out = io.StringIO() + out.write("# The file is autogenerated from aiohttp/hdrs.py\n") + out.write("# Run ./tools/gen.py to update it after the origin changing.") + out.write("\n\n") + out.write("from . import hdrs\n") + out.write("cdef tuple headers = (\n") + for hdr in headers: + out.write(" hdrs.{},\n".format(hdr.upper().replace("-", "_"))) + out.write(")\n") + return out + + +# print(gen(dct).getvalue()) +# print(gen_headers(headers).getvalue()) + +folder = ROOT / "aiohttp" + +with (folder / "_find_header.c").open("w") as f: + f.write(gen(dct).getvalue()) + +with (folder / "_headers.pxi").open("w") as f: + f.write(gen_headers(headers).getvalue()) diff --git a/tox.ini b/tox.ini index fd0a550e333..5a8e90cf374 100644 --- a/tox.ini +++ b/tox.ini @@ -1,22 +1,25 @@ [tox] -envlist = check, {py34,py35}-{debug,release}-{cchardet,cython,pure}, report +envlist = check, {py35,py36}-{debug,release}-{cython,pure}, report [testenv] deps = pytest + pytest-mock + pytest-xdist # pytest-cov coverage gunicorn - cchardet: cython - cchardet: cchardet + async-generator + brotlipy cython: cython + -e . commands = # --cov={envsitepackagesdir}/tests # py.test --cov={envsitepackagesdir}/aiohttp tests {posargs} - coverage run -m pytest tests {posargs} + coverage run -m pytest {posargs:tests} mv .coverage .coverage.{envname} setenv = @@ -24,8 +27,8 @@ setenv = pure: AIOHTTP_NO_EXTENSIONS = 1 basepython: - py34: python3.4 py35: python3.5 + py36: python3.6 whitelist_externals = coverage @@ -39,14 +42,18 @@ deps = flake8 pyflakes>=1.0.0 coverage + docutils + pygments + isort commands = flake8 aiohttp examples tests - python setup.py check -rm + python setup.py check -rms + # isort --check -rc aiohttp tests examples coverage erase basepython: - python3.5 + python3.6 [testenv:report] @@ -54,7 +61,7 @@ commands = coverage combine coverage report coverage html - echo "open file://{toxinidir}/coverage/index.html" + echo "open file://{toxinidir}/htmlcov/index.html" basepython: - python3.5 + python3.6 diff --git a/vendor/http-parser b/vendor/http-parser new file mode 160000 index 00000000000..2343fd6b521 --- /dev/null +++ b/vendor/http-parser @@ -0,0 +1 @@ +Subproject commit 2343fd6b5214b2ded2cdcf76de2bf60903bb90cd diff --git a/vendor/http-parser/.gitignore b/vendor/http-parser/.gitignore deleted file mode 100644 index c122e76fb91..00000000000 --- a/vendor/http-parser/.gitignore +++ /dev/null @@ -1,30 +0,0 @@ -/out/ -core -tags -*.o -test -test_g -test_fast -bench -url_parser -parsertrace -parsertrace_g -*.mk -*.Makefile -*.so.* -*.exe.* -*.exe -*.a - - -# Visual Studio uglies -*.suo -*.sln -*.vcxproj -*.vcxproj.filters -*.vcxproj.user -*.opensdf -*.ncrunchsolution* -*.sdf -*.vsp -*.psess diff --git a/vendor/http-parser/.mailmap b/vendor/http-parser/.mailmap deleted file mode 100644 index 278d1412637..00000000000 --- a/vendor/http-parser/.mailmap +++ /dev/null @@ -1,8 +0,0 @@ -# update AUTHORS with: -# git log --all --reverse --format='%aN <%aE>' | perl -ne 'BEGIN{print "# Authors ordered by first contribution.\n"} print unless $h{$_}; $h{$_} = 1' > AUTHORS -Ryan Dahl -Salman Haq -Simon Zimmermann -Thomas LE ROUX LE ROUX Thomas -Thomas LE ROUX Thomas LE ROUX -Fedor Indutny diff --git a/vendor/http-parser/.travis.yml b/vendor/http-parser/.travis.yml deleted file mode 100644 index 4b038e6e62d..00000000000 --- a/vendor/http-parser/.travis.yml +++ /dev/null @@ -1,13 +0,0 @@ -language: c - -compiler: - - clang - - gcc - -script: - - "make" - -notifications: - email: false - irc: - - "irc.freenode.net#node-ci" diff --git a/vendor/http-parser/AUTHORS b/vendor/http-parser/AUTHORS deleted file mode 100644 index 5323b685cae..00000000000 --- a/vendor/http-parser/AUTHORS +++ /dev/null @@ -1,68 +0,0 @@ -# Authors ordered by first contribution. -Ryan Dahl -Jeremy Hinegardner -Sergey Shepelev -Joe Damato -tomika -Phoenix Sol -Cliff Frey -Ewen Cheslack-Postava -Santiago Gala -Tim Becker -Jeff Terrace -Ben Noordhuis -Nathan Rajlich -Mark Nottingham -Aman Gupta -Tim Becker -Sean Cunningham -Peter Griess -Salman Haq -Cliff Frey -Jon Kolb -Fouad Mardini -Paul Querna -Felix Geisendörfer -koichik -Andre Caron -Ivo Raisr -James McLaughlin -David Gwynne -Thomas LE ROUX -Randy Rizun -Andre Louis Caron -Simon Zimmermann -Erik Dubbelboer -Martell Malone -Bertrand Paquet -BogDan Vatra -Peter Faiman -Corey Richardson -Tóth Tamás -Cam Swords -Chris Dickinson -Uli Köhler -Charlie Somerville -Patrik Stutz -Fedor Indutny -runner -Alexis Campailla -David Wragg -Vinnie Falco -Alex Butum -Rex Feng -Alex Kocharin -Mark Koopman -Helge Heß -Alexis La Goutte -George Miroshnykov -Maciej Małecki -Marc O'Morain -Jeff Pinner -Timothy J Fontaine -Akagi201 -Romain Giraud -Jay Satiro -Arne Steen -Kjell Schubert -Olivier Mengué diff --git a/vendor/http-parser/LICENSE-MIT b/vendor/http-parser/LICENSE-MIT deleted file mode 100644 index 58010b38894..00000000000 --- a/vendor/http-parser/LICENSE-MIT +++ /dev/null @@ -1,23 +0,0 @@ -http_parser.c is based on src/http/ngx_http_parse.c from NGINX copyright -Igor Sysoev. - -Additional changes are licensed under the same terms as NGINX and -copyright Joyent, Inc. and other Node contributors. All rights reserved. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to -deal in the Software without restriction, including without limitation the -rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -sell copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -IN THE SOFTWARE. diff --git a/vendor/http-parser/Makefile b/vendor/http-parser/Makefile deleted file mode 100644 index b2476dbd4a8..00000000000 --- a/vendor/http-parser/Makefile +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright Joyent, Inc. and other Node contributors. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to -# deal in the Software without restriction, including without limitation the -# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -# sell copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -# IN THE SOFTWARE. - -PLATFORM ?= $(shell sh -c 'uname -s | tr "[A-Z]" "[a-z]"') -HELPER ?= -BINEXT ?= -ifeq (darwin,$(PLATFORM)) -SONAME ?= libhttp_parser.2.7.1.dylib -SOEXT ?= dylib -else ifeq (wine,$(PLATFORM)) -CC = winegcc -BINEXT = .exe.so -HELPER = wine -else -SONAME ?= libhttp_parser.so.2.7.1 -SOEXT ?= so -endif - -CC?=gcc -AR?=ar - -CPPFLAGS ?= -LDFLAGS ?= - -CPPFLAGS += -I. -CPPFLAGS_DEBUG = $(CPPFLAGS) -DHTTP_PARSER_STRICT=1 -CPPFLAGS_DEBUG += $(CPPFLAGS_DEBUG_EXTRA) -CPPFLAGS_FAST = $(CPPFLAGS) -DHTTP_PARSER_STRICT=0 -CPPFLAGS_FAST += $(CPPFLAGS_FAST_EXTRA) -CPPFLAGS_BENCH = $(CPPFLAGS_FAST) - -CFLAGS += -Wall -Wextra -Werror -CFLAGS_DEBUG = $(CFLAGS) -O0 -g $(CFLAGS_DEBUG_EXTRA) -CFLAGS_FAST = $(CFLAGS) -O3 $(CFLAGS_FAST_EXTRA) -CFLAGS_BENCH = $(CFLAGS_FAST) -Wno-unused-parameter -CFLAGS_LIB = $(CFLAGS_FAST) -fPIC - -LDFLAGS_LIB = $(LDFLAGS) -shared - -INSTALL ?= install -PREFIX ?= $(DESTDIR)/usr/local -LIBDIR = $(PREFIX)/lib -INCLUDEDIR = $(PREFIX)/include - -ifneq (darwin,$(PLATFORM)) -# TODO(bnoordhuis) The native SunOS linker expects -h rather than -soname... -LDFLAGS_LIB += -Wl,-soname=$(SONAME) -endif - -test: test_g test_fast - $(HELPER) ./test_g$(BINEXT) - $(HELPER) ./test_fast$(BINEXT) - -test_g: http_parser_g.o test_g.o - $(CC) $(CFLAGS_DEBUG) $(LDFLAGS) http_parser_g.o test_g.o -o $@ - -test_g.o: test.c http_parser.h Makefile - $(CC) $(CPPFLAGS_DEBUG) $(CFLAGS_DEBUG) -c test.c -o $@ - -http_parser_g.o: http_parser.c http_parser.h Makefile - $(CC) $(CPPFLAGS_DEBUG) $(CFLAGS_DEBUG) -c http_parser.c -o $@ - -test_fast: http_parser.o test.o http_parser.h - $(CC) $(CFLAGS_FAST) $(LDFLAGS) http_parser.o test.o -o $@ - -test.o: test.c http_parser.h Makefile - $(CC) $(CPPFLAGS_FAST) $(CFLAGS_FAST) -c test.c -o $@ - -bench: http_parser.o bench.o - $(CC) $(CFLAGS_BENCH) $(LDFLAGS) http_parser.o bench.o -o $@ - -bench.o: bench.c http_parser.h Makefile - $(CC) $(CPPFLAGS_BENCH) $(CFLAGS_BENCH) -c bench.c -o $@ - -http_parser.o: http_parser.c http_parser.h Makefile - $(CC) $(CPPFLAGS_FAST) $(CFLAGS_FAST) -c http_parser.c - -test-run-timed: test_fast - while(true) do time $(HELPER) ./test_fast$(BINEXT) > /dev/null; done - -test-valgrind: test_g - valgrind ./test_g - -libhttp_parser.o: http_parser.c http_parser.h Makefile - $(CC) $(CPPFLAGS_FAST) $(CFLAGS_LIB) -c http_parser.c -o libhttp_parser.o - -library: libhttp_parser.o - $(CC) $(LDFLAGS_LIB) -o $(SONAME) $< - -package: http_parser.o - $(AR) rcs libhttp_parser.a http_parser.o - -url_parser: http_parser.o contrib/url_parser.c - $(CC) $(CPPFLAGS_FAST) $(CFLAGS_FAST) $^ -o $@ - -url_parser_g: http_parser_g.o contrib/url_parser.c - $(CC) $(CPPFLAGS_DEBUG) $(CFLAGS_DEBUG) $^ -o $@ - -parsertrace: http_parser.o contrib/parsertrace.c - $(CC) $(CPPFLAGS_FAST) $(CFLAGS_FAST) $^ -o parsertrace$(BINEXT) - -parsertrace_g: http_parser_g.o contrib/parsertrace.c - $(CC) $(CPPFLAGS_DEBUG) $(CFLAGS_DEBUG) $^ -o parsertrace_g$(BINEXT) - -tags: http_parser.c http_parser.h test.c - ctags $^ - -install: library - $(INSTALL) -D http_parser.h $(INCLUDEDIR)/http_parser.h - $(INSTALL) -D $(SONAME) $(LIBDIR)/$(SONAME) - ln -s $(LIBDIR)/$(SONAME) $(LIBDIR)/libhttp_parser.$(SOEXT) - -install-strip: library - $(INSTALL) -D http_parser.h $(INCLUDEDIR)/http_parser.h - $(INSTALL) -D -s $(SONAME) $(LIBDIR)/$(SONAME) - ln -s $(LIBDIR)/$(SONAME) $(LIBDIR)/libhttp_parser.$(SOEXT) - -uninstall: - rm $(INCLUDEDIR)/http_parser.h - rm $(LIBDIR)/$(SONAME) - rm $(LIBDIR)/libhttp_parser.so - -clean: - rm -f *.o *.a tags test test_fast test_g \ - http_parser.tar libhttp_parser.so.* \ - url_parser url_parser_g parsertrace parsertrace_g \ - *.exe *.exe.so - -contrib/url_parser.c: http_parser.h -contrib/parsertrace.c: http_parser.h - -.PHONY: clean package test-run test-run-timed test-valgrind install install-strip uninstall diff --git a/vendor/http-parser/README.md b/vendor/http-parser/README.md deleted file mode 100644 index 439b30998d4..00000000000 --- a/vendor/http-parser/README.md +++ /dev/null @@ -1,246 +0,0 @@ -HTTP Parser -=========== - -[![Build Status](https://api.travis-ci.org/nodejs/http-parser.svg?branch=master)](https://travis-ci.org/nodejs/http-parser) - -This is a parser for HTTP messages written in C. It parses both requests and -responses. The parser is designed to be used in performance HTTP -applications. It does not make any syscalls nor allocations, it does not -buffer data, it can be interrupted at anytime. Depending on your -architecture, it only requires about 40 bytes of data per message -stream (in a web server that is per connection). - -Features: - - * No dependencies - * Handles persistent streams (keep-alive). - * Decodes chunked encoding. - * Upgrade support - * Defends against buffer overflow attacks. - -The parser extracts the following information from HTTP messages: - - * Header fields and values - * Content-Length - * Request method - * Response status code - * Transfer-Encoding - * HTTP version - * Request URL - * Message body - - -Usage ------ - -One `http_parser` object is used per TCP connection. Initialize the struct -using `http_parser_init()` and set the callbacks. That might look something -like this for a request parser: -```c -http_parser_settings settings; -settings.on_url = my_url_callback; -settings.on_header_field = my_header_field_callback; -/* ... */ - -http_parser *parser = malloc(sizeof(http_parser)); -http_parser_init(parser, HTTP_REQUEST); -parser->data = my_socket; -``` - -When data is received on the socket execute the parser and check for errors. - -```c -size_t len = 80*1024, nparsed; -char buf[len]; -ssize_t recved; - -recved = recv(fd, buf, len, 0); - -if (recved < 0) { - /* Handle error. */ -} - -/* Start up / continue the parser. - * Note we pass recved==0 to signal that EOF has been received. - */ -nparsed = http_parser_execute(parser, &settings, buf, recved); - -if (parser->upgrade) { - /* handle new protocol */ -} else if (nparsed != recved) { - /* Handle error. Usually just close the connection. */ -} -``` - -HTTP needs to know where the end of the stream is. For example, sometimes -servers send responses without Content-Length and expect the client to -consume input (for the body) until EOF. To tell http_parser about EOF, give -`0` as the fourth parameter to `http_parser_execute()`. Callbacks and errors -can still be encountered during an EOF, so one must still be prepared -to receive them. - -Scalar valued message information such as `status_code`, `method`, and the -HTTP version are stored in the parser structure. This data is only -temporally stored in `http_parser` and gets reset on each new message. If -this information is needed later, copy it out of the structure during the -`headers_complete` callback. - -The parser decodes the transfer-encoding for both requests and responses -transparently. That is, a chunked encoding is decoded before being sent to -the on_body callback. - - -The Special Problem of Upgrade ------------------------------- - -HTTP supports upgrading the connection to a different protocol. An -increasingly common example of this is the WebSocket protocol which sends -a request like - - GET /demo HTTP/1.1 - Upgrade: WebSocket - Connection: Upgrade - Host: example.com - Origin: http://example.com - WebSocket-Protocol: sample - -followed by non-HTTP data. - -(See [RFC6455](https://tools.ietf.org/html/rfc6455) for more information the -WebSocket protocol.) - -To support this, the parser will treat this as a normal HTTP message without a -body, issuing both on_headers_complete and on_message_complete callbacks. However -http_parser_execute() will stop parsing at the end of the headers and return. - -The user is expected to check if `parser->upgrade` has been set to 1 after -`http_parser_execute()` returns. Non-HTTP data begins at the buffer supplied -offset by the return value of `http_parser_execute()`. - - -Callbacks ---------- - -During the `http_parser_execute()` call, the callbacks set in -`http_parser_settings` will be executed. The parser maintains state and -never looks behind, so buffering the data is not necessary. If you need to -save certain data for later usage, you can do that from the callbacks. - -There are two types of callbacks: - -* notification `typedef int (*http_cb) (http_parser*);` - Callbacks: on_message_begin, on_headers_complete, on_message_complete. -* data `typedef int (*http_data_cb) (http_parser*, const char *at, size_t length);` - Callbacks: (requests only) on_url, - (common) on_header_field, on_header_value, on_body; - -Callbacks must return 0 on success. Returning a non-zero value indicates -error to the parser, making it exit immediately. - -For cases where it is necessary to pass local information to/from a callback, -the `http_parser` object's `data` field can be used. -An example of such a case is when using threads to handle a socket connection, -parse a request, and then give a response over that socket. By instantiation -of a thread-local struct containing relevant data (e.g. accepted socket, -allocated memory for callbacks to write into, etc), a parser's callbacks are -able to communicate data between the scope of the thread and the scope of the -callback in a threadsafe manner. This allows http-parser to be used in -multi-threaded contexts. - -Example: -```c - typedef struct { - socket_t sock; - void* buffer; - int buf_len; - } custom_data_t; - - -int my_url_callback(http_parser* parser, const char *at, size_t length) { - /* access to thread local custom_data_t struct. - Use this access save parsed data for later use into thread local - buffer, or communicate over socket - */ - parser->data; - ... - return 0; -} - -... - -void http_parser_thread(socket_t sock) { - int nparsed = 0; - /* allocate memory for user data */ - custom_data_t *my_data = malloc(sizeof(custom_data_t)); - - /* some information for use by callbacks. - * achieves thread -> callback information flow */ - my_data->sock = sock; - - /* instantiate a thread-local parser */ - http_parser *parser = malloc(sizeof(http_parser)); - http_parser_init(parser, HTTP_REQUEST); /* initialise parser */ - /* this custom data reference is accessible through the reference to the - parser supplied to callback functions */ - parser->data = my_data; - - http_parser_settings settings; /* set up callbacks */ - settings.on_url = my_url_callback; - - /* execute parser */ - nparsed = http_parser_execute(parser, &settings, buf, recved); - - ... - /* parsed information copied from callback. - can now perform action on data copied into thread-local memory from callbacks. - achieves callback -> thread information flow */ - my_data->buffer; - ... -} - -``` - -In case you parse HTTP message in chunks (i.e. `read()` request line -from socket, parse, read half headers, parse, etc) your data callbacks -may be called more than once. Http-parser guarantees that data pointer is only -valid for the lifetime of callback. You can also `read()` into a heap allocated -buffer to avoid copying memory around if this fits your application. - -Reading headers may be a tricky task if you read/parse headers partially. -Basically, you need to remember whether last header callback was field or value -and apply the following logic: - - (on_header_field and on_header_value shortened to on_h_*) - ------------------------ ------------ -------------------------------------------- - | State (prev. callback) | Callback | Description/action | - ------------------------ ------------ -------------------------------------------- - | nothing (first call) | on_h_field | Allocate new buffer and copy callback data | - | | | into it | - ------------------------ ------------ -------------------------------------------- - | value | on_h_field | New header started. | - | | | Copy current name,value buffers to headers | - | | | list and allocate new buffer for new name | - ------------------------ ------------ -------------------------------------------- - | field | on_h_field | Previous name continues. Reallocate name | - | | | buffer and append callback data to it | - ------------------------ ------------ -------------------------------------------- - | field | on_h_value | Value for current header started. Allocate | - | | | new buffer and copy callback data to it | - ------------------------ ------------ -------------------------------------------- - | value | on_h_value | Value continues. Reallocate value buffer | - | | | and append callback data to it | - ------------------------ ------------ -------------------------------------------- - - -Parsing URLs ------------- - -A simplistic zero-copy URL parser is provided as `http_parser_parse_url()`. -Users of this library may wish to use it to parse URLs constructed from -consecutive `on_url` callbacks. - -See examples of reading in headers: - -* [partial example](http://gist.github.com/155877) in C -* [from http-parser tests](http://github.com/joyent/http-parser/blob/37a0ff8/test.c#L403) in C -* [from Node library](http://github.com/joyent/node/blob/842eaf4/src/http.js#L284) in Javascript diff --git a/vendor/http-parser/bench.c b/vendor/http-parser/bench.c deleted file mode 100644 index 5b452fa1cdb..00000000000 --- a/vendor/http-parser/bench.c +++ /dev/null @@ -1,111 +0,0 @@ -/* Copyright Fedor Indutny. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - */ -#include "http_parser.h" -#include -#include -#include -#include - -static const char data[] = - "POST /joyent/http-parser HTTP/1.1\r\n" - "Host: github.com\r\n" - "DNT: 1\r\n" - "Accept-Encoding: gzip, deflate, sdch\r\n" - "Accept-Language: ru-RU,ru;q=0.8,en-US;q=0.6,en;q=0.4\r\n" - "User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) " - "AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/39.0.2171.65 Safari/537.36\r\n" - "Accept: text/html,application/xhtml+xml,application/xml;q=0.9," - "image/webp,*/*;q=0.8\r\n" - "Referer: https://github.com/joyent/http-parser\r\n" - "Connection: keep-alive\r\n" - "Transfer-Encoding: chunked\r\n" - "Cache-Control: max-age=0\r\n\r\nb\r\nhello world\r\n0\r\n\r\n"; -static const size_t data_len = sizeof(data) - 1; - -static int on_info(http_parser* p) { - return 0; -} - - -static int on_data(http_parser* p, const char *at, size_t length) { - return 0; -} - -static http_parser_settings settings = { - .on_message_begin = on_info, - .on_headers_complete = on_info, - .on_message_complete = on_info, - .on_header_field = on_data, - .on_header_value = on_data, - .on_url = on_data, - .on_status = on_data, - .on_body = on_data -}; - -int bench(int iter_count, int silent) { - struct http_parser parser; - int i; - int err; - struct timeval start; - struct timeval end; - float rps; - - if (!silent) { - err = gettimeofday(&start, NULL); - assert(err == 0); - } - - for (i = 0; i < iter_count; i++) { - size_t parsed; - http_parser_init(&parser, HTTP_REQUEST); - - parsed = http_parser_execute(&parser, &settings, data, data_len); - assert(parsed == data_len); - } - - if (!silent) { - err = gettimeofday(&end, NULL); - assert(err == 0); - - fprintf(stdout, "Benchmark result:\n"); - - rps = (float) (end.tv_sec - start.tv_sec) + - (end.tv_usec - start.tv_usec) * 1e-6f; - fprintf(stdout, "Took %f seconds to run\n", rps); - - rps = (float) iter_count / rps; - fprintf(stdout, "%f req/sec\n", rps); - fflush(stdout); - } - - return 0; -} - -int main(int argc, char** argv) { - if (argc == 2 && strcmp(argv[1], "infinite") == 0) { - for (;;) - bench(5000000, 1); - return 0; - } else { - return bench(5000000, 0); - } -} diff --git a/vendor/http-parser/contrib/parsertrace.c b/vendor/http-parser/contrib/parsertrace.c deleted file mode 100644 index e7153680f46..00000000000 --- a/vendor/http-parser/contrib/parsertrace.c +++ /dev/null @@ -1,160 +0,0 @@ -/* Based on src/http/ngx_http_parse.c from NGINX copyright Igor Sysoev - * - * Additional changes are licensed under the same terms as NGINX and - * copyright Joyent, Inc. and other Node contributors. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - */ - -/* Dump what the parser finds to stdout as it happen */ - -#include "http_parser.h" -#include -#include -#include - -int on_message_begin(http_parser* _) { - (void)_; - printf("\n***MESSAGE BEGIN***\n\n"); - return 0; -} - -int on_headers_complete(http_parser* _) { - (void)_; - printf("\n***HEADERS COMPLETE***\n\n"); - return 0; -} - -int on_message_complete(http_parser* _) { - (void)_; - printf("\n***MESSAGE COMPLETE***\n\n"); - return 0; -} - -int on_url(http_parser* _, const char* at, size_t length) { - (void)_; - printf("Url: %.*s\n", (int)length, at); - return 0; -} - -int on_header_field(http_parser* _, const char* at, size_t length) { - (void)_; - printf("Header field: %.*s\n", (int)length, at); - return 0; -} - -int on_header_value(http_parser* _, const char* at, size_t length) { - (void)_; - printf("Header value: %.*s\n", (int)length, at); - return 0; -} - -int on_body(http_parser* _, const char* at, size_t length) { - (void)_; - printf("Body: %.*s\n", (int)length, at); - return 0; -} - -void usage(const char* name) { - fprintf(stderr, - "Usage: %s $type $filename\n" - " type: -x, where x is one of {r,b,q}\n" - " parses file as a Response, reQuest, or Both\n", - name); - exit(EXIT_FAILURE); -} - -int main(int argc, char* argv[]) { - enum http_parser_type file_type; - - if (argc != 3) { - usage(argv[0]); - } - - char* type = argv[1]; - if (type[0] != '-') { - usage(argv[0]); - } - - switch (type[1]) { - /* in the case of "-", type[1] will be NUL */ - case 'r': - file_type = HTTP_RESPONSE; - break; - case 'q': - file_type = HTTP_REQUEST; - break; - case 'b': - file_type = HTTP_BOTH; - break; - default: - usage(argv[0]); - } - - char* filename = argv[2]; - FILE* file = fopen(filename, "r"); - if (file == NULL) { - perror("fopen"); - goto fail; - } - - fseek(file, 0, SEEK_END); - long file_length = ftell(file); - if (file_length == -1) { - perror("ftell"); - goto fail; - } - fseek(file, 0, SEEK_SET); - - char* data = malloc(file_length); - if (fread(data, 1, file_length, file) != (size_t)file_length) { - fprintf(stderr, "couldn't read entire file\n"); - free(data); - goto fail; - } - - http_parser_settings settings; - memset(&settings, 0, sizeof(settings)); - settings.on_message_begin = on_message_begin; - settings.on_url = on_url; - settings.on_header_field = on_header_field; - settings.on_header_value = on_header_value; - settings.on_headers_complete = on_headers_complete; - settings.on_body = on_body; - settings.on_message_complete = on_message_complete; - - http_parser parser; - http_parser_init(&parser, file_type); - size_t nparsed = http_parser_execute(&parser, &settings, data, file_length); - free(data); - - if (nparsed != (size_t)file_length) { - fprintf(stderr, - "Error: %s (%s)\n", - http_errno_description(HTTP_PARSER_ERRNO(&parser)), - http_errno_name(HTTP_PARSER_ERRNO(&parser))); - goto fail; - } - - return EXIT_SUCCESS; - -fail: - fclose(file); - return EXIT_FAILURE; -} diff --git a/vendor/http-parser/contrib/url_parser.c b/vendor/http-parser/contrib/url_parser.c deleted file mode 100644 index f235bed9e48..00000000000 --- a/vendor/http-parser/contrib/url_parser.c +++ /dev/null @@ -1,47 +0,0 @@ -#include "http_parser.h" -#include -#include - -void -dump_url (const char *url, const struct http_parser_url *u) -{ - unsigned int i; - - printf("\tfield_set: 0x%x, port: %u\n", u->field_set, u->port); - for (i = 0; i < UF_MAX; i++) { - if ((u->field_set & (1 << i)) == 0) { - printf("\tfield_data[%u]: unset\n", i); - continue; - } - - printf("\tfield_data[%u]: off: %u, len: %u, part: %.*s\n", - i, - u->field_data[i].off, - u->field_data[i].len, - u->field_data[i].len, - url + u->field_data[i].off); - } -} - -int main(int argc, char ** argv) { - struct http_parser_url u; - int len, connect, result; - - if (argc != 3) { - printf("Syntax : %s connect|get url\n", argv[0]); - return 1; - } - len = strlen(argv[2]); - connect = strcmp("connect", argv[1]) == 0 ? 1 : 0; - printf("Parsing %s, connect %d\n", argv[2], connect); - - http_parser_url_init(&u); - result = http_parser_parse_url(argv[2], len, connect, &u); - if (result != 0) { - printf("Parse error : %d\n", result); - return result; - } - printf("Parse ok, result : \n"); - dump_url(argv[2], &u); - return 0; -} diff --git a/vendor/http-parser/http_parser.c b/vendor/http-parser/http_parser.c deleted file mode 100644 index 895bf0c7376..00000000000 --- a/vendor/http-parser/http_parser.c +++ /dev/null @@ -1,2470 +0,0 @@ -/* Based on src/http/ngx_http_parse.c from NGINX copyright Igor Sysoev - * - * Additional changes are licensed under the same terms as NGINX and - * copyright Joyent, Inc. and other Node contributors. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - */ -#include "http_parser.h" -#include -#include -#include -#include -#include -#include - -#ifndef ULLONG_MAX -# define ULLONG_MAX ((uint64_t) -1) /* 2^64-1 */ -#endif - -#ifndef MIN -# define MIN(a,b) ((a) < (b) ? (a) : (b)) -#endif - -#ifndef ARRAY_SIZE -# define ARRAY_SIZE(a) (sizeof(a) / sizeof((a)[0])) -#endif - -#ifndef BIT_AT -# define BIT_AT(a, i) \ - (!!((unsigned int) (a)[(unsigned int) (i) >> 3] & \ - (1 << ((unsigned int) (i) & 7)))) -#endif - -#ifndef ELEM_AT -# define ELEM_AT(a, i, v) ((unsigned int) (i) < ARRAY_SIZE(a) ? (a)[(i)] : (v)) -#endif - -#define SET_ERRNO(e) \ -do { \ - parser->http_errno = (e); \ -} while(0) - -#define CURRENT_STATE() p_state -#define UPDATE_STATE(V) p_state = (enum state) (V); -#define RETURN(V) \ -do { \ - parser->state = CURRENT_STATE(); \ - return (V); \ -} while (0); -#define REEXECUTE() \ - goto reexecute; \ - - -#ifdef __GNUC__ -# define LIKELY(X) __builtin_expect(!!(X), 1) -# define UNLIKELY(X) __builtin_expect(!!(X), 0) -#else -# define LIKELY(X) (X) -# define UNLIKELY(X) (X) -#endif - - -/* Run the notify callback FOR, returning ER if it fails */ -#define CALLBACK_NOTIFY_(FOR, ER) \ -do { \ - assert(HTTP_PARSER_ERRNO(parser) == HPE_OK); \ - \ - if (LIKELY(settings->on_##FOR)) { \ - parser->state = CURRENT_STATE(); \ - if (UNLIKELY(0 != settings->on_##FOR(parser))) { \ - SET_ERRNO(HPE_CB_##FOR); \ - } \ - UPDATE_STATE(parser->state); \ - \ - /* We either errored above or got paused; get out */ \ - if (UNLIKELY(HTTP_PARSER_ERRNO(parser) != HPE_OK)) { \ - return (ER); \ - } \ - } \ -} while (0) - -/* Run the notify callback FOR and consume the current byte */ -#define CALLBACK_NOTIFY(FOR) CALLBACK_NOTIFY_(FOR, p - data + 1) - -/* Run the notify callback FOR and don't consume the current byte */ -#define CALLBACK_NOTIFY_NOADVANCE(FOR) CALLBACK_NOTIFY_(FOR, p - data) - -/* Run data callback FOR with LEN bytes, returning ER if it fails */ -#define CALLBACK_DATA_(FOR, LEN, ER) \ -do { \ - assert(HTTP_PARSER_ERRNO(parser) == HPE_OK); \ - \ - if (FOR##_mark) { \ - if (LIKELY(settings->on_##FOR)) { \ - parser->state = CURRENT_STATE(); \ - if (UNLIKELY(0 != \ - settings->on_##FOR(parser, FOR##_mark, (LEN)))) { \ - SET_ERRNO(HPE_CB_##FOR); \ - } \ - UPDATE_STATE(parser->state); \ - \ - /* We either errored above or got paused; get out */ \ - if (UNLIKELY(HTTP_PARSER_ERRNO(parser) != HPE_OK)) { \ - return (ER); \ - } \ - } \ - FOR##_mark = NULL; \ - } \ -} while (0) - -/* Run the data callback FOR and consume the current byte */ -#define CALLBACK_DATA(FOR) \ - CALLBACK_DATA_(FOR, p - FOR##_mark, p - data + 1) - -/* Run the data callback FOR and don't consume the current byte */ -#define CALLBACK_DATA_NOADVANCE(FOR) \ - CALLBACK_DATA_(FOR, p - FOR##_mark, p - data) - -/* Set the mark FOR; non-destructive if mark is already set */ -#define MARK(FOR) \ -do { \ - if (!FOR##_mark) { \ - FOR##_mark = p; \ - } \ -} while (0) - -/* Don't allow the total size of the HTTP headers (including the status - * line) to exceed HTTP_MAX_HEADER_SIZE. This check is here to protect - * embedders against denial-of-service attacks where the attacker feeds - * us a never-ending header that the embedder keeps buffering. - * - * This check is arguably the responsibility of embedders but we're doing - * it on the embedder's behalf because most won't bother and this way we - * make the web a little safer. HTTP_MAX_HEADER_SIZE is still far bigger - * than any reasonable request or response so this should never affect - * day-to-day operation. - */ -#define COUNT_HEADER_SIZE(V) \ -do { \ - parser->nread += (V); \ - if (UNLIKELY(parser->nread > (HTTP_MAX_HEADER_SIZE))) { \ - SET_ERRNO(HPE_HEADER_OVERFLOW); \ - goto error; \ - } \ -} while (0) - - -#define PROXY_CONNECTION "proxy-connection" -#define CONNECTION "connection" -#define CONTENT_LENGTH "content-length" -#define TRANSFER_ENCODING "transfer-encoding" -#define UPGRADE "upgrade" -#define CHUNKED "chunked" -#define KEEP_ALIVE "keep-alive" -#define CLOSE "close" - - -static const char *method_strings[] = - { -#define XX(num, name, string) #string, - HTTP_METHOD_MAP(XX) -#undef XX - }; - - -/* Tokens as defined by rfc 2616. Also lowercases them. - * token = 1* - * separators = "(" | ")" | "<" | ">" | "@" - * | "," | ";" | ":" | "\" | <"> - * | "/" | "[" | "]" | "?" | "=" - * | "{" | "}" | SP | HT - */ -static const char tokens[256] = { -/* 0 nul 1 soh 2 stx 3 etx 4 eot 5 enq 6 ack 7 bel */ - 0, 0, 0, 0, 0, 0, 0, 0, -/* 8 bs 9 ht 10 nl 11 vt 12 np 13 cr 14 so 15 si */ - 0, 0, 0, 0, 0, 0, 0, 0, -/* 16 dle 17 dc1 18 dc2 19 dc3 20 dc4 21 nak 22 syn 23 etb */ - 0, 0, 0, 0, 0, 0, 0, 0, -/* 24 can 25 em 26 sub 27 esc 28 fs 29 gs 30 rs 31 us */ - 0, 0, 0, 0, 0, 0, 0, 0, -/* 32 sp 33 ! 34 " 35 # 36 $ 37 % 38 & 39 ' */ - 0, '!', 0, '#', '$', '%', '&', '\'', -/* 40 ( 41 ) 42 * 43 + 44 , 45 - 46 . 47 / */ - 0, 0, '*', '+', 0, '-', '.', 0, -/* 48 0 49 1 50 2 51 3 52 4 53 5 54 6 55 7 */ - '0', '1', '2', '3', '4', '5', '6', '7', -/* 56 8 57 9 58 : 59 ; 60 < 61 = 62 > 63 ? */ - '8', '9', 0, 0, 0, 0, 0, 0, -/* 64 @ 65 A 66 B 67 C 68 D 69 E 70 F 71 G */ - 0, 'a', 'b', 'c', 'd', 'e', 'f', 'g', -/* 72 H 73 I 74 J 75 K 76 L 77 M 78 N 79 O */ - 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', -/* 80 P 81 Q 82 R 83 S 84 T 85 U 86 V 87 W */ - 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', -/* 88 X 89 Y 90 Z 91 [ 92 \ 93 ] 94 ^ 95 _ */ - 'x', 'y', 'z', 0, 0, 0, '^', '_', -/* 96 ` 97 a 98 b 99 c 100 d 101 e 102 f 103 g */ - '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', -/* 104 h 105 i 106 j 107 k 108 l 109 m 110 n 111 o */ - 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', -/* 112 p 113 q 114 r 115 s 116 t 117 u 118 v 119 w */ - 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', -/* 120 x 121 y 122 z 123 { 124 | 125 } 126 ~ 127 del */ - 'x', 'y', 'z', 0, '|', 0, '~', 0 }; - - -static const int8_t unhex[256] = - {-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 - ,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 - ,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 - , 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,-1,-1,-1,-1,-1,-1 - ,-1,10,11,12,13,14,15,-1,-1,-1,-1,-1,-1,-1,-1,-1 - ,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 - ,-1,10,11,12,13,14,15,-1,-1,-1,-1,-1,-1,-1,-1,-1 - ,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 - }; - - -#if HTTP_PARSER_STRICT -# define T(v) 0 -#else -# define T(v) v -#endif - - -static const uint8_t normal_url_char[32] = { -/* 0 nul 1 soh 2 stx 3 etx 4 eot 5 enq 6 ack 7 bel */ - 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0, -/* 8 bs 9 ht 10 nl 11 vt 12 np 13 cr 14 so 15 si */ - 0 | T(2) | 0 | 0 | T(16) | 0 | 0 | 0, -/* 16 dle 17 dc1 18 dc2 19 dc3 20 dc4 21 nak 22 syn 23 etb */ - 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0, -/* 24 can 25 em 26 sub 27 esc 28 fs 29 gs 30 rs 31 us */ - 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0, -/* 32 sp 33 ! 34 " 35 # 36 $ 37 % 38 & 39 ' */ - 0 | 2 | 4 | 0 | 16 | 32 | 64 | 128, -/* 40 ( 41 ) 42 * 43 + 44 , 45 - 46 . 47 / */ - 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, -/* 48 0 49 1 50 2 51 3 52 4 53 5 54 6 55 7 */ - 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, -/* 56 8 57 9 58 : 59 ; 60 < 61 = 62 > 63 ? */ - 1 | 2 | 4 | 8 | 16 | 32 | 64 | 0, -/* 64 @ 65 A 66 B 67 C 68 D 69 E 70 F 71 G */ - 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, -/* 72 H 73 I 74 J 75 K 76 L 77 M 78 N 79 O */ - 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, -/* 80 P 81 Q 82 R 83 S 84 T 85 U 86 V 87 W */ - 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, -/* 88 X 89 Y 90 Z 91 [ 92 \ 93 ] 94 ^ 95 _ */ - 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, -/* 96 ` 97 a 98 b 99 c 100 d 101 e 102 f 103 g */ - 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, -/* 104 h 105 i 106 j 107 k 108 l 109 m 110 n 111 o */ - 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, -/* 112 p 113 q 114 r 115 s 116 t 117 u 118 v 119 w */ - 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, -/* 120 x 121 y 122 z 123 { 124 | 125 } 126 ~ 127 del */ - 1 | 2 | 4 | 8 | 16 | 32 | 64 | 0, }; - -#undef T - -enum state - { s_dead = 1 /* important that this is > 0 */ - - , s_start_req_or_res - , s_res_or_resp_H - , s_start_res - , s_res_H - , s_res_HT - , s_res_HTT - , s_res_HTTP - , s_res_first_http_major - , s_res_http_major - , s_res_first_http_minor - , s_res_http_minor - , s_res_first_status_code - , s_res_status_code - , s_res_status_start - , s_res_status - , s_res_line_almost_done - - , s_start_req - - , s_req_method - , s_req_spaces_before_url - , s_req_schema - , s_req_schema_slash - , s_req_schema_slash_slash - , s_req_server_start - , s_req_server - , s_req_server_with_at - , s_req_path - , s_req_query_string_start - , s_req_query_string - , s_req_fragment_start - , s_req_fragment - , s_req_http_start - , s_req_http_H - , s_req_http_HT - , s_req_http_HTT - , s_req_http_HTTP - , s_req_first_http_major - , s_req_http_major - , s_req_first_http_minor - , s_req_http_minor - , s_req_line_almost_done - - , s_header_field_start - , s_header_field - , s_header_value_discard_ws - , s_header_value_discard_ws_almost_done - , s_header_value_discard_lws - , s_header_value_start - , s_header_value - , s_header_value_lws - - , s_header_almost_done - - , s_chunk_size_start - , s_chunk_size - , s_chunk_parameters - , s_chunk_size_almost_done - - , s_headers_almost_done - , s_headers_done - - /* Important: 's_headers_done' must be the last 'header' state. All - * states beyond this must be 'body' states. It is used for overflow - * checking. See the PARSING_HEADER() macro. - */ - - , s_chunk_data - , s_chunk_data_almost_done - , s_chunk_data_done - - , s_body_identity - , s_body_identity_eof - - , s_message_done - }; - - -#define PARSING_HEADER(state) (state <= s_headers_done) - - -enum header_states - { h_general = 0 - , h_C - , h_CO - , h_CON - - , h_matching_connection - , h_matching_proxy_connection - , h_matching_content_length - , h_matching_transfer_encoding - , h_matching_upgrade - - , h_connection - , h_content_length - , h_transfer_encoding - , h_upgrade - - , h_matching_transfer_encoding_chunked - , h_matching_connection_token_start - , h_matching_connection_keep_alive - , h_matching_connection_close - , h_matching_connection_upgrade - , h_matching_connection_token - - , h_transfer_encoding_chunked - , h_connection_keep_alive - , h_connection_close - , h_connection_upgrade - }; - -enum http_host_state - { - s_http_host_dead = 1 - , s_http_userinfo_start - , s_http_userinfo - , s_http_host_start - , s_http_host_v6_start - , s_http_host - , s_http_host_v6 - , s_http_host_v6_end - , s_http_host_v6_zone_start - , s_http_host_v6_zone - , s_http_host_port_start - , s_http_host_port -}; - -/* Macros for character classes; depends on strict-mode */ -#define CR '\r' -#define LF '\n' -#define LOWER(c) (unsigned char)(c | 0x20) -#define IS_ALPHA(c) (LOWER(c) >= 'a' && LOWER(c) <= 'z') -#define IS_NUM(c) ((c) >= '0' && (c) <= '9') -#define IS_ALPHANUM(c) (IS_ALPHA(c) || IS_NUM(c)) -#define IS_HEX(c) (IS_NUM(c) || (LOWER(c) >= 'a' && LOWER(c) <= 'f')) -#define IS_MARK(c) ((c) == '-' || (c) == '_' || (c) == '.' || \ - (c) == '!' || (c) == '~' || (c) == '*' || (c) == '\'' || (c) == '(' || \ - (c) == ')') -#define IS_USERINFO_CHAR(c) (IS_ALPHANUM(c) || IS_MARK(c) || (c) == '%' || \ - (c) == ';' || (c) == ':' || (c) == '&' || (c) == '=' || (c) == '+' || \ - (c) == '$' || (c) == ',') - -#define STRICT_TOKEN(c) (tokens[(unsigned char)c]) - -#if HTTP_PARSER_STRICT -#define TOKEN(c) (tokens[(unsigned char)c]) -#define IS_URL_CHAR(c) (BIT_AT(normal_url_char, (unsigned char)c)) -#define IS_HOST_CHAR(c) (IS_ALPHANUM(c) || (c) == '.' || (c) == '-') -#else -#define TOKEN(c) ((c == ' ') ? ' ' : tokens[(unsigned char)c]) -#define IS_URL_CHAR(c) \ - (BIT_AT(normal_url_char, (unsigned char)c) || ((c) & 0x80)) -#define IS_HOST_CHAR(c) \ - (IS_ALPHANUM(c) || (c) == '.' || (c) == '-' || (c) == '_') -#endif - -/** - * Verify that a char is a valid visible (printable) US-ASCII - * character or %x80-FF - **/ -#define IS_HEADER_CHAR(ch) \ - (ch == CR || ch == LF || ch == 9 || ((unsigned char)ch > 31 && ch != 127)) - -#define start_state (parser->type == HTTP_REQUEST ? s_start_req : s_start_res) - - -#if HTTP_PARSER_STRICT -# define STRICT_CHECK(cond) \ -do { \ - if (cond) { \ - SET_ERRNO(HPE_STRICT); \ - goto error; \ - } \ -} while (0) -# define NEW_MESSAGE() (http_should_keep_alive(parser) ? start_state : s_dead) -#else -# define STRICT_CHECK(cond) -# define NEW_MESSAGE() start_state -#endif - - -/* Map errno values to strings for human-readable output */ -#define HTTP_STRERROR_GEN(n, s) { "HPE_" #n, s }, -static struct { - const char *name; - const char *description; -} http_strerror_tab[] = { - HTTP_ERRNO_MAP(HTTP_STRERROR_GEN) -}; -#undef HTTP_STRERROR_GEN - -int http_message_needs_eof(const http_parser *parser); - -/* Our URL parser. - * - * This is designed to be shared by http_parser_execute() for URL validation, - * hence it has a state transition + byte-for-byte interface. In addition, it - * is meant to be embedded in http_parser_parse_url(), which does the dirty - * work of turning state transitions URL components for its API. - * - * This function should only be invoked with non-space characters. It is - * assumed that the caller cares about (and can detect) the transition between - * URL and non-URL states by looking for these. - */ -static enum state -parse_url_char(enum state s, const char ch) -{ - if (ch == ' ' || ch == '\r' || ch == '\n') { - return s_dead; - } - -#if HTTP_PARSER_STRICT - if (ch == '\t' || ch == '\f') { - return s_dead; - } -#endif - - switch (s) { - case s_req_spaces_before_url: - /* Proxied requests are followed by scheme of an absolute URI (alpha). - * All methods except CONNECT are followed by '/' or '*'. - */ - - if (ch == '/' || ch == '*') { - return s_req_path; - } - - if (IS_ALPHA(ch)) { - return s_req_schema; - } - - break; - - case s_req_schema: - if (IS_ALPHA(ch)) { - return s; - } - - if (ch == ':') { - return s_req_schema_slash; - } - - break; - - case s_req_schema_slash: - if (ch == '/') { - return s_req_schema_slash_slash; - } - - break; - - case s_req_schema_slash_slash: - if (ch == '/') { - return s_req_server_start; - } - - break; - - case s_req_server_with_at: - if (ch == '@') { - return s_dead; - } - - /* FALLTHROUGH */ - case s_req_server_start: - case s_req_server: - if (ch == '/') { - return s_req_path; - } - - if (ch == '?') { - return s_req_query_string_start; - } - - if (ch == '@') { - return s_req_server_with_at; - } - - if (IS_USERINFO_CHAR(ch) || ch == '[' || ch == ']') { - return s_req_server; - } - - break; - - case s_req_path: - if (IS_URL_CHAR(ch)) { - return s; - } - - switch (ch) { - case '?': - return s_req_query_string_start; - - case '#': - return s_req_fragment_start; - } - - break; - - case s_req_query_string_start: - case s_req_query_string: - if (IS_URL_CHAR(ch)) { - return s_req_query_string; - } - - switch (ch) { - case '?': - /* allow extra '?' in query string */ - return s_req_query_string; - - case '#': - return s_req_fragment_start; - } - - break; - - case s_req_fragment_start: - if (IS_URL_CHAR(ch)) { - return s_req_fragment; - } - - switch (ch) { - case '?': - return s_req_fragment; - - case '#': - return s; - } - - break; - - case s_req_fragment: - if (IS_URL_CHAR(ch)) { - return s; - } - - switch (ch) { - case '?': - case '#': - return s; - } - - break; - - default: - break; - } - - /* We should never fall out of the switch above unless there's an error */ - return s_dead; -} - -size_t http_parser_execute (http_parser *parser, - const http_parser_settings *settings, - const char *data, - size_t len) -{ - char c, ch; - int8_t unhex_val; - const char *p = data; - const char *header_field_mark = 0; - const char *header_value_mark = 0; - const char *url_mark = 0; - const char *body_mark = 0; - const char *status_mark = 0; - enum state p_state = (enum state) parser->state; - const unsigned int lenient = parser->lenient_http_headers; - - /* We're in an error state. Don't bother doing anything. */ - if (HTTP_PARSER_ERRNO(parser) != HPE_OK) { - return 0; - } - - if (len == 0) { - switch (CURRENT_STATE()) { - case s_body_identity_eof: - /* Use of CALLBACK_NOTIFY() here would erroneously return 1 byte read if - * we got paused. - */ - CALLBACK_NOTIFY_NOADVANCE(message_complete); - return 0; - - case s_dead: - case s_start_req_or_res: - case s_start_res: - case s_start_req: - return 0; - - default: - SET_ERRNO(HPE_INVALID_EOF_STATE); - return 1; - } - } - - - if (CURRENT_STATE() == s_header_field) - header_field_mark = data; - if (CURRENT_STATE() == s_header_value) - header_value_mark = data; - switch (CURRENT_STATE()) { - case s_req_path: - case s_req_schema: - case s_req_schema_slash: - case s_req_schema_slash_slash: - case s_req_server_start: - case s_req_server: - case s_req_server_with_at: - case s_req_query_string_start: - case s_req_query_string: - case s_req_fragment_start: - case s_req_fragment: - url_mark = data; - break; - case s_res_status: - status_mark = data; - break; - default: - break; - } - - for (p=data; p != data + len; p++) { - ch = *p; - - if (PARSING_HEADER(CURRENT_STATE())) - COUNT_HEADER_SIZE(1); - -reexecute: - switch (CURRENT_STATE()) { - - case s_dead: - /* this state is used after a 'Connection: close' message - * the parser will error out if it reads another message - */ - if (LIKELY(ch == CR || ch == LF)) - break; - - SET_ERRNO(HPE_CLOSED_CONNECTION); - goto error; - - case s_start_req_or_res: - { - if (ch == CR || ch == LF) - break; - parser->flags = 0; - parser->content_length = ULLONG_MAX; - - if (ch == 'H') { - UPDATE_STATE(s_res_or_resp_H); - - CALLBACK_NOTIFY(message_begin); - } else { - parser->type = HTTP_REQUEST; - UPDATE_STATE(s_start_req); - REEXECUTE(); - } - - break; - } - - case s_res_or_resp_H: - if (ch == 'T') { - parser->type = HTTP_RESPONSE; - UPDATE_STATE(s_res_HT); - } else { - if (UNLIKELY(ch != 'E')) { - SET_ERRNO(HPE_INVALID_CONSTANT); - goto error; - } - - parser->type = HTTP_REQUEST; - parser->method = HTTP_HEAD; - parser->index = 2; - UPDATE_STATE(s_req_method); - } - break; - - case s_start_res: - { - parser->flags = 0; - parser->content_length = ULLONG_MAX; - - switch (ch) { - case 'H': - UPDATE_STATE(s_res_H); - break; - - case CR: - case LF: - break; - - default: - SET_ERRNO(HPE_INVALID_CONSTANT); - goto error; - } - - CALLBACK_NOTIFY(message_begin); - break; - } - - case s_res_H: - STRICT_CHECK(ch != 'T'); - UPDATE_STATE(s_res_HT); - break; - - case s_res_HT: - STRICT_CHECK(ch != 'T'); - UPDATE_STATE(s_res_HTT); - break; - - case s_res_HTT: - STRICT_CHECK(ch != 'P'); - UPDATE_STATE(s_res_HTTP); - break; - - case s_res_HTTP: - STRICT_CHECK(ch != '/'); - UPDATE_STATE(s_res_first_http_major); - break; - - case s_res_first_http_major: - if (UNLIKELY(ch < '0' || ch > '9')) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - parser->http_major = ch - '0'; - UPDATE_STATE(s_res_http_major); - break; - - /* major HTTP version or dot */ - case s_res_http_major: - { - if (ch == '.') { - UPDATE_STATE(s_res_first_http_minor); - break; - } - - if (!IS_NUM(ch)) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - parser->http_major *= 10; - parser->http_major += ch - '0'; - - if (UNLIKELY(parser->http_major > 999)) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - break; - } - - /* first digit of minor HTTP version */ - case s_res_first_http_minor: - if (UNLIKELY(!IS_NUM(ch))) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - parser->http_minor = ch - '0'; - UPDATE_STATE(s_res_http_minor); - break; - - /* minor HTTP version or end of request line */ - case s_res_http_minor: - { - if (ch == ' ') { - UPDATE_STATE(s_res_first_status_code); - break; - } - - if (UNLIKELY(!IS_NUM(ch))) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - parser->http_minor *= 10; - parser->http_minor += ch - '0'; - - if (UNLIKELY(parser->http_minor > 999)) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - break; - } - - case s_res_first_status_code: - { - if (!IS_NUM(ch)) { - if (ch == ' ') { - break; - } - - SET_ERRNO(HPE_INVALID_STATUS); - goto error; - } - parser->status_code = ch - '0'; - UPDATE_STATE(s_res_status_code); - break; - } - - case s_res_status_code: - { - if (!IS_NUM(ch)) { - switch (ch) { - case ' ': - UPDATE_STATE(s_res_status_start); - break; - case CR: - UPDATE_STATE(s_res_line_almost_done); - break; - case LF: - UPDATE_STATE(s_header_field_start); - break; - default: - SET_ERRNO(HPE_INVALID_STATUS); - goto error; - } - break; - } - - parser->status_code *= 10; - parser->status_code += ch - '0'; - - if (UNLIKELY(parser->status_code > 999)) { - SET_ERRNO(HPE_INVALID_STATUS); - goto error; - } - - break; - } - - case s_res_status_start: - { - if (ch == CR) { - UPDATE_STATE(s_res_line_almost_done); - break; - } - - if (ch == LF) { - UPDATE_STATE(s_header_field_start); - break; - } - - MARK(status); - UPDATE_STATE(s_res_status); - parser->index = 0; - break; - } - - case s_res_status: - if (ch == CR) { - UPDATE_STATE(s_res_line_almost_done); - CALLBACK_DATA(status); - break; - } - - if (ch == LF) { - UPDATE_STATE(s_header_field_start); - CALLBACK_DATA(status); - break; - } - - break; - - case s_res_line_almost_done: - STRICT_CHECK(ch != LF); - UPDATE_STATE(s_header_field_start); - break; - - case s_start_req: - { - if (ch == CR || ch == LF) - break; - parser->flags = 0; - parser->content_length = ULLONG_MAX; - - if (UNLIKELY(!IS_ALPHA(ch))) { - SET_ERRNO(HPE_INVALID_METHOD); - goto error; - } - - parser->method = (enum http_method) 0; - parser->index = 1; - switch (ch) { - case 'A': parser->method = HTTP_ACL; break; - case 'B': parser->method = HTTP_BIND; break; - case 'C': parser->method = HTTP_CONNECT; /* or COPY, CHECKOUT */ break; - case 'D': parser->method = HTTP_DELETE; break; - case 'G': parser->method = HTTP_GET; break; - case 'H': parser->method = HTTP_HEAD; break; - case 'L': parser->method = HTTP_LOCK; /* or LINK */ break; - case 'M': parser->method = HTTP_MKCOL; /* or MOVE, MKACTIVITY, MERGE, M-SEARCH, MKCALENDAR */ break; - case 'N': parser->method = HTTP_NOTIFY; break; - case 'O': parser->method = HTTP_OPTIONS; break; - case 'P': parser->method = HTTP_POST; - /* or PROPFIND|PROPPATCH|PUT|PATCH|PURGE */ - break; - case 'R': parser->method = HTTP_REPORT; /* or REBIND */ break; - case 'S': parser->method = HTTP_SUBSCRIBE; /* or SEARCH */ break; - case 'T': parser->method = HTTP_TRACE; break; - case 'U': parser->method = HTTP_UNLOCK; /* or UNSUBSCRIBE, UNBIND, UNLINK */ break; - default: - SET_ERRNO(HPE_INVALID_METHOD); - goto error; - } - UPDATE_STATE(s_req_method); - - CALLBACK_NOTIFY(message_begin); - - break; - } - - case s_req_method: - { - const char *matcher; - if (UNLIKELY(ch == '\0')) { - SET_ERRNO(HPE_INVALID_METHOD); - goto error; - } - - matcher = method_strings[parser->method]; - if (ch == ' ' && matcher[parser->index] == '\0') { - UPDATE_STATE(s_req_spaces_before_url); - } else if (ch == matcher[parser->index]) { - ; /* nada */ - } else if (IS_ALPHA(ch)) { - - switch (parser->method << 16 | parser->index << 8 | ch) { -#define XX(meth, pos, ch, new_meth) \ - case (HTTP_##meth << 16 | pos << 8 | ch): \ - parser->method = HTTP_##new_meth; break; - - XX(POST, 1, 'U', PUT) - XX(POST, 1, 'A', PATCH) - XX(CONNECT, 1, 'H', CHECKOUT) - XX(CONNECT, 2, 'P', COPY) - XX(MKCOL, 1, 'O', MOVE) - XX(MKCOL, 1, 'E', MERGE) - XX(MKCOL, 2, 'A', MKACTIVITY) - XX(MKCOL, 3, 'A', MKCALENDAR) - XX(SUBSCRIBE, 1, 'E', SEARCH) - XX(REPORT, 2, 'B', REBIND) - XX(POST, 1, 'R', PROPFIND) - XX(PROPFIND, 4, 'P', PROPPATCH) - XX(PUT, 2, 'R', PURGE) - XX(LOCK, 1, 'I', LINK) - XX(UNLOCK, 2, 'S', UNSUBSCRIBE) - XX(UNLOCK, 2, 'B', UNBIND) - XX(UNLOCK, 3, 'I', UNLINK) -#undef XX - - default: - SET_ERRNO(HPE_INVALID_METHOD); - goto error; - } - } else if (ch == '-' && - parser->index == 1 && - parser->method == HTTP_MKCOL) { - parser->method = HTTP_MSEARCH; - } else { - SET_ERRNO(HPE_INVALID_METHOD); - goto error; - } - - ++parser->index; - break; - } - - case s_req_spaces_before_url: - { - if (ch == ' ') break; - - MARK(url); - if (parser->method == HTTP_CONNECT) { - UPDATE_STATE(s_req_server_start); - } - - UPDATE_STATE(parse_url_char(CURRENT_STATE(), ch)); - if (UNLIKELY(CURRENT_STATE() == s_dead)) { - SET_ERRNO(HPE_INVALID_URL); - goto error; - } - - break; - } - - case s_req_schema: - case s_req_schema_slash: - case s_req_schema_slash_slash: - case s_req_server_start: - { - switch (ch) { - /* No whitespace allowed here */ - case ' ': - case CR: - case LF: - SET_ERRNO(HPE_INVALID_URL); - goto error; - default: - UPDATE_STATE(parse_url_char(CURRENT_STATE(), ch)); - if (UNLIKELY(CURRENT_STATE() == s_dead)) { - SET_ERRNO(HPE_INVALID_URL); - goto error; - } - } - - break; - } - - case s_req_server: - case s_req_server_with_at: - case s_req_path: - case s_req_query_string_start: - case s_req_query_string: - case s_req_fragment_start: - case s_req_fragment: - { - switch (ch) { - case ' ': - UPDATE_STATE(s_req_http_start); - CALLBACK_DATA(url); - break; - case CR: - case LF: - parser->http_major = 0; - parser->http_minor = 9; - UPDATE_STATE((ch == CR) ? - s_req_line_almost_done : - s_header_field_start); - CALLBACK_DATA(url); - break; - default: - UPDATE_STATE(parse_url_char(CURRENT_STATE(), ch)); - if (UNLIKELY(CURRENT_STATE() == s_dead)) { - SET_ERRNO(HPE_INVALID_URL); - goto error; - } - } - break; - } - - case s_req_http_start: - switch (ch) { - case 'H': - UPDATE_STATE(s_req_http_H); - break; - case ' ': - break; - default: - SET_ERRNO(HPE_INVALID_CONSTANT); - goto error; - } - break; - - case s_req_http_H: - STRICT_CHECK(ch != 'T'); - UPDATE_STATE(s_req_http_HT); - break; - - case s_req_http_HT: - STRICT_CHECK(ch != 'T'); - UPDATE_STATE(s_req_http_HTT); - break; - - case s_req_http_HTT: - STRICT_CHECK(ch != 'P'); - UPDATE_STATE(s_req_http_HTTP); - break; - - case s_req_http_HTTP: - STRICT_CHECK(ch != '/'); - UPDATE_STATE(s_req_first_http_major); - break; - - /* first digit of major HTTP version */ - case s_req_first_http_major: - if (UNLIKELY(ch < '1' || ch > '9')) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - parser->http_major = ch - '0'; - UPDATE_STATE(s_req_http_major); - break; - - /* major HTTP version or dot */ - case s_req_http_major: - { - if (ch == '.') { - UPDATE_STATE(s_req_first_http_minor); - break; - } - - if (UNLIKELY(!IS_NUM(ch))) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - parser->http_major *= 10; - parser->http_major += ch - '0'; - - if (UNLIKELY(parser->http_major > 999)) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - break; - } - - /* first digit of minor HTTP version */ - case s_req_first_http_minor: - if (UNLIKELY(!IS_NUM(ch))) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - parser->http_minor = ch - '0'; - UPDATE_STATE(s_req_http_minor); - break; - - /* minor HTTP version or end of request line */ - case s_req_http_minor: - { - if (ch == CR) { - UPDATE_STATE(s_req_line_almost_done); - break; - } - - if (ch == LF) { - UPDATE_STATE(s_header_field_start); - break; - } - - /* XXX allow spaces after digit? */ - - if (UNLIKELY(!IS_NUM(ch))) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - parser->http_minor *= 10; - parser->http_minor += ch - '0'; - - if (UNLIKELY(parser->http_minor > 999)) { - SET_ERRNO(HPE_INVALID_VERSION); - goto error; - } - - break; - } - - /* end of request line */ - case s_req_line_almost_done: - { - if (UNLIKELY(ch != LF)) { - SET_ERRNO(HPE_LF_EXPECTED); - goto error; - } - - UPDATE_STATE(s_header_field_start); - break; - } - - case s_header_field_start: - { - if (ch == CR) { - UPDATE_STATE(s_headers_almost_done); - break; - } - - if (ch == LF) { - /* they might be just sending \n instead of \r\n so this would be - * the second \n to denote the end of headers*/ - UPDATE_STATE(s_headers_almost_done); - REEXECUTE(); - } - - c = TOKEN(ch); - - if (UNLIKELY(!c)) { - SET_ERRNO(HPE_INVALID_HEADER_TOKEN); - goto error; - } - - MARK(header_field); - - parser->index = 0; - UPDATE_STATE(s_header_field); - - switch (c) { - case 'c': - parser->header_state = h_C; - break; - - case 'p': - parser->header_state = h_matching_proxy_connection; - break; - - case 't': - parser->header_state = h_matching_transfer_encoding; - break; - - case 'u': - parser->header_state = h_matching_upgrade; - break; - - default: - parser->header_state = h_general; - break; - } - break; - } - - case s_header_field: - { - const char* start = p; - for (; p != data + len; p++) { - ch = *p; - c = TOKEN(ch); - - if (!c) - break; - - switch (parser->header_state) { - case h_general: - break; - - case h_C: - parser->index++; - parser->header_state = (c == 'o' ? h_CO : h_general); - break; - - case h_CO: - parser->index++; - parser->header_state = (c == 'n' ? h_CON : h_general); - break; - - case h_CON: - parser->index++; - switch (c) { - case 'n': - parser->header_state = h_matching_connection; - break; - case 't': - parser->header_state = h_matching_content_length; - break; - default: - parser->header_state = h_general; - break; - } - break; - - /* connection */ - - case h_matching_connection: - parser->index++; - if (parser->index > sizeof(CONNECTION)-1 - || c != CONNECTION[parser->index]) { - parser->header_state = h_general; - } else if (parser->index == sizeof(CONNECTION)-2) { - parser->header_state = h_connection; - } - break; - - /* proxy-connection */ - - case h_matching_proxy_connection: - parser->index++; - if (parser->index > sizeof(PROXY_CONNECTION)-1 - || c != PROXY_CONNECTION[parser->index]) { - parser->header_state = h_general; - } else if (parser->index == sizeof(PROXY_CONNECTION)-2) { - parser->header_state = h_connection; - } - break; - - /* content-length */ - - case h_matching_content_length: - parser->index++; - if (parser->index > sizeof(CONTENT_LENGTH)-1 - || c != CONTENT_LENGTH[parser->index]) { - parser->header_state = h_general; - } else if (parser->index == sizeof(CONTENT_LENGTH)-2) { - parser->header_state = h_content_length; - } - break; - - /* transfer-encoding */ - - case h_matching_transfer_encoding: - parser->index++; - if (parser->index > sizeof(TRANSFER_ENCODING)-1 - || c != TRANSFER_ENCODING[parser->index]) { - parser->header_state = h_general; - } else if (parser->index == sizeof(TRANSFER_ENCODING)-2) { - parser->header_state = h_transfer_encoding; - } - break; - - /* upgrade */ - - case h_matching_upgrade: - parser->index++; - if (parser->index > sizeof(UPGRADE)-1 - || c != UPGRADE[parser->index]) { - parser->header_state = h_general; - } else if (parser->index == sizeof(UPGRADE)-2) { - parser->header_state = h_upgrade; - } - break; - - case h_connection: - case h_content_length: - case h_transfer_encoding: - case h_upgrade: - if (ch != ' ') parser->header_state = h_general; - break; - - default: - assert(0 && "Unknown header_state"); - break; - } - } - - COUNT_HEADER_SIZE(p - start); - - if (p == data + len) { - --p; - break; - } - - if (ch == ':') { - UPDATE_STATE(s_header_value_discard_ws); - CALLBACK_DATA(header_field); - break; - } - - SET_ERRNO(HPE_INVALID_HEADER_TOKEN); - goto error; - } - - case s_header_value_discard_ws: - if (ch == ' ' || ch == '\t') break; - - if (ch == CR) { - UPDATE_STATE(s_header_value_discard_ws_almost_done); - break; - } - - if (ch == LF) { - UPDATE_STATE(s_header_value_discard_lws); - break; - } - - /* FALLTHROUGH */ - - case s_header_value_start: - { - MARK(header_value); - - UPDATE_STATE(s_header_value); - parser->index = 0; - - c = LOWER(ch); - - switch (parser->header_state) { - case h_upgrade: - parser->flags |= F_UPGRADE; - parser->header_state = h_general; - break; - - case h_transfer_encoding: - /* looking for 'Transfer-Encoding: chunked' */ - if ('c' == c) { - parser->header_state = h_matching_transfer_encoding_chunked; - } else { - parser->header_state = h_general; - } - break; - - case h_content_length: - if (UNLIKELY(!IS_NUM(ch))) { - SET_ERRNO(HPE_INVALID_CONTENT_LENGTH); - goto error; - } - - if (parser->flags & F_CONTENTLENGTH) { - SET_ERRNO(HPE_UNEXPECTED_CONTENT_LENGTH); - goto error; - } - - parser->flags |= F_CONTENTLENGTH; - parser->content_length = ch - '0'; - break; - - case h_connection: - /* looking for 'Connection: keep-alive' */ - if (c == 'k') { - parser->header_state = h_matching_connection_keep_alive; - /* looking for 'Connection: close' */ - } else if (c == 'c') { - parser->header_state = h_matching_connection_close; - } else if (c == 'u') { - parser->header_state = h_matching_connection_upgrade; - } else { - parser->header_state = h_matching_connection_token; - } - break; - - /* Multi-value `Connection` header */ - case h_matching_connection_token_start: - break; - - default: - parser->header_state = h_general; - break; - } - break; - } - - case s_header_value: - { - const char* start = p; - enum header_states h_state = (enum header_states) parser->header_state; - for (; p != data + len; p++) { - ch = *p; - if (ch == CR) { - UPDATE_STATE(s_header_almost_done); - parser->header_state = h_state; - CALLBACK_DATA(header_value); - break; - } - - if (ch == LF) { - UPDATE_STATE(s_header_almost_done); - COUNT_HEADER_SIZE(p - start); - parser->header_state = h_state; - CALLBACK_DATA_NOADVANCE(header_value); - REEXECUTE(); - } - - if (!lenient && !IS_HEADER_CHAR(ch)) { - SET_ERRNO(HPE_INVALID_HEADER_TOKEN); - goto error; - } - - c = LOWER(ch); - - switch (h_state) { - case h_general: - { - const char* p_cr; - const char* p_lf; - size_t limit = data + len - p; - - limit = MIN(limit, HTTP_MAX_HEADER_SIZE); - - p_cr = (const char*) memchr(p, CR, limit); - p_lf = (const char*) memchr(p, LF, limit); - if (p_cr != NULL) { - if (p_lf != NULL && p_cr >= p_lf) - p = p_lf; - else - p = p_cr; - } else if (UNLIKELY(p_lf != NULL)) { - p = p_lf; - } else { - p = data + len; - } - --p; - - break; - } - - case h_connection: - case h_transfer_encoding: - assert(0 && "Shouldn't get here."); - break; - - case h_content_length: - { - uint64_t t; - - if (ch == ' ') break; - - if (UNLIKELY(!IS_NUM(ch))) { - SET_ERRNO(HPE_INVALID_CONTENT_LENGTH); - parser->header_state = h_state; - goto error; - } - - t = parser->content_length; - t *= 10; - t += ch - '0'; - - /* Overflow? Test against a conservative limit for simplicity. */ - if (UNLIKELY((ULLONG_MAX - 10) / 10 < parser->content_length)) { - SET_ERRNO(HPE_INVALID_CONTENT_LENGTH); - parser->header_state = h_state; - goto error; - } - - parser->content_length = t; - break; - } - - /* Transfer-Encoding: chunked */ - case h_matching_transfer_encoding_chunked: - parser->index++; - if (parser->index > sizeof(CHUNKED)-1 - || c != CHUNKED[parser->index]) { - h_state = h_general; - } else if (parser->index == sizeof(CHUNKED)-2) { - h_state = h_transfer_encoding_chunked; - } - break; - - case h_matching_connection_token_start: - /* looking for 'Connection: keep-alive' */ - if (c == 'k') { - h_state = h_matching_connection_keep_alive; - /* looking for 'Connection: close' */ - } else if (c == 'c') { - h_state = h_matching_connection_close; - } else if (c == 'u') { - h_state = h_matching_connection_upgrade; - } else if (STRICT_TOKEN(c)) { - h_state = h_matching_connection_token; - } else if (c == ' ' || c == '\t') { - /* Skip lws */ - } else { - h_state = h_general; - } - break; - - /* looking for 'Connection: keep-alive' */ - case h_matching_connection_keep_alive: - parser->index++; - if (parser->index > sizeof(KEEP_ALIVE)-1 - || c != KEEP_ALIVE[parser->index]) { - h_state = h_matching_connection_token; - } else if (parser->index == sizeof(KEEP_ALIVE)-2) { - h_state = h_connection_keep_alive; - } - break; - - /* looking for 'Connection: close' */ - case h_matching_connection_close: - parser->index++; - if (parser->index > sizeof(CLOSE)-1 || c != CLOSE[parser->index]) { - h_state = h_matching_connection_token; - } else if (parser->index == sizeof(CLOSE)-2) { - h_state = h_connection_close; - } - break; - - /* looking for 'Connection: upgrade' */ - case h_matching_connection_upgrade: - parser->index++; - if (parser->index > sizeof(UPGRADE) - 1 || - c != UPGRADE[parser->index]) { - h_state = h_matching_connection_token; - } else if (parser->index == sizeof(UPGRADE)-2) { - h_state = h_connection_upgrade; - } - break; - - case h_matching_connection_token: - if (ch == ',') { - h_state = h_matching_connection_token_start; - parser->index = 0; - } - break; - - case h_transfer_encoding_chunked: - if (ch != ' ') h_state = h_general; - break; - - case h_connection_keep_alive: - case h_connection_close: - case h_connection_upgrade: - if (ch == ',') { - if (h_state == h_connection_keep_alive) { - parser->flags |= F_CONNECTION_KEEP_ALIVE; - } else if (h_state == h_connection_close) { - parser->flags |= F_CONNECTION_CLOSE; - } else if (h_state == h_connection_upgrade) { - parser->flags |= F_CONNECTION_UPGRADE; - } - h_state = h_matching_connection_token_start; - parser->index = 0; - } else if (ch != ' ') { - h_state = h_matching_connection_token; - } - break; - - default: - UPDATE_STATE(s_header_value); - h_state = h_general; - break; - } - } - parser->header_state = h_state; - - COUNT_HEADER_SIZE(p - start); - - if (p == data + len) - --p; - break; - } - - case s_header_almost_done: - { - if (UNLIKELY(ch != LF)) { - SET_ERRNO(HPE_LF_EXPECTED); - goto error; - } - - UPDATE_STATE(s_header_value_lws); - break; - } - - case s_header_value_lws: - { - if (ch == ' ' || ch == '\t') { - UPDATE_STATE(s_header_value_start); - REEXECUTE(); - } - - /* finished the header */ - switch (parser->header_state) { - case h_connection_keep_alive: - parser->flags |= F_CONNECTION_KEEP_ALIVE; - break; - case h_connection_close: - parser->flags |= F_CONNECTION_CLOSE; - break; - case h_transfer_encoding_chunked: - parser->flags |= F_CHUNKED; - break; - case h_connection_upgrade: - parser->flags |= F_CONNECTION_UPGRADE; - break; - default: - break; - } - - UPDATE_STATE(s_header_field_start); - REEXECUTE(); - } - - case s_header_value_discard_ws_almost_done: - { - STRICT_CHECK(ch != LF); - UPDATE_STATE(s_header_value_discard_lws); - break; - } - - case s_header_value_discard_lws: - { - if (ch == ' ' || ch == '\t') { - UPDATE_STATE(s_header_value_discard_ws); - break; - } else { - switch (parser->header_state) { - case h_connection_keep_alive: - parser->flags |= F_CONNECTION_KEEP_ALIVE; - break; - case h_connection_close: - parser->flags |= F_CONNECTION_CLOSE; - break; - case h_connection_upgrade: - parser->flags |= F_CONNECTION_UPGRADE; - break; - case h_transfer_encoding_chunked: - parser->flags |= F_CHUNKED; - break; - default: - break; - } - - /* header value was empty */ - MARK(header_value); - UPDATE_STATE(s_header_field_start); - CALLBACK_DATA_NOADVANCE(header_value); - REEXECUTE(); - } - } - - case s_headers_almost_done: - { - STRICT_CHECK(ch != LF); - - if (parser->flags & F_TRAILING) { - /* End of a chunked request */ - UPDATE_STATE(s_message_done); - CALLBACK_NOTIFY_NOADVANCE(chunk_complete); - REEXECUTE(); - } - - /* Cannot use chunked encoding and a content-length header together - per the HTTP specification. */ - if ((parser->flags & F_CHUNKED) && - (parser->flags & F_CONTENTLENGTH)) { - SET_ERRNO(HPE_UNEXPECTED_CONTENT_LENGTH); - goto error; - } - - UPDATE_STATE(s_headers_done); - - /* Set this here so that on_headers_complete() callbacks can see it */ - parser->upgrade = - ((parser->flags & (F_UPGRADE | F_CONNECTION_UPGRADE)) == - (F_UPGRADE | F_CONNECTION_UPGRADE) || - parser->method == HTTP_CONNECT); - - /* Here we call the headers_complete callback. This is somewhat - * different than other callbacks because if the user returns 1, we - * will interpret that as saying that this message has no body. This - * is needed for the annoying case of recieving a response to a HEAD - * request. - * - * We'd like to use CALLBACK_NOTIFY_NOADVANCE() here but we cannot, so - * we have to simulate it by handling a change in errno below. - */ - if (settings->on_headers_complete) { - switch (settings->on_headers_complete(parser)) { - case 0: - break; - - case 2: - parser->upgrade = 1; - - case 1: - parser->flags |= F_SKIPBODY; - break; - - default: - SET_ERRNO(HPE_CB_headers_complete); - RETURN(p - data); /* Error */ - } - } - - if (HTTP_PARSER_ERRNO(parser) != HPE_OK) { - RETURN(p - data); - } - - REEXECUTE(); - } - - case s_headers_done: - { - int hasBody; - STRICT_CHECK(ch != LF); - - parser->nread = 0; - - hasBody = parser->flags & F_CHUNKED || - (parser->content_length > 0 && parser->content_length != ULLONG_MAX); - if (parser->upgrade && (parser->method == HTTP_CONNECT || - (parser->flags & F_SKIPBODY) || !hasBody)) { - /* Exit, the rest of the message is in a different protocol. */ - UPDATE_STATE(NEW_MESSAGE()); - CALLBACK_NOTIFY(message_complete); - RETURN((p - data) + 1); - } - - if (parser->flags & F_SKIPBODY) { - UPDATE_STATE(NEW_MESSAGE()); - CALLBACK_NOTIFY(message_complete); - } else if (parser->flags & F_CHUNKED) { - /* chunked encoding - ignore Content-Length header */ - UPDATE_STATE(s_chunk_size_start); - } else { - if (parser->content_length == 0) { - /* Content-Length header given but zero: Content-Length: 0\r\n */ - UPDATE_STATE(NEW_MESSAGE()); - CALLBACK_NOTIFY(message_complete); - } else if (parser->content_length != ULLONG_MAX) { - /* Content-Length header given and non-zero */ - UPDATE_STATE(s_body_identity); - } else { - if (!http_message_needs_eof(parser)) { - /* Assume content-length 0 - read the next */ - UPDATE_STATE(NEW_MESSAGE()); - CALLBACK_NOTIFY(message_complete); - } else { - /* Read body until EOF */ - UPDATE_STATE(s_body_identity_eof); - } - } - } - - break; - } - - case s_body_identity: - { - uint64_t to_read = MIN(parser->content_length, - (uint64_t) ((data + len) - p)); - - assert(parser->content_length != 0 - && parser->content_length != ULLONG_MAX); - - /* The difference between advancing content_length and p is because - * the latter will automaticaly advance on the next loop iteration. - * Further, if content_length ends up at 0, we want to see the last - * byte again for our message complete callback. - */ - MARK(body); - parser->content_length -= to_read; - p += to_read - 1; - - if (parser->content_length == 0) { - UPDATE_STATE(s_message_done); - - /* Mimic CALLBACK_DATA_NOADVANCE() but with one extra byte. - * - * The alternative to doing this is to wait for the next byte to - * trigger the data callback, just as in every other case. The - * problem with this is that this makes it difficult for the test - * harness to distinguish between complete-on-EOF and - * complete-on-length. It's not clear that this distinction is - * important for applications, but let's keep it for now. - */ - CALLBACK_DATA_(body, p - body_mark + 1, p - data); - REEXECUTE(); - } - - break; - } - - /* read until EOF */ - case s_body_identity_eof: - MARK(body); - p = data + len - 1; - - break; - - case s_message_done: - UPDATE_STATE(NEW_MESSAGE()); - CALLBACK_NOTIFY(message_complete); - if (parser->upgrade) { - /* Exit, the rest of the message is in a different protocol. */ - RETURN((p - data) + 1); - } - break; - - case s_chunk_size_start: - { - assert(parser->nread == 1); - assert(parser->flags & F_CHUNKED); - - unhex_val = unhex[(unsigned char)ch]; - if (UNLIKELY(unhex_val == -1)) { - SET_ERRNO(HPE_INVALID_CHUNK_SIZE); - goto error; - } - - parser->content_length = unhex_val; - UPDATE_STATE(s_chunk_size); - break; - } - - case s_chunk_size: - { - uint64_t t; - - assert(parser->flags & F_CHUNKED); - - if (ch == CR) { - UPDATE_STATE(s_chunk_size_almost_done); - break; - } - - unhex_val = unhex[(unsigned char)ch]; - - if (unhex_val == -1) { - if (ch == ';' || ch == ' ') { - UPDATE_STATE(s_chunk_parameters); - break; - } - - SET_ERRNO(HPE_INVALID_CHUNK_SIZE); - goto error; - } - - t = parser->content_length; - t *= 16; - t += unhex_val; - - /* Overflow? Test against a conservative limit for simplicity. */ - if (UNLIKELY((ULLONG_MAX - 16) / 16 < parser->content_length)) { - SET_ERRNO(HPE_INVALID_CONTENT_LENGTH); - goto error; - } - - parser->content_length = t; - break; - } - - case s_chunk_parameters: - { - assert(parser->flags & F_CHUNKED); - /* just ignore this shit. TODO check for overflow */ - if (ch == CR) { - UPDATE_STATE(s_chunk_size_almost_done); - break; - } - break; - } - - case s_chunk_size_almost_done: - { - assert(parser->flags & F_CHUNKED); - STRICT_CHECK(ch != LF); - - parser->nread = 0; - - if (parser->content_length == 0) { - parser->flags |= F_TRAILING; - UPDATE_STATE(s_header_field_start); - } else { - UPDATE_STATE(s_chunk_data); - } - CALLBACK_NOTIFY(chunk_header); - break; - } - - case s_chunk_data: - { - uint64_t to_read = MIN(parser->content_length, - (uint64_t) ((data + len) - p)); - - assert(parser->flags & F_CHUNKED); - assert(parser->content_length != 0 - && parser->content_length != ULLONG_MAX); - - /* See the explanation in s_body_identity for why the content - * length and data pointers are managed this way. - */ - MARK(body); - parser->content_length -= to_read; - p += to_read - 1; - - if (parser->content_length == 0) { - UPDATE_STATE(s_chunk_data_almost_done); - } - - break; - } - - case s_chunk_data_almost_done: - assert(parser->flags & F_CHUNKED); - assert(parser->content_length == 0); - STRICT_CHECK(ch != CR); - UPDATE_STATE(s_chunk_data_done); - CALLBACK_DATA(body); - break; - - case s_chunk_data_done: - assert(parser->flags & F_CHUNKED); - STRICT_CHECK(ch != LF); - parser->nread = 0; - UPDATE_STATE(s_chunk_size_start); - CALLBACK_NOTIFY(chunk_complete); - break; - - default: - assert(0 && "unhandled state"); - SET_ERRNO(HPE_INVALID_INTERNAL_STATE); - goto error; - } - } - - /* Run callbacks for any marks that we have leftover after we ran our of - * bytes. There should be at most one of these set, so it's OK to invoke - * them in series (unset marks will not result in callbacks). - * - * We use the NOADVANCE() variety of callbacks here because 'p' has already - * overflowed 'data' and this allows us to correct for the off-by-one that - * we'd otherwise have (since CALLBACK_DATA() is meant to be run with a 'p' - * value that's in-bounds). - */ - - assert(((header_field_mark ? 1 : 0) + - (header_value_mark ? 1 : 0) + - (url_mark ? 1 : 0) + - (body_mark ? 1 : 0) + - (status_mark ? 1 : 0)) <= 1); - - CALLBACK_DATA_NOADVANCE(header_field); - CALLBACK_DATA_NOADVANCE(header_value); - CALLBACK_DATA_NOADVANCE(url); - CALLBACK_DATA_NOADVANCE(body); - CALLBACK_DATA_NOADVANCE(status); - - RETURN(len); - -error: - if (HTTP_PARSER_ERRNO(parser) == HPE_OK) { - SET_ERRNO(HPE_UNKNOWN); - } - - RETURN(p - data); -} - - -/* Does the parser need to see an EOF to find the end of the message? */ -int -http_message_needs_eof (const http_parser *parser) -{ - if (parser->type == HTTP_REQUEST) { - return 0; - } - - /* See RFC 2616 section 4.4 */ - if (parser->status_code / 100 == 1 || /* 1xx e.g. Continue */ - parser->status_code == 204 || /* No Content */ - parser->status_code == 304 || /* Not Modified */ - parser->flags & F_SKIPBODY) { /* response to a HEAD request */ - return 0; - } - - if ((parser->flags & F_CHUNKED) || parser->content_length != ULLONG_MAX) { - return 0; - } - - return 1; -} - - -int -http_should_keep_alive (const http_parser *parser) -{ - if (parser->http_major > 0 && parser->http_minor > 0) { - /* HTTP/1.1 */ - if (parser->flags & F_CONNECTION_CLOSE) { - return 0; - } - } else { - /* HTTP/1.0 or earlier */ - if (!(parser->flags & F_CONNECTION_KEEP_ALIVE)) { - return 0; - } - } - - return !http_message_needs_eof(parser); -} - - -const char * -http_method_str (enum http_method m) -{ - return ELEM_AT(method_strings, m, ""); -} - - -void -http_parser_init (http_parser *parser, enum http_parser_type t) -{ - void *data = parser->data; /* preserve application data */ - memset(parser, 0, sizeof(*parser)); - parser->data = data; - parser->type = t; - parser->state = (t == HTTP_REQUEST ? s_start_req : (t == HTTP_RESPONSE ? s_start_res : s_start_req_or_res)); - parser->http_errno = HPE_OK; -} - -void -http_parser_settings_init(http_parser_settings *settings) -{ - memset(settings, 0, sizeof(*settings)); -} - -const char * -http_errno_name(enum http_errno err) { - assert(((size_t) err) < ARRAY_SIZE(http_strerror_tab)); - return http_strerror_tab[err].name; -} - -const char * -http_errno_description(enum http_errno err) { - assert(((size_t) err) < ARRAY_SIZE(http_strerror_tab)); - return http_strerror_tab[err].description; -} - -static enum http_host_state -http_parse_host_char(enum http_host_state s, const char ch) { - switch(s) { - case s_http_userinfo: - case s_http_userinfo_start: - if (ch == '@') { - return s_http_host_start; - } - - if (IS_USERINFO_CHAR(ch)) { - return s_http_userinfo; - } - break; - - case s_http_host_start: - if (ch == '[') { - return s_http_host_v6_start; - } - - if (IS_HOST_CHAR(ch)) { - return s_http_host; - } - - break; - - case s_http_host: - if (IS_HOST_CHAR(ch)) { - return s_http_host; - } - - /* FALLTHROUGH */ - case s_http_host_v6_end: - if (ch == ':') { - return s_http_host_port_start; - } - - break; - - case s_http_host_v6: - if (ch == ']') { - return s_http_host_v6_end; - } - - /* FALLTHROUGH */ - case s_http_host_v6_start: - if (IS_HEX(ch) || ch == ':' || ch == '.') { - return s_http_host_v6; - } - - if (s == s_http_host_v6 && ch == '%') { - return s_http_host_v6_zone_start; - } - break; - - case s_http_host_v6_zone: - if (ch == ']') { - return s_http_host_v6_end; - } - - /* FALLTHROUGH */ - case s_http_host_v6_zone_start: - /* RFC 6874 Zone ID consists of 1*( unreserved / pct-encoded) */ - if (IS_ALPHANUM(ch) || ch == '%' || ch == '.' || ch == '-' || ch == '_' || - ch == '~') { - return s_http_host_v6_zone; - } - break; - - case s_http_host_port: - case s_http_host_port_start: - if (IS_NUM(ch)) { - return s_http_host_port; - } - - break; - - default: - break; - } - return s_http_host_dead; -} - -static int -http_parse_host(const char * buf, struct http_parser_url *u, int found_at) { - enum http_host_state s; - - const char *p; - size_t buflen = u->field_data[UF_HOST].off + u->field_data[UF_HOST].len; - - assert(u->field_set & (1 << UF_HOST)); - - u->field_data[UF_HOST].len = 0; - - s = found_at ? s_http_userinfo_start : s_http_host_start; - - for (p = buf + u->field_data[UF_HOST].off; p < buf + buflen; p++) { - enum http_host_state new_s = http_parse_host_char(s, *p); - - if (new_s == s_http_host_dead) { - return 1; - } - - switch(new_s) { - case s_http_host: - if (s != s_http_host) { - u->field_data[UF_HOST].off = p - buf; - } - u->field_data[UF_HOST].len++; - break; - - case s_http_host_v6: - if (s != s_http_host_v6) { - u->field_data[UF_HOST].off = p - buf; - } - u->field_data[UF_HOST].len++; - break; - - case s_http_host_v6_zone_start: - case s_http_host_v6_zone: - u->field_data[UF_HOST].len++; - break; - - case s_http_host_port: - if (s != s_http_host_port) { - u->field_data[UF_PORT].off = p - buf; - u->field_data[UF_PORT].len = 0; - u->field_set |= (1 << UF_PORT); - } - u->field_data[UF_PORT].len++; - break; - - case s_http_userinfo: - if (s != s_http_userinfo) { - u->field_data[UF_USERINFO].off = p - buf ; - u->field_data[UF_USERINFO].len = 0; - u->field_set |= (1 << UF_USERINFO); - } - u->field_data[UF_USERINFO].len++; - break; - - default: - break; - } - s = new_s; - } - - /* Make sure we don't end somewhere unexpected */ - switch (s) { - case s_http_host_start: - case s_http_host_v6_start: - case s_http_host_v6: - case s_http_host_v6_zone_start: - case s_http_host_v6_zone: - case s_http_host_port_start: - case s_http_userinfo: - case s_http_userinfo_start: - return 1; - default: - break; - } - - return 0; -} - -void -http_parser_url_init(struct http_parser_url *u) { - memset(u, 0, sizeof(*u)); -} - -int -http_parser_parse_url(const char *buf, size_t buflen, int is_connect, - struct http_parser_url *u) -{ - enum state s; - const char *p; - enum http_parser_url_fields uf, old_uf; - int found_at = 0; - - u->port = u->field_set = 0; - s = is_connect ? s_req_server_start : s_req_spaces_before_url; - old_uf = UF_MAX; - - for (p = buf; p < buf + buflen; p++) { - s = parse_url_char(s, *p); - - /* Figure out the next field that we're operating on */ - switch (s) { - case s_dead: - return 1; - - /* Skip delimeters */ - case s_req_schema_slash: - case s_req_schema_slash_slash: - case s_req_server_start: - case s_req_query_string_start: - case s_req_fragment_start: - continue; - - case s_req_schema: - uf = UF_SCHEMA; - break; - - case s_req_server_with_at: - found_at = 1; - - /* FALLTROUGH */ - case s_req_server: - uf = UF_HOST; - break; - - case s_req_path: - uf = UF_PATH; - break; - - case s_req_query_string: - uf = UF_QUERY; - break; - - case s_req_fragment: - uf = UF_FRAGMENT; - break; - - default: - assert(!"Unexpected state"); - return 1; - } - - /* Nothing's changed; soldier on */ - if (uf == old_uf) { - u->field_data[uf].len++; - continue; - } - - u->field_data[uf].off = p - buf; - u->field_data[uf].len = 1; - - u->field_set |= (1 << uf); - old_uf = uf; - } - - /* host must be present if there is a schema */ - /* parsing http:///toto will fail */ - if ((u->field_set & (1 << UF_SCHEMA)) && - (u->field_set & (1 << UF_HOST)) == 0) { - return 1; - } - - if (u->field_set & (1 << UF_HOST)) { - if (http_parse_host(buf, u, found_at) != 0) { - return 1; - } - } - - /* CONNECT requests can only contain "hostname:port" */ - if (is_connect && u->field_set != ((1 << UF_HOST)|(1 << UF_PORT))) { - return 1; - } - - if (u->field_set & (1 << UF_PORT)) { - /* Don't bother with endp; we've already validated the string */ - unsigned long v = strtoul(buf + u->field_data[UF_PORT].off, NULL, 10); - - /* Ports have a max value of 2^16 */ - if (v > 0xffff) { - return 1; - } - - u->port = (uint16_t) v; - } - - return 0; -} - -void -http_parser_pause(http_parser *parser, int paused) { - /* Users should only be pausing/unpausing a parser that is not in an error - * state. In non-debug builds, there's not much that we can do about this - * other than ignore it. - */ - if (HTTP_PARSER_ERRNO(parser) == HPE_OK || - HTTP_PARSER_ERRNO(parser) == HPE_PAUSED) { - SET_ERRNO((paused) ? HPE_PAUSED : HPE_OK); - } else { - assert(0 && "Attempting to pause parser in error state"); - } -} - -int -http_body_is_final(const struct http_parser *parser) { - return parser->state == s_message_done; -} - -unsigned long -http_parser_version(void) { - return HTTP_PARSER_VERSION_MAJOR * 0x10000 | - HTTP_PARSER_VERSION_MINOR * 0x00100 | - HTTP_PARSER_VERSION_PATCH * 0x00001; -} diff --git a/vendor/http-parser/http_parser.gyp b/vendor/http-parser/http_parser.gyp deleted file mode 100644 index ef34ecaeaea..00000000000 --- a/vendor/http-parser/http_parser.gyp +++ /dev/null @@ -1,111 +0,0 @@ -# This file is used with the GYP meta build system. -# http://code.google.com/p/gyp/ -# To build try this: -# svn co http://gyp.googlecode.com/svn/trunk gyp -# ./gyp/gyp -f make --depth=`pwd` http_parser.gyp -# ./out/Debug/test -{ - 'target_defaults': { - 'default_configuration': 'Debug', - 'configurations': { - # TODO: hoist these out and put them somewhere common, because - # RuntimeLibrary MUST MATCH across the entire project - 'Debug': { - 'defines': [ 'DEBUG', '_DEBUG' ], - 'cflags': [ '-Wall', '-Wextra', '-O0', '-g', '-ftrapv' ], - 'msvs_settings': { - 'VCCLCompilerTool': { - 'RuntimeLibrary': 1, # static debug - }, - }, - }, - 'Release': { - 'defines': [ 'NDEBUG' ], - 'cflags': [ '-Wall', '-Wextra', '-O3' ], - 'msvs_settings': { - 'VCCLCompilerTool': { - 'RuntimeLibrary': 0, # static release - }, - }, - } - }, - 'msvs_settings': { - 'VCCLCompilerTool': { - }, - 'VCLibrarianTool': { - }, - 'VCLinkerTool': { - 'GenerateDebugInformation': 'true', - }, - }, - 'conditions': [ - ['OS == "win"', { - 'defines': [ - 'WIN32' - ], - }] - ], - }, - - 'targets': [ - { - 'target_name': 'http_parser', - 'type': 'static_library', - 'include_dirs': [ '.' ], - 'direct_dependent_settings': { - 'defines': [ 'HTTP_PARSER_STRICT=0' ], - 'include_dirs': [ '.' ], - }, - 'defines': [ 'HTTP_PARSER_STRICT=0' ], - 'sources': [ './http_parser.c', ], - 'conditions': [ - ['OS=="win"', { - 'msvs_settings': { - 'VCCLCompilerTool': { - # Compile as C++. http_parser.c is actually C99, but C++ is - # close enough in this case. - 'CompileAs': 2, - }, - }, - }] - ], - }, - - { - 'target_name': 'http_parser_strict', - 'type': 'static_library', - 'include_dirs': [ '.' ], - 'direct_dependent_settings': { - 'defines': [ 'HTTP_PARSER_STRICT=1' ], - 'include_dirs': [ '.' ], - }, - 'defines': [ 'HTTP_PARSER_STRICT=1' ], - 'sources': [ './http_parser.c', ], - 'conditions': [ - ['OS=="win"', { - 'msvs_settings': { - 'VCCLCompilerTool': { - # Compile as C++. http_parser.c is actually C99, but C++ is - # close enough in this case. - 'CompileAs': 2, - }, - }, - }] - ], - }, - - { - 'target_name': 'test-nonstrict', - 'type': 'executable', - 'dependencies': [ 'http_parser' ], - 'sources': [ 'test.c' ] - }, - - { - 'target_name': 'test-strict', - 'type': 'executable', - 'dependencies': [ 'http_parser_strict' ], - 'sources': [ 'test.c' ] - } - ] -} diff --git a/vendor/http-parser/http_parser.h b/vendor/http-parser/http_parser.h deleted file mode 100644 index ea263948240..00000000000 --- a/vendor/http-parser/http_parser.h +++ /dev/null @@ -1,362 +0,0 @@ -/* Copyright Joyent, Inc. and other Node contributors. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - */ -#ifndef http_parser_h -#define http_parser_h -#ifdef __cplusplus -extern "C" { -#endif - -/* Also update SONAME in the Makefile whenever you change these. */ -#define HTTP_PARSER_VERSION_MAJOR 2 -#define HTTP_PARSER_VERSION_MINOR 7 -#define HTTP_PARSER_VERSION_PATCH 1 - -#include -#if defined(_WIN32) && !defined(__MINGW32__) && \ - (!defined(_MSC_VER) || _MSC_VER<1600) && !defined(__WINE__) -#include -#include -typedef __int8 int8_t; -typedef unsigned __int8 uint8_t; -typedef __int16 int16_t; -typedef unsigned __int16 uint16_t; -typedef __int32 int32_t; -typedef unsigned __int32 uint32_t; -typedef __int64 int64_t; -typedef unsigned __int64 uint64_t; -#else -#include -#endif - -/* Compile with -DHTTP_PARSER_STRICT=0 to make less checks, but run - * faster - */ -#ifndef HTTP_PARSER_STRICT -# define HTTP_PARSER_STRICT 1 -#endif - -/* Maximium header size allowed. If the macro is not defined - * before including this header then the default is used. To - * change the maximum header size, define the macro in the build - * environment (e.g. -DHTTP_MAX_HEADER_SIZE=). To remove - * the effective limit on the size of the header, define the macro - * to a very large number (e.g. -DHTTP_MAX_HEADER_SIZE=0x7fffffff) - */ -#ifndef HTTP_MAX_HEADER_SIZE -# define HTTP_MAX_HEADER_SIZE (80*1024) -#endif - -typedef struct http_parser http_parser; -typedef struct http_parser_settings http_parser_settings; - - -/* Callbacks should return non-zero to indicate an error. The parser will - * then halt execution. - * - * The one exception is on_headers_complete. In a HTTP_RESPONSE parser - * returning '1' from on_headers_complete will tell the parser that it - * should not expect a body. This is used when receiving a response to a - * HEAD request which may contain 'Content-Length' or 'Transfer-Encoding: - * chunked' headers that indicate the presence of a body. - * - * Returning `2` from on_headers_complete will tell parser that it should not - * expect neither a body nor any futher responses on this connection. This is - * useful for handling responses to a CONNECT request which may not contain - * `Upgrade` or `Connection: upgrade` headers. - * - * http_data_cb does not return data chunks. It will be called arbitrarily - * many times for each string. E.G. you might get 10 callbacks for "on_url" - * each providing just a few characters more data. - */ -typedef int (*http_data_cb) (http_parser*, const char *at, size_t length); -typedef int (*http_cb) (http_parser*); - - -/* Request Methods */ -#define HTTP_METHOD_MAP(XX) \ - XX(0, DELETE, DELETE) \ - XX(1, GET, GET) \ - XX(2, HEAD, HEAD) \ - XX(3, POST, POST) \ - XX(4, PUT, PUT) \ - /* pathological */ \ - XX(5, CONNECT, CONNECT) \ - XX(6, OPTIONS, OPTIONS) \ - XX(7, TRACE, TRACE) \ - /* WebDAV */ \ - XX(8, COPY, COPY) \ - XX(9, LOCK, LOCK) \ - XX(10, MKCOL, MKCOL) \ - XX(11, MOVE, MOVE) \ - XX(12, PROPFIND, PROPFIND) \ - XX(13, PROPPATCH, PROPPATCH) \ - XX(14, SEARCH, SEARCH) \ - XX(15, UNLOCK, UNLOCK) \ - XX(16, BIND, BIND) \ - XX(17, REBIND, REBIND) \ - XX(18, UNBIND, UNBIND) \ - XX(19, ACL, ACL) \ - /* subversion */ \ - XX(20, REPORT, REPORT) \ - XX(21, MKACTIVITY, MKACTIVITY) \ - XX(22, CHECKOUT, CHECKOUT) \ - XX(23, MERGE, MERGE) \ - /* upnp */ \ - XX(24, MSEARCH, M-SEARCH) \ - XX(25, NOTIFY, NOTIFY) \ - XX(26, SUBSCRIBE, SUBSCRIBE) \ - XX(27, UNSUBSCRIBE, UNSUBSCRIBE) \ - /* RFC-5789 */ \ - XX(28, PATCH, PATCH) \ - XX(29, PURGE, PURGE) \ - /* CalDAV */ \ - XX(30, MKCALENDAR, MKCALENDAR) \ - /* RFC-2068, section 19.6.1.2 */ \ - XX(31, LINK, LINK) \ - XX(32, UNLINK, UNLINK) \ - -enum http_method - { -#define XX(num, name, string) HTTP_##name = num, - HTTP_METHOD_MAP(XX) -#undef XX - }; - - -enum http_parser_type { HTTP_REQUEST, HTTP_RESPONSE, HTTP_BOTH }; - - -/* Flag values for http_parser.flags field */ -enum flags - { F_CHUNKED = 1 << 0 - , F_CONNECTION_KEEP_ALIVE = 1 << 1 - , F_CONNECTION_CLOSE = 1 << 2 - , F_CONNECTION_UPGRADE = 1 << 3 - , F_TRAILING = 1 << 4 - , F_UPGRADE = 1 << 5 - , F_SKIPBODY = 1 << 6 - , F_CONTENTLENGTH = 1 << 7 - }; - - -/* Map for errno-related constants - * - * The provided argument should be a macro that takes 2 arguments. - */ -#define HTTP_ERRNO_MAP(XX) \ - /* No error */ \ - XX(OK, "success") \ - \ - /* Callback-related errors */ \ - XX(CB_message_begin, "the on_message_begin callback failed") \ - XX(CB_url, "the on_url callback failed") \ - XX(CB_header_field, "the on_header_field callback failed") \ - XX(CB_header_value, "the on_header_value callback failed") \ - XX(CB_headers_complete, "the on_headers_complete callback failed") \ - XX(CB_body, "the on_body callback failed") \ - XX(CB_message_complete, "the on_message_complete callback failed") \ - XX(CB_status, "the on_status callback failed") \ - XX(CB_chunk_header, "the on_chunk_header callback failed") \ - XX(CB_chunk_complete, "the on_chunk_complete callback failed") \ - \ - /* Parsing-related errors */ \ - XX(INVALID_EOF_STATE, "stream ended at an unexpected time") \ - XX(HEADER_OVERFLOW, \ - "too many header bytes seen; overflow detected") \ - XX(CLOSED_CONNECTION, \ - "data received after completed connection: close message") \ - XX(INVALID_VERSION, "invalid HTTP version") \ - XX(INVALID_STATUS, "invalid HTTP status code") \ - XX(INVALID_METHOD, "invalid HTTP method") \ - XX(INVALID_URL, "invalid URL") \ - XX(INVALID_HOST, "invalid host") \ - XX(INVALID_PORT, "invalid port") \ - XX(INVALID_PATH, "invalid path") \ - XX(INVALID_QUERY_STRING, "invalid query string") \ - XX(INVALID_FRAGMENT, "invalid fragment") \ - XX(LF_EXPECTED, "LF character expected") \ - XX(INVALID_HEADER_TOKEN, "invalid character in header") \ - XX(INVALID_CONTENT_LENGTH, \ - "invalid character in content-length header") \ - XX(UNEXPECTED_CONTENT_LENGTH, \ - "unexpected content-length header") \ - XX(INVALID_CHUNK_SIZE, \ - "invalid character in chunk size header") \ - XX(INVALID_CONSTANT, "invalid constant string") \ - XX(INVALID_INTERNAL_STATE, "encountered unexpected internal state")\ - XX(STRICT, "strict mode assertion failed") \ - XX(PAUSED, "parser is paused") \ - XX(UNKNOWN, "an unknown error occurred") - - -/* Define HPE_* values for each errno value above */ -#define HTTP_ERRNO_GEN(n, s) HPE_##n, -enum http_errno { - HTTP_ERRNO_MAP(HTTP_ERRNO_GEN) -}; -#undef HTTP_ERRNO_GEN - - -/* Get an http_errno value from an http_parser */ -#define HTTP_PARSER_ERRNO(p) ((enum http_errno) (p)->http_errno) - - -struct http_parser { - /** PRIVATE **/ - unsigned int type : 2; /* enum http_parser_type */ - unsigned int flags : 8; /* F_* values from 'flags' enum; semi-public */ - unsigned int state : 7; /* enum state from http_parser.c */ - unsigned int header_state : 7; /* enum header_state from http_parser.c */ - unsigned int index : 7; /* index into current matcher */ - unsigned int lenient_http_headers : 1; - - uint32_t nread; /* # bytes read in various scenarios */ - uint64_t content_length; /* # bytes in body (0 if no Content-Length header) */ - - /** READ-ONLY **/ - unsigned short http_major; - unsigned short http_minor; - unsigned int status_code : 16; /* responses only */ - unsigned int method : 8; /* requests only */ - unsigned int http_errno : 7; - - /* 1 = Upgrade header was present and the parser has exited because of that. - * 0 = No upgrade header present. - * Should be checked when http_parser_execute() returns in addition to - * error checking. - */ - unsigned int upgrade : 1; - - /** PUBLIC **/ - void *data; /* A pointer to get hook to the "connection" or "socket" object */ -}; - - -struct http_parser_settings { - http_cb on_message_begin; - http_data_cb on_url; - http_data_cb on_status; - http_data_cb on_header_field; - http_data_cb on_header_value; - http_cb on_headers_complete; - http_data_cb on_body; - http_cb on_message_complete; - /* When on_chunk_header is called, the current chunk length is stored - * in parser->content_length. - */ - http_cb on_chunk_header; - http_cb on_chunk_complete; -}; - - -enum http_parser_url_fields - { UF_SCHEMA = 0 - , UF_HOST = 1 - , UF_PORT = 2 - , UF_PATH = 3 - , UF_QUERY = 4 - , UF_FRAGMENT = 5 - , UF_USERINFO = 6 - , UF_MAX = 7 - }; - - -/* Result structure for http_parser_parse_url(). - * - * Callers should index into field_data[] with UF_* values iff field_set - * has the relevant (1 << UF_*) bit set. As a courtesy to clients (and - * because we probably have padding left over), we convert any port to - * a uint16_t. - */ -struct http_parser_url { - uint16_t field_set; /* Bitmask of (1 << UF_*) values */ - uint16_t port; /* Converted UF_PORT string */ - - struct { - uint16_t off; /* Offset into buffer in which field starts */ - uint16_t len; /* Length of run in buffer */ - } field_data[UF_MAX]; -}; - - -/* Returns the library version. Bits 16-23 contain the major version number, - * bits 8-15 the minor version number and bits 0-7 the patch level. - * Usage example: - * - * unsigned long version = http_parser_version(); - * unsigned major = (version >> 16) & 255; - * unsigned minor = (version >> 8) & 255; - * unsigned patch = version & 255; - * printf("http_parser v%u.%u.%u\n", major, minor, patch); - */ -unsigned long http_parser_version(void); - -void http_parser_init(http_parser *parser, enum http_parser_type type); - - -/* Initialize http_parser_settings members to 0 - */ -void http_parser_settings_init(http_parser_settings *settings); - - -/* Executes the parser. Returns number of parsed bytes. Sets - * `parser->http_errno` on error. */ -size_t http_parser_execute(http_parser *parser, - const http_parser_settings *settings, - const char *data, - size_t len); - - -/* If http_should_keep_alive() in the on_headers_complete or - * on_message_complete callback returns 0, then this should be - * the last message on the connection. - * If you are the server, respond with the "Connection: close" header. - * If you are the client, close the connection. - */ -int http_should_keep_alive(const http_parser *parser); - -/* Returns a string version of the HTTP method. */ -const char *http_method_str(enum http_method m); - -/* Return a string name of the given error */ -const char *http_errno_name(enum http_errno err); - -/* Return a string description of the given error */ -const char *http_errno_description(enum http_errno err); - -/* Initialize all http_parser_url members to 0 */ -void http_parser_url_init(struct http_parser_url *u); - -/* Parse a URL; return nonzero on failure */ -int http_parser_parse_url(const char *buf, size_t buflen, - int is_connect, - struct http_parser_url *u); - -/* Pause or un-pause the parser; a nonzero value pauses */ -void http_parser_pause(http_parser *parser, int paused); - -/* Checks if this is the final chunk of the body. */ -int http_body_is_final(const http_parser *parser); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/vendor/http-parser/test.c b/vendor/http-parser/test.c deleted file mode 100644 index f5744aa07f8..00000000000 --- a/vendor/http-parser/test.c +++ /dev/null @@ -1,4226 +0,0 @@ -/* Copyright Joyent, Inc. and other Node contributors. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - */ -#include "http_parser.h" -#include -#include -#include -#include /* rand */ -#include -#include - -#if defined(__APPLE__) -# undef strlcat -# undef strlncpy -# undef strlcpy -#endif /* defined(__APPLE__) */ - -#undef TRUE -#define TRUE 1 -#undef FALSE -#define FALSE 0 - -#define MAX_HEADERS 13 -#define MAX_ELEMENT_SIZE 2048 -#define MAX_CHUNKS 16 - -#define MIN(a,b) ((a) < (b) ? (a) : (b)) - -static http_parser *parser; - -struct message { - const char *name; // for debugging purposes - const char *raw; - enum http_parser_type type; - enum http_method method; - int status_code; - char response_status[MAX_ELEMENT_SIZE]; - char request_path[MAX_ELEMENT_SIZE]; - char request_url[MAX_ELEMENT_SIZE]; - char fragment[MAX_ELEMENT_SIZE]; - char query_string[MAX_ELEMENT_SIZE]; - char body[MAX_ELEMENT_SIZE]; - size_t body_size; - const char *host; - const char *userinfo; - uint16_t port; - int num_headers; - enum { NONE=0, FIELD, VALUE } last_header_element; - char headers [MAX_HEADERS][2][MAX_ELEMENT_SIZE]; - int should_keep_alive; - - int num_chunks; - int num_chunks_complete; - int chunk_lengths[MAX_CHUNKS]; - - const char *upgrade; // upgraded body - - unsigned short http_major; - unsigned short http_minor; - - int message_begin_cb_called; - int headers_complete_cb_called; - int message_complete_cb_called; - int message_complete_on_eof; - int body_is_final; -}; - -static int currently_parsing_eof; - -static struct message messages[5]; -static int num_messages; -static http_parser_settings *current_pause_parser; - -/* * R E Q U E S T S * */ -const struct message requests[] = -#define CURL_GET 0 -{ {.name= "curl get" - ,.type= HTTP_REQUEST - ,.raw= "GET /test HTTP/1.1\r\n" - "User-Agent: curl/7.18.0 (i486-pc-linux-gnu) libcurl/7.18.0 OpenSSL/0.9.8g zlib/1.2.3.3 libidn/1.1\r\n" - "Host: 0.0.0.0=5000\r\n" - "Accept: */*\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/test" - ,.request_url= "/test" - ,.num_headers= 3 - ,.headers= - { { "User-Agent", "curl/7.18.0 (i486-pc-linux-gnu) libcurl/7.18.0 OpenSSL/0.9.8g zlib/1.2.3.3 libidn/1.1" } - , { "Host", "0.0.0.0=5000" } - , { "Accept", "*/*" } - } - ,.body= "" - } - -#define FIREFOX_GET 1 -, {.name= "firefox get" - ,.type= HTTP_REQUEST - ,.raw= "GET /favicon.ico HTTP/1.1\r\n" - "Host: 0.0.0.0=5000\r\n" - "User-Agent: Mozilla/5.0 (X11; U; Linux i686; en-US; rv:1.9) Gecko/2008061015 Firefox/3.0\r\n" - "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" - "Accept-Language: en-us,en;q=0.5\r\n" - "Accept-Encoding: gzip,deflate\r\n" - "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" - "Keep-Alive: 300\r\n" - "Connection: keep-alive\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/favicon.ico" - ,.request_url= "/favicon.ico" - ,.num_headers= 8 - ,.headers= - { { "Host", "0.0.0.0=5000" } - , { "User-Agent", "Mozilla/5.0 (X11; U; Linux i686; en-US; rv:1.9) Gecko/2008061015 Firefox/3.0" } - , { "Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8" } - , { "Accept-Language", "en-us,en;q=0.5" } - , { "Accept-Encoding", "gzip,deflate" } - , { "Accept-Charset", "ISO-8859-1,utf-8;q=0.7,*;q=0.7" } - , { "Keep-Alive", "300" } - , { "Connection", "keep-alive" } - } - ,.body= "" - } - -#define DUMBFUCK 2 -, {.name= "dumbfuck" - ,.type= HTTP_REQUEST - ,.raw= "GET /dumbfuck HTTP/1.1\r\n" - "aaaaaaaaaaaaa:++++++++++\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/dumbfuck" - ,.request_url= "/dumbfuck" - ,.num_headers= 1 - ,.headers= - { { "aaaaaaaaaaaaa", "++++++++++" } - } - ,.body= "" - } - -#define FRAGMENT_IN_URI 3 -, {.name= "fragment in url" - ,.type= HTTP_REQUEST - ,.raw= "GET /forums/1/topics/2375?page=1#posts-17408 HTTP/1.1\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "page=1" - ,.fragment= "posts-17408" - ,.request_path= "/forums/1/topics/2375" - /* XXX request url does include fragment? */ - ,.request_url= "/forums/1/topics/2375?page=1#posts-17408" - ,.num_headers= 0 - ,.body= "" - } - -#define GET_NO_HEADERS_NO_BODY 4 -, {.name= "get no headers no body" - ,.type= HTTP_REQUEST - ,.raw= "GET /get_no_headers_no_body/world HTTP/1.1\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE /* would need Connection: close */ - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/get_no_headers_no_body/world" - ,.request_url= "/get_no_headers_no_body/world" - ,.num_headers= 0 - ,.body= "" - } - -#define GET_ONE_HEADER_NO_BODY 5 -, {.name= "get one header no body" - ,.type= HTTP_REQUEST - ,.raw= "GET /get_one_header_no_body HTTP/1.1\r\n" - "Accept: */*\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE /* would need Connection: close */ - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/get_one_header_no_body" - ,.request_url= "/get_one_header_no_body" - ,.num_headers= 1 - ,.headers= - { { "Accept" , "*/*" } - } - ,.body= "" - } - -#define GET_FUNKY_CONTENT_LENGTH 6 -, {.name= "get funky content length body hello" - ,.type= HTTP_REQUEST - ,.raw= "GET /get_funky_content_length_body_hello HTTP/1.0\r\n" - "conTENT-Length: 5\r\n" - "\r\n" - "HELLO" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 0 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/get_funky_content_length_body_hello" - ,.request_url= "/get_funky_content_length_body_hello" - ,.num_headers= 1 - ,.headers= - { { "conTENT-Length" , "5" } - } - ,.body= "HELLO" - } - -#define POST_IDENTITY_BODY_WORLD 7 -, {.name= "post identity body world" - ,.type= HTTP_REQUEST - ,.raw= "POST /post_identity_body_world?q=search#hey HTTP/1.1\r\n" - "Accept: */*\r\n" - "Transfer-Encoding: identity\r\n" - "Content-Length: 5\r\n" - "\r\n" - "World" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_POST - ,.query_string= "q=search" - ,.fragment= "hey" - ,.request_path= "/post_identity_body_world" - ,.request_url= "/post_identity_body_world?q=search#hey" - ,.num_headers= 3 - ,.headers= - { { "Accept", "*/*" } - , { "Transfer-Encoding", "identity" } - , { "Content-Length", "5" } - } - ,.body= "World" - } - -#define POST_CHUNKED_ALL_YOUR_BASE 8 -, {.name= "post - chunked body: all your base are belong to us" - ,.type= HTTP_REQUEST - ,.raw= "POST /post_chunked_all_your_base HTTP/1.1\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "1e\r\nall your base are belong to us\r\n" - "0\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_POST - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/post_chunked_all_your_base" - ,.request_url= "/post_chunked_all_your_base" - ,.num_headers= 1 - ,.headers= - { { "Transfer-Encoding" , "chunked" } - } - ,.body= "all your base are belong to us" - ,.num_chunks_complete= 2 - ,.chunk_lengths= { 0x1e } - } - -#define TWO_CHUNKS_MULT_ZERO_END 9 -, {.name= "two chunks ; triple zero ending" - ,.type= HTTP_REQUEST - ,.raw= "POST /two_chunks_mult_zero_end HTTP/1.1\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "5\r\nhello\r\n" - "6\r\n world\r\n" - "000\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_POST - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/two_chunks_mult_zero_end" - ,.request_url= "/two_chunks_mult_zero_end" - ,.num_headers= 1 - ,.headers= - { { "Transfer-Encoding", "chunked" } - } - ,.body= "hello world" - ,.num_chunks_complete= 3 - ,.chunk_lengths= { 5, 6 } - } - -#define CHUNKED_W_TRAILING_HEADERS 10 -, {.name= "chunked with trailing headers. blech." - ,.type= HTTP_REQUEST - ,.raw= "POST /chunked_w_trailing_headers HTTP/1.1\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "5\r\nhello\r\n" - "6\r\n world\r\n" - "0\r\n" - "Vary: *\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_POST - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/chunked_w_trailing_headers" - ,.request_url= "/chunked_w_trailing_headers" - ,.num_headers= 3 - ,.headers= - { { "Transfer-Encoding", "chunked" } - , { "Vary", "*" } - , { "Content-Type", "text/plain" } - } - ,.body= "hello world" - ,.num_chunks_complete= 3 - ,.chunk_lengths= { 5, 6 } - } - -#define CHUNKED_W_BULLSHIT_AFTER_LENGTH 11 -, {.name= "with bullshit after the length" - ,.type= HTTP_REQUEST - ,.raw= "POST /chunked_w_bullshit_after_length HTTP/1.1\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "5; ihatew3;whatthefuck=aretheseparametersfor\r\nhello\r\n" - "6; blahblah; blah\r\n world\r\n" - "0\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_POST - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/chunked_w_bullshit_after_length" - ,.request_url= "/chunked_w_bullshit_after_length" - ,.num_headers= 1 - ,.headers= - { { "Transfer-Encoding", "chunked" } - } - ,.body= "hello world" - ,.num_chunks_complete= 3 - ,.chunk_lengths= { 5, 6 } - } - -#define WITH_QUOTES 12 -, {.name= "with quotes" - ,.type= HTTP_REQUEST - ,.raw= "GET /with_\"stupid\"_quotes?foo=\"bar\" HTTP/1.1\r\n\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "foo=\"bar\"" - ,.fragment= "" - ,.request_path= "/with_\"stupid\"_quotes" - ,.request_url= "/with_\"stupid\"_quotes?foo=\"bar\"" - ,.num_headers= 0 - ,.headers= { } - ,.body= "" - } - -#define APACHEBENCH_GET 13 -/* The server receiving this request SHOULD NOT wait for EOF - * to know that content-length == 0. - * How to represent this in a unit test? message_complete_on_eof - * Compare with NO_CONTENT_LENGTH_RESPONSE. - */ -, {.name = "apachebench get" - ,.type= HTTP_REQUEST - ,.raw= "GET /test HTTP/1.0\r\n" - "Host: 0.0.0.0:5000\r\n" - "User-Agent: ApacheBench/2.3\r\n" - "Accept: */*\r\n\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 0 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/test" - ,.request_url= "/test" - ,.num_headers= 3 - ,.headers= { { "Host", "0.0.0.0:5000" } - , { "User-Agent", "ApacheBench/2.3" } - , { "Accept", "*/*" } - } - ,.body= "" - } - -#define QUERY_URL_WITH_QUESTION_MARK_GET 14 -/* Some clients include '?' characters in query strings. - */ -, {.name = "query url with question mark" - ,.type= HTTP_REQUEST - ,.raw= "GET /test.cgi?foo=bar?baz HTTP/1.1\r\n\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "foo=bar?baz" - ,.fragment= "" - ,.request_path= "/test.cgi" - ,.request_url= "/test.cgi?foo=bar?baz" - ,.num_headers= 0 - ,.headers= {} - ,.body= "" - } - -#define PREFIX_NEWLINE_GET 15 -/* Some clients, especially after a POST in a keep-alive connection, - * will send an extra CRLF before the next request - */ -, {.name = "newline prefix get" - ,.type= HTTP_REQUEST - ,.raw= "\r\nGET /test HTTP/1.1\r\n\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/test" - ,.request_url= "/test" - ,.num_headers= 0 - ,.headers= { } - ,.body= "" - } - -#define UPGRADE_REQUEST 16 -, {.name = "upgrade request" - ,.type= HTTP_REQUEST - ,.raw= "GET /demo HTTP/1.1\r\n" - "Host: example.com\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" - "Sec-WebSocket-Protocol: sample\r\n" - "Upgrade: WebSocket\r\n" - "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" - "Origin: http://example.com\r\n" - "\r\n" - "Hot diggity dogg" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/demo" - ,.request_url= "/demo" - ,.num_headers= 7 - ,.upgrade="Hot diggity dogg" - ,.headers= { { "Host", "example.com" } - , { "Connection", "Upgrade" } - , { "Sec-WebSocket-Key2", "12998 5 Y3 1 .P00" } - , { "Sec-WebSocket-Protocol", "sample" } - , { "Upgrade", "WebSocket" } - , { "Sec-WebSocket-Key1", "4 @1 46546xW%0l 1 5" } - , { "Origin", "http://example.com" } - } - ,.body= "" - } - -#define CONNECT_REQUEST 17 -, {.name = "connect request" - ,.type= HTTP_REQUEST - ,.raw= "CONNECT 0-home0.netscape.com:443 HTTP/1.0\r\n" - "User-agent: Mozilla/1.1N\r\n" - "Proxy-authorization: basic aGVsbG86d29ybGQ=\r\n" - "\r\n" - "some data\r\n" - "and yet even more data" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 0 - ,.method= HTTP_CONNECT - ,.query_string= "" - ,.fragment= "" - ,.request_path= "" - ,.request_url= "0-home0.netscape.com:443" - ,.num_headers= 2 - ,.upgrade="some data\r\nand yet even more data" - ,.headers= { { "User-agent", "Mozilla/1.1N" } - , { "Proxy-authorization", "basic aGVsbG86d29ybGQ=" } - } - ,.body= "" - } - -#define REPORT_REQ 18 -, {.name= "report request" - ,.type= HTTP_REQUEST - ,.raw= "REPORT /test HTTP/1.1\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_REPORT - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/test" - ,.request_url= "/test" - ,.num_headers= 0 - ,.headers= {} - ,.body= "" - } - -#define NO_HTTP_VERSION 19 -, {.name= "request with no http version" - ,.type= HTTP_REQUEST - ,.raw= "GET /\r\n" - "\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 0 - ,.http_minor= 9 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/" - ,.request_url= "/" - ,.num_headers= 0 - ,.headers= {} - ,.body= "" - } - -#define MSEARCH_REQ 20 -, {.name= "m-search request" - ,.type= HTTP_REQUEST - ,.raw= "M-SEARCH * HTTP/1.1\r\n" - "HOST: 239.255.255.250:1900\r\n" - "MAN: \"ssdp:discover\"\r\n" - "ST: \"ssdp:all\"\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_MSEARCH - ,.query_string= "" - ,.fragment= "" - ,.request_path= "*" - ,.request_url= "*" - ,.num_headers= 3 - ,.headers= { { "HOST", "239.255.255.250:1900" } - , { "MAN", "\"ssdp:discover\"" } - , { "ST", "\"ssdp:all\"" } - } - ,.body= "" - } - -#define LINE_FOLDING_IN_HEADER 21 -, {.name= "line folding in header value" - ,.type= HTTP_REQUEST - ,.raw= "GET / HTTP/1.1\r\n" - "Line1: abc\r\n" - "\tdef\r\n" - " ghi\r\n" - "\t\tjkl\r\n" - " mno \r\n" - "\t \tqrs\r\n" - "Line2: \t line2\t\r\n" - "Line3:\r\n" - " line3\r\n" - "Line4: \r\n" - " \r\n" - "Connection:\r\n" - " close\r\n" - "\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/" - ,.request_url= "/" - ,.num_headers= 5 - ,.headers= { { "Line1", "abc\tdef ghi\t\tjkl mno \t \tqrs" } - , { "Line2", "line2\t" } - , { "Line3", "line3" } - , { "Line4", "" } - , { "Connection", "close" }, - } - ,.body= "" - } - - -#define QUERY_TERMINATED_HOST 22 -, {.name= "host terminated by a query string" - ,.type= HTTP_REQUEST - ,.raw= "GET http://hypnotoad.org?hail=all HTTP/1.1\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "hail=all" - ,.fragment= "" - ,.request_path= "" - ,.request_url= "http://hypnotoad.org?hail=all" - ,.host= "hypnotoad.org" - ,.num_headers= 0 - ,.headers= { } - ,.body= "" - } - -#define QUERY_TERMINATED_HOSTPORT 23 -, {.name= "host:port terminated by a query string" - ,.type= HTTP_REQUEST - ,.raw= "GET http://hypnotoad.org:1234?hail=all HTTP/1.1\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "hail=all" - ,.fragment= "" - ,.request_path= "" - ,.request_url= "http://hypnotoad.org:1234?hail=all" - ,.host= "hypnotoad.org" - ,.port= 1234 - ,.num_headers= 0 - ,.headers= { } - ,.body= "" - } - -#define SPACE_TERMINATED_HOSTPORT 24 -, {.name= "host:port terminated by a space" - ,.type= HTTP_REQUEST - ,.raw= "GET http://hypnotoad.org:1234 HTTP/1.1\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "" - ,.request_url= "http://hypnotoad.org:1234" - ,.host= "hypnotoad.org" - ,.port= 1234 - ,.num_headers= 0 - ,.headers= { } - ,.body= "" - } - -#define PATCH_REQ 25 -, {.name = "PATCH request" - ,.type= HTTP_REQUEST - ,.raw= "PATCH /file.txt HTTP/1.1\r\n" - "Host: www.example.com\r\n" - "Content-Type: application/example\r\n" - "If-Match: \"e0023aa4e\"\r\n" - "Content-Length: 10\r\n" - "\r\n" - "cccccccccc" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_PATCH - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/file.txt" - ,.request_url= "/file.txt" - ,.num_headers= 4 - ,.headers= { { "Host", "www.example.com" } - , { "Content-Type", "application/example" } - , { "If-Match", "\"e0023aa4e\"" } - , { "Content-Length", "10" } - } - ,.body= "cccccccccc" - } - -#define CONNECT_CAPS_REQUEST 26 -, {.name = "connect caps request" - ,.type= HTTP_REQUEST - ,.raw= "CONNECT HOME0.NETSCAPE.COM:443 HTTP/1.0\r\n" - "User-agent: Mozilla/1.1N\r\n" - "Proxy-authorization: basic aGVsbG86d29ybGQ=\r\n" - "\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 0 - ,.method= HTTP_CONNECT - ,.query_string= "" - ,.fragment= "" - ,.request_path= "" - ,.request_url= "HOME0.NETSCAPE.COM:443" - ,.num_headers= 2 - ,.upgrade="" - ,.headers= { { "User-agent", "Mozilla/1.1N" } - , { "Proxy-authorization", "basic aGVsbG86d29ybGQ=" } - } - ,.body= "" - } - -#if !HTTP_PARSER_STRICT -#define UTF8_PATH_REQ 27 -, {.name= "utf-8 path request" - ,.type= HTTP_REQUEST - ,.raw= "GET /δ¶/δt/pope?q=1#narf HTTP/1.1\r\n" - "Host: github.com\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "q=1" - ,.fragment= "narf" - ,.request_path= "/δ¶/δt/pope" - ,.request_url= "/δ¶/δt/pope?q=1#narf" - ,.num_headers= 1 - ,.headers= { {"Host", "github.com" } - } - ,.body= "" - } - -#define HOSTNAME_UNDERSCORE 28 -, {.name = "hostname underscore" - ,.type= HTTP_REQUEST - ,.raw= "CONNECT home_0.netscape.com:443 HTTP/1.0\r\n" - "User-agent: Mozilla/1.1N\r\n" - "Proxy-authorization: basic aGVsbG86d29ybGQ=\r\n" - "\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 0 - ,.method= HTTP_CONNECT - ,.query_string= "" - ,.fragment= "" - ,.request_path= "" - ,.request_url= "home_0.netscape.com:443" - ,.num_headers= 2 - ,.upgrade="" - ,.headers= { { "User-agent", "Mozilla/1.1N" } - , { "Proxy-authorization", "basic aGVsbG86d29ybGQ=" } - } - ,.body= "" - } -#endif /* !HTTP_PARSER_STRICT */ - -/* see https://github.com/ry/http-parser/issues/47 */ -#define EAT_TRAILING_CRLF_NO_CONNECTION_CLOSE 29 -, {.name = "eat CRLF between requests, no \"Connection: close\" header" - ,.raw= "POST / HTTP/1.1\r\n" - "Host: www.example.com\r\n" - "Content-Type: application/x-www-form-urlencoded\r\n" - "Content-Length: 4\r\n" - "\r\n" - "q=42\r\n" /* note the trailing CRLF */ - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_POST - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/" - ,.request_url= "/" - ,.num_headers= 3 - ,.upgrade= 0 - ,.headers= { { "Host", "www.example.com" } - , { "Content-Type", "application/x-www-form-urlencoded" } - , { "Content-Length", "4" } - } - ,.body= "q=42" - } - -/* see https://github.com/ry/http-parser/issues/47 */ -#define EAT_TRAILING_CRLF_WITH_CONNECTION_CLOSE 30 -, {.name = "eat CRLF between requests even if \"Connection: close\" is set" - ,.raw= "POST / HTTP/1.1\r\n" - "Host: www.example.com\r\n" - "Content-Type: application/x-www-form-urlencoded\r\n" - "Content-Length: 4\r\n" - "Connection: close\r\n" - "\r\n" - "q=42\r\n" /* note the trailing CRLF */ - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE /* input buffer isn't empty when on_message_complete is called */ - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_POST - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/" - ,.request_url= "/" - ,.num_headers= 4 - ,.upgrade= 0 - ,.headers= { { "Host", "www.example.com" } - , { "Content-Type", "application/x-www-form-urlencoded" } - , { "Content-Length", "4" } - , { "Connection", "close" } - } - ,.body= "q=42" - } - -#define PURGE_REQ 31 -, {.name = "PURGE request" - ,.type= HTTP_REQUEST - ,.raw= "PURGE /file.txt HTTP/1.1\r\n" - "Host: www.example.com\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_PURGE - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/file.txt" - ,.request_url= "/file.txt" - ,.num_headers= 1 - ,.headers= { { "Host", "www.example.com" } } - ,.body= "" - } - -#define SEARCH_REQ 32 -, {.name = "SEARCH request" - ,.type= HTTP_REQUEST - ,.raw= "SEARCH / HTTP/1.1\r\n" - "Host: www.example.com\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_SEARCH - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/" - ,.request_url= "/" - ,.num_headers= 1 - ,.headers= { { "Host", "www.example.com" } } - ,.body= "" - } - -#define PROXY_WITH_BASIC_AUTH 33 -, {.name= "host:port and basic_auth" - ,.type= HTTP_REQUEST - ,.raw= "GET http://a%12:b!&*$@hypnotoad.org:1234/toto HTTP/1.1\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.fragment= "" - ,.request_path= "/toto" - ,.request_url= "http://a%12:b!&*$@hypnotoad.org:1234/toto" - ,.host= "hypnotoad.org" - ,.userinfo= "a%12:b!&*$" - ,.port= 1234 - ,.num_headers= 0 - ,.headers= { } - ,.body= "" - } - -#define LINE_FOLDING_IN_HEADER_WITH_LF 34 -, {.name= "line folding in header value" - ,.type= HTTP_REQUEST - ,.raw= "GET / HTTP/1.1\n" - "Line1: abc\n" - "\tdef\n" - " ghi\n" - "\t\tjkl\n" - " mno \n" - "\t \tqrs\n" - "Line2: \t line2\t\n" - "Line3:\n" - " line3\n" - "Line4: \n" - " \n" - "Connection:\n" - " close\n" - "\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/" - ,.request_url= "/" - ,.num_headers= 5 - ,.headers= { { "Line1", "abc\tdef ghi\t\tjkl mno \t \tqrs" } - , { "Line2", "line2\t" } - , { "Line3", "line3" } - , { "Line4", "" } - , { "Connection", "close" }, - } - ,.body= "" - } - -#define CONNECTION_MULTI 35 -, {.name = "multiple connection header values with folding" - ,.type= HTTP_REQUEST - ,.raw= "GET /demo HTTP/1.1\r\n" - "Host: example.com\r\n" - "Connection: Something,\r\n" - " Upgrade, ,Keep-Alive\r\n" - "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" - "Sec-WebSocket-Protocol: sample\r\n" - "Upgrade: WebSocket\r\n" - "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" - "Origin: http://example.com\r\n" - "\r\n" - "Hot diggity dogg" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/demo" - ,.request_url= "/demo" - ,.num_headers= 7 - ,.upgrade="Hot diggity dogg" - ,.headers= { { "Host", "example.com" } - , { "Connection", "Something, Upgrade, ,Keep-Alive" } - , { "Sec-WebSocket-Key2", "12998 5 Y3 1 .P00" } - , { "Sec-WebSocket-Protocol", "sample" } - , { "Upgrade", "WebSocket" } - , { "Sec-WebSocket-Key1", "4 @1 46546xW%0l 1 5" } - , { "Origin", "http://example.com" } - } - ,.body= "" - } - -#define CONNECTION_MULTI_LWS 36 -, {.name = "multiple connection header values with folding and lws" - ,.type= HTTP_REQUEST - ,.raw= "GET /demo HTTP/1.1\r\n" - "Connection: keep-alive, upgrade\r\n" - "Upgrade: WebSocket\r\n" - "\r\n" - "Hot diggity dogg" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/demo" - ,.request_url= "/demo" - ,.num_headers= 2 - ,.upgrade="Hot diggity dogg" - ,.headers= { { "Connection", "keep-alive, upgrade" } - , { "Upgrade", "WebSocket" } - } - ,.body= "" - } - -#define CONNECTION_MULTI_LWS_CRLF 37 -, {.name = "multiple connection header values with folding and lws" - ,.type= HTTP_REQUEST - ,.raw= "GET /demo HTTP/1.1\r\n" - "Connection: keep-alive, \r\n upgrade\r\n" - "Upgrade: WebSocket\r\n" - "\r\n" - "Hot diggity dogg" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_GET - ,.query_string= "" - ,.fragment= "" - ,.request_path= "/demo" - ,.request_url= "/demo" - ,.num_headers= 2 - ,.upgrade="Hot diggity dogg" - ,.headers= { { "Connection", "keep-alive, upgrade" } - , { "Upgrade", "WebSocket" } - } - ,.body= "" - } - -#define UPGRADE_POST_REQUEST 38 -, {.name = "upgrade post request" - ,.type= HTTP_REQUEST - ,.raw= "POST /demo HTTP/1.1\r\n" - "Host: example.com\r\n" - "Connection: Upgrade\r\n" - "Upgrade: HTTP/2.0\r\n" - "Content-Length: 15\r\n" - "\r\n" - "sweet post body" - "Hot diggity dogg" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_POST - ,.request_path= "/demo" - ,.request_url= "/demo" - ,.num_headers= 4 - ,.upgrade="Hot diggity dogg" - ,.headers= { { "Host", "example.com" } - , { "Connection", "Upgrade" } - , { "Upgrade", "HTTP/2.0" } - , { "Content-Length", "15" } - } - ,.body= "sweet post body" - } - -#define CONNECT_WITH_BODY_REQUEST 39 -, {.name = "connect with body request" - ,.type= HTTP_REQUEST - ,.raw= "CONNECT foo.bar.com:443 HTTP/1.0\r\n" - "User-agent: Mozilla/1.1N\r\n" - "Proxy-authorization: basic aGVsbG86d29ybGQ=\r\n" - "Content-Length: 10\r\n" - "\r\n" - "blarfcicle" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 0 - ,.method= HTTP_CONNECT - ,.request_url= "foo.bar.com:443" - ,.num_headers= 3 - ,.upgrade="blarfcicle" - ,.headers= { { "User-agent", "Mozilla/1.1N" } - , { "Proxy-authorization", "basic aGVsbG86d29ybGQ=" } - , { "Content-Length", "10" } - } - ,.body= "" - } - -/* Examples from the Internet draft for LINK/UNLINK methods: - * https://tools.ietf.org/id/draft-snell-link-method-01.html#rfc.section.5 - */ - -#define LINK_REQUEST 40 -, {.name = "link request" - ,.type= HTTP_REQUEST - ,.raw= "LINK /images/my_dog.jpg HTTP/1.1\r\n" - "Host: example.com\r\n" - "Link: ; rel=\"tag\"\r\n" - "Link: ; rel=\"tag\"\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_LINK - ,.request_path= "/images/my_dog.jpg" - ,.request_url= "/images/my_dog.jpg" - ,.query_string= "" - ,.fragment= "" - ,.num_headers= 3 - ,.headers= { { "Host", "example.com" } - , { "Link", "; rel=\"tag\"" } - , { "Link", "; rel=\"tag\"" } - } - ,.body= "" - } - -#define UNLINK_REQUEST 41 -, {.name = "link request" - ,.type= HTTP_REQUEST - ,.raw= "UNLINK /images/my_dog.jpg HTTP/1.1\r\n" - "Host: example.com\r\n" - "Link: ; rel=\"tag\"\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.method= HTTP_UNLINK - ,.request_path= "/images/my_dog.jpg" - ,.request_url= "/images/my_dog.jpg" - ,.query_string= "" - ,.fragment= "" - ,.num_headers= 2 - ,.headers= { { "Host", "example.com" } - , { "Link", "; rel=\"tag\"" } - } - ,.body= "" - } - -, {.name= NULL } /* sentinel */ -}; - -/* * R E S P O N S E S * */ -const struct message responses[] = -#define GOOGLE_301 0 -{ {.name= "google 301" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 301 Moved Permanently\r\n" - "Location: http://www.google.com/\r\n" - "Content-Type: text/html; charset=UTF-8\r\n" - "Date: Sun, 26 Apr 2009 11:11:49 GMT\r\n" - "Expires: Tue, 26 May 2009 11:11:49 GMT\r\n" - "X-$PrototypeBI-Version: 1.6.0.3\r\n" /* $ char in header field */ - "Cache-Control: public, max-age=2592000\r\n" - "Server: gws\r\n" - "Content-Length: 219 \r\n" - "\r\n" - "\n" - "301 Moved\n" - "

    301 Moved

    \n" - "The document has moved\n" - "here.\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 301 - ,.response_status= "Moved Permanently" - ,.num_headers= 8 - ,.headers= - { { "Location", "http://www.google.com/" } - , { "Content-Type", "text/html; charset=UTF-8" } - , { "Date", "Sun, 26 Apr 2009 11:11:49 GMT" } - , { "Expires", "Tue, 26 May 2009 11:11:49 GMT" } - , { "X-$PrototypeBI-Version", "1.6.0.3" } - , { "Cache-Control", "public, max-age=2592000" } - , { "Server", "gws" } - , { "Content-Length", "219 " } - } - ,.body= "\n" - "301 Moved\n" - "

    301 Moved

    \n" - "The document has moved\n" - "here.\r\n" - "\r\n" - } - -#define NO_CONTENT_LENGTH_RESPONSE 1 -/* The client should wait for the server's EOF. That is, when content-length - * is not specified, and "Connection: close", the end of body is specified - * by the EOF. - * Compare with APACHEBENCH_GET - */ -, {.name= "no content-length response" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 OK\r\n" - "Date: Tue, 04 Aug 2009 07:59:32 GMT\r\n" - "Server: Apache\r\n" - "X-Powered-By: Servlet/2.5 JSP/2.1\r\n" - "Content-Type: text/xml; charset=utf-8\r\n" - "Connection: close\r\n" - "\r\n" - "\n" - "\n" - " \n" - " \n" - " SOAP-ENV:Client\n" - " Client Error\n" - " \n" - " \n" - "" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= TRUE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 5 - ,.headers= - { { "Date", "Tue, 04 Aug 2009 07:59:32 GMT" } - , { "Server", "Apache" } - , { "X-Powered-By", "Servlet/2.5 JSP/2.1" } - , { "Content-Type", "text/xml; charset=utf-8" } - , { "Connection", "close" } - } - ,.body= "\n" - "\n" - " \n" - " \n" - " SOAP-ENV:Client\n" - " Client Error\n" - " \n" - " \n" - "" - } - -#define NO_HEADERS_NO_BODY_404 2 -, {.name= "404 no headers no body" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 404 Not Found\r\n\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= TRUE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 404 - ,.response_status= "Not Found" - ,.num_headers= 0 - ,.headers= {} - ,.body_size= 0 - ,.body= "" - } - -#define NO_REASON_PHRASE 3 -, {.name= "301 no response phrase" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 301\r\n\r\n" - ,.should_keep_alive = FALSE - ,.message_complete_on_eof= TRUE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 301 - ,.response_status= "" - ,.num_headers= 0 - ,.headers= {} - ,.body= "" - } - -#define TRAILING_SPACE_ON_CHUNKED_BODY 4 -, {.name="200 trailing space on chunked body" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 OK\r\n" - "Content-Type: text/plain\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "25 \r\n" - "This is the data in the first chunk\r\n" - "\r\n" - "1C\r\n" - "and this is the second one\r\n" - "\r\n" - "0 \r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 2 - ,.headers= - { {"Content-Type", "text/plain" } - , {"Transfer-Encoding", "chunked" } - } - ,.body_size = 37+28 - ,.body = - "This is the data in the first chunk\r\n" - "and this is the second one\r\n" - ,.num_chunks_complete= 3 - ,.chunk_lengths= { 0x25, 0x1c } - } - -#define NO_CARRIAGE_RET 5 -, {.name="no carriage ret" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 OK\n" - "Content-Type: text/html; charset=utf-8\n" - "Connection: close\n" - "\n" - "these headers are from http://news.ycombinator.com/" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= TRUE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 2 - ,.headers= - { {"Content-Type", "text/html; charset=utf-8" } - , {"Connection", "close" } - } - ,.body= "these headers are from http://news.ycombinator.com/" - } - -#define PROXY_CONNECTION 6 -, {.name="proxy connection" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 OK\r\n" - "Content-Type: text/html; charset=UTF-8\r\n" - "Content-Length: 11\r\n" - "Proxy-Connection: close\r\n" - "Date: Thu, 31 Dec 2009 20:55:48 +0000\r\n" - "\r\n" - "hello world" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 4 - ,.headers= - { {"Content-Type", "text/html; charset=UTF-8" } - , {"Content-Length", "11" } - , {"Proxy-Connection", "close" } - , {"Date", "Thu, 31 Dec 2009 20:55:48 +0000"} - } - ,.body= "hello world" - } - -#define UNDERSTORE_HEADER_KEY 7 - // shown by - // curl -o /dev/null -v "http://ad.doubleclick.net/pfadx/DARTSHELLCONFIGXML;dcmt=text/xml;" -, {.name="underscore header key" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 OK\r\n" - "Server: DCLK-AdSvr\r\n" - "Content-Type: text/xml\r\n" - "Content-Length: 0\r\n" - "DCLK_imp: v7;x;114750856;0-0;0;17820020;0/0;21603567/21621457/1;;~okv=;dcmt=text/xml;;~cs=o\r\n\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 4 - ,.headers= - { {"Server", "DCLK-AdSvr" } - , {"Content-Type", "text/xml" } - , {"Content-Length", "0" } - , {"DCLK_imp", "v7;x;114750856;0-0;0;17820020;0/0;21603567/21621457/1;;~okv=;dcmt=text/xml;;~cs=o" } - } - ,.body= "" - } - -#define BONJOUR_MADAME_FR 8 -/* The client should not merge two headers fields when the first one doesn't - * have a value. - */ -, {.name= "bonjourmadame.fr" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.0 301 Moved Permanently\r\n" - "Date: Thu, 03 Jun 2010 09:56:32 GMT\r\n" - "Server: Apache/2.2.3 (Red Hat)\r\n" - "Cache-Control: public\r\n" - "Pragma: \r\n" - "Location: http://www.bonjourmadame.fr/\r\n" - "Vary: Accept-Encoding\r\n" - "Content-Length: 0\r\n" - "Content-Type: text/html; charset=UTF-8\r\n" - "Connection: keep-alive\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 0 - ,.status_code= 301 - ,.response_status= "Moved Permanently" - ,.num_headers= 9 - ,.headers= - { { "Date", "Thu, 03 Jun 2010 09:56:32 GMT" } - , { "Server", "Apache/2.2.3 (Red Hat)" } - , { "Cache-Control", "public" } - , { "Pragma", "" } - , { "Location", "http://www.bonjourmadame.fr/" } - , { "Vary", "Accept-Encoding" } - , { "Content-Length", "0" } - , { "Content-Type", "text/html; charset=UTF-8" } - , { "Connection", "keep-alive" } - } - ,.body= "" - } - -#define RES_FIELD_UNDERSCORE 9 -/* Should handle spaces in header fields */ -, {.name= "field underscore" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 OK\r\n" - "Date: Tue, 28 Sep 2010 01:14:13 GMT\r\n" - "Server: Apache\r\n" - "Cache-Control: no-cache, must-revalidate\r\n" - "Expires: Mon, 26 Jul 1997 05:00:00 GMT\r\n" - ".et-Cookie: PlaxoCS=1274804622353690521; path=/; domain=.plaxo.com\r\n" - "Vary: Accept-Encoding\r\n" - "_eep-Alive: timeout=45\r\n" /* semantic value ignored */ - "_onnection: Keep-Alive\r\n" /* semantic value ignored */ - "Transfer-Encoding: chunked\r\n" - "Content-Type: text/html\r\n" - "Connection: close\r\n" - "\r\n" - "0\r\n\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 11 - ,.headers= - { { "Date", "Tue, 28 Sep 2010 01:14:13 GMT" } - , { "Server", "Apache" } - , { "Cache-Control", "no-cache, must-revalidate" } - , { "Expires", "Mon, 26 Jul 1997 05:00:00 GMT" } - , { ".et-Cookie", "PlaxoCS=1274804622353690521; path=/; domain=.plaxo.com" } - , { "Vary", "Accept-Encoding" } - , { "_eep-Alive", "timeout=45" } - , { "_onnection", "Keep-Alive" } - , { "Transfer-Encoding", "chunked" } - , { "Content-Type", "text/html" } - , { "Connection", "close" } - } - ,.body= "" - ,.num_chunks_complete= 1 - ,.chunk_lengths= {} - } - -#define NON_ASCII_IN_STATUS_LINE 10 -/* Should handle non-ASCII in status line */ -, {.name= "non-ASCII in status line" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 500 Oriëntatieprobleem\r\n" - "Date: Fri, 5 Nov 2010 23:07:12 GMT+2\r\n" - "Content-Length: 0\r\n" - "Connection: close\r\n" - "\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 500 - ,.response_status= "Oriëntatieprobleem" - ,.num_headers= 3 - ,.headers= - { { "Date", "Fri, 5 Nov 2010 23:07:12 GMT+2" } - , { "Content-Length", "0" } - , { "Connection", "close" } - } - ,.body= "" - } - -#define HTTP_VERSION_0_9 11 -/* Should handle HTTP/0.9 */ -, {.name= "http version 0.9" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/0.9 200 OK\r\n" - "\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= TRUE - ,.http_major= 0 - ,.http_minor= 9 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 0 - ,.headers= - {} - ,.body= "" - } - -#define NO_CONTENT_LENGTH_NO_TRANSFER_ENCODING_RESPONSE 12 -/* The client should wait for the server's EOF. That is, when neither - * content-length nor transfer-encoding is specified, the end of body - * is specified by the EOF. - */ -, {.name= "neither content-length nor transfer-encoding response" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 OK\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "hello world" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= TRUE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 1 - ,.headers= - { { "Content-Type", "text/plain" } - } - ,.body= "hello world" - } - -#define NO_BODY_HTTP10_KA_200 13 -, {.name= "HTTP/1.0 with keep-alive and EOF-terminated 200 status" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.0 200 OK\r\n" - "Connection: keep-alive\r\n" - "\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= TRUE - ,.http_major= 1 - ,.http_minor= 0 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 1 - ,.headers= - { { "Connection", "keep-alive" } - } - ,.body_size= 0 - ,.body= "" - } - -#define NO_BODY_HTTP10_KA_204 14 -, {.name= "HTTP/1.0 with keep-alive and a 204 status" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.0 204 No content\r\n" - "Connection: keep-alive\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 0 - ,.status_code= 204 - ,.response_status= "No content" - ,.num_headers= 1 - ,.headers= - { { "Connection", "keep-alive" } - } - ,.body_size= 0 - ,.body= "" - } - -#define NO_BODY_HTTP11_KA_200 15 -, {.name= "HTTP/1.1 with an EOF-terminated 200 status" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 OK\r\n" - "\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= TRUE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 0 - ,.headers={} - ,.body_size= 0 - ,.body= "" - } - -#define NO_BODY_HTTP11_KA_204 16 -, {.name= "HTTP/1.1 with a 204 status" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 204 No content\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 204 - ,.response_status= "No content" - ,.num_headers= 0 - ,.headers={} - ,.body_size= 0 - ,.body= "" - } - -#define NO_BODY_HTTP11_NOKA_204 17 -, {.name= "HTTP/1.1 with a 204 status and keep-alive disabled" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 204 No content\r\n" - "Connection: close\r\n" - "\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 204 - ,.response_status= "No content" - ,.num_headers= 1 - ,.headers= - { { "Connection", "close" } - } - ,.body_size= 0 - ,.body= "" - } - -#define NO_BODY_HTTP11_KA_CHUNKED_200 18 -, {.name= "HTTP/1.1 with chunked endocing and a 200 response" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 OK\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "0\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 1 - ,.headers= - { { "Transfer-Encoding", "chunked" } - } - ,.body_size= 0 - ,.body= "" - ,.num_chunks_complete= 1 - } - -#if !HTTP_PARSER_STRICT -#define SPACE_IN_FIELD_RES 19 -/* Should handle spaces in header fields */ -, {.name= "field space" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 OK\r\n" - "Server: Microsoft-IIS/6.0\r\n" - "X-Powered-By: ASP.NET\r\n" - "en-US Content-Type: text/xml\r\n" /* this is the problem */ - "Content-Type: text/xml\r\n" - "Content-Length: 16\r\n" - "Date: Fri, 23 Jul 2010 18:45:38 GMT\r\n" - "Connection: keep-alive\r\n" - "\r\n" - "hello" /* fake body */ - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 7 - ,.headers= - { { "Server", "Microsoft-IIS/6.0" } - , { "X-Powered-By", "ASP.NET" } - , { "en-US Content-Type", "text/xml" } - , { "Content-Type", "text/xml" } - , { "Content-Length", "16" } - , { "Date", "Fri, 23 Jul 2010 18:45:38 GMT" } - , { "Connection", "keep-alive" } - } - ,.body= "hello" - } -#endif /* !HTTP_PARSER_STRICT */ - -#define AMAZON_COM 20 -, {.name= "amazon.com" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 301 MovedPermanently\r\n" - "Date: Wed, 15 May 2013 17:06:33 GMT\r\n" - "Server: Server\r\n" - "x-amz-id-1: 0GPHKXSJQ826RK7GZEB2\r\n" - "p3p: policyref=\"http://www.amazon.com/w3c/p3p.xml\",CP=\"CAO DSP LAW CUR ADM IVAo IVDo CONo OTPo OUR DELi PUBi OTRi BUS PHY ONL UNI PUR FIN COM NAV INT DEM CNT STA HEA PRE LOC GOV OTC \"\r\n" - "x-amz-id-2: STN69VZxIFSz9YJLbz1GDbxpbjG6Qjmmq5E3DxRhOUw+Et0p4hr7c/Q8qNcx4oAD\r\n" - "Location: http://www.amazon.com/Dan-Brown/e/B000AP9DSU/ref=s9_pop_gw_al1?_encoding=UTF8&refinementId=618073011&pf_rd_m=ATVPDKIKX0DER&pf_rd_s=center-2&pf_rd_r=0SHYY5BZXN3KR20BNFAY&pf_rd_t=101&pf_rd_p=1263340922&pf_rd_i=507846\r\n" - "Vary: Accept-Encoding,User-Agent\r\n" - "Content-Type: text/html; charset=ISO-8859-1\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "1\r\n" - "\n\r\n" - "0\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 301 - ,.response_status= "MovedPermanently" - ,.num_headers= 9 - ,.headers= { { "Date", "Wed, 15 May 2013 17:06:33 GMT" } - , { "Server", "Server" } - , { "x-amz-id-1", "0GPHKXSJQ826RK7GZEB2" } - , { "p3p", "policyref=\"http://www.amazon.com/w3c/p3p.xml\",CP=\"CAO DSP LAW CUR ADM IVAo IVDo CONo OTPo OUR DELi PUBi OTRi BUS PHY ONL UNI PUR FIN COM NAV INT DEM CNT STA HEA PRE LOC GOV OTC \"" } - , { "x-amz-id-2", "STN69VZxIFSz9YJLbz1GDbxpbjG6Qjmmq5E3DxRhOUw+Et0p4hr7c/Q8qNcx4oAD" } - , { "Location", "http://www.amazon.com/Dan-Brown/e/B000AP9DSU/ref=s9_pop_gw_al1?_encoding=UTF8&refinementId=618073011&pf_rd_m=ATVPDKIKX0DER&pf_rd_s=center-2&pf_rd_r=0SHYY5BZXN3KR20BNFAY&pf_rd_t=101&pf_rd_p=1263340922&pf_rd_i=507846" } - , { "Vary", "Accept-Encoding,User-Agent" } - , { "Content-Type", "text/html; charset=ISO-8859-1" } - , { "Transfer-Encoding", "chunked" } - } - ,.body= "\n" - ,.num_chunks_complete= 2 - ,.chunk_lengths= { 1 } - } - -#define EMPTY_REASON_PHRASE_AFTER_SPACE 20 -, {.name= "empty reason phrase after space" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 \r\n" - "\r\n" - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= TRUE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "" - ,.num_headers= 0 - ,.headers= {} - ,.body= "" - } - -#define CONTENT_LENGTH_X 21 -, {.name= "Content-Length-X" - ,.type= HTTP_RESPONSE - ,.raw= "HTTP/1.1 200 OK\r\n" - "Content-Length-X: 0\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "2\r\n" - "OK\r\n" - "0\r\n" - "\r\n" - ,.should_keep_alive= TRUE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 1 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 2 - ,.headers= { { "Content-Length-X", "0" } - , { "Transfer-Encoding", "chunked" } - } - ,.body= "OK" - ,.num_chunks_complete= 2 - ,.chunk_lengths= { 2 } - } - -, {.name= NULL } /* sentinel */ -}; - -/* strnlen() is a POSIX.2008 addition. Can't rely on it being available so - * define it ourselves. - */ -size_t -strnlen(const char *s, size_t maxlen) -{ - const char *p; - - p = memchr(s, '\0', maxlen); - if (p == NULL) - return maxlen; - - return p - s; -} - -size_t -strlncat(char *dst, size_t len, const char *src, size_t n) -{ - size_t slen; - size_t dlen; - size_t rlen; - size_t ncpy; - - slen = strnlen(src, n); - dlen = strnlen(dst, len); - - if (dlen < len) { - rlen = len - dlen; - ncpy = slen < rlen ? slen : (rlen - 1); - memcpy(dst + dlen, src, ncpy); - dst[dlen + ncpy] = '\0'; - } - - assert(len > slen + dlen); - return slen + dlen; -} - -size_t -strlcat(char *dst, const char *src, size_t len) -{ - return strlncat(dst, len, src, (size_t) -1); -} - -size_t -strlncpy(char *dst, size_t len, const char *src, size_t n) -{ - size_t slen; - size_t ncpy; - - slen = strnlen(src, n); - - if (len > 0) { - ncpy = slen < len ? slen : (len - 1); - memcpy(dst, src, ncpy); - dst[ncpy] = '\0'; - } - - assert(len > slen); - return slen; -} - -size_t -strlcpy(char *dst, const char *src, size_t len) -{ - return strlncpy(dst, len, src, (size_t) -1); -} - -int -request_url_cb (http_parser *p, const char *buf, size_t len) -{ - assert(p == parser); - strlncat(messages[num_messages].request_url, - sizeof(messages[num_messages].request_url), - buf, - len); - return 0; -} - -int -header_field_cb (http_parser *p, const char *buf, size_t len) -{ - assert(p == parser); - struct message *m = &messages[num_messages]; - - if (m->last_header_element != FIELD) - m->num_headers++; - - strlncat(m->headers[m->num_headers-1][0], - sizeof(m->headers[m->num_headers-1][0]), - buf, - len); - - m->last_header_element = FIELD; - - return 0; -} - -int -header_value_cb (http_parser *p, const char *buf, size_t len) -{ - assert(p == parser); - struct message *m = &messages[num_messages]; - - strlncat(m->headers[m->num_headers-1][1], - sizeof(m->headers[m->num_headers-1][1]), - buf, - len); - - m->last_header_element = VALUE; - - return 0; -} - -void -check_body_is_final (const http_parser *p) -{ - if (messages[num_messages].body_is_final) { - fprintf(stderr, "\n\n *** Error http_body_is_final() should return 1 " - "on last on_body callback call " - "but it doesn't! ***\n\n"); - assert(0); - abort(); - } - messages[num_messages].body_is_final = http_body_is_final(p); -} - -int -body_cb (http_parser *p, const char *buf, size_t len) -{ - assert(p == parser); - strlncat(messages[num_messages].body, - sizeof(messages[num_messages].body), - buf, - len); - messages[num_messages].body_size += len; - check_body_is_final(p); - // printf("body_cb: '%s'\n", requests[num_messages].body); - return 0; -} - -int -count_body_cb (http_parser *p, const char *buf, size_t len) -{ - assert(p == parser); - assert(buf); - messages[num_messages].body_size += len; - check_body_is_final(p); - return 0; -} - -int -message_begin_cb (http_parser *p) -{ - assert(p == parser); - messages[num_messages].message_begin_cb_called = TRUE; - return 0; -} - -int -headers_complete_cb (http_parser *p) -{ - assert(p == parser); - messages[num_messages].method = parser->method; - messages[num_messages].status_code = parser->status_code; - messages[num_messages].http_major = parser->http_major; - messages[num_messages].http_minor = parser->http_minor; - messages[num_messages].headers_complete_cb_called = TRUE; - messages[num_messages].should_keep_alive = http_should_keep_alive(parser); - return 0; -} - -int -message_complete_cb (http_parser *p) -{ - assert(p == parser); - if (messages[num_messages].should_keep_alive != http_should_keep_alive(parser)) - { - fprintf(stderr, "\n\n *** Error http_should_keep_alive() should have same " - "value in both on_message_complete and on_headers_complete " - "but it doesn't! ***\n\n"); - assert(0); - abort(); - } - - if (messages[num_messages].body_size && - http_body_is_final(p) && - !messages[num_messages].body_is_final) - { - fprintf(stderr, "\n\n *** Error http_body_is_final() should return 1 " - "on last on_body callback call " - "but it doesn't! ***\n\n"); - assert(0); - abort(); - } - - messages[num_messages].message_complete_cb_called = TRUE; - - messages[num_messages].message_complete_on_eof = currently_parsing_eof; - - num_messages++; - return 0; -} - -int -response_status_cb (http_parser *p, const char *buf, size_t len) -{ - assert(p == parser); - strlncat(messages[num_messages].response_status, - sizeof(messages[num_messages].response_status), - buf, - len); - return 0; -} - -int -chunk_header_cb (http_parser *p) -{ - assert(p == parser); - int chunk_idx = messages[num_messages].num_chunks; - messages[num_messages].num_chunks++; - if (chunk_idx < MAX_CHUNKS) { - messages[num_messages].chunk_lengths[chunk_idx] = p->content_length; - } - - return 0; -} - -int -chunk_complete_cb (http_parser *p) -{ - assert(p == parser); - - /* Here we want to verify that each chunk_header_cb is matched by a - * chunk_complete_cb, so not only should the total number of calls to - * both callbacks be the same, but they also should be interleaved - * properly */ - assert(messages[num_messages].num_chunks == - messages[num_messages].num_chunks_complete + 1); - - messages[num_messages].num_chunks_complete++; - return 0; -} - -/* These dontcall_* callbacks exist so that we can verify that when we're - * paused, no additional callbacks are invoked */ -int -dontcall_message_begin_cb (http_parser *p) -{ - if (p) { } // gcc - fprintf(stderr, "\n\n*** on_message_begin() called on paused parser ***\n\n"); - abort(); -} - -int -dontcall_header_field_cb (http_parser *p, const char *buf, size_t len) -{ - if (p || buf || len) { } // gcc - fprintf(stderr, "\n\n*** on_header_field() called on paused parser ***\n\n"); - abort(); -} - -int -dontcall_header_value_cb (http_parser *p, const char *buf, size_t len) -{ - if (p || buf || len) { } // gcc - fprintf(stderr, "\n\n*** on_header_value() called on paused parser ***\n\n"); - abort(); -} - -int -dontcall_request_url_cb (http_parser *p, const char *buf, size_t len) -{ - if (p || buf || len) { } // gcc - fprintf(stderr, "\n\n*** on_request_url() called on paused parser ***\n\n"); - abort(); -} - -int -dontcall_body_cb (http_parser *p, const char *buf, size_t len) -{ - if (p || buf || len) { } // gcc - fprintf(stderr, "\n\n*** on_body_cb() called on paused parser ***\n\n"); - abort(); -} - -int -dontcall_headers_complete_cb (http_parser *p) -{ - if (p) { } // gcc - fprintf(stderr, "\n\n*** on_headers_complete() called on paused " - "parser ***\n\n"); - abort(); -} - -int -dontcall_message_complete_cb (http_parser *p) -{ - if (p) { } // gcc - fprintf(stderr, "\n\n*** on_message_complete() called on paused " - "parser ***\n\n"); - abort(); -} - -int -dontcall_response_status_cb (http_parser *p, const char *buf, size_t len) -{ - if (p || buf || len) { } // gcc - fprintf(stderr, "\n\n*** on_status() called on paused parser ***\n\n"); - abort(); -} - -int -dontcall_chunk_header_cb (http_parser *p) -{ - if (p) { } // gcc - fprintf(stderr, "\n\n*** on_chunk_header() called on paused parser ***\n\n"); - exit(1); -} - -int -dontcall_chunk_complete_cb (http_parser *p) -{ - if (p) { } // gcc - fprintf(stderr, "\n\n*** on_chunk_complete() " - "called on paused parser ***\n\n"); - exit(1); -} - -static http_parser_settings settings_dontcall = - {.on_message_begin = dontcall_message_begin_cb - ,.on_header_field = dontcall_header_field_cb - ,.on_header_value = dontcall_header_value_cb - ,.on_url = dontcall_request_url_cb - ,.on_status = dontcall_response_status_cb - ,.on_body = dontcall_body_cb - ,.on_headers_complete = dontcall_headers_complete_cb - ,.on_message_complete = dontcall_message_complete_cb - ,.on_chunk_header = dontcall_chunk_header_cb - ,.on_chunk_complete = dontcall_chunk_complete_cb - }; - -/* These pause_* callbacks always pause the parser and just invoke the regular - * callback that tracks content. Before returning, we overwrite the parser - * settings to point to the _dontcall variety so that we can verify that - * the pause actually did, you know, pause. */ -int -pause_message_begin_cb (http_parser *p) -{ - http_parser_pause(p, 1); - *current_pause_parser = settings_dontcall; - return message_begin_cb(p); -} - -int -pause_header_field_cb (http_parser *p, const char *buf, size_t len) -{ - http_parser_pause(p, 1); - *current_pause_parser = settings_dontcall; - return header_field_cb(p, buf, len); -} - -int -pause_header_value_cb (http_parser *p, const char *buf, size_t len) -{ - http_parser_pause(p, 1); - *current_pause_parser = settings_dontcall; - return header_value_cb(p, buf, len); -} - -int -pause_request_url_cb (http_parser *p, const char *buf, size_t len) -{ - http_parser_pause(p, 1); - *current_pause_parser = settings_dontcall; - return request_url_cb(p, buf, len); -} - -int -pause_body_cb (http_parser *p, const char *buf, size_t len) -{ - http_parser_pause(p, 1); - *current_pause_parser = settings_dontcall; - return body_cb(p, buf, len); -} - -int -pause_headers_complete_cb (http_parser *p) -{ - http_parser_pause(p, 1); - *current_pause_parser = settings_dontcall; - return headers_complete_cb(p); -} - -int -pause_message_complete_cb (http_parser *p) -{ - http_parser_pause(p, 1); - *current_pause_parser = settings_dontcall; - return message_complete_cb(p); -} - -int -pause_response_status_cb (http_parser *p, const char *buf, size_t len) -{ - http_parser_pause(p, 1); - *current_pause_parser = settings_dontcall; - return response_status_cb(p, buf, len); -} - -int -pause_chunk_header_cb (http_parser *p) -{ - http_parser_pause(p, 1); - *current_pause_parser = settings_dontcall; - return chunk_header_cb(p); -} - -int -pause_chunk_complete_cb (http_parser *p) -{ - http_parser_pause(p, 1); - *current_pause_parser = settings_dontcall; - return chunk_complete_cb(p); -} - -int -connect_headers_complete_cb (http_parser *p) -{ - headers_complete_cb(p); - return 1; -} - -int -connect_message_complete_cb (http_parser *p) -{ - messages[num_messages].should_keep_alive = http_should_keep_alive(parser); - return message_complete_cb(p); -} - -static http_parser_settings settings_pause = - {.on_message_begin = pause_message_begin_cb - ,.on_header_field = pause_header_field_cb - ,.on_header_value = pause_header_value_cb - ,.on_url = pause_request_url_cb - ,.on_status = pause_response_status_cb - ,.on_body = pause_body_cb - ,.on_headers_complete = pause_headers_complete_cb - ,.on_message_complete = pause_message_complete_cb - ,.on_chunk_header = pause_chunk_header_cb - ,.on_chunk_complete = pause_chunk_complete_cb - }; - -static http_parser_settings settings = - {.on_message_begin = message_begin_cb - ,.on_header_field = header_field_cb - ,.on_header_value = header_value_cb - ,.on_url = request_url_cb - ,.on_status = response_status_cb - ,.on_body = body_cb - ,.on_headers_complete = headers_complete_cb - ,.on_message_complete = message_complete_cb - ,.on_chunk_header = chunk_header_cb - ,.on_chunk_complete = chunk_complete_cb - }; - -static http_parser_settings settings_count_body = - {.on_message_begin = message_begin_cb - ,.on_header_field = header_field_cb - ,.on_header_value = header_value_cb - ,.on_url = request_url_cb - ,.on_status = response_status_cb - ,.on_body = count_body_cb - ,.on_headers_complete = headers_complete_cb - ,.on_message_complete = message_complete_cb - ,.on_chunk_header = chunk_header_cb - ,.on_chunk_complete = chunk_complete_cb - }; - -static http_parser_settings settings_connect = - {.on_message_begin = message_begin_cb - ,.on_header_field = header_field_cb - ,.on_header_value = header_value_cb - ,.on_url = request_url_cb - ,.on_status = response_status_cb - ,.on_body = dontcall_body_cb - ,.on_headers_complete = connect_headers_complete_cb - ,.on_message_complete = connect_message_complete_cb - ,.on_chunk_header = chunk_header_cb - ,.on_chunk_complete = chunk_complete_cb - }; - -static http_parser_settings settings_null = - {.on_message_begin = 0 - ,.on_header_field = 0 - ,.on_header_value = 0 - ,.on_url = 0 - ,.on_status = 0 - ,.on_body = 0 - ,.on_headers_complete = 0 - ,.on_message_complete = 0 - ,.on_chunk_header = 0 - ,.on_chunk_complete = 0 - }; - -void -parser_init (enum http_parser_type type) -{ - num_messages = 0; - - assert(parser == NULL); - - parser = malloc(sizeof(http_parser)); - - http_parser_init(parser, type); - - memset(&messages, 0, sizeof messages); - -} - -void -parser_free () -{ - assert(parser); - free(parser); - parser = NULL; -} - -size_t parse (const char *buf, size_t len) -{ - size_t nparsed; - currently_parsing_eof = (len == 0); - nparsed = http_parser_execute(parser, &settings, buf, len); - return nparsed; -} - -size_t parse_count_body (const char *buf, size_t len) -{ - size_t nparsed; - currently_parsing_eof = (len == 0); - nparsed = http_parser_execute(parser, &settings_count_body, buf, len); - return nparsed; -} - -size_t parse_pause (const char *buf, size_t len) -{ - size_t nparsed; - http_parser_settings s = settings_pause; - - currently_parsing_eof = (len == 0); - current_pause_parser = &s; - nparsed = http_parser_execute(parser, current_pause_parser, buf, len); - return nparsed; -} - -size_t parse_connect (const char *buf, size_t len) -{ - size_t nparsed; - currently_parsing_eof = (len == 0); - nparsed = http_parser_execute(parser, &settings_connect, buf, len); - return nparsed; -} - -static inline int -check_str_eq (const struct message *m, - const char *prop, - const char *expected, - const char *found) { - if ((expected == NULL) != (found == NULL)) { - printf("\n*** Error: %s in '%s' ***\n\n", prop, m->name); - printf("expected %s\n", (expected == NULL) ? "NULL" : expected); - printf(" found %s\n", (found == NULL) ? "NULL" : found); - return 0; - } - if (expected != NULL && 0 != strcmp(expected, found)) { - printf("\n*** Error: %s in '%s' ***\n\n", prop, m->name); - printf("expected '%s'\n", expected); - printf(" found '%s'\n", found); - return 0; - } - return 1; -} - -static inline int -check_num_eq (const struct message *m, - const char *prop, - int expected, - int found) { - if (expected != found) { - printf("\n*** Error: %s in '%s' ***\n\n", prop, m->name); - printf("expected %d\n", expected); - printf(" found %d\n", found); - return 0; - } - return 1; -} - -#define MESSAGE_CHECK_STR_EQ(expected, found, prop) \ - if (!check_str_eq(expected, #prop, expected->prop, found->prop)) return 0 - -#define MESSAGE_CHECK_NUM_EQ(expected, found, prop) \ - if (!check_num_eq(expected, #prop, expected->prop, found->prop)) return 0 - -#define MESSAGE_CHECK_URL_EQ(u, expected, found, prop, fn) \ -do { \ - char ubuf[256]; \ - \ - if ((u)->field_set & (1 << (fn))) { \ - memcpy(ubuf, (found)->request_url + (u)->field_data[(fn)].off, \ - (u)->field_data[(fn)].len); \ - ubuf[(u)->field_data[(fn)].len] = '\0'; \ - } else { \ - ubuf[0] = '\0'; \ - } \ - \ - check_str_eq(expected, #prop, expected->prop, ubuf); \ -} while(0) - -int -message_eq (int index, int connect, const struct message *expected) -{ - int i; - struct message *m = &messages[index]; - - MESSAGE_CHECK_NUM_EQ(expected, m, http_major); - MESSAGE_CHECK_NUM_EQ(expected, m, http_minor); - - if (expected->type == HTTP_REQUEST) { - MESSAGE_CHECK_NUM_EQ(expected, m, method); - } else { - MESSAGE_CHECK_NUM_EQ(expected, m, status_code); - MESSAGE_CHECK_STR_EQ(expected, m, response_status); - } - - if (!connect) { - MESSAGE_CHECK_NUM_EQ(expected, m, should_keep_alive); - MESSAGE_CHECK_NUM_EQ(expected, m, message_complete_on_eof); - } - - assert(m->message_begin_cb_called); - assert(m->headers_complete_cb_called); - assert(m->message_complete_cb_called); - - - MESSAGE_CHECK_STR_EQ(expected, m, request_url); - - /* Check URL components; we can't do this w/ CONNECT since it doesn't - * send us a well-formed URL. - */ - if (*m->request_url && m->method != HTTP_CONNECT) { - struct http_parser_url u; - - if (http_parser_parse_url(m->request_url, strlen(m->request_url), 0, &u)) { - fprintf(stderr, "\n\n*** failed to parse URL %s ***\n\n", - m->request_url); - abort(); - } - - if (expected->host) { - MESSAGE_CHECK_URL_EQ(&u, expected, m, host, UF_HOST); - } - - if (expected->userinfo) { - MESSAGE_CHECK_URL_EQ(&u, expected, m, userinfo, UF_USERINFO); - } - - m->port = (u.field_set & (1 << UF_PORT)) ? - u.port : 0; - - MESSAGE_CHECK_URL_EQ(&u, expected, m, query_string, UF_QUERY); - MESSAGE_CHECK_URL_EQ(&u, expected, m, fragment, UF_FRAGMENT); - MESSAGE_CHECK_URL_EQ(&u, expected, m, request_path, UF_PATH); - MESSAGE_CHECK_NUM_EQ(expected, m, port); - } - - if (connect) { - check_num_eq(m, "body_size", 0, m->body_size); - } else if (expected->body_size) { - MESSAGE_CHECK_NUM_EQ(expected, m, body_size); - } else { - MESSAGE_CHECK_STR_EQ(expected, m, body); - } - - if (connect) { - check_num_eq(m, "num_chunks_complete", 0, m->num_chunks_complete); - } else { - assert(m->num_chunks == m->num_chunks_complete); - MESSAGE_CHECK_NUM_EQ(expected, m, num_chunks_complete); - for (i = 0; i < m->num_chunks && i < MAX_CHUNKS; i++) { - MESSAGE_CHECK_NUM_EQ(expected, m, chunk_lengths[i]); - } - } - - MESSAGE_CHECK_NUM_EQ(expected, m, num_headers); - - int r; - for (i = 0; i < m->num_headers; i++) { - r = check_str_eq(expected, "header field", expected->headers[i][0], m->headers[i][0]); - if (!r) return 0; - r = check_str_eq(expected, "header value", expected->headers[i][1], m->headers[i][1]); - if (!r) return 0; - } - - MESSAGE_CHECK_STR_EQ(expected, m, upgrade); - - return 1; -} - -/* Given a sequence of varargs messages, return the number of them that the - * parser should successfully parse, taking into account that upgraded - * messages prevent all subsequent messages from being parsed. - */ -size_t -count_parsed_messages(const size_t nmsgs, ...) { - size_t i; - va_list ap; - - va_start(ap, nmsgs); - - for (i = 0; i < nmsgs; i++) { - struct message *m = va_arg(ap, struct message *); - - if (m->upgrade) { - va_end(ap); - return i + 1; - } - } - - va_end(ap); - return nmsgs; -} - -/* Given a sequence of bytes and the number of these that we were able to - * parse, verify that upgrade bodies are correct. - */ -void -upgrade_message_fix(char *body, const size_t nread, const size_t nmsgs, ...) { - va_list ap; - size_t i; - size_t off = 0; - - va_start(ap, nmsgs); - - for (i = 0; i < nmsgs; i++) { - struct message *m = va_arg(ap, struct message *); - - off += strlen(m->raw); - - if (m->upgrade) { - off -= strlen(m->upgrade); - - /* Check the portion of the response after its specified upgrade */ - if (!check_str_eq(m, "upgrade", body + off, body + nread)) { - abort(); - } - - /* Fix up the response so that message_eq() will verify the beginning - * of the upgrade */ - *(body + nread + strlen(m->upgrade)) = '\0'; - messages[num_messages -1 ].upgrade = body + nread; - - va_end(ap); - return; - } - } - - va_end(ap); - printf("\n\n*** Error: expected a message with upgrade ***\n"); - - abort(); -} - -static void -print_error (const char *raw, size_t error_location) -{ - fprintf(stderr, "\n*** %s ***\n\n", - http_errno_description(HTTP_PARSER_ERRNO(parser))); - - int this_line = 0, char_len = 0; - size_t i, j, len = strlen(raw), error_location_line = 0; - for (i = 0; i < len; i++) { - if (i == error_location) this_line = 1; - switch (raw[i]) { - case '\r': - char_len = 2; - fprintf(stderr, "\\r"); - break; - - case '\n': - fprintf(stderr, "\\n\n"); - - if (this_line) goto print; - - error_location_line = 0; - continue; - - default: - char_len = 1; - fputc(raw[i], stderr); - break; - } - if (!this_line) error_location_line += char_len; - } - - fprintf(stderr, "[eof]\n"); - - print: - for (j = 0; j < error_location_line; j++) { - fputc(' ', stderr); - } - fprintf(stderr, "^\n\nerror location: %u\n", (unsigned int)error_location); -} - -void -test_preserve_data (void) -{ - char my_data[] = "application-specific data"; - http_parser parser; - parser.data = my_data; - http_parser_init(&parser, HTTP_REQUEST); - if (parser.data != my_data) { - printf("\n*** parser.data not preserved accross http_parser_init ***\n\n"); - abort(); - } -} - -struct url_test { - const char *name; - const char *url; - int is_connect; - struct http_parser_url u; - int rv; -}; - -const struct url_test url_tests[] = -{ {.name="proxy request" - ,.url="http://hostname/" - ,.is_connect=0 - ,.u= - {.field_set=(1 << UF_SCHEMA) | (1 << UF_HOST) | (1 << UF_PATH) - ,.port=0 - ,.field_data= - {{ 0, 4 } /* UF_SCHEMA */ - ,{ 7, 8 } /* UF_HOST */ - ,{ 0, 0 } /* UF_PORT */ - ,{ 15, 1 } /* UF_PATH */ - ,{ 0, 0 } /* UF_QUERY */ - ,{ 0, 0 } /* UF_FRAGMENT */ - ,{ 0, 0 } /* UF_USERINFO */ - } - } - ,.rv=0 - } - -, {.name="proxy request with port" - ,.url="http://hostname:444/" - ,.is_connect=0 - ,.u= - {.field_set=(1 << UF_SCHEMA) | (1 << UF_HOST) | (1 << UF_PORT) | (1 << UF_PATH) - ,.port=444 - ,.field_data= - {{ 0, 4 } /* UF_SCHEMA */ - ,{ 7, 8 } /* UF_HOST */ - ,{ 16, 3 } /* UF_PORT */ - ,{ 19, 1 } /* UF_PATH */ - ,{ 0, 0 } /* UF_QUERY */ - ,{ 0, 0 } /* UF_FRAGMENT */ - ,{ 0, 0 } /* UF_USERINFO */ - } - } - ,.rv=0 - } - -, {.name="CONNECT request" - ,.url="hostname:443" - ,.is_connect=1 - ,.u= - {.field_set=(1 << UF_HOST) | (1 << UF_PORT) - ,.port=443 - ,.field_data= - {{ 0, 0 } /* UF_SCHEMA */ - ,{ 0, 8 } /* UF_HOST */ - ,{ 9, 3 } /* UF_PORT */ - ,{ 0, 0 } /* UF_PATH */ - ,{ 0, 0 } /* UF_QUERY */ - ,{ 0, 0 } /* UF_FRAGMENT */ - ,{ 0, 0 } /* UF_USERINFO */ - } - } - ,.rv=0 - } - -, {.name="CONNECT request but not connect" - ,.url="hostname:443" - ,.is_connect=0 - ,.rv=1 - } - -, {.name="proxy ipv6 request" - ,.url="http://[1:2::3:4]/" - ,.is_connect=0 - ,.u= - {.field_set=(1 << UF_SCHEMA) | (1 << UF_HOST) | (1 << UF_PATH) - ,.port=0 - ,.field_data= - {{ 0, 4 } /* UF_SCHEMA */ - ,{ 8, 8 } /* UF_HOST */ - ,{ 0, 0 } /* UF_PORT */ - ,{ 17, 1 } /* UF_PATH */ - ,{ 0, 0 } /* UF_QUERY */ - ,{ 0, 0 } /* UF_FRAGMENT */ - ,{ 0, 0 } /* UF_USERINFO */ - } - } - ,.rv=0 - } - -, {.name="proxy ipv6 request with port" - ,.url="http://[1:2::3:4]:67/" - ,.is_connect=0 - ,.u= - {.field_set=(1 << UF_SCHEMA) | (1 << UF_HOST) | (1 << UF_PORT) | (1 << UF_PATH) - ,.port=67 - ,.field_data= - {{ 0, 4 } /* UF_SCHEMA */ - ,{ 8, 8 } /* UF_HOST */ - ,{ 18, 2 } /* UF_PORT */ - ,{ 20, 1 } /* UF_PATH */ - ,{ 0, 0 } /* UF_QUERY */ - ,{ 0, 0 } /* UF_FRAGMENT */ - ,{ 0, 0 } /* UF_USERINFO */ - } - } - ,.rv=0 - } - -, {.name="CONNECT ipv6 address" - ,.url="[1:2::3:4]:443" - ,.is_connect=1 - ,.u= - {.field_set=(1 << UF_HOST) | (1 << UF_PORT) - ,.port=443 - ,.field_data= - {{ 0, 0 } /* UF_SCHEMA */ - ,{ 1, 8 } /* UF_HOST */ - ,{ 11, 3 } /* UF_PORT */ - ,{ 0, 0 } /* UF_PATH */ - ,{ 0, 0 } /* UF_QUERY */ - ,{ 0, 0 } /* UF_FRAGMENT */ - ,{ 0, 0 } /* UF_USERINFO */ - } - } - ,.rv=0 - } - -, {.name="ipv4 in ipv6 address" - ,.url="http://[2001:0000:0000:0000:0000:0000:1.9.1.1]/" - ,.is_connect=0 - ,.u= - {.field_set=(1 << UF_SCHEMA) | (1 << UF_HOST) | (1 << UF_PATH) - ,.port=0 - ,.field_data= - {{ 0, 4 } /* UF_SCHEMA */ - ,{ 8, 37 } /* UF_HOST */ - ,{ 0, 0 } /* UF_PORT */ - ,{ 46, 1 } /* UF_PATH */ - ,{ 0, 0 } /* UF_QUERY */ - ,{ 0, 0 } /* UF_FRAGMENT */ - ,{ 0, 0 } /* UF_USERINFO */ - } - } - ,.rv=0 - } - -, {.name="extra ? in query string" - ,.url="http://a.tbcdn.cn/p/fp/2010c/??fp-header-min.css,fp-base-min.css," - "fp-channel-min.css,fp-product-min.css,fp-mall-min.css,fp-category-min.css," - "fp-sub-min.css,fp-gdp4p-min.css,fp-css3-min.css,fp-misc-min.css?t=20101022.css" - ,.is_connect=0 - ,.u= - {.field_set=(1<field_set, u->port); - for (i = 0; i < UF_MAX; i++) { - if ((u->field_set & (1 << i)) == 0) { - printf("\tfield_data[%u]: unset\n", i); - continue; - } - - printf("\tfield_data[%u]: off: %u len: %u part: \"%.*s\n\"", - i, - u->field_data[i].off, - u->field_data[i].len, - u->field_data[i].len, - url + u->field_data[i].off); - } -} - -void -test_parse_url (void) -{ - struct http_parser_url u; - const struct url_test *test; - unsigned int i; - int rv; - - for (i = 0; i < (sizeof(url_tests) / sizeof(url_tests[0])); i++) { - test = &url_tests[i]; - memset(&u, 0, sizeof(u)); - - rv = http_parser_parse_url(test->url, - strlen(test->url), - test->is_connect, - &u); - - if (test->rv == 0) { - if (rv != 0) { - printf("\n*** http_parser_parse_url(\"%s\") \"%s\" test failed, " - "unexpected rv %d ***\n\n", test->url, test->name, rv); - abort(); - } - - if (memcmp(&u, &test->u, sizeof(u)) != 0) { - printf("\n*** http_parser_parse_url(\"%s\") \"%s\" failed ***\n", - test->url, test->name); - - printf("target http_parser_url:\n"); - dump_url(test->url, &test->u); - printf("result http_parser_url:\n"); - dump_url(test->url, &u); - - abort(); - } - } else { - /* test->rv != 0 */ - if (rv == 0) { - printf("\n*** http_parser_parse_url(\"%s\") \"%s\" test failed, " - "unexpected rv %d ***\n\n", test->url, test->name, rv); - abort(); - } - } - } -} - -void -test_method_str (void) -{ - assert(0 == strcmp("GET", http_method_str(HTTP_GET))); - assert(0 == strcmp("", http_method_str(1337))); -} - -void -test_message (const struct message *message) -{ - size_t raw_len = strlen(message->raw); - size_t msg1len; - for (msg1len = 0; msg1len < raw_len; msg1len++) { - parser_init(message->type); - - size_t read; - const char *msg1 = message->raw; - const char *msg2 = msg1 + msg1len; - size_t msg2len = raw_len - msg1len; - - if (msg1len) { - read = parse(msg1, msg1len); - - if (message->upgrade && parser->upgrade && num_messages > 0) { - messages[num_messages - 1].upgrade = msg1 + read; - goto test; - } - - if (read != msg1len) { - print_error(msg1, read); - abort(); - } - } - - - read = parse(msg2, msg2len); - - if (message->upgrade && parser->upgrade) { - messages[num_messages - 1].upgrade = msg2 + read; - goto test; - } - - if (read != msg2len) { - print_error(msg2, read); - abort(); - } - - read = parse(NULL, 0); - - if (read != 0) { - print_error(message->raw, read); - abort(); - } - - test: - - if (num_messages != 1) { - printf("\n*** num_messages != 1 after testing '%s' ***\n\n", message->name); - abort(); - } - - if(!message_eq(0, 0, message)) abort(); - - parser_free(); - } -} - -void -test_message_count_body (const struct message *message) -{ - parser_init(message->type); - - size_t read; - size_t l = strlen(message->raw); - size_t i, toread; - size_t chunk = 4024; - - for (i = 0; i < l; i+= chunk) { - toread = MIN(l-i, chunk); - read = parse_count_body(message->raw + i, toread); - if (read != toread) { - print_error(message->raw, read); - abort(); - } - } - - - read = parse_count_body(NULL, 0); - if (read != 0) { - print_error(message->raw, read); - abort(); - } - - if (num_messages != 1) { - printf("\n*** num_messages != 1 after testing '%s' ***\n\n", message->name); - abort(); - } - - if(!message_eq(0, 0, message)) abort(); - - parser_free(); -} - -void -test_simple (const char *buf, enum http_errno err_expected) -{ - parser_init(HTTP_REQUEST); - - enum http_errno err; - - parse(buf, strlen(buf)); - err = HTTP_PARSER_ERRNO(parser); - parse(NULL, 0); - - parser_free(); - - /* In strict mode, allow us to pass with an unexpected HPE_STRICT as - * long as the caller isn't expecting success. - */ -#if HTTP_PARSER_STRICT - if (err_expected != err && err_expected != HPE_OK && err != HPE_STRICT) { -#else - if (err_expected != err) { -#endif - fprintf(stderr, "\n*** test_simple expected %s, but saw %s ***\n\n%s\n", - http_errno_name(err_expected), http_errno_name(err), buf); - abort(); - } -} - -void -test_invalid_header_content (int req, const char* str) -{ - http_parser parser; - http_parser_init(&parser, req ? HTTP_REQUEST : HTTP_RESPONSE); - size_t parsed; - const char *buf; - buf = req ? - "GET / HTTP/1.1\r\n" : - "HTTP/1.1 200 OK\r\n"; - parsed = http_parser_execute(&parser, &settings_null, buf, strlen(buf)); - assert(parsed == strlen(buf)); - - buf = str; - size_t buflen = strlen(buf); - - parsed = http_parser_execute(&parser, &settings_null, buf, buflen); - if (parsed != buflen) { - assert(HTTP_PARSER_ERRNO(&parser) == HPE_INVALID_HEADER_TOKEN); - return; - } - - fprintf(stderr, - "\n*** Error expected but none in invalid header content test ***\n"); - abort(); -} - -void -test_invalid_header_field_content_error (int req) -{ - test_invalid_header_content(req, "Foo: F\01ailure"); - test_invalid_header_content(req, "Foo: B\02ar"); -} - -void -test_invalid_header_field (int req, const char* str) -{ - http_parser parser; - http_parser_init(&parser, req ? HTTP_REQUEST : HTTP_RESPONSE); - size_t parsed; - const char *buf; - buf = req ? - "GET / HTTP/1.1\r\n" : - "HTTP/1.1 200 OK\r\n"; - parsed = http_parser_execute(&parser, &settings_null, buf, strlen(buf)); - assert(parsed == strlen(buf)); - - buf = str; - size_t buflen = strlen(buf); - - parsed = http_parser_execute(&parser, &settings_null, buf, buflen); - if (parsed != buflen) { - assert(HTTP_PARSER_ERRNO(&parser) == HPE_INVALID_HEADER_TOKEN); - return; - } - - fprintf(stderr, - "\n*** Error expected but none in invalid header token test ***\n"); - abort(); -} - -void -test_invalid_header_field_token_error (int req) -{ - test_invalid_header_field(req, "Fo@: Failure"); - test_invalid_header_field(req, "Foo\01\test: Bar"); -} - -void -test_double_content_length_error (int req) -{ - http_parser parser; - http_parser_init(&parser, req ? HTTP_REQUEST : HTTP_RESPONSE); - size_t parsed; - const char *buf; - buf = req ? - "GET / HTTP/1.1\r\n" : - "HTTP/1.1 200 OK\r\n"; - parsed = http_parser_execute(&parser, &settings_null, buf, strlen(buf)); - assert(parsed == strlen(buf)); - - buf = "Content-Length: 0\r\nContent-Length: 1\r\n\r\n"; - size_t buflen = strlen(buf); - - parsed = http_parser_execute(&parser, &settings_null, buf, buflen); - if (parsed != buflen) { - assert(HTTP_PARSER_ERRNO(&parser) == HPE_UNEXPECTED_CONTENT_LENGTH); - return; - } - - fprintf(stderr, - "\n*** Error expected but none in double content-length test ***\n"); - abort(); -} - -void -test_chunked_content_length_error (int req) -{ - http_parser parser; - http_parser_init(&parser, req ? HTTP_REQUEST : HTTP_RESPONSE); - size_t parsed; - const char *buf; - buf = req ? - "GET / HTTP/1.1\r\n" : - "HTTP/1.1 200 OK\r\n"; - parsed = http_parser_execute(&parser, &settings_null, buf, strlen(buf)); - assert(parsed == strlen(buf)); - - buf = "Transfer-Encoding: chunked\r\nContent-Length: 1\r\n\r\n"; - size_t buflen = strlen(buf); - - parsed = http_parser_execute(&parser, &settings_null, buf, buflen); - if (parsed != buflen) { - assert(HTTP_PARSER_ERRNO(&parser) == HPE_UNEXPECTED_CONTENT_LENGTH); - return; - } - - fprintf(stderr, - "\n*** Error expected but none in chunked content-length test ***\n"); - abort(); -} - -void -test_header_cr_no_lf_error (int req) -{ - http_parser parser; - http_parser_init(&parser, req ? HTTP_REQUEST : HTTP_RESPONSE); - size_t parsed; - const char *buf; - buf = req ? - "GET / HTTP/1.1\r\n" : - "HTTP/1.1 200 OK\r\n"; - parsed = http_parser_execute(&parser, &settings_null, buf, strlen(buf)); - assert(parsed == strlen(buf)); - - buf = "Foo: 1\rBar: 1\r\n\r\n"; - size_t buflen = strlen(buf); - - parsed = http_parser_execute(&parser, &settings_null, buf, buflen); - if (parsed != buflen) { - assert(HTTP_PARSER_ERRNO(&parser) == HPE_LF_EXPECTED); - return; - } - - fprintf(stderr, - "\n*** Error expected but none in header whitespace test ***\n"); - abort(); -} - -void -test_header_overflow_error (int req) -{ - http_parser parser; - http_parser_init(&parser, req ? HTTP_REQUEST : HTTP_RESPONSE); - size_t parsed; - const char *buf; - buf = req ? "GET / HTTP/1.1\r\n" : "HTTP/1.0 200 OK\r\n"; - parsed = http_parser_execute(&parser, &settings_null, buf, strlen(buf)); - assert(parsed == strlen(buf)); - - buf = "header-key: header-value\r\n"; - size_t buflen = strlen(buf); - - int i; - for (i = 0; i < 10000; i++) { - parsed = http_parser_execute(&parser, &settings_null, buf, buflen); - if (parsed != buflen) { - //fprintf(stderr, "error found on iter %d\n", i); - assert(HTTP_PARSER_ERRNO(&parser) == HPE_HEADER_OVERFLOW); - return; - } - } - - fprintf(stderr, "\n*** Error expected but none in header overflow test ***\n"); - abort(); -} - - -void -test_header_nread_value () -{ - http_parser parser; - http_parser_init(&parser, HTTP_REQUEST); - size_t parsed; - const char *buf; - buf = "GET / HTTP/1.1\r\nheader: value\nhdr: value\r\n"; - parsed = http_parser_execute(&parser, &settings_null, buf, strlen(buf)); - assert(parsed == strlen(buf)); - - assert(parser.nread == strlen(buf)); -} - - -static void -test_content_length_overflow (const char *buf, size_t buflen, int expect_ok) -{ - http_parser parser; - http_parser_init(&parser, HTTP_RESPONSE); - http_parser_execute(&parser, &settings_null, buf, buflen); - - if (expect_ok) - assert(HTTP_PARSER_ERRNO(&parser) == HPE_OK); - else - assert(HTTP_PARSER_ERRNO(&parser) == HPE_INVALID_CONTENT_LENGTH); -} - -void -test_header_content_length_overflow_error (void) -{ -#define X(size) \ - "HTTP/1.1 200 OK\r\n" \ - "Content-Length: " #size "\r\n" \ - "\r\n" - const char a[] = X(1844674407370955160); /* 2^64 / 10 - 1 */ - const char b[] = X(18446744073709551615); /* 2^64-1 */ - const char c[] = X(18446744073709551616); /* 2^64 */ -#undef X - test_content_length_overflow(a, sizeof(a) - 1, 1); /* expect ok */ - test_content_length_overflow(b, sizeof(b) - 1, 0); /* expect failure */ - test_content_length_overflow(c, sizeof(c) - 1, 0); /* expect failure */ -} - -void -test_chunk_content_length_overflow_error (void) -{ -#define X(size) \ - "HTTP/1.1 200 OK\r\n" \ - "Transfer-Encoding: chunked\r\n" \ - "\r\n" \ - #size "\r\n" \ - "..." - const char a[] = X(FFFFFFFFFFFFFFE); /* 2^64 / 16 - 1 */ - const char b[] = X(FFFFFFFFFFFFFFFF); /* 2^64-1 */ - const char c[] = X(10000000000000000); /* 2^64 */ -#undef X - test_content_length_overflow(a, sizeof(a) - 1, 1); /* expect ok */ - test_content_length_overflow(b, sizeof(b) - 1, 0); /* expect failure */ - test_content_length_overflow(c, sizeof(c) - 1, 0); /* expect failure */ -} - -void -test_no_overflow_long_body (int req, size_t length) -{ - http_parser parser; - http_parser_init(&parser, req ? HTTP_REQUEST : HTTP_RESPONSE); - size_t parsed; - size_t i; - char buf1[3000]; - size_t buf1len = sprintf(buf1, "%s\r\nConnection: Keep-Alive\r\nContent-Length: %lu\r\n\r\n", - req ? "POST / HTTP/1.0" : "HTTP/1.0 200 OK", (unsigned long)length); - parsed = http_parser_execute(&parser, &settings_null, buf1, buf1len); - if (parsed != buf1len) - goto err; - - for (i = 0; i < length; i++) { - char foo = 'a'; - parsed = http_parser_execute(&parser, &settings_null, &foo, 1); - if (parsed != 1) - goto err; - } - - parsed = http_parser_execute(&parser, &settings_null, buf1, buf1len); - if (parsed != buf1len) goto err; - return; - - err: - fprintf(stderr, - "\n*** error in test_no_overflow_long_body %s of length %lu ***\n", - req ? "REQUEST" : "RESPONSE", - (unsigned long)length); - abort(); -} - -void -test_multiple3 (const struct message *r1, const struct message *r2, const struct message *r3) -{ - int message_count = count_parsed_messages(3, r1, r2, r3); - - char total[ strlen(r1->raw) - + strlen(r2->raw) - + strlen(r3->raw) - + 1 - ]; - total[0] = '\0'; - - strcat(total, r1->raw); - strcat(total, r2->raw); - strcat(total, r3->raw); - - parser_init(r1->type); - - size_t read; - - read = parse(total, strlen(total)); - - if (parser->upgrade) { - upgrade_message_fix(total, read, 3, r1, r2, r3); - goto test; - } - - if (read != strlen(total)) { - print_error(total, read); - abort(); - } - - read = parse(NULL, 0); - - if (read != 0) { - print_error(total, read); - abort(); - } - -test: - - if (message_count != num_messages) { - fprintf(stderr, "\n\n*** Parser didn't see 3 messages only %d *** \n", num_messages); - abort(); - } - - if (!message_eq(0, 0, r1)) abort(); - if (message_count > 1 && !message_eq(1, 0, r2)) abort(); - if (message_count > 2 && !message_eq(2, 0, r3)) abort(); - - parser_free(); -} - -/* SCAN through every possible breaking to make sure the - * parser can handle getting the content in any chunks that - * might come from the socket - */ -void -test_scan (const struct message *r1, const struct message *r2, const struct message *r3) -{ - char total[80*1024] = "\0"; - char buf1[80*1024] = "\0"; - char buf2[80*1024] = "\0"; - char buf3[80*1024] = "\0"; - - strcat(total, r1->raw); - strcat(total, r2->raw); - strcat(total, r3->raw); - - size_t read; - - int total_len = strlen(total); - - int total_ops = 2 * (total_len - 1) * (total_len - 2) / 2; - int ops = 0 ; - - size_t buf1_len, buf2_len, buf3_len; - int message_count = count_parsed_messages(3, r1, r2, r3); - - int i,j,type_both; - for (type_both = 0; type_both < 2; type_both ++ ) { - for (j = 2; j < total_len; j ++ ) { - for (i = 1; i < j; i ++ ) { - - if (ops % 1000 == 0) { - printf("\b\b\b\b%3.0f%%", 100 * (float)ops /(float)total_ops); - fflush(stdout); - } - ops += 1; - - parser_init(type_both ? HTTP_BOTH : r1->type); - - buf1_len = i; - strlncpy(buf1, sizeof(buf1), total, buf1_len); - buf1[buf1_len] = 0; - - buf2_len = j - i; - strlncpy(buf2, sizeof(buf1), total+i, buf2_len); - buf2[buf2_len] = 0; - - buf3_len = total_len - j; - strlncpy(buf3, sizeof(buf1), total+j, buf3_len); - buf3[buf3_len] = 0; - - read = parse(buf1, buf1_len); - - if (parser->upgrade) goto test; - - if (read != buf1_len) { - print_error(buf1, read); - goto error; - } - - read += parse(buf2, buf2_len); - - if (parser->upgrade) goto test; - - if (read != buf1_len + buf2_len) { - print_error(buf2, read); - goto error; - } - - read += parse(buf3, buf3_len); - - if (parser->upgrade) goto test; - - if (read != buf1_len + buf2_len + buf3_len) { - print_error(buf3, read); - goto error; - } - - parse(NULL, 0); - -test: - if (parser->upgrade) { - upgrade_message_fix(total, read, 3, r1, r2, r3); - } - - if (message_count != num_messages) { - fprintf(stderr, "\n\nParser didn't see %d messages only %d\n", - message_count, num_messages); - goto error; - } - - if (!message_eq(0, 0, r1)) { - fprintf(stderr, "\n\nError matching messages[0] in test_scan.\n"); - goto error; - } - - if (message_count > 1 && !message_eq(1, 0, r2)) { - fprintf(stderr, "\n\nError matching messages[1] in test_scan.\n"); - goto error; - } - - if (message_count > 2 && !message_eq(2, 0, r3)) { - fprintf(stderr, "\n\nError matching messages[2] in test_scan.\n"); - goto error; - } - - parser_free(); - } - } - } - puts("\b\b\b\b100%"); - return; - - error: - fprintf(stderr, "i=%d j=%d\n", i, j); - fprintf(stderr, "buf1 (%u) %s\n\n", (unsigned int)buf1_len, buf1); - fprintf(stderr, "buf2 (%u) %s\n\n", (unsigned int)buf2_len , buf2); - fprintf(stderr, "buf3 (%u) %s\n", (unsigned int)buf3_len, buf3); - abort(); -} - -// user required to free the result -// string terminated by \0 -char * -create_large_chunked_message (int body_size_in_kb, const char* headers) -{ - int i; - size_t wrote = 0; - size_t headers_len = strlen(headers); - size_t bufsize = headers_len + (5+1024+2)*body_size_in_kb + 6; - char * buf = malloc(bufsize); - - memcpy(buf, headers, headers_len); - wrote += headers_len; - - for (i = 0; i < body_size_in_kb; i++) { - // write 1kb chunk into the body. - memcpy(buf + wrote, "400\r\n", 5); - wrote += 5; - memset(buf + wrote, 'C', 1024); - wrote += 1024; - strcpy(buf + wrote, "\r\n"); - wrote += 2; - } - - memcpy(buf + wrote, "0\r\n\r\n", 6); - wrote += 6; - assert(wrote == bufsize); - - return buf; -} - -/* Verify that we can pause parsing at any of the bytes in the - * message and still get the result that we're expecting. */ -void -test_message_pause (const struct message *msg) -{ - char *buf = (char*) msg->raw; - size_t buflen = strlen(msg->raw); - size_t nread; - - parser_init(msg->type); - - do { - nread = parse_pause(buf, buflen); - - // We can only set the upgrade buffer once we've gotten our message - // completion callback. - if (messages[0].message_complete_cb_called && - msg->upgrade && - parser->upgrade) { - messages[0].upgrade = buf + nread; - goto test; - } - - if (nread < buflen) { - - // Not much do to if we failed a strict-mode check - if (HTTP_PARSER_ERRNO(parser) == HPE_STRICT) { - parser_free(); - return; - } - - assert (HTTP_PARSER_ERRNO(parser) == HPE_PAUSED); - } - - buf += nread; - buflen -= nread; - http_parser_pause(parser, 0); - } while (buflen > 0); - - nread = parse_pause(NULL, 0); - assert (nread == 0); - -test: - if (num_messages != 1) { - printf("\n*** num_messages != 1 after testing '%s' ***\n\n", msg->name); - abort(); - } - - if(!message_eq(0, 0, msg)) abort(); - - parser_free(); -} - -/* Verify that body and next message won't be parsed in responses to CONNECT */ -void -test_message_connect (const struct message *msg) -{ - char *buf = (char*) msg->raw; - size_t buflen = strlen(msg->raw); - - parser_init(msg->type); - - parse_connect(buf, buflen); - - if (num_messages != 1) { - printf("\n*** num_messages != 1 after testing '%s' ***\n\n", msg->name); - abort(); - } - - if(!message_eq(0, 1, msg)) abort(); - - parser_free(); -} - -int -main (void) -{ - parser = NULL; - int i, j, k; - int request_count; - int response_count; - unsigned long version; - unsigned major; - unsigned minor; - unsigned patch; - - version = http_parser_version(); - major = (version >> 16) & 255; - minor = (version >> 8) & 255; - patch = version & 255; - printf("http_parser v%u.%u.%u (0x%06lx)\n", major, minor, patch, version); - - printf("sizeof(http_parser) = %u\n", (unsigned int)sizeof(http_parser)); - - for (request_count = 0; requests[request_count].name; request_count++); - for (response_count = 0; responses[response_count].name; response_count++); - - //// API - test_preserve_data(); - test_parse_url(); - test_method_str(); - - //// NREAD - test_header_nread_value(); - - //// OVERFLOW CONDITIONS - - test_header_overflow_error(HTTP_REQUEST); - test_no_overflow_long_body(HTTP_REQUEST, 1000); - test_no_overflow_long_body(HTTP_REQUEST, 100000); - - test_header_overflow_error(HTTP_RESPONSE); - test_no_overflow_long_body(HTTP_RESPONSE, 1000); - test_no_overflow_long_body(HTTP_RESPONSE, 100000); - - test_header_content_length_overflow_error(); - test_chunk_content_length_overflow_error(); - - //// HEADER FIELD CONDITIONS - test_double_content_length_error(HTTP_REQUEST); - test_chunked_content_length_error(HTTP_REQUEST); - test_header_cr_no_lf_error(HTTP_REQUEST); - test_invalid_header_field_token_error(HTTP_REQUEST); - test_invalid_header_field_content_error(HTTP_REQUEST); - test_double_content_length_error(HTTP_RESPONSE); - test_chunked_content_length_error(HTTP_RESPONSE); - test_header_cr_no_lf_error(HTTP_RESPONSE); - test_invalid_header_field_token_error(HTTP_RESPONSE); - test_invalid_header_field_content_error(HTTP_RESPONSE); - - //// RESPONSES - - for (i = 0; i < response_count; i++) { - test_message(&responses[i]); - } - - for (i = 0; i < response_count; i++) { - test_message_pause(&responses[i]); - } - - for (i = 0; i < response_count; i++) { - test_message_connect(&responses[i]); - } - - for (i = 0; i < response_count; i++) { - if (!responses[i].should_keep_alive) continue; - for (j = 0; j < response_count; j++) { - if (!responses[j].should_keep_alive) continue; - for (k = 0; k < response_count; k++) { - test_multiple3(&responses[i], &responses[j], &responses[k]); - } - } - } - - test_message_count_body(&responses[NO_HEADERS_NO_BODY_404]); - test_message_count_body(&responses[TRAILING_SPACE_ON_CHUNKED_BODY]); - - // test very large chunked response - { - char * msg = create_large_chunked_message(31337, - "HTTP/1.0 200 OK\r\n" - "Transfer-Encoding: chunked\r\n" - "Content-Type: text/plain\r\n" - "\r\n"); - struct message large_chunked = - {.name= "large chunked" - ,.type= HTTP_RESPONSE - ,.raw= msg - ,.should_keep_alive= FALSE - ,.message_complete_on_eof= FALSE - ,.http_major= 1 - ,.http_minor= 0 - ,.status_code= 200 - ,.response_status= "OK" - ,.num_headers= 2 - ,.headers= - { { "Transfer-Encoding", "chunked" } - , { "Content-Type", "text/plain" } - } - ,.body_size= 31337*1024 - ,.num_chunks_complete= 31338 - }; - for (i = 0; i < MAX_CHUNKS; i++) { - large_chunked.chunk_lengths[i] = 1024; - } - test_message_count_body(&large_chunked); - free(msg); - } - - - - printf("response scan 1/2 "); - test_scan( &responses[TRAILING_SPACE_ON_CHUNKED_BODY] - , &responses[NO_BODY_HTTP10_KA_204] - , &responses[NO_REASON_PHRASE] - ); - - printf("response scan 2/2 "); - test_scan( &responses[BONJOUR_MADAME_FR] - , &responses[UNDERSTORE_HEADER_KEY] - , &responses[NO_CARRIAGE_RET] - ); - - puts("responses okay"); - - - /// REQUESTS - - test_simple("GET / HTP/1.1\r\n\r\n", HPE_INVALID_VERSION); - - // Extended characters - see nodejs/test/parallel/test-http-headers-obstext.js - test_simple("GET / HTTP/1.1\r\n" - "Test: Düsseldorf\r\n", - HPE_OK); - - // Well-formed but incomplete - test_simple("GET / HTTP/1.1\r\n" - "Content-Type: text/plain\r\n" - "Content-Length: 6\r\n" - "\r\n" - "fooba", - HPE_OK); - - static const char *all_methods[] = { - "DELETE", - "GET", - "HEAD", - "POST", - "PUT", - //"CONNECT", //CONNECT can't be tested like other methods, it's a tunnel - "OPTIONS", - "TRACE", - "COPY", - "LOCK", - "MKCOL", - "MOVE", - "PROPFIND", - "PROPPATCH", - "SEARCH", - "UNLOCK", - "BIND", - "REBIND", - "UNBIND", - "ACL", - "REPORT", - "MKACTIVITY", - "CHECKOUT", - "MERGE", - "M-SEARCH", - "NOTIFY", - "SUBSCRIBE", - "UNSUBSCRIBE", - "PATCH", - "PURGE", - "MKCALENDAR", - "LINK", - "UNLINK", - 0 }; - const char **this_method; - for (this_method = all_methods; *this_method; this_method++) { - char buf[200]; - sprintf(buf, "%s / HTTP/1.1\r\n\r\n", *this_method); - test_simple(buf, HPE_OK); - } - - static const char *bad_methods[] = { - "ASDF", - "C******", - "COLA", - "GEM", - "GETA", - "M****", - "MKCOLA", - "PROPPATCHA", - "PUN", - "PX", - "SA", - "hello world", - 0 }; - for (this_method = bad_methods; *this_method; this_method++) { - char buf[200]; - sprintf(buf, "%s / HTTP/1.1\r\n\r\n", *this_method); - test_simple(buf, HPE_INVALID_METHOD); - } - - // illegal header field name line folding - test_simple("GET / HTTP/1.1\r\n" - "name\r\n" - " : value\r\n" - "\r\n", - HPE_INVALID_HEADER_TOKEN); - - const char *dumbfuck2 = - "GET / HTTP/1.1\r\n" - "X-SSL-Bullshit: -----BEGIN CERTIFICATE-----\r\n" - "\tMIIFbTCCBFWgAwIBAgICH4cwDQYJKoZIhvcNAQEFBQAwcDELMAkGA1UEBhMCVUsx\r\n" - "\tETAPBgNVBAoTCGVTY2llbmNlMRIwEAYDVQQLEwlBdXRob3JpdHkxCzAJBgNVBAMT\r\n" - "\tAkNBMS0wKwYJKoZIhvcNAQkBFh5jYS1vcGVyYXRvckBncmlkLXN1cHBvcnQuYWMu\r\n" - "\tdWswHhcNMDYwNzI3MTQxMzI4WhcNMDcwNzI3MTQxMzI4WjBbMQswCQYDVQQGEwJV\r\n" - "\tSzERMA8GA1UEChMIZVNjaWVuY2UxEzARBgNVBAsTCk1hbmNoZXN0ZXIxCzAJBgNV\r\n" - "\tBAcTmrsogriqMWLAk1DMRcwFQYDVQQDEw5taWNoYWVsIHBhcmQYJKoZIhvcNAQEB\r\n" - "\tBQADggEPADCCAQoCggEBANPEQBgl1IaKdSS1TbhF3hEXSl72G9J+WC/1R64fAcEF\r\n" - "\tW51rEyFYiIeZGx/BVzwXbeBoNUK41OK65sxGuflMo5gLflbwJtHBRIEKAfVVp3YR\r\n" - "\tgW7cMA/s/XKgL1GEC7rQw8lIZT8RApukCGqOVHSi/F1SiFlPDxuDfmdiNzL31+sL\r\n" - "\t0iwHDdNkGjy5pyBSB8Y79dsSJtCW/iaLB0/n8Sj7HgvvZJ7x0fr+RQjYOUUfrePP\r\n" - "\tu2MSpFyf+9BbC/aXgaZuiCvSR+8Snv3xApQY+fULK/xY8h8Ua51iXoQ5jrgu2SqR\r\n" - "\twgA7BUi3G8LFzMBl8FRCDYGUDy7M6QaHXx1ZWIPWNKsCAwEAAaOCAiQwggIgMAwG\r\n" - "\tA1UdEwEB/wQCMAAwEQYJYIZIAYb4QgHTTPAQDAgWgMA4GA1UdDwEB/wQEAwID6DAs\r\n" - "\tBglghkgBhvhCAQ0EHxYdVUsgZS1TY2llbmNlIFVzZXIgQ2VydGlmaWNhdGUwHQYD\r\n" - "\tVR0OBBYEFDTt/sf9PeMaZDHkUIldrDYMNTBZMIGaBgNVHSMEgZIwgY+AFAI4qxGj\r\n" - "\tloCLDdMVKwiljjDastqooXSkcjBwMQswCQYDVQQGEwJVSzERMA8GA1UEChMIZVNj\r\n" - "\taWVuY2UxEjAQBgNVBAsTCUF1dGhvcml0eTELMAkGA1UEAxMCQ0ExLTArBgkqhkiG\r\n" - "\t9w0BCQEWHmNhLW9wZXJhdG9yQGdyaWQtc3VwcG9ydC5hYy51a4IBADApBgNVHRIE\r\n" - "\tIjAggR5jYS1vcGVyYXRvckBncmlkLXN1cHBvcnQuYWMudWswGQYDVR0gBBIwEDAO\r\n" - "\tBgwrBgEEAdkvAQEBAQYwPQYJYIZIAYb4QgEEBDAWLmh0dHA6Ly9jYS5ncmlkLXN1\r\n" - "\tcHBvcnQuYWMudmT4sopwqlBWsvcHViL2NybC9jYWNybC5jcmwwPQYJYIZIAYb4QgEDBDAWLmh0\r\n" - "\tdHA6Ly9jYS5ncmlkLXN1cHBvcnQuYWMudWsvcHViL2NybC9jYWNybC5jcmwwPwYD\r\n" - "\tVR0fBDgwNjA0oDKgMIYuaHR0cDovL2NhLmdyaWQt5hYy51ay9wdWIv\r\n" - "\tY3JsL2NhY3JsLmNybDANBgkqhkiG9w0BAQUFAAOCAQEAS/U4iiooBENGW/Hwmmd3\r\n" - "\tXCy6Zrt08YjKCzGNjorT98g8uGsqYjSxv/hmi0qlnlHs+k/3Iobc3LjS5AMYr5L8\r\n" - "\tUO7OSkgFFlLHQyC9JzPfmLCAugvzEbyv4Olnsr8hbxF1MbKZoQxUZtMVu29wjfXk\r\n" - "\thTeApBv7eaKCWpSp7MCbvgzm74izKhu3vlDk9w6qVrxePfGgpKPqfHiOoGhFnbTK\r\n" - "\twTC6o2xq5y0qZ03JonF7OJspEd3I5zKY3E+ov7/ZhW6DqT8UFvsAdjvQbXyhV8Eu\r\n" - "\tYhixw1aKEPzNjNowuIseVogKOLXxWI5vAi5HgXdS0/ES5gDGsABo4fqovUKlgop3\r\n" - "\tRA==\r\n" - "\t-----END CERTIFICATE-----\r\n" - "\r\n"; - test_simple(dumbfuck2, HPE_OK); - - const char *corrupted_connection = - "GET / HTTP/1.1\r\n" - "Host: www.example.com\r\n" - "Connection\r\033\065\325eep-Alive\r\n" - "Accept-Encoding: gzip\r\n" - "\r\n"; - test_simple(corrupted_connection, HPE_INVALID_HEADER_TOKEN); - - const char *corrupted_header_name = - "GET / HTTP/1.1\r\n" - "Host: www.example.com\r\n" - "X-Some-Header\r\033\065\325eep-Alive\r\n" - "Accept-Encoding: gzip\r\n" - "\r\n"; - test_simple(corrupted_header_name, HPE_INVALID_HEADER_TOKEN); - -#if 0 - // NOTE(Wed Nov 18 11:57:27 CET 2009) this seems okay. we just read body - // until EOF. - // - // no content-length - // error if there is a body without content length - const char *bad_get_no_headers_no_body = "GET /bad_get_no_headers_no_body/world HTTP/1.1\r\n" - "Accept: */*\r\n" - "\r\n" - "HELLO"; - test_simple(bad_get_no_headers_no_body, 0); -#endif - /* TODO sending junk and large headers gets rejected */ - - - /* check to make sure our predefined requests are okay */ - for (i = 0; requests[i].name; i++) { - test_message(&requests[i]); - } - - for (i = 0; i < request_count; i++) { - test_message_pause(&requests[i]); - } - - for (i = 0; i < request_count; i++) { - if (!requests[i].should_keep_alive) continue; - for (j = 0; j < request_count; j++) { - if (!requests[j].should_keep_alive) continue; - for (k = 0; k < request_count; k++) { - test_multiple3(&requests[i], &requests[j], &requests[k]); - } - } - } - - printf("request scan 1/4 "); - test_scan( &requests[GET_NO_HEADERS_NO_BODY] - , &requests[GET_ONE_HEADER_NO_BODY] - , &requests[GET_NO_HEADERS_NO_BODY] - ); - - printf("request scan 2/4 "); - test_scan( &requests[POST_CHUNKED_ALL_YOUR_BASE] - , &requests[POST_IDENTITY_BODY_WORLD] - , &requests[GET_FUNKY_CONTENT_LENGTH] - ); - - printf("request scan 3/4 "); - test_scan( &requests[TWO_CHUNKS_MULT_ZERO_END] - , &requests[CHUNKED_W_TRAILING_HEADERS] - , &requests[CHUNKED_W_BULLSHIT_AFTER_LENGTH] - ); - - printf("request scan 4/4 "); - test_scan( &requests[QUERY_URL_WITH_QUESTION_MARK_GET] - , &requests[PREFIX_NEWLINE_GET ] - , &requests[CONNECT_REQUEST] - ); - - puts("requests okay"); - - return 0; -}