Skip to content

Commit

Permalink
Add Oauth flow for GitHub, remove reports modal, use oauth token for …
Browse files Browse the repository at this point in the history
…inviting members
  • Loading branch information
cbrxyz committed Aug 23, 2024
1 parent 63ef4b2 commit d12b85e
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 36 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ icalendar==5.0.11
recurring-ical-events==2.1.2
aiosmtplib==3.0.1
better-ipc==2.0.3
sqlalchemy==2.0.22
aiosqlite==0.20.0
9 changes: 9 additions & 0 deletions src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
from google.auth import crypt
from google.oauth2.service_account import Credentials
from rich.logging import RichHandler
from sqlalchemy.ext.asyncio import create_async_engine

from .anonymous import AnonymousReportView
from .calendar import CalendarView
from .constants import Team
from .db import Base, DatabaseFactory
from .env import (
DATABASE_ENGINE_URL,
DISCORD_TOKEN,
GITHUB_TOKEN,
GSPREAD_PRIVATE_KEY,
Expand Down Expand Up @@ -125,6 +128,7 @@ class MILBot(commands.Bot):
tasks: TaskManager
github: GitHub
verifier: Verifier
db_factory: DatabaseFactory

def __init__(self):
super().__init__(
Expand All @@ -143,11 +147,16 @@ async def on_ready(self):
if not self.change_status.is_running():
self.change_status.start()
self.tasks.start()
engine = create_async_engine(DATABASE_ENGINE_URL)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
self.db_factory = DatabaseFactory(bot=self, engine=engine)
await self.fetch_vars()
await self.reports_cog.update_report_channel.run_immediately()

async def close(self):
await self.session.close()
await self.db_factory.close()
await super().close()

@ext_tasks.loop(hours=1)
Expand Down
83 changes: 83 additions & 0 deletions src/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from sqlalchemy import (
BigInteger,
String,
select,
)
from sqlalchemy.ext.asyncio import (
AsyncAttrs,
AsyncEngine,
AsyncSession,
)
from sqlalchemy.orm import DeclarativeBase, mapped_column

if TYPE_CHECKING:
from .bot import MILBot


logger = logging.getLogger(__name__)


class Base(AsyncAttrs, DeclarativeBase):
pass


class GitHubOauthMember(Base):
__tablename__ = "github_oauth_member"

discord_id = mapped_column(BigInteger, primary_key=True)
device_code = mapped_column(String, nullable=False)
access_token = mapped_column(String, nullable=True)


class Database(AsyncSession):
def __init__(self, *, bot: MILBot, engine: AsyncEngine):
self.bot = bot
self.engine = engine
super().__init__(bind=engine, expire_on_commit=False)

async def __aenter__(self) -> Database:
return self

async def __aexit__(self, *args) -> None:
await self.close()

async def add_github_oauth_member(
self,
discord_id: int,
device_code: str,
access_token: str,
):
member = GitHubOauthMember(
discord_id=discord_id,
device_code=device_code,
access_token=access_token,
)
await self.merge(member)
await self.commit()

async def get_github_oauth_member(
self,
discord_id: int,
) -> GitHubOauthMember | None:
result = await self.execute(
select(GitHubOauthMember).where(GitHubOauthMember.discord_id == discord_id),
)
response = result.scalars().first()
return response


class DatabaseFactory:
def __init__(self, *, engine: AsyncEngine, bot: MILBot):
self.engine = engine
self.bot = bot

def __call__(self) -> Database:
return Database(bot=self.bot, engine=self.engine)

async def close(self):
await self.engine.dispose()
4 changes: 4 additions & 0 deletions src/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def ensure_string(name: str, optional: bool = False) -> str | None:
GITHUB_TOKEN = ensure_string("GITHUB_TOKEN")
LEADERS_MEETING_NOTES_URL = ensure_string("MEETING_NOTES_URL")
LEADERS_MEETING_URL = ensure_string("MEETING_URL")
DATABASE_ENGINE_URL = ensure_string("DATABASE_ENGINE_URL")

# Calendars
GENERAL_CALENDAR = ensure_string("GENERAL_CALENDAR", True)
Expand All @@ -51,5 +52,8 @@ def ensure_string(name: str, optional: bool = False) -> str | None:
# Email
EMAIL_USERNAME = ensure_string("EMAIL_USERNAME", True)
EMAIL_PASSWORD = ensure_string("EMAIL_PASSWORD", True)

WEBHOOK_SERVER_PORT = ensure_string("WEBHOOK_SERVER_PORT", True)
GITHUB_OAUTH_CLIENT_ID = ensure_string("GITHUB_OAUTH_CLIENT_ID", True)
GITHUB_OAUTH_CLIENT_SECRET = ensure_string("GITHUB_OAUTH_CLIENT_SECRET", True)
IPC_PORT = ensure_string("IPC_PORT", True)
63 changes: 59 additions & 4 deletions src/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import discord
from discord.ext import commands

from .env import GITHUB_TOKEN
from .env import GITHUB_OAUTH_CLIENT_ID, GITHUB_TOKEN
from .github_types import (
Branch,
CheckRunsData,
Expand Down Expand Up @@ -65,6 +65,13 @@ async def on_submit(self, interaction: discord.Interaction):
)
raise e

async with self.bot.db_factory() as db:
oauth_user = await db.get_github_oauth_member(interaction.user.id)
if not oauth_user:
return await interaction.response.send_message(
f"You have not connected your GitHub account. Please connect your account first in {self.bot.member_services_channel.mention}!",
ephemeral=True,
)
try:
# If the org is uf-mil, invite to the "Developers" team
if self.org_name == "uf-mil":
Expand All @@ -73,20 +80,30 @@ async def on_submit(self, interaction: discord.Interaction):
user["id"],
self.org_name,
team["id"],
oauth_user.access_token,
)
else:
await self.bot.github.invite_user_to_org(user["id"], self.org_name)
await self.bot.github.invite_user_to_org(
user["id"],
self.org_name,
user_access_token=oauth_user.access_token,
)
await interaction.response.send_message(
f"Successfully invited {username} to {self.org_name}.",
ephemeral=True,
)
except aiohttp.ClientResponseError as e:
if e.status == 403:
await interaction.response.send_message(
"Your GitHub account does not have the necessary permissions to invite users to the organization.",
ephemeral=True,
)
if e.status == 422:
await interaction.response.send_message(
"Validaton failed, the user might already be in the organization.",
ephemeral=True,
)
return
return
except Exception:
await interaction.response.send_message(
f"Failed to invite {username} to {self.org_name}.",
Expand Down Expand Up @@ -155,14 +172,15 @@ async def fetch(
method: Literal["GET", "POST"] = "GET",
extra_headers: dict[str, str] | None = None,
data: dict[str, Any] | str | None = None,
user_access_token: str | None = None,
):
"""
Fetches a URL with the given method and headers.
Raises ClientResponseError if the response status is not 2xx.
"""
headers = {
"Authorization": f"Bearer {self.auth_token}",
"Authorization": f"Bearer {user_access_token or self.auth_token}",
}
if extra_headers:
headers.update(extra_headers)
Expand All @@ -179,6 +197,41 @@ async def fetch(
response.raise_for_status()
return await response.json()

async def get_oauth_device_code(self) -> dict[str, Any]:
url = "https://github.com/login/device/code"
extra_headers = {
"Accept": "application/json",
}
data = {
"client_id": GITHUB_OAUTH_CLIENT_ID,
"scope": "repo",
}
response = await self.fetch(
url,
method="POST",
extra_headers=extra_headers,
data=data,
)
return response

async def get_oauth_access_token(self, device_code: str) -> dict[str, str]:
url = "https://github.com/login/oauth/access_token"
extra_headers = {
"Accept": "application/json",
}
data = {
"client_id": GITHUB_OAUTH_CLIENT_ID,
"device_code": device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
}
response = await self.fetch(
url,
method="POST",
extra_headers=extra_headers,
data=data,
)
return response

async def get_repo(self, repo_name: str) -> Repository:
url = f"https://api.github.com/repos/{repo_name}"
return await self.fetch(url)
Expand Down Expand Up @@ -305,6 +358,7 @@ async def invite_user_to_org(
user_id: int,
org_name: str,
team_id: int | None = None,
user_access_token: str | None = None,
) -> Invitation:
url = f"https://api.github.com/orgs/{org_name}/invitations"
extra_headers = {
Expand All @@ -321,6 +375,7 @@ async def invite_user_to_org(
method="POST",
extra_headers=extra_headers,
data=str_data,
user_access_token=user_access_token,
)

async def get_team(self, org_name: str, team_name: str) -> OrganizationTeam:
Expand Down
94 changes: 62 additions & 32 deletions src/reports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import calendar
import datetime
import itertools
Expand Down Expand Up @@ -532,45 +533,74 @@ async def on_submit(self, interaction: discord.Interaction) -> None:
)


class SubmitButton(discord.ui.Button):
class OauthSetupButton(discord.ui.Button):
def __init__(self, bot: MILBot):
self.bot = bot
if not is_active():
super().__init__(
label="Reports are not currently active.",
style=discord.ButtonStyle.grey,
disabled=True,
custom_id="reports_view:submit",
)
elif datetime.datetime.today().weekday() in [calendar.MONDAY, calendar.TUESDAY]:
super().__init__(
label="Reports can only be submitted between Wednesday and Sunday.",
style=discord.ButtonStyle.red,
disabled=True,
custom_id="reports_view:submit",
)
else:
super().__init__(
label="Submit your report!",
style=discord.ButtonStyle.green,
custom_id="reports_view:submit",
)
super().__init__(
label="Connect/Re-connect your GitHub account",
style=discord.ButtonStyle.green,
custom_id="reports_view:oauth_connect",
)

async def callback(self, interaction: discord.Interaction):
# If button is triggered on Monday or Tuesday, send error message
if datetime.datetime.today().weekday() in [calendar.MONDAY, calendar.TUESDAY]:
return await interaction.response.send_message(
":x: Weekly reports should be submitted between Wednesday and Sunday. While occasional exceptions can be made if you miss a week—simply inform your team lead—this should not become a regular occurrence. Be aware that the submission window closes promptly at 11:59pm on Sunday.",
assert isinstance(interaction.user, discord.Member)
if (
self.bot.egn4912_role not in interaction.user.roles
and self.bot.leaders_role not in interaction.user.roles
):
await interaction.response.send_message(
"❌ You must be an active member of EGN4912 to connect your GitHub account.",
ephemeral=True,
)
return

if not is_active():
return await interaction.response.send_message(
"❌ The weekly reports system is currently inactive due to the interim period between semesters. Please wait until next semester to document any work you have completed in between semesters. Thank you!",
ephemeral=True,
device_code_response = await self.bot.github.get_oauth_device_code()
code, device_code = (
device_code_response["user_code"],
device_code_response["device_code"],
)
button = discord.ui.Button(
label="Authorize GitHub",
url="https://github.com/login/device",
)
view = MILBotView()
view.add_item(button)
expires_in_dt = datetime.datetime.now() + datetime.timedelta(
seconds=device_code_response["expires_in"],
)
await interaction.response.send_message(
f"To authorize your GitHub account, please visit the link below using the button and enter the following code:\n`{code}`\n\n* Please note that it may take a few seconds after authorizing in your browser to appear in Discord, due to GitHub limitations.\n* This authorization attempt will expire {discord.utils.format_dt(expires_in_dt, 'R')}.",
view=view,
ephemeral=True,
)
access_token = None
while not access_token and datetime.datetime.now() < expires_in_dt:
await asyncio.sleep(device_code_response["interval"])
resp = await self.bot.github.get_oauth_access_token(device_code)
if "access_token" in resp:
access_token = resp["access_token"]
if "error" in resp and resp["error"] == "access_denied":
await interaction.edit_original_response(
content="❌ Authorization was denied (did you hit cancel?). Please try again.",
view=None,
)
return
if access_token:
async with self.bot.db_factory() as db:
await db.add_github_oauth_member(
interaction.user.id,
device_code,
access_token,
)
await interaction.edit_original_response(
content="Thanks! Your GitHub account has been successfully connected.",
view=None,
)
else:
await interaction.edit_original_response(
content="❌ Authorization expired. Please try again.",
view=None,
)
# Send modal where user fills out report
await interaction.response.send_modal(ReportsModal(self.bot))


class ReportHistoryButton(discord.ui.Button):
Expand Down Expand Up @@ -672,7 +702,7 @@ class ReportsView(MILBotView):
def __init__(self, bot: MILBot):
self.bot = bot
super().__init__(timeout=None)
self.add_item(SubmitButton(bot))
self.add_item(OauthSetupButton(bot))
self.add_item(ReportHistoryButton(bot))


Expand Down

0 comments on commit d12b85e

Please sign in to comment.