Skip to content

Commit

Permalink
upgrade to async serenity and partially migrate to sqlx
Browse files Browse the repository at this point in the history
  • Loading branch information
technetos committed Jan 7, 2022
1 parent 220d6c5 commit 1ec8819
Show file tree
Hide file tree
Showing 15 changed files with 1,165 additions and 893 deletions.
1,056 changes: 619 additions & 437 deletions Cargo.lock

Large diffs are not rendered by default.

30 changes: 26 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,37 @@ license = "MIT"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
serenity = { version = "0.10.8", features = ["model"] }
reqwest = { version = "0.10", features = ["json"] }
futures = { version = "0.3" }
reqwest = { version = "0.11", features = ["json"] }
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3" }
diesel = { version = "1.4.0", features = ["postgres", "r2d2"] }
diesel_migrations = { version = "1.4.0", features = ["postgres"] }
serde = "1.0"
serde_derive = "1.0"
lazy_static = "1.4.0"
log = "0.4.0"
env_logger = "0.7.1"
envy = "0.4"
indexmap = "1.6"


[dependencies.sqlx]
features = [
"runtime-tokio-native-tls",
"postgres",
"chrono",
]
version = "0.5"

[dependencies.serenity]
default-features = false
features = [
"builder",
"cache",
"client",
"gateway",
"model",
"utils",
"rustls_backend",
]
version = "0.10"
54 changes: 41 additions & 13 deletions src/api.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{command_history::CommandHistory, commands::Args, db::DB, schema::roles, Error};
use diesel::prelude::*;
use crate::{command::Auth, command_history::CommandHistory, commands::Args, Error};
use indexmap::IndexMap;
use serenity::{model::prelude::*, utils::parse_username};
use std::sync::Arc;
use tracing::info;

/// Send a reply to the channel the message was received on.
pub async fn send_reply(args: Arc<Args>, message: &str) -> Result<(), Error> {
Expand Down Expand Up @@ -53,31 +54,58 @@ fn check_permission(args: Arc<Args>, role: Option<String>) -> Result<bool, Error
}

