Skip to content

Commit 94c3767

Browse files
committed
Add client fetch to auth dependency and create new dependency for async stuff
1 parent e4c0922 commit 94c3767

File tree

9 files changed

+102
-105
lines changed

9 files changed

+102
-105
lines changed

pv_site_api/_db_helpers.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import sqlalchemy as sa
1515
import structlog
16+
from fastapi import Depends
1617
from pvsite_datamodel.read.generation import get_pv_generation_by_sites
1718
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, InverterSQL, SiteSQL
1819
from sqlalchemy.orm import Session, aliased
@@ -24,6 +25,7 @@
2425
PVSiteMetadata,
2526
SiteForecastValues,
2627
)
28+
from .session import get_session
2729

2830
logger = structlog.stdlib.get_logger()
2931

@@ -60,12 +62,6 @@ def _get_forecasts_for_horizon(
6062
return list(session.execute(stmt))
6163

6264

63-
def _get_inverters_by_site(session: Session, site_uuid: str) -> list[Row]:
64-
query = session.query(InverterSQL).filter(InverterSQL.site_uuid == site_uuid)
65-
66-
return query.all()
67-
68-
6965
def _get_latest_forecast_by_sites(
7066
session: Session, site_uuids: list[str], start_utc: Optional[dt.datetime] = None
7167
) -> list[Row]:
@@ -240,3 +236,14 @@ def does_site_exist(session: Session, site_uuid: str) -> bool:
240236
session.execute(sa.select(SiteSQL).where(SiteSQL.site_uuid == site_uuid)).one_or_none()
241237
is not None
242238
)
239+
240+
241+
def get_inverters_for_site(
242+
site_uuid: str, session: Session = Depends(get_session)
243+
) -> list[Row] | None:
244+
"""Path dependency to get a list of inverters for a site, or None if the site doesn't exist"""
245+
if not does_site_exist(session, site_uuid):
246+
return None
247+
248+
query = session.query(InverterSQL).filter(InverterSQL.site_uuid == site_uuid)
249+
return query.all()

pv_site_api/auth.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import jwt
22
from fastapi import Depends, HTTPException
33
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
4+
from pvsite_datamodel.sqlmodels import ClientSQL
5+
from sqlalchemy.orm import Session
6+
7+
from .session import get_session
48

59
token_auth_scheme = HTTPBearer()
610

@@ -15,7 +19,11 @@ def __init__(self, domain: str, api_audience: str, algorithm: str):
1519

1620
self._jwks_client = jwt.PyJWKClient(f"https://{domain}/.well-known/jwks.json")
1721

18-
def __call__(self, auth_credentials: HTTPAuthorizationCredentials = Depends(token_auth_scheme)):
22+
def __call__(
23+
self,
24+
auth_credentials: HTTPAuthorizationCredentials = Depends(token_auth_scheme),
25+
session: Session = Depends(get_session),
26+
):
1927
token = auth_credentials.credentials
2028

2129
try:
@@ -24,7 +32,7 @@ def __call__(self, auth_credentials: HTTPAuthorizationCredentials = Depends(toke
2432
raise HTTPException(status_code=401, detail=str(e))
2533

2634
try:
27-
payload = jwt.decode(
35+
jwt.decode(
2836
token,
2937
signing_key,
3038
algorithms=self._algorithm,
@@ -34,4 +42,11 @@ def __call__(self, auth_credentials: HTTPAuthorizationCredentials = Depends(toke
3442
except Exception as e:
3543
raise HTTPException(status_code=401, detail=str(e))
3644

37-
return payload
45+
if session is None:
46+
return None
47+
48+
# @TODO: get client corresponding to auth
49+
# See: https://github.com/openclimatefix/pv-site-api/issues/90
50+
client = session.query(ClientSQL).first()
51+
assert client is not None
52+
return client

pv_site_api/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def wrapper(*args, **kwargs): # noqa
3939
route_variables = kwargs.copy()
4040

4141
# drop session and user
42-
for var in ["session", "user"]:
42+
for var in ["session", "user", "auth"]:
4343
if var in route_variables:
4444
route_variables.pop(var)
4545

pv_site_api/enode_auth.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,35 @@ def __init__(
1212
self._token_url = token_url
1313
self._access_token = access_token
1414

15-
def auth_flow(self, request: httpx.Request):
15+
def sync_auth_flow(self, request: httpx.Request):
1616
# Add the Authorization header to the request using the current access token
1717
request.headers["Authorization"] = f"Bearer {self._access_token}"
1818
response = yield request
1919

2020
if response.status_code == 401:
2121
# The access token is no longer valid, refresh it
2222
token_response = yield self._build_refresh_request()
23+
token_response.read()
2324
self._update_access_token(token_response)
2425
# Update the request's Authorization header with the new access token
2526
request.headers["Authorization"] = f"Bearer {self._access_token}"
2627
# Resend the request with the new access token
27-
response = yield request
28-
return response
28+
yield request
29+
30+
async def async_auth_flow(self, request: httpx.Request):
31+
# Add the Authorization header to the request using the current access token
32+
request.headers["Authorization"] = f"Bearer {self._access_token}"
33+
response = yield request
34+
35+
if response.status_code == 401:
36+
# The access token is no longer valid, refresh it
37+
token_response = yield self._build_refresh_request()
38+
await token_response.aread()
39+
self._update_access_token(token_response)
40+
# Update the request's Authorization header with the new access token
41+
request.headers["Authorization"] = f"Bearer {self._access_token}"
42+
# Resend the request with the new access token
43+
yield request
2944

3045
def _build_refresh_request(self):
3146
basic_auth = httpx.BasicAuth(self._client_id, self._client_secret)
@@ -35,5 +50,4 @@ def _build_refresh_request(self):
3550
return request
3651

3752
def _update_access_token(self, response):
38-
response.read()
3953
self._access_token = response.json()["access_token"]

0 commit comments

Comments
 (0)