Skip to content

Commit

Permalink
Merge pull request #1 from Achronus/v0.1.13
Browse files Browse the repository at this point in the history
v0.1.13 - Authentication Updates
  • Loading branch information
Achronus authored Sep 9, 2024
2 parents ce7d5a9 + d0e46a8 commit 24f77bd
Show file tree
Hide file tree
Showing 20 changed files with 402 additions and 122 deletions.
50 changes: 49 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zentra_api"
version = "0.1.12"
version = "0.1.13"
description = "A CLI tool for building FastAPI apps faster."
authors = ["Ryan Partridge <rpartridge101@gmail.com>"]
license = "MIT License"
Expand Down Expand Up @@ -35,6 +35,7 @@ pydantic = "^2.8.2"
toml = "^0.10.2"
sqlalchemy = "^2.0.31"
bcrypt = "^4.2.0"
inflect = "^7.3.1"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2"
Expand Down
6 changes: 4 additions & 2 deletions tests/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ def database_config() -> DatabaseConfig:
@pytest.fixture
def auth_config() -> AuthConfig:
return AuthConfig(
SECRET_KEY="supersecret",
SECRET_ACCESS_KEY="supersecretaccess",
SECRET_REFRESH_KEY="supersecretrefresh",
ALGORITHM="HS256",
ACCESS_TOKEN_EXPIRE_MINUTES=10080,
ACCESS_TOKEN_EXPIRE_MINS=15,
REFRESH_TOKEN_EXPIRE_MINUTES=10080,
TOKEN_URL="auth/token",
ROUNDS=12,
)
Expand Down
30 changes: 24 additions & 6 deletions tests/root/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,56 @@
from pydantic import ValidationError

from zentra_api.schema import Token
from zentra_api.enums import TokenType


class TestToken:
@staticmethod
def test_init():
token = Token(
access_token="valid_access_token",
token_type=TokenType.BEARER,
refresh_token="valid_refresh_token",
token_type="bearer",
)

assert isinstance(token, Token)
assert token.access_token == "valid_access_token"
assert token.refresh_token == "valid_refresh_token"
assert token.token_type == "bearer"

@staticmethod
def test_invalid_access_token():
with pytest.raises(ValidationError):
Token(access_token=None, token_type=TokenType.BEARER)
Token(access_token=None, token_type="bearer")

@staticmethod
def test_invalid_token_type():
with pytest.raises(ValidationError):
Token(access_token="valid_access_token", token_type="invalid_token_type")
Token(
access_token="valid_access_token",
refresh_token="valid_refresh_token",
token_type="invalid_token_type",
)

@staticmethod
def test_model_dump_valid():
token = Token(access_token="valid_access_token", token_type=TokenType.BEARER)
def test_invalid_refresh_token():
with pytest.raises(ValidationError):
Token(
access_token="valid_access_token",
refresh_token=None,
token_type="bearer",
)

@staticmethod
def test_model_dump_valid():
token = Token(
access_token="valid_access_token",
refresh_token="valid_refresh_token",
token_type="bearer",
)
token_dict = token.model_dump()

assert token_dict == {
"access_token": "valid_access_token",
"refresh_token": "valid_refresh_token",
"token_type": "bearer",
}
76 changes: 63 additions & 13 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
class TestSecurityUtils:
@pytest.fixture
def security_utils(self) -> SecurityUtils:
return SecurityUtils(auth=AuthConfig(SECRET_KEY="asupersecretkey"))
return SecurityUtils(
auth=AuthConfig(
SECRET_ACCESS_KEY="secretaccess", SECRET_REFRESH_KEY="secretrefresh"
)
)

@staticmethod
def test_hash_password(security_utils: SecurityUtils):
Expand All @@ -28,22 +32,37 @@ def test_verify_password(security_utils: SecurityUtils):
assert not security_utils.verify_password("wrongpassword", hashed_password)

@staticmethod
def test_expiration_with_value(security_utils: SecurityUtils):
def test_expiration_with_delta(security_utils: SecurityUtils):
expires_delta = timedelta(minutes=5)
tolerance = timedelta(seconds=2)
expected_expire_time = datetime.now(timezone.utc) + expires_delta
expire_time = security_utils.expiration(expires_delta)
expire_time = security_utils.expiration("access", expires_delta)

assert (
abs((expire_time - expected_expire_time).total_seconds())
<= tolerance.total_seconds()
), f"Expected expire time to be close to {expected_expire_time}, but got {expire_time}"

@staticmethod
def test_expiration_access_default(security_utils: SecurityUtils):
tolerance = timedelta(seconds=2)
expected_expire_time = datetime.now(timezone.utc) + security_utils.expire_mins(
"access"
)
expire_time = security_utils.expiration("access")

