Skip to content

Feature/async #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ on:
push:
branches:
- master

jobs:
test:
name: test
runs-on: ${{ matrix.os }}
strategy:
matrix:
build: [linux_3.8, windows_3.8, mac_3.8, linux_3.7]
build: [linux_3.9, linux_3.8, windows_3.8, mac_3.8]
include:
- build: linux_3.9
os: ubuntu-latest
python: 3.9
- build: linux_3.8
os: ubuntu-latest
python: 3.8
Expand All @@ -24,9 +27,6 @@ jobs:
- build: mac_3.8
os: macos-latest
python: 3.8
- build: linux_3.7
os: ubuntu-latest
python: 3.7
steps:
- name: Checkout repository
uses: actions/checkout@v2
Expand All @@ -35,39 +35,39 @@ jobs:
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel
pip install -r requirements.txt

# test all the builds apart from linux_3.8...
- name: Test with pytest
if: matrix.build != 'linux_3.8'
run: pytest

# only do the test coverage for linux_3.8
- name: Produce coverage report
if: matrix.build == 'linux_3.8'
run: pytest --cov=fastapi_sqlalchemy --cov-report=xml

- name: Upload coverage report
if: matrix.build == 'linux_3.8'
uses: codecov/codecov-action@v1
with:
file: ./coverage.xml

lint:
name: lint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8

- name: Install dependencies
run: pip install flake8
Expand All @@ -85,13 +85,13 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8

- name: Install dependencies
# isort needs all of the packages to be installed so it can
# isort needs all of the packages to be installed so it can
# tell which are third party and which are first party
run: pip install -r requirements.txt

- name: Check formatting of imports
run: isort --check-only --diff --verbose

Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Set up Python 3.7

- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8

- name: Install the dependencies
run: pip install --upgrade pip wheel setuptools

- name: Build the distributions
run: python setup.py sdist bdist_wheel

- name: Upload to PyPI
uses: pypa/gh-action-pypi-publish@master
with:
Expand All @@ -36,10 +36,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v2

- name: Get the version
run: echo ::set-env name=VERSION::${GITHUB_REF#refs/tags/}

- name: Create release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
8 changes: 4 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Usage inside of a route

app.add_middleware(DBSessionMiddleware, db_url="sqlite://")

# once the middleware is applied, any route can then access the database session
# once the middleware is applied, any route can then access the database session
# from the global ``db``

@app.get("/users")
Expand All @@ -49,7 +49,7 @@ Usage inside of a route

return users

Note that the session object provided by ``db.session`` is based on the Python3.7+ ``ContextVar``. This means that
Note that the session object provided by ``db.session`` is based on the Python3.8+ ``ContextVar``. This means that
each session is linked to the individual request context in which it was created.

Usage outside of a route
Expand Down Expand Up @@ -82,15 +82,15 @@ Sometimes it is useful to be able to access the database outside the context of
"""Count the number of users in the database and save it into the user_counts table."""

# we are outside of a request context, therefore we cannot rely on ``DBSessionMiddleware``
# to create a database session for us. Instead, we can use the same ``db`` object and
# to create a database session for us. Instead, we can use the same ``db`` object and
# use it as a context manager, like so:

with db():
user_count = db.session.query(User).count()

db.session.add(UserCount(user_count))
db.session.commit()

# no longer able to access a database session once the db() context manager has ended

return users
Expand Down
5 changes: 3 additions & 2 deletions fastapi_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi_sqlalchemy.middleware import DBSessionMiddleware, db
from fastapi_sqlalchemy.async_middleware import AsyncDBSessionMiddleware, async_db

__all__ = ["db", "DBSessionMiddleware"]
__all__ = ["db", "DBSessionMiddleware", "async_db", "AsyncDBSessionMiddleware"]

__version__ = "0.2.1"
__version__ = "0.3.0"
89 changes: 89 additions & 0 deletions fastapi_sqlalchemy/async_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from contextvars import ContextVar
from typing import Dict, Optional, Union

from sqlalchemy.engine.url import URL
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.types import ASGIApp

from fastapi_sqlalchemy.exceptions import (
MissingSessionError, SessionNotInitialisedError
)

_Session: sessionmaker = None
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)


class AsyncDBSessionMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: ASGIApp,
db_url: Optional[Union[str, URL]] = None,
custom_engine: Optional[AsyncEngine] = None,
engine_args: Dict = None,
session_args: Dict = None,
commit_on_exit: bool = False,
):
super().__init__(app)
global _Session
engine_args = engine_args or {}
self.commit_on_exit = commit_on_exit