/// Return whether or not the user is a mod.
pub fn is_mod(args: Arc<Args>) -> Result<bool, Error> {
let role = roles::table
.filter(roles::name.eq("mod"))
.first::<(i32, String, String)>(&DB.get()?)
.optional()?;
pub async fn is_mod(args: Arc<Args>) -> Result<bool, Error> {
let role: Option<(i32, String, String)> =
sqlx::query_as("select * from roles where name = 'mod'")
.fetch_optional(&*args.db)
.await?;

check_permission(args.clone(), role.map(|(_, role_id, _)| role_id))
}

pub async fn is_wg_and_teams(args: Arc<Args>) -> Result<bool, Error> {
let role = roles::table
.filter(roles::name.eq("wg_and_teams"))
.first::<(i32, String, String)>(&DB.get()?)
.optional()?;
let role: Option<(i32, String, String)> =
sqlx::query_as("select * from roles where name = 'wg_and_teams'")
.fetch_optional(&*args.db)
.await?;

check_permission(args.clone(), role.map(|(_, role_id, _)| role_id))
}

pub async fn main_menu(
args: Arc<Args>,
commands: &IndexMap<&'static str, (&'static str, &'static Auth)>,
) -> String {
use futures::stream::{self, StreamExt};

let mut menu = format!("Commands:\n");

menu = stream::iter(commands)
.fold(menu, |mut menu, (base_cmd, (description, auth))| {
let args_clone = args.clone();
async move {
if let Ok(true) = auth.call(args_clone).await {
menu += &format!("\t{cmd:<12}{desc}\n", cmd = base_cmd, desc = description);
}
menu
}
})
.await;

menu += &format!("\t{help:<12}This menu\n", help = "?help");
menu += "\nType ?help command for more info on a command.";
menu += "\n\nAdditional Info:\n";
menu += "\tYou can edit your message to the bot and the bot will edit its response.";
menu
}

/// Set slow mode for a channel.
///
/// A `seconds` value of 0 will disable slowmode
pub async fn slow_mode(args: Arc<Args>) -> Result<(), Error> {
use std::str::FromStr;

if is_mod(args.clone())? {
if is_mod(args.clone()).await? {
let seconds = &args
.params
.get("seconds")
Expand Down Expand Up @@ -122,7 +150,7 @@ will disable slowmode on the `#bot-usage` channel.";
///
/// Requires the kick members permission
pub async fn kick(args: Arc<Args>) -> Result<(), Error> {
if is_mod(args.clone())? {
if is_mod(args.clone()).await? {
let user_id = parse_username(
&args
.params
Expand Down
88 changes: 48 additions & 40 deletions src/ban.rs
Original file line number Diff line number Diff line change
@@ -1,61 +1,63 @@
use crate::{
api, commands::Args, db::DB, schema::bans, text::ban_message, Error, SendSyncError, HOUR,
};
use diesel::prelude::*;
use crate::{api, commands::Args, db::DB, schema::bans, text::ban_message, Error, HOUR};
use serenity::{model::prelude::*, prelude::*, utils::parse_username};
use sqlx::{
postgres::PgPool,
types::chrono::{DateTime, Utc},
};
use std::{
sync::Arc,
time::{Duration, SystemTime},
};

pub fn save_ban(user_id: String, guild_id: String, hours: u64) -> Result<(), Error> {
use tracing::info;

pub async fn save_ban(
user_id: String,
guild_id: String,
hours: u64,
db: Arc<PgPool>,
) -> Result<(), Error> {
info!("Recording ban for user {}", &user_id);
let conn = DB.get()?;
diesel::insert_into(bans::table)
.values((
bans::user_id.eq(user_id),
bans::guild_id.eq(guild_id),
bans::start_time.eq(SystemTime::now()),
bans::end_time.eq(SystemTime::now()
.checked_add(Duration::new(hours * HOUR, 0))
.ok_or("out of range Duration for ban end_time")?),
))
.execute(&conn)?;
sqlx::query(
"insert into bans(user_id, guild_id, start_time, end_time) values ($1, $2, $3, $4)",
)
.bind(user_id)
.bind(guild_id)
.bind(DateTime::<Utc>::from(SystemTime::now()))
.bind(DateTime::<Utc>::from(
SystemTime::now()
.checked_add(Duration::new(hours * HOUR, 0))
.ok_or("out of range Duration for ban end_time")?,
))
.execute(&*db)
.await?;

Ok(())
}

pub fn save_unban(user_id: String, guild_id: String) -> Result<(), Error> {
pub async fn save_unban(user_id: String, guild_id: String, db: Arc<PgPool>) -> Result<(), Error> {
info!("Recording unban for user {}", &user_id);
let conn = DB.get()?;
diesel::update(bans::table)
.filter(
bans::user_id
.eq(user_id)
.and(bans::guild_id.eq(guild_id).and(bans::unbanned.eq(false))),
)
.set(bans::unbanned.eq(true))
.execute(&conn)?;
sqlx::query(
"update bans set unbanned = true where user_id = $1 and guild_id = $2 and unbanned = false",
)
.bind(user_id)
.bind(guild_id)
.execute(&*db)
.await?;

Ok(())
}

pub async fn unban_users(cx: &Context) -> Result<(), SendSyncError> {
pub async fn unban_users(cx: &Context, db: Arc<PgPool>) -> Result<(), Error> {
use std::str::FromStr;

let to_unban = tokio::task::spawn_blocking(move || -> Result<Vec<(i32, String, String, bool, SystemTime, SystemTime)>, SendSyncError> {
let conn = DB.get()?;
Ok(bans::table
.filter(
bans::unbanned
.eq(false)
.and(bans::end_time.le(SystemTime::now())),
)
.load::<(i32, String, String, bool, SystemTime, SystemTime)>(&conn)?)
})
.await?;
let to_unban: Vec<(i32, String, String, bool, DateTime<Utc>, DateTime<Utc>)> =
sqlx::query_as("select * from bans where unbanned = false and end_time < $1")
.bind(DateTime::<Utc>::from(SystemTime::now()))
.fetch_all(&*db)
.await?;

for row in &to_unban? {
for row in &to_unban {
let guild_id = GuildId::from(u64::from_str(&row.2)?);
info!("Unbanning user {}", &row.1);
guild_id.unban(&cx, u64::from_str(&row.1)?).await?;
Expand Down Expand Up @@ -100,7 +102,13 @@ pub async fn temp_ban(args: Arc<Args>) -> Result<(), Error> {

guild.ban(&args.cx, &user, 7).await?;

save_ban(format!("{}", user_id), format!("{}", guild.id), hours)?;
save_ban(
format!("{}", user_id),
format!("{}", guild.id),
hours,
args.db.clone(),
)
.await?;
}
Ok(())
}
Expand Down
60 changes: 60 additions & 0 deletions src/command.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use crate::{commands::Args, Error};
use std::{future::Future, pin::Pin, sync::Arc};

type ResultFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;

pub trait AsyncFn<A, T>: 'static {
fn call(&self, args: A) -> ResultFuture<T, Error>;
}

impl<A, F, G, T> AsyncFn<A, T> for F
where
F: Fn(A) -> G + 'static,
G: Future<Output = Result<T, Error>> + Send + 'static,
{
fn call(&self, args: A) -> ResultFuture<T, Error> {
let fut = (self)(args);
Box::pin(async move { fut.await })
}
}

pub type Handler = dyn AsyncFn<Arc<Args>, ()> + Send + Sync;
pub type Auth = dyn AsyncFn<Arc<Args>, bool> + Send + Sync;

pub enum CommandKind {
Base,
Protected,
Help,
}

pub struct Command {
pub kind: CommandKind,
pub auth: &'static Auth,
pub handler: &'static Handler,
}

impl Command {
pub fn new(handler: &'static Handler) -> Self {
Self {
kind: CommandKind::Base,
auth: &|_| async { Ok(true) },
handler,
}
}

pub fn new_with_auth(handler: &'static Handler, auth: &'static Auth) -> Self {
Self {
kind: CommandKind::Protected,
auth,
handler,
}
}

pub fn help() -> Self {
Self {
kind: CommandKind::Help,
auth: &|_| async { Ok(true) },
handler: &|_| async { Ok(()) },
}
}
}
13 changes: 9 additions & 4 deletions src/command_history.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use crate::{
commands::{Commands, PREFIX},
Error, SendSyncError, HOUR,
Error, HOUR,
};
use indexmap::IndexMap;
use reqwest::Client as HttpClient;
use serenity::{model::prelude::*, prelude::*, utils::CustomMessage};
use std::time::Duration;
use sqlx::postgres::PgPool;
use std::{sync::Arc, time::Duration};
use tracing::info;

const MESSAGE_AGE_MAX: Duration = Duration::from_secs(HOUR);

Expand All @@ -18,6 +21,8 @@ pub async fn replay_message(
cx: Context,
ev: MessageUpdateEvent,
cmds: &Commands,
http: Arc<HttpClient>,
db: Arc<PgPool>,
) -> Result<(), Error> {
let age = ev.timestamp.and_then(|create| {
ev.edited_timestamp
Expand All @@ -37,14 +42,14 @@ pub async fn replay_message(
"sending edited message - {:?} {:?}",
msg.content, msg.author
);
cmds.execute(cx, msg);
cmds.execute(cx, msg, http, db).await;
}
}

Ok(())
}

pub async fn clear_command_history(cx: &Context) -> Result<(), SendSyncError> {
pub async fn clear_command_history(cx: &Context) -> Result<(), Error> {
let mut data = cx.data.write().await;
let history = data.get_mut::<CommandHistory>().unwrap();

Expand Down
Loading

0 comments on commit 1ec8819

Please sign in to comment.