assert (
abs((expire_time - expected_expire_time).total_seconds())
<= tolerance.total_seconds()
), f"Expected expire time to be close to {expected_expire_time}, but got {expire_time}"

@staticmethod
def test_expiration_default(security_utils: SecurityUtils):
def test_expiration_refresh_default(security_utils: SecurityUtils):
tolerance = timedelta(seconds=2)
expected_expire_time = datetime.now(timezone.utc) + security_utils.expire_mins()
expire_time = security_utils.expiration()
expected_expire_time = datetime.now(timezone.utc) + security_utils.expire_mins(
"refresh"
)
expire_time = security_utils.expiration("refresh")

assert (
abs((expire_time - expected_expire_time).total_seconds())
Expand Down Expand Up @@ -77,27 +96,58 @@ def test_create_access_token(security_utils: SecurityUtils):
token = security_utils.create_access_token(data)
decoded_data = jwt.decode(
token,
key=security_utils.auth.SECRET_KEY,
key=security_utils.auth.SECRET_ACCESS_KEY,
algorithms=[security_utils.auth.ALGORITHM],
)
assert decoded_data["sub"] == "testuser", (token, decoded_data)

@staticmethod
def test_create_refresh_token(security_utils: SecurityUtils):
data = {"sub": "testuser"}
token = security_utils.create_refresh_token(data)
decoded_data = jwt.decode(
token,
key=security_utils.auth.SECRET_REFRESH_KEY,
algorithms=[security_utils.auth.ALGORITHM],
)
assert decoded_data["sub"] == "testuser", (token, decoded_data)

@staticmethod
def test_get_token_data(security_utils: SecurityUtils):
def test_verify_access_token(security_utils: SecurityUtils):
data = {"sub": "testuser"}
token = security_utils.create_access_token(data)
token_data = security_utils.verify_token(token)
token_data = security_utils.verify_access_token(token)
assert token_data == "testuser"

@staticmethod
def test_verify_refresh_token(security_utils: SecurityUtils):
data = {"sub": "testuser"}
token = security_utils.create_refresh_token(data)
token_data = security_utils.verify_refresh_token(token)
assert token_data == "testuser"

@staticmethod
def test_get_token_data_invalid_token(security_utils: SecurityUtils):
def test_invalid_access_token(security_utils: SecurityUtils):
with pytest.raises(HTTPException):
security_utils.verify_access_token("invalidtoken")

@staticmethod
def test_invalid_refresh_token(security_utils: SecurityUtils):
with pytest.raises(HTTPException):
security_utils.verify_refresh_token("invalidtoken")

@staticmethod
def test_empty_token_data_access(security_utils: SecurityUtils):
data = {"sub": None}
token = security_utils.create_access_token(data)

with pytest.raises(HTTPException):
security_utils.verify_token("invalidtoken")
security_utils.verify_access_token(token)

@staticmethod
def test_empty_token_data(security_utils: SecurityUtils):
def test_empty_token_data_refresh(security_utils: SecurityUtils):
data = {"sub": None}
token = security_utils.create_access_token(data)

with pytest.raises(HTTPException):
security_utils.verify_token(token)
security_utils.verify_refresh_token(token)
8 changes: 4 additions & 4 deletions zentra_api/auth/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class BcryptContext(BaseModel):
A custom context for `bcrypt` hashing.
Parameters:
- `rounds` (`integer, optional`) - the computational cost factor for hashing. `12` by default
rounds (integer, optional): the computational cost factor for hashing. `12` by default
"""

rounds: int = 12
Expand All @@ -17,7 +17,7 @@ def hash(self, password: str) -> str:
Hashes a password. Returns the hashed password.
Parameters:
- `password` (`string`) - the plain password to hash
password (string): the plain password to hash
"""
salt = bcrypt.gensalt(rounds=self.rounds)
hashed_password = bcrypt.hashpw(password.encode("utf-8"), salt)
Expand All @@ -28,7 +28,7 @@ def verify(self, password: str, hashed_password: str) -> bool:
Verifies a password against a given hash. Returns `True` if the password matches, `False` otherwise.
Parameters:
- `password` (`string`) - the plain password to verify
- `hashed_password` (`string`) - The hashed password to verify against
password (string): the plain password to verify
hashed_password (string): The hashed password to verify against
"""
return bcrypt.checkpw(password.encode("utf-8"), hashed_password.encode("utf-8"))
6 changes: 0 additions & 6 deletions zentra_api/auth/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,3 @@ class JWTAlgorithm(StrEnum):
HS256 = "HS256"
HS384 = "HS384"
HS512 = "HS512"


class DeploymentType(StrEnum):
RAILWAY = "railway"
DOCKERFILE = "dockerfile"
DOCKER_COMPOSE = "docker_compose"
Loading

0 comments on commit 24f77bd

Please sign in to comment.