From 2c8f310958345b22cf332a802bd4eef0852eedd7 Mon Sep 17 00:00:00 2001 From: Tert0 <62036464+Tert0@users.noreply.github.com> Date: Wed, 4 Dec 2024 21:11:31 +0100 Subject: [PATCH] fix tests and codestyle --- fastapi_framework/__init__.py | 27 ++++++---------- fastapi_framework/config.py | 2 +- fastapi_framework/database.py | 11 ++++--- fastapi_framework/in_memory_backend.py | 2 +- fastapi_framework/jwt_auth.py | 3 +- fastapi_framework/logger.py | 3 +- fastapi_framework/modules.py | 2 +- fastapi_framework/rate_limit.py | 6 ++-- fastapi_framework/redis.py | 8 ++--- fastapi_framework/session.py | 7 ++-- fastapi_framework/settings.py | 6 ++-- tests/test_database.py | 45 +++++++++++++------------- tests/test_jwt_auth.py | 33 +++++++++---------- tests/test_rate_limit.py | 20 ++++++------ 14 files changed, 85 insertions(+), 90 deletions(-) diff --git a/fastapi_framework/__init__.py b/fastapi_framework/__init__.py index cc2b9547..ab6cf7df 100644 --- a/fastapi_framework/__init__.py +++ b/fastapi_framework/__init__.py @@ -5,23 +5,16 @@ from .modules import check_dependencies, disabled_modules check_dependencies() # noqa: FLK-E402 +from .config import Config, ConfigField from .database import database_dependency -from .jwt_auth import ( - create_jwt_token, - create_access_token, - create_refresh_token, - invalidate_refresh_token, - get_token, - get_data, - pwd_context, - ACCESS_TOKEN_EXPIRE_MINUTES, - REFRESH_TOKEN_EXPIRE_MINUTES, - check_refresh_token, - generate_tokens, -) -from .logger import get_logger -from .rate_limit import RateLimitManager, RateLimiter, get_uuid_user_id, RateLimitTime -from .redis import get_redis, RedisDependency, redis_dependency, Redis from .in_memory_backend import InMemoryBackend, RAMBackend -from .config import Config, ConfigField +from .jwt_auth import (ACCESS_TOKEN_EXPIRE_MINUTES, + REFRESH_TOKEN_EXPIRE_MINUTES, check_refresh_token, + create_access_token, create_jwt_token, + create_refresh_token, generate_tokens, get_data, + get_token, invalidate_refresh_token, pwd_context) +from .logger import get_logger +from .rate_limit import (RateLimiter, RateLimitManager, RateLimitTime, + get_uuid_user_id) +from .redis import Redis, RedisDependency, get_redis, redis_dependency from .session import Session diff --git a/fastapi_framework/config.py b/fastapi_framework/config.py index 586b3286..2849d4e8 100644 --- a/fastapi_framework/config.py +++ b/fastapi_framework/config.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Optional, Callable, List +from typing import Any, Callable, Dict, List, Optional from fastapi_framework import disabled_modules diff --git a/fastapi_framework/database.py b/fastapi_framework/database.py index 23fadf92..16e4306c 100644 --- a/fastapi_framework/database.py +++ b/fastapi_framework/database.py @@ -1,16 +1,19 @@ from os import getenv -from typing import TypeVar, Dict +from typing import Dict, TypeVar from dotenv import load_dotenv from sqlalchemy.engine import URL -from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine, AsyncSession, async_sessionmaker +from sqlalchemy.ext.asyncio import (AsyncEngine, AsyncSession, + async_sessionmaker, create_async_engine) from sqlalchemy.future import select as sa_select from sqlalchemy.orm import DeclarativeBase from sqlalchemy.pool import NullPool from sqlalchemy.sql import Executable -from sqlalchemy.sql.expression import exists as sa_exists, delete as sa_delete, Delete +from sqlalchemy.sql.expression import Delete +from sqlalchemy.sql.expression import delete as sa_delete +from sqlalchemy.sql.expression import exists as sa_exists from sqlalchemy.sql.functions import count -from sqlalchemy.sql.selectable import Select, Exists +from sqlalchemy.sql.selectable import Exists, Select from .logger import get_logger diff --git a/fastapi_framework/in_memory_backend.py b/fastapi_framework/in_memory_backend.py index 486d2752..757ab979 100644 --- a/fastapi_framework/in_memory_backend.py +++ b/fastapi_framework/in_memory_backend.py @@ -1,6 +1,6 @@ import time -from typing import Dict, Any, Optional, Set, Union, List from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Set, Union class InMemoryBackend(ABC): diff --git a/fastapi_framework/jwt_auth.py b/fastapi_framework/jwt_auth.py index afdb4ff2..6adcb9fe 100644 --- a/fastapi_framework/jwt_auth.py +++ b/fastapi_framework/jwt_auth.py @@ -3,12 +3,13 @@ from typing import Dict import jwt -from .redis import Redis from dotenv import load_dotenv from fastapi import Depends, HTTPException from fastapi.security import HTTPBearer from passlib.context import CryptContext +from .redis import Redis + load_dotenv() SECRET_KEY = getenv("JWT_SECRET_KEY", "") diff --git a/fastapi_framework/logger.py b/fastapi_framework/logger.py index c3a165aa..feb421dc 100644 --- a/fastapi_framework/logger.py +++ b/fastapi_framework/logger.py @@ -1,8 +1,9 @@ import logging import sys -from dotenv import load_dotenv from os import getenv +from dotenv import load_dotenv + from .modules import disabled_modules load_dotenv() diff --git a/fastapi_framework/modules.py b/fastapi_framework/modules.py index 1dda3e0a..9da986d1 100644 --- a/fastapi_framework/modules.py +++ b/fastapi_framework/modules.py @@ -1,5 +1,5 @@ from os import getenv -from typing import List, Set, Dict +from typing import Dict, List, Set from dotenv import load_dotenv diff --git a/fastapi_framework/rate_limit.py b/fastapi_framework/rate_limit.py index 89bbc580..72ab7799 100644 --- a/fastapi_framework/rate_limit.py +++ b/fastapi_framework/rate_limit.py @@ -1,7 +1,7 @@ -from typing import Union, Callable, Dict, Coroutine, Optional, Any +from typing import Any, Callable, Coroutine, Dict, Optional, Union -from fastapi import Request, HTTPException, Response -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi import HTTPException, Request, Response +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from .in_memory_backend import InMemoryBackend from .jwt_auth import get_data diff --git a/fastapi_framework/redis.py b/fastapi_framework/redis.py index a6f1e4b6..2811ee5e 100644 --- a/fastapi_framework/redis.py +++ b/fastapi_framework/redis.py @@ -1,11 +1,11 @@ -from typing import Set, Any, Optional +from os import getenv +from typing import Any, Optional, Set -from aioredis import create_redis_pool from aioredis import Redis as RedisConnection +from aioredis import create_redis_pool from dotenv import load_dotenv -from os import getenv -from .in_memory_backend import InMemoryBackend, RAMBackend +from .in_memory_backend import InMemoryBackend, RAMBackend from .modules import disabled_modules load_dotenv() diff --git a/fastapi_framework/session.py b/fastapi_framework/session.py index aab1eec7..8801f731 100644 --- a/fastapi_framework/session.py +++ b/fastapi_framework/session.py @@ -1,12 +1,13 @@ import random import string -from typing import Union, Callable, Coroutine, Type, Optional +from typing import Callable, Coroutine, Optional, Type, Union from fastapi import FastAPI -from pydantic import BaseModel from fastapi.requests import Request from fastapi.responses import Response -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from pydantic import BaseModel +from starlette.middleware.base import (BaseHTTPMiddleware, + RequestResponseEndpoint) from .redis import redis_dependency diff --git a/fastapi_framework/settings.py b/fastapi_framework/settings.py index 05446871..18384c78 100644 --- a/fastapi_framework/settings.py +++ b/fastapi_framework/settings.py @@ -1,11 +1,11 @@ from os import getenv -from typing import Union, Optional +from typing import Optional, Union from sqlalchemy import String from sqlalchemy.orm import Mapped, mapped_column -from .redis import redis_dependency, Redis -from .database import database_dependency, DB, select, Base +from .database import DB, Base, database_dependency, select +from .redis import Redis, redis_dependency CACHE_TTL = int(getenv("CACHE_TTL", str(60 * 60 * 5))) diff --git a/tests/test_database.py b/tests/test_database.py index 21ba263d..af8f2e72 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,28 +1,27 @@ from os import getenv -from typing import Union, List, Dict +from random import choices +from string import ascii_letters +from typing import Dict, List from unittest import IsolatedAsyncioTestCase from unittest.mock import MagicMock, patch -from fastapi import HTTPException, FastAPI, Depends -from pydantic import BaseModel, constr, conint -from sqlalchemy import Column, String, Integer -from sqlalchemy.orm import mapped_column, Mapped +from fastapi import Depends, FastAPI, HTTPException +from httpx import ASGITransport, AsyncClient, Response +from pydantic import BaseModel, constr +from sqlalchemy import Integer, String +from sqlalchemy.orm import Mapped, mapped_column from fastapi_framework.database import ( - select, - filter_by, - exists, - delete, - database_dependency, DB, - DatabaseDependency, Base, + DatabaseDependency, + database_dependency, + delete, + exists, + filter_by, + select, ) -from httpx import AsyncClient, Response -from random import choices -from string import ascii_letters - app = FastAPI() @@ -137,7 +136,7 @@ async def test_add_row(self): db._session.add.assert_called_with(row) async def test_get_users(self): - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response: Response = await ac.get("/users") self.assertIsInstance(response.json(), List) @@ -145,7 +144,7 @@ async def test_get_users(self): async def test_add_user(self): username = "".join(choices(ascii_letters, k=100)) - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response: Response = await ac.post(f"/users/{username}") self.assertEqual(response.status_code, 200) @@ -157,7 +156,7 @@ async def test_add_user(self): async def test_add_user_already_exists(self): username = "".join(choices(ascii_letters, k=100)) - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response: Response = await ac.post(f"/users/{username}") response2: Response = await ac.post(f"/users/{username}") @@ -167,9 +166,9 @@ async def test_add_user_already_exists(self): async def test_get_user_by_name(self): username = "".join(choices(ascii_letters, k=100)) - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: await ac.post(f"/users/{username}") - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response: Response = await ac.get(f"/users/{username}") self.assertEqual(response.status_code, 200) @@ -179,19 +178,19 @@ async def test_get_user_by_name(self): async def test_remove_user(self): username = "".join(choices(ascii_letters, k=100)) - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response: Response = await ac.post(f"/users/{username}") self.assertEqual(response.status_code, 200) - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response = await ac.delete(f"/users/{username}") self.assertEqual(response.status_code, 200) self.assertEqual(response.content.decode("utf-8"), "true") async def test_remove_user_not_exists(self): - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response: Response = await ac.delete("/users/this_username_dont_exists") self.assertEqual(response.status_code, 404) diff --git a/tests/test_jwt_auth.py b/tests/test_jwt_auth.py index 7e0829ad..71c7749b 100644 --- a/tests/test_jwt_auth.py +++ b/tests/test_jwt_auth.py @@ -1,24 +1,24 @@ -from datetime import timedelta, datetime -from typing import Dict, Union, List +from datetime import datetime, timedelta +from typing import Dict, List, Union from unittest import IsolatedAsyncioTestCase -from unittest.mock import MagicMock, patch, AsyncMock +from unittest.mock import AsyncMock, MagicMock, patch import jwt from aioredis import Redis -from fastapi import FastAPI, Depends, HTTPException -from httpx import AsyncClient, Response +from fastapi import Depends, FastAPI, HTTPException +from httpx import ASGITransport, AsyncClient, Response from fastapi_framework import redis_dependency from fastapi_framework.jwt_auth import ( - get_token, - create_jwt_token, ALGORITHM, - get_data, + check_refresh_token, create_access_token, + create_jwt_token, create_refresh_token, - invalidate_refresh_token, - check_refresh_token, generate_tokens, + get_data, + get_token, + invalidate_refresh_token, ) app = FastAPI() @@ -161,10 +161,7 @@ async def test_check_refresh_token_positive(self): async def test_check_refresh_token_negative(self): redis = AsyncMock() redis.smembers = AsyncMock() - redis.smembers.return_value = [ - b"TEST_FALSE_REFRESH_TOKEN", - b"TEST_SECOND_FALSE_REFRESH_TOKEN", - ] + redis.smembers.return_value = [b"TEST_FALSE_REFRESH_TOKEN", b"TEST_SECOND_FALSE_REFRESH_TOKEN"] result = await check_refresh_token("TEST_REFRESH_TOKEN", redis) @@ -195,7 +192,7 @@ async def test_generate_tokens(self): @patch("fastapi_framework.jwt_auth.SECRET_KEY", "TEST_SECRET_KEY") async def test_login(self): - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response: Response = await ac.get("/token", params={"username": "test", "password": "123"}) self.assertEqual(response.status_code, 200) @@ -206,14 +203,14 @@ async def test_login(self): @patch("fastapi_framework.jwt_auth.SECRET_KEY", "TEST_SECRET_KEY") async def test_login_invalid_credentials(self): - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response: Response = await ac.get("/token", params={"username": "not_exists", "password": "wrong"}) self.assertEqual(response.status_code, 401) @patch("fastapi_framework.jwt_auth.SECRET_KEY", "TEST_SECRET_KEY") async def test_secret_route(self): - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response: Response = await ac.get("/token", params={"username": "test", "password": "123"}) self.assertEqual(response.status_code, 200) @@ -222,7 +219,7 @@ async def test_secret_route(self): self.assertEqual(response.json()["token_type"], "bearer") access_token = response.json()["access_token"] - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response: Response = await ac.get("/secret", headers={"Authorization": f"Bearer {access_token}"}) self.assertEqual(response.status_code, 200) diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py index fa7c8a57..9aa66a4d 100644 --- a/tests/test_rate_limit.py +++ b/tests/test_rate_limit.py @@ -1,19 +1,19 @@ from typing import List from unittest import IsolatedAsyncioTestCase -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch -from fastapi import HTTPException, FastAPI, Depends, Request -from httpx import AsyncClient, Response +from fastapi import Depends, FastAPI, HTTPException +from httpx import ASGITransport, AsyncClient, Response +from fastapi_framework import rate_limit, redis_dependency from fastapi_framework.rate_limit import ( - RateLimitManager, RateLimiter, + RateLimitManager, RateLimitTime, - default_get_uuid, default_callback, + default_get_uuid, get_uuid_user_id, ) -from fastapi_framework import rate_limit, redis_dependency app = FastAPI() @@ -126,7 +126,7 @@ async def test_get_uuid_user_id_no_token(self, http_bearer_patch: AsyncMock, get async def test_limited_route(self): self.testing_uuid = "test_limited_route" - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: response: Response = await ac.get("/limited") self.assertEqual(response.status_code, 200) @@ -137,7 +137,7 @@ async def test_limited_route_without_init(self): RateLimitManager.redis = None with self.assertRaises(Exception): - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: await ac.get("/limited") RateLimitManager.redis = await redis_dependency() @@ -147,7 +147,7 @@ async def test_spam_limited_route_with_async_callback(self): async_callback = AsyncMock() RateLimitManager.callback = async_callback - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: for i in range(4): response: Response = await ac.get("/limited") @@ -162,7 +162,7 @@ async def test_spam_limited_route(self): self.testing_uuid = "test_spam_limited_route" responses: List[Response] = [] - async with AsyncClient(app=app, base_url="https://test") as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="https://test") as ac: for _ in range(4): responses.append(await ac.get("/limited"))