Skip to content

Commit 955c612

Browse files
committed
Update typing in api_jwt
1 parent d64f155 commit 955c612

File tree

1 file changed

+40
-10
lines changed

1 file changed

+40
-10
lines changed

jwt/api_jwt.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121

2222

2323
class PyJWT:
24-
def __init__(self, options=None):
24+
def __init__(self, options: Optional[dict[str, Any]] = None) -> None:
2525
if options is None:
2626
options = {}
27-
self.options = {**self._get_default_options(), **options}
27+
self.options: dict[str, Any] = {**self._get_default_options(), **options}
2828

2929
@staticmethod
3030
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
@@ -157,7 +157,7 @@ def decode(
157157
leeway: Union[int, float, timedelta] = 0,
158158
# kwargs
159159
**kwargs,
160-
) -> Dict[str, Any]:
160+
) -> Any:
161161
if kwargs:
162162
warnings.warn(
163163
"passing additional kwargs to decode() is deprecated "
@@ -178,7 +178,14 @@ def decode(
178178
)
179179
return decoded["payload"]
180180

181-
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0):
181+
def _validate_claims(
182+
self,
183+
payload: dict[str, Any],
184+
options: dict[str, Any],
185+
audience=None,
186+
issuer=None,
187+
leeway: float | timedelta = 0,
188+
) -> None:
182189
if isinstance(leeway, timedelta):
183190
leeway = leeway.total_seconds()
184191

@@ -204,12 +211,21 @@ def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=
204211
if options["verify_aud"]:
205212
self._validate_aud(payload, audience)
206213

207-
def _validate_required_claims(self, payload, options):
214+
def _validate_required_claims(
215+
self,
216+
payload: dict[str, Any],
217+
options: dict[str, Any],
218+
) -> None:
208219
for claim in options["require"]:
209220
if payload.get(claim) is None:
210221
raise MissingRequiredClaimError(claim)
211222

212-
def _validate_iat(self, payload, now, leeway):
223+
def _validate_iat(
224+
self,
225+
payload: dict[str, Any],
226+
now: float,
227+
leeway: float,
228+
) -> None:
213229
iat = payload["iat"]
214230
try:
215231
int(iat)
@@ -218,7 +234,12 @@ def _validate_iat(self, payload, now, leeway):
218234
if iat > (now + leeway):
219235
raise ImmatureSignatureError("The token is not yet valid (iat)")
220236

221-
def _validate_nbf(self, payload, now, leeway):
237+
def _validate_nbf(
238+
self,
239+
payload: dict[str, Any],
240+
now: float,
241+
leeway: float,
242+
) -> None:
222243
try:
223244
nbf = int(payload["nbf"])
224245
except ValueError:
@@ -227,7 +248,12 @@ def _validate_nbf(self, payload, now, leeway):
227248
if nbf > (now + leeway):
228249
raise ImmatureSignatureError("The token is not yet valid (nbf)")
229250

230-
def _validate_exp(self, payload, now, leeway):
251+
def _validate_exp(
252+
self,
253+
payload: dict[str, Any],
254+
now: float,
255+
leeway: float,
256+
) -> None:
231257
try:
232258
exp = int(payload["exp"])
233259
except ValueError:
@@ -236,7 +262,11 @@ def _validate_exp(self, payload, now, leeway):
236262
if exp <= (now - leeway):
237263
raise ExpiredSignatureError("Signature has expired")
238264

239-
def _validate_aud(self, payload, audience):
265+
def _validate_aud(
266+
self,
267+
payload: dict[str, Any],
268+
audience: Optional[Union[str, Iterable[str]]],
269+
) -> None:
240270
if audience is None:
241271
if "aud" not in payload or not payload["aud"]:
242272
return
@@ -264,7 +294,7 @@ def _validate_aud(self, payload, audience):
264294
if all(aud not in audience_claims for aud in audience):
265295
raise InvalidAudienceError("Invalid audience")
266296

267-
def _validate_iss(self, payload, issuer):
297+
def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
268298
if issuer is None:
269299
return
270300

0 commit comments

Comments
 (0)