Skip to content

Commit

Permalink
refactor: began work on sqlalchemy models
Browse files Browse the repository at this point in the history
- added sqlalchemy, began work on the models
- added alembic because that will be used for future database migrations
- added arrow because that will replace whatever parsedatetime is in the future lol
- fixed "config.yml" -> "config.toml" in error message in config.py
- removed unused import in reminder.py
- adjusted verbose docstrings in bot.py to simpler ones
  • Loading branch information
Snaacky committed Dec 22, 2024
1 parent 254fbc9 commit 323081f
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 115 deletions.
8 changes: 2 additions & 6 deletions chiya/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@

@bot.event
async def on_ready() -> None:
"""
Called when the client is done preparing the data received from Discord.
"""
"Called when the client is done preparing the data received from Discord."
log.info(f"Logged in as: {str(bot.user)}")
await bot.tree.sync(guild=discord.Object(config.guild_id))

Expand All @@ -43,9 +41,7 @@ async def setup_logger():
log.remove()

class InterceptHandler(logging.Handler):
"""
Setup up an Interceptor class to redirect all logs from the standard logging library to loguru.
"""
"Setup up an Interceptor class to redirect all logs from the standard logging library to loguru."

def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists.
Expand Down
2 changes: 0 additions & 2 deletions chiya/cogs/commands/reminder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import asyncio

import discord
from discord.ext import commands
from discord import app_commands
Expand Down
3 changes: 1 addition & 2 deletions chiya/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ class ChiyaConfig(ParentModel):
workspace = Path(__file__).parent.parent
config_file = workspace / "config.toml"


if not config_file.is_file():
raise FileNotFoundError("Unable to load config.yml, exiting...")
raise FileNotFoundError("Unable to load config.toml, exiting...")

config = ChiyaConfig.model_validate(tomllib.load(config_file.open("rb")))
182 changes: 78 additions & 104 deletions chiya/database.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,82 @@
import dataset
from loguru import logger as log
from sqlalchemy import create_engine
from sqlalchemy_utils import database_exists, create_database
from sqlalchemy import create_engine, BigInteger, Boolean, Column, Integer, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

from chiya.config import config


class Database:
def __init__(self) -> None:
host = config.database.host
database = config.database.database
user = config.database.user
password = config.database.password

if not all([host, database, user, password]):
log.error("One or more database connection variables are missing, exiting...")
raise SystemExit

# self.url = f"mysql://{user}:{password}@{host}/{database}?charset=utf8mb4"
self.url = config.database.url

def get(self) -> dataset.Database:
"""Returns the dataset database object."""
return dataset.connect(url=self.url)

def setup(self) -> None:
"""Sets up the tables needed for Chiya."""
engine = create_engine(self.url)
if not database_exists(engine.url):
create_database(engine.url)

db = self.get()

if "mod_logs" not in db:
mod_logs = db.create_table("mod_logs")
mod_logs.create_column("user_id", db.types.bigint)
mod_logs.create_column("mod_id", db.types.bigint)
mod_logs.create_column("timestamp", db.types.bigint)
mod_logs.create_column("reason", db.types.text)
mod_logs.create_column("duration", db.types.text)
mod_logs.create_column("type", db.types.text)
log.info("Created missing table: mod_logs")

if "remind_me" not in db:
remind_me = db.create_table("remind_me")
remind_me.create_column("reminder_location", db.types.bigint)
remind_me.create_column("author_id", db.types.bigint)
remind_me.create_column("date_to_remind", db.types.bigint)
remind_me.create_column("message", db.types.text)
remind_me.create_column("sent", db.types.boolean, default=False)
log.info("Created missing table: remind_me")

if "timed_mod_actions" not in db:
timed_mod_actions = db.create_table("timed_mod_actions")
timed_mod_actions.create_column("user_id", db.types.bigint)
timed_mod_actions.create_column("mod_id", db.types.bigint)
timed_mod_actions.create_column("action_type", db.types.text)
timed_mod_actions.create_column("start_time", db.types.bigint)
timed_mod_actions.create_column("end_time", db.types.bigint)
timed_mod_actions.create_column("is_done", db.types.boolean, default=False)
timed_mod_actions.create_column("reason", db.types.text)
log.info("Created missing table: timed_mod_actions")

if "tickets" not in db:
tickets = db.create_table("tickets")
tickets.create_column("user_id", db.types.bigint)
tickets.create_column("guild", db.types.bigint)
tickets.create_column("timestamp", db.types.bigint)
tickets.create_column("ticket_subject", db.types.text)
tickets.create_column("ticket_message", db.types.text)
tickets.create_column("log_url", db.types.text)
tickets.create_column("status", db.types.boolean)
log.info("Created missing table: tickets")

if "starboard" not in db:
starboard = db.create_table("starboard")
starboard.create_column("channel_id", db.types.bigint)
starboard.create_column("message_id", db.types.bigint)
starboard.create_column("star_embed_id", db.types.bigint)
log.info("Created missing table: starboard")

if "joyboard" not in db:
joyboard = db.create_table("joyboard")
joyboard.create_column("channel_id", db.types.bigint)
joyboard.create_column("message_id", db.types.bigint)
joyboard.create_column("joy_embed_id", db.types.bigint)
log.info("Created missing table: joyboard")

if "highlights" not in db:
highlights = db.create_table("highlights")
highlights.create_column("term", db.types.text)
highlights.create_column("users", db.types.text)
log.info("Created missing table: highlights")

# utf8mb4_unicode_ci is required to support emojis and other unicode.
# dataset does not expose collation in any capacity so rather than
# checking an object property, we have to do this hacky way of checking
# the charset via queries and updating it where necessary.
# for table in db.tables:
# charset = next(db.query(f"SHOW TABLE STATUS WHERE NAME = '{table}';"))["Collation"]
# if charset == "utf8mb4_unicode_ci":
# continue
# db.query(f"ALTER TABLE {table} CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;")
# log.info(f"Converted table to utf8mb4_unicode_ci: {table}")

db.commit()
db.close()
Base = declarative_base()
engine = create_engine(config.database.url, connect_args={"check_same_thread": False})
session = sessionmaker(autocommit=False, autoflush=False, bind=engine)


class BaseModel(Base):
__abstract__ = True

def save(self):
session.add(self)
session.commit()
return self

def delete(self):
session.delete(self)
session.commit()
return self

def flush(self):
session.add(self)
session.flush()
return self


class ModLog(Base):
__tablename__ = "mod_logs"

id = Column(Integer, primar_key=True)
user_id = Column(BigInteger, nullable=False)
mod_id = Column(BigInteger, nullable=False)
timestamp = Column(BigInteger, nullable=False)
reason = Column(Text, nullable=False)
duration = Column(Text, nullable=False)
type = Column(Text, nullable=False)


class RemindMe(Base):
__tablename__ = "remind_me"

id = Column(Integer, primar_key=True)
reminder_location = Column(BigInteger, nullable=False)
author_id = Column(BigInteger, nullable=False)
date_to_remind = Column(BigInteger, nullable=False)
message = Column(Text, nullable=False)
sent = Column(Boolean, nullable=False, default=False)


class Ticket(Base):
__tablename__ = "tickets"

id = Column(Integer, primary_key=True)
user_id = Column(BigInteger, nullable=False)
guild = Column(BigInteger, nullable=False)
timestamp = Column(BigInteger, nullable=False)
ticket_subject = Column(Text, nullable=False)
ticket_message = Column(Text, nullable=False)
log_url = Column(Text, nullable=False)
status = Column(Boolean)


class Joyboard(Base):
__tablename__ = "joyboard"

id = Column(Integer, primary_key=True)
channel_id = Column(BigInteger, nullable=False)
message_id = Column(BigInteger, nullable=False)
joy_embed_id = Column(BigInteger, nullable=False)


class Highlight(Base):
__tablename__ = "highlights"

id = Column(Integer, primary_key=True)
term = Column(Text, nullable=False)
users = Column(Text, nullable=False)
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ description = "A moderation-heavy general purpose Discord bot"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"alembic>=1.14.0",
"arrow>=1.3.0",
"dataset==1.6.2",
"discord-ext-menus>=1.1",
"discord-py[speed]==2.4.0",
Expand All @@ -14,6 +16,7 @@ dependencies = [
"privatebinapi==1.0.0",
"pydantic>=2.10.4",
"requests>=2.32.3",
"sqlalchemy>=1.4.54",
"sqlalchemy-utils>=0.41.2",
]

Expand All @@ -24,4 +27,4 @@ dev = [
]

[tool.ruff]
line-length = 120
line-length = 120
49 changes: 49 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 323081f

Please sign in to comment.