-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: began work on sqlalchemy models
- added sqlalchemy, began work on the models - added alembic because that will be used for future database migrations - added arrow because that will replace whatever parsedatetime is in the future lol - fixed "config.yml" -> "config.toml" in error message in config.py - removed unused import in reminder.py - adjusted verbose docstrings in bot.py to simpler ones
- Loading branch information
Showing
6 changed files
with
134 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,3 @@ | ||
import asyncio | ||
|
||
import discord | ||
from discord.ext import commands | ||
from discord import app_commands | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,108 +1,82 @@ | ||
import dataset | ||
from loguru import logger as log | ||
from sqlalchemy import create_engine | ||
from sqlalchemy_utils import database_exists, create_database | ||
from sqlalchemy import create_engine, BigInteger, Boolean, Column, Integer, Text | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from sqlalchemy.orm import sessionmaker | ||
|
||
from chiya.config import config | ||
|
||
|
||
class Database: | ||
def __init__(self) -> None: | ||
host = config.database.host | ||
database = config.database.database | ||
user = config.database.user | ||
password = config.database.password | ||
|
||
if not all([host, database, user, password]): | ||
log.error("One or more database connection variables are missing, exiting...") | ||
raise SystemExit | ||
|
||
# self.url = f"mysql://{user}:{password}@{host}/{database}?charset=utf8mb4" | ||
self.url = config.database.url | ||
|
||
def get(self) -> dataset.Database: | ||
"""Returns the dataset database object.""" | ||
return dataset.connect(url=self.url) | ||
|
||
def setup(self) -> None: | ||
"""Sets up the tables needed for Chiya.""" | ||
engine = create_engine(self.url) | ||
if not database_exists(engine.url): | ||
create_database(engine.url) | ||
|
||
db = self.get() | ||
|
||
if "mod_logs" not in db: | ||
mod_logs = db.create_table("mod_logs") | ||
mod_logs.create_column("user_id", db.types.bigint) | ||
mod_logs.create_column("mod_id", db.types.bigint) | ||
mod_logs.create_column("timestamp", db.types.bigint) | ||
mod_logs.create_column("reason", db.types.text) | ||
mod_logs.create_column("duration", db.types.text) | ||
mod_logs.create_column("type", db.types.text) | ||
log.info("Created missing table: mod_logs") | ||
|
||
if "remind_me" not in db: | ||
remind_me = db.create_table("remind_me") | ||
remind_me.create_column("reminder_location", db.types.bigint) | ||
remind_me.create_column("author_id", db.types.bigint) | ||
remind_me.create_column("date_to_remind", db.types.bigint) | ||
remind_me.create_column("message", db.types.text) | ||
remind_me.create_column("sent", db.types.boolean, default=False) | ||
log.info("Created missing table: remind_me") | ||
|
||
if "timed_mod_actions" not in db: | ||
timed_mod_actions = db.create_table("timed_mod_actions") | ||
timed_mod_actions.create_column("user_id", db.types.bigint) | ||
timed_mod_actions.create_column("mod_id", db.types.bigint) | ||
timed_mod_actions.create_column("action_type", db.types.text) | ||
timed_mod_actions.create_column("start_time", db.types.bigint) | ||
timed_mod_actions.create_column("end_time", db.types.bigint) | ||
timed_mod_actions.create_column("is_done", db.types.boolean, default=False) | ||
timed_mod_actions.create_column("reason", db.types.text) | ||
log.info("Created missing table: timed_mod_actions") | ||
|
||
if "tickets" not in db: | ||
tickets = db.create_table("tickets") | ||
tickets.create_column("user_id", db.types.bigint) | ||
tickets.create_column("guild", db.types.bigint) | ||
tickets.create_column("timestamp", db.types.bigint) | ||
tickets.create_column("ticket_subject", db.types.text) | ||
tickets.create_column("ticket_message", db.types.text) | ||
tickets.create_column("log_url", db.types.text) | ||
tickets.create_column("status", db.types.boolean) | ||
log.info("Created missing table: tickets") | ||
|
||
if "starboard" not in db: | ||
starboard = db.create_table("starboard") | ||
starboard.create_column("channel_id", db.types.bigint) | ||
starboard.create_column("message_id", db.types.bigint) | ||
starboard.create_column("star_embed_id", db.types.bigint) | ||
log.info("Created missing table: starboard") | ||
|
||
if "joyboard" not in db: | ||
joyboard = db.create_table("joyboard") | ||
joyboard.create_column("channel_id", db.types.bigint) | ||
joyboard.create_column("message_id", db.types.bigint) | ||
joyboard.create_column("joy_embed_id", db.types.bigint) | ||
log.info("Created missing table: joyboard") | ||
|
||
if "highlights" not in db: | ||
highlights = db.create_table("highlights") | ||
highlights.create_column("term", db.types.text) | ||
highlights.create_column("users", db.types.text) | ||
log.info("Created missing table: highlights") | ||
|
||
# utf8mb4_unicode_ci is required to support emojis and other unicode. | ||
# dataset does not expose collation in any capacity so rather than | ||
# checking an object property, we have to do this hacky way of checking | ||
# the charset via queries and updating it where necessary. | ||
# for table in db.tables: | ||
# charset = next(db.query(f"SHOW TABLE STATUS WHERE NAME = '{table}';"))["Collation"] | ||
# if charset == "utf8mb4_unicode_ci": | ||
# continue | ||
# db.query(f"ALTER TABLE {table} CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;") | ||
# log.info(f"Converted table to utf8mb4_unicode_ci: {table}") | ||
|
||
db.commit() | ||
db.close() | ||
Base = declarative_base() | ||
engine = create_engine(config.database.url, connect_args={"check_same_thread": False}) | ||
session = sessionmaker(autocommit=False, autoflush=False, bind=engine) | ||
|
||
|
||
class BaseModel(Base): | ||
__abstract__ = True | ||
|
||
def save(self): | ||
session.add(self) | ||
session.commit() | ||
return self | ||
|
||
def delete(self): | ||
session.delete(self) | ||
session.commit() | ||
return self | ||
|
||
def flush(self): | ||
session.add(self) | ||
session.flush() | ||
return self | ||
|
||
|
||
class ModLog(Base): | ||
__tablename__ = "mod_logs" | ||
|
||
id = Column(Integer, primar_key=True) | ||
user_id = Column(BigInteger, nullable=False) | ||
mod_id = Column(BigInteger, nullable=False) | ||
timestamp = Column(BigInteger, nullable=False) | ||
reason = Column(Text, nullable=False) | ||
duration = Column(Text, nullable=False) | ||
type = Column(Text, nullable=False) | ||
|
||
|
||
class RemindMe(Base): | ||
__tablename__ = "remind_me" | ||
|
||
id = Column(Integer, primar_key=True) | ||
reminder_location = Column(BigInteger, nullable=False) | ||
author_id = Column(BigInteger, nullable=False) | ||
date_to_remind = Column(BigInteger, nullable=False) | ||
message = Column(Text, nullable=False) | ||
sent = Column(Boolean, nullable=False, default=False) | ||
|
||
|
||
class Ticket(Base): | ||
__tablename__ = "tickets" | ||
|
||
id = Column(Integer, primary_key=True) | ||
user_id = Column(BigInteger, nullable=False) | ||
guild = Column(BigInteger, nullable=False) | ||
timestamp = Column(BigInteger, nullable=False) | ||
ticket_subject = Column(Text, nullable=False) | ||
ticket_message = Column(Text, nullable=False) | ||
log_url = Column(Text, nullable=False) | ||
status = Column(Boolean) | ||
|
||
|
||
class Joyboard(Base): | ||
__tablename__ = "joyboard" | ||
|
||
id = Column(Integer, primary_key=True) | ||
channel_id = Column(BigInteger, nullable=False) | ||
message_id = Column(BigInteger, nullable=False) | ||
joy_embed_id = Column(BigInteger, nullable=False) | ||
|
||
|
||
class Highlight(Base): | ||
__tablename__ = "highlights" | ||
|
||
id = Column(Integer, primary_key=True) | ||
term = Column(Text, nullable=False) | ||
users = Column(Text, nullable=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.