Skip to content

Commit

Permalink
Use resources (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
Masynchin authored Jul 15, 2023
1 parent df8ac22 commit b129169
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 133 deletions.
4 changes: 3 additions & 1 deletion app/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

from app.bot.main import main

main()
asyncio.run(main())
65 changes: 34 additions & 31 deletions app/bot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from aiogram import Bot, Dispatcher
from aiogram.fsm.storage.memory import MemoryStorage
import aiohttp

from app import config
from app.bot.handlers import (
Expand All @@ -27,7 +28,7 @@
from app.bot.polling import Polling
from app.bot.webhook import Webhook
from app.bot.task import MailingTask
from app.db import Subscribers
from app.db import Subscribers, async_session
from app.logger import logger
from app.weather import OwmWeather

Expand All @@ -37,40 +38,42 @@


@logger.catch(level="CRITICAL")
def main():
async def main():
"""Главная функция, отвечающая за запуск бота и рассылки"""
bot = Bot(token=config.BOT_TOKEN)
dp = Dispatcher(storage=MemoryStorage())

logger.info("Запуск")

db = Subscribers()
weather = OwmWeather.for_che(config.WEATHER_API_KEY)
task = MailingTask.default(db, weather)
routes = [
Welcome(),
Info(),
CurrentWeather(weather),
HourForecast(weather),
ExactHourOptions(),
ExactHourForecast(weather),
DailyForecast(weather),
ExactDayOptions(),
ExactDayForecast(weather),
MailingInfo(db),
SubscribeToMailing(),
SetMailingHour(),
SetMailingMinute(db),
ChangeMailingTime(),
ChangeMailingHour(),
ChangeMailingMinute(db),
CancelMailing(db),
Errors(),
]
for route in routes:
route.register(dp)
async with async_session() as db_session, \
aiohttp.ClientSession() as client_session:
db = Subscribers(db_session)
weather = OwmWeather.for_che(config.WEATHER_API_KEY, client_session)
task = MailingTask.default(db, weather)
routes = [
Welcome(),
Info(),
CurrentWeather(weather),
HourForecast(weather),
ExactHourOptions(),
ExactHourForecast(weather),
DailyForecast(weather),
ExactDayOptions(),
ExactDayForecast(weather),
MailingInfo(db),
SubscribeToMailing(),
SetMailingHour(),
SetMailingMinute(db),
ChangeMailingTime(),
ChangeMailingHour(),
ChangeMailingMinute(db),
CancelMailing(db),
Errors(),
]
for route in routes:
route.register(dp)

if config.RUN_TYPE == "polling":
Polling(dp, tasks=[task]).run(bot)
elif config.RUN_TYPE == "webhook":
...
if config.RUN_TYPE == "polling":
await Polling(dp, tasks=[task]).run(bot)
elif config.RUN_TYPE == "webhook":
...
7 changes: 5 additions & 2 deletions app/bot/polling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from contextlib import suppress

from app.db import create_db


Expand All @@ -8,9 +10,10 @@ def __init__(self, dp, tasks):
self.dp = dp
self.tasks = tasks

def run(self, bot):
async def run(self, bot):
self.dp.startup.register(on_startup(bot, self.tasks))
self.dp.run_polling(bot)
with suppress(KeyboardInterrupt, SystemExit):
await self.dp.start_polling(bot)


def on_startup(bot, tasks):
Expand Down
45 changes: 21 additions & 24 deletions app/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,47 +28,44 @@ class Subscriber(Base):
class Subscribers:
"""БД с подписчиками"""

def __init__(self, session):
self.session = session

async def add(self, user_id, mailing_time):
"""Регистрация в БД нового подписчика рассылки"""
async with async_session() as session:
subscriber = Subscriber(id=user_id, mailing_time=mailing_time)
session.add(subscriber)
await session.commit()
subscriber = Subscriber(id=user_id, mailing_time=mailing_time)
self.session.add(subscriber)
await self.session.commit()

async def new_time(self, user_id, new_mailing_time):
"""Меняем время рассылки подписчика"""
async with async_session() as session:
subscriber = await session.get(Subscriber, user_id)
subscriber.mailing_time = new_mailing_time
await session.commit()
subscriber = await self.session.get(Subscriber, user_id)
subscriber.mailing_time = new_mailing_time
await self.session.commit()

async def delete(self, user_id):
"""Удаление подписчика из БД"""
async with async_session() as session:
subscriber = await session.get(Subscriber, user_id)
await session.delete(subscriber)
await session.commit()
subscriber = await self.session.get(Subscriber, user_id)
await self.session.delete(subscriber)
await self.session.commit()

async def of_time(self, mailing_time):
"""Все подписчики с данным временем рассылки"""
async with async_session() as session:
statement = select(Subscriber).where(
Subscriber.mailing_time == mailing_time
)
subscribers = await session.execute(statement)
return subscribers.scalars().all()
statement = select(Subscriber).where(
Subscriber.mailing_time == mailing_time
)
subscribers = await self.session.execute(statement)
return subscribers.scalars().all()

async def exists(self, user_id):
"""Проверяем наличие пользователя в подписке"""
async with async_session() as session:
subscriber = await session.get(Subscriber, user_id)
return subscriber is not None
subscriber = await self.session.get(Subscriber, user_id)
return subscriber is not None

async def time(self, user_id):
"""Время рассылки данного подписчика"""
async with async_session() as session:
subscriber = await session.get(Subscriber, user_id)
return subscriber.mailing_time
subscriber = await self.session.get(Subscriber, user_id)
return subscriber.mailing_time


async def create_db():
Expand Down
56 changes: 32 additions & 24 deletions app/weather.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""API погоды"""

import time
from urllib.parse import urlencode

import aiohttp
from async_lru import alru_cache

from app.forecasts import CurrentForecast, DailyForecast, HourlyForecast
Expand All @@ -13,17 +11,21 @@
class OwmWeather:
"""Погода с OpenWeatherMap"""

def __init__(self, url, cache_time):
self.url = url
self.cache_time = cache_time
def __init__(self, api):
self.api = api

@classmethod
def default(cls, url):
def from_url(cls, url, session, cache_time):
"""Со ссылкой и временем жизни кеша"""
return cls(OwmApi(url, session, cache_time))

@classmethod
def default(cls, url, session):
"""Со значением времени жизни кеша по умолчанию"""
return cls(url, cache_time=300)
return cls.from_url(url, session, cache_time=300)

@classmethod
def from_geo(cls, lat, lon, api_key):
def from_geo(cls, lat, lon, api_key, session):
"""Для конкретного места по координатам"""
url = "https://api.openweathermap.org/data/2.5/onecall?" + urlencode({
"lat": lat,
Expand All @@ -33,43 +35,41 @@ def from_geo(cls, lat, lon, api_key):
"exclude": "minutely",
"lang": "ru",
})
return cls.default(url)
return cls.default(url, session)

@classmethod
def for_che(cls, api_key):
def for_che(cls, api_key, session):
"""Для Череповца"""
return cls.from_geo(lat=59.09, lon=37.91, api_key=api_key)

async def weather(self):
"""Кешированная погода с OpenWeatherMap"""
return await _get_weather(self.url, time.time() // self.cache_time)
return cls.from_geo(
lat=59.09, lon=37.91, api_key=api_key, session=session
)

async def current(self):
"""Текущая погода"""
weather = await self.weather()
weather = await self.api()
return CurrentForecast(weather.current, weather.alerts)

async def hourly(self, timestamp):
"""Прогноз на час"""
weather = await self.weather()
weather = await self.api()
forecast = _next(weather.hourly, timestamp)
return HourlyForecast(forecast, weather.alerts)

async def exact_hour(self, hour):
"""Прогноз на конкретный час"""
weather = await self.weather()
weather = await self.api()
forecast = _exact_hour(weather.hourly, hour)
return HourlyForecast(forecast, weather.alerts)

async def daily(self, timestamp):
"""Прогноз на день"""
weather = await self.weather()
weather = await self.api()
forecast = _next(weather.daily, timestamp)
return DailyForecast(forecast, weather.alerts)

async def exact_day(self, day):
"""Получение прогноза в конкретный день"""
weather = await self.weather()
weather = await self.api()
forecast = _exact_day(weather.daily, day)
return DailyForecast(forecast, weather.alerts)

Expand Down Expand Up @@ -101,10 +101,18 @@ def _exact_day(forecasts, day):
return forecast


@alru_cache(maxsize=1)
async def _get_weather(url, time):
"""Кешированный прогноз погоды в виде WeatherResponse"""
async with aiohttp.ClientSession() as session:
def OwmApi(url, session, cache_time):
"""OpenWeatherMap API.
Функция, а не класс, потому что не понимаю,
как красиво через класс сделать кеширование
"""

@alru_cache(maxsize=1, ttl=cache_time)
async def request():
"""Кешированный прогноз погоды в виде WeatherResponse"""
async with session.get(url) as response:
data = await response.json()
return WeatherResponse(**data)

return request
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ aiogram==3.0.0b6
aiohttp==3.8.4
aiosignal==1.3.1
aiosqlite==0.17.0
async-lru==1.0.3
async-lru==2.0.3
async-timeout==4.0.2
attrs==23.1.0
charset-normalizer==3.2.0
Expand Down
24 changes: 13 additions & 11 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import datetime as dt

import pytest
import pytest_asyncio

from app.db import Subscribers, create_db
from app.db import Subscribers, async_session, create_db


mailing_time = dt.time(hour=18, minute=45)
Expand All @@ -16,16 +17,17 @@ def event_loop():
loop.close()


@pytest.fixture(scope="module", autouse=True)
async def init_db():
@pytest_asyncio.fixture
async def session():
"""Инициализация ДБ для тестов этого модуля"""
await create_db()
yield
async with async_session() as session:
await create_db()
yield session


@pytest.mark.asyncio
async def test_add():
db = Subscribers()
async def test_add(session):
db = Subscribers(session)

await db.add(user_id=0, mailing_time=mailing_time)

Expand All @@ -37,8 +39,8 @@ async def test_add():


@pytest.mark.asyncio
async def test_change_subscriber_time():
db = Subscribers()
async def test_change_subscriber_time(session):
db = Subscribers(session)

await db.add(user_id=0, mailing_time=mailing_time)
assert await db.time(user_id=0) == mailing_time
Expand All @@ -51,8 +53,8 @@ async def test_change_subscriber_time():


@pytest.mark.asyncio
async def test_delete():
db = Subscribers()
async def test_delete(session):
db = Subscribers(session)

before = await db.of_time(mailing_time)
await db.add(user_id=0, mailing_time=mailing_time)
Expand Down
Loading

0 comments on commit b129169

Please sign in to comment.