session_args = session_args or {}
if not custom_engine and not db_url:
raise ValueError("You need to pass a db_url or a custom_engine parameter.")
if not custom_engine:
engine = create_async_engine(db_url, future=True, **engine_args)
else:
engine = custom_engine
_Session = sessionmaker(bind=engine, class_=AsyncSession, **session_args)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
async with async_db(commit_on_exit=self.commit_on_exit):
response = await call_next(request)
return response


class AsyncDBSessionMeta(type):
# using this metaclass means that we can access db.session as a property at a class level,
# rather than db().session
@property
def session(self) -> AsyncSession:
"""Return an instance of Session local to the current async context."""
if _Session is None:
raise SessionNotInitialisedError

session = _session.get()
if session is None:
raise MissingSessionError

return session


class AsyncDBSession(metaclass=AsyncDBSessionMeta):
def __init__(self, session_args: Dict = None, commit_on_exit: bool = False):
self.token = None
self.session_args = session_args or {}
self.commit_on_exit = commit_on_exit

async def __aenter__(self):
if not isinstance(_Session, sessionmaker):
raise SessionNotInitialisedError
self.token = _session.set(_Session(**self.session_args))
return type(self)

async def __aexit__(self, exc_type, exc_value, traceback):
sess = _session.get()
if exc_type is not None:
await sess.rollback()

if self.commit_on_exit:
await sess.commit()

await sess.close()
_session.reset(self.token)


async_db: AsyncDBSessionMeta = AsyncDBSession
80 changes: 45 additions & 35 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,37 +1,47 @@
appdirs==1.4.3
atomicwrites==1.3.0
attrs==19.3.0
black==19.10b0
certifi==2019.9.11
chardet==3.0.4
click==7.1.1
coverage==4.5.4
entrypoints==0.3
fastapi==0.52.0
flake8==3.7.9
idna==2.8
importlib-metadata==1.5.0
isort==4.3.21
aiosqlite==0.17.0
anyio==3.5.0
async-generator==1.10
atomicwrites==1.4.0
attrs==21.4.0
black==22.3.0
certifi==2021.10.8
cffi==1.15.0
charset-normalizer==2.0.12
click==8.1.3
colorama==0.4.4
coverage==6.3.2
fastapi==0.77.1
flake8==4.0.1
greenlet==1.1.2
h11==0.12.0
httpcore==0.14.7
httpx==0.22.0
idna==3.3
iniconfig==1.1.1
isort==5.10.1
mccabe==0.6.1
more-itertools==7.2.0
packaging==19.2
pathspec==0.7.0
pluggy==0.13.0
py==1.8.0
pycodestyle==2.5.0
pydantic==0.32.2
pyflakes==2.1.1
pyparsing==2.4.2
pytest==5.2.2
pytest-cov==2.8.1
PyYAML==5.3.1
regex==2020.2.20
requests==2.22.0
six==1.12.0
SQLAlchemy==1.3.10
starlette==0.13.2
toml==0.10.0
typed-ast==1.4.1
urllib3==1.25.6
mypy-extensions==0.4.3
outcome==1.1.0
packaging==21.3
pathspec==0.9.0
platformdirs==2.5.2
pluggy==1.0.0
py==1.11.0
pycodestyle==2.8.0
pycparser==2.21
pydantic==1.9.0
pyflakes==2.4.0
pyparsing==3.0.9
pytest==7.1.2
pytest-cov==3.0.0
requests==2.27.1
rfc3986==1.5.0
sniffio==1.2.0
sortedcontainers==2.4.0
SQLAlchemy==1.4.36
starlette==0.19.1
tomli==2.0.1
trio==0.20.0
typing_extensions==4.2.0
urllib3==1.26.9
wcwidth==0.1.7
zipp==3.1.0
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
packages=["fastapi_sqlalchemy"],
package_data={"fastapi_sqlalchemy": ["py.typed"]},
zip_safe=False,
python_requires=">=3.7",
python_requires=">=3.8",
install_requires=["starlette>=0.12.9", "SQLAlchemy>=1.2"],
classifiers=[
"Development Status :: 4 - Beta",
Expand All @@ -34,9 +34,8 @@
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Internet :: WWW/HTTP :: HTTP Servers",
Expand Down
Loading