Skip to content

Commit

Permalink
WIP async serenity support, getting there, not there yet
Browse files Browse the repository at this point in the history
  • Loading branch information
technetos committed Sep 13, 2021
1 parent 933b3ad commit 220d6c5
Show file tree
Hide file tree
Showing 15 changed files with 995 additions and 684 deletions.
664 changes: 445 additions & 219 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ license = "MIT"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
serenity = { version = "0.8.7", features = ["model"] }
serenity = { version = "0.10.8", features = ["model"] }
reqwest = { version = "0.10", features = ["json"] }
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
diesel = { version = "1.4.0", features = ["postgres", "r2d2"] }
diesel_migrations = { version = "1.4.0", features = ["postgres"] }
reqwest = { version = "0.10", features = ["blocking", "json"] }
serde = "1.0"
serde_derive = "1.0"
lazy_static = "1.4.0"
Expand Down
57 changes: 32 additions & 25 deletions src/api.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
use crate::{command_history::CommandHistory, commands::Args, db::DB, schema::roles, Error};
use diesel::prelude::*;
use serenity::{model::prelude::*, utils::parse_username};
use std::sync::Arc;

/// Send a reply to the channel the message was received on.
pub(crate) fn send_reply(args: &Args, message: &str) -> Result<(), Error> {
if let Some(response_id) = response_exists(args) {
pub async fn send_reply(args: Arc<Args>, message: &str) -> Result<(), Error> {
if let Some(response_id) = response_exists(args.clone()).await {
info!("editing message: {:?}", response_id);
args.msg
.channel_id
.edit_message(&args.cx, response_id, |msg| msg.content(message))?;
.edit_message(&args.clone().cx, response_id, |msg| msg.content(message))
.await?;
} else {
let command_id = args.msg.id;
let response = args.msg.channel_id.say(&args.cx, message)?;
let response = args.clone().msg.channel_id.say(&args.cx, message).await?;

let mut data = args.cx.data.write();
let mut data = args.cx.data.write().await;
let history = data.get_mut::<CommandHistory>().unwrap();
history.insert(command_id, response.id);
}

Ok(())
}

fn response_exists(args: &Args) -> Option<MessageId> {
let data = args.cx.data.read();
async fn response_exists(args: Arc<Args>) -> Option<MessageId> {
let data = args.cx.data.read().await;
let history = data.get::<CommandHistory>().unwrap();
history.get(&args.msg.id).cloned()
}

/// Determine if a member sending a message has the `Role`.
pub(crate) fn has_role(args: &Args, role: &RoleId) -> Result<bool, Error> {
pub fn has_role(args: Arc<Args>, role: &RoleId) -> Result<bool, Error> {
Ok(args
.msg
.member
Expand All @@ -38,41 +40,44 @@ pub(crate) fn has_role(args: &Args, role: &RoleId) -> Result<bool, Error> {
.contains(role))
}

fn check_permission(args: &Args, role: Option<String>) -> Result<bool, Error> {
fn check_permission(args: Arc<Args>, role: Option<String>) -> Result<bool, Error> {
use std::str::FromStr;
if let Some(role_id) = role {
Ok(has_role(args, &RoleId::from(u64::from_str(&role_id)?))?)
Ok(has_role(
args.clone(),
&RoleId::from(u64::from_str(&role_id)?),
)?)
} else {
Ok(false)
}
}

/// Return whether or not the user is a mod.
pub(crate) fn is_mod(args: &Args) -> Result<bool, Error> {
pub fn is_mod(args: Arc<Args>) -> Result<bool, Error> {
let role = roles::table
.filter(roles::name.eq("mod"))
.first::<(i32, String, String)>(&DB.get()?)
.optional()?;

check_permission(args, role.map(|(_, role_id, _)| role_id))
check_permission(args.clone(), role.map(|(_, role_id, _)| role_id))
}

pub(crate) fn is_wg_and_teams(args: &Args) -> Result<bool, Error> {
pub async fn is_wg_and_teams(args: Arc<Args>) -> Result<bool, Error> {
let role = roles::table
.filter(roles::name.eq("wg_and_teams"))
.first::<(i32, String, String)>(&DB.get()?)
.optional()?;

check_permission(args, role.map(|(_, role_id, _)| role_id))
check_permission(args.clone(), role.map(|(_, role_id, _)| role_id))
}

/// Set slow mode for a channel.
///
/// A `seconds` value of 0 will disable slowmode
pub(crate) fn slow_mode(args: Args) -> Result<(), Error> {
pub async fn slow_mode(args: Arc<Args>) -> Result<(), Error> {
use std::str::FromStr;

if is_mod(&args)? {
if is_mod(args.clone())? {
let seconds = &args
.params
.get("seconds")
Expand All @@ -85,12 +90,14 @@ pub(crate) fn slow_mode(args: Args) -> Result<(), Error> {
.ok_or("unable to retrieve channel param")?;

info!("Applying slowmode to channel {}", &channel_name);
ChannelId::from_str(channel_name)?.edit(&args.cx, |c| c.slow_mode_rate(*seconds))?;
ChannelId::from_str(channel_name)?
.edit(&args.cx, |c| c.slow_mode_rate(*seconds))
.await?;
}
Ok(())
}

pub(crate) fn slow_mode_help(args: Args) -> Result<(), Error> {
pub async fn slow_mode_help(args: Arc<Args>) -> Result<(), Error> {
let help_string = "
Set slowmode on a channel
```
Expand All @@ -107,15 +114,15 @@ will set slowmode on the `#bot-usage` channel with a delay of 10 seconds.
?slowmode #bot-usage 0
```
will disable slowmode on the `#bot-usage` channel.";
send_reply(&args, &help_string)?;
send_reply(args.clone(), &help_string).await?;
Ok(())
}

/// Kick a user from the guild.
///
/// Requires the kick members permission
pub(crate) fn kick(args: Args) -> Result<(), Error> {
if is_mod(&args)? {
pub async fn kick(args: Arc<Args>) -> Result<(), Error> {
if is_mod(args.clone())? {
let user_id = parse_username(
&args
.params
Expand All @@ -124,15 +131,15 @@ pub(crate) fn kick(args: Args) -> Result<(), Error> {
)
.ok_or("unable to retrieve user id")?;

if let Some(guild) = args.msg.guild(&args.cx) {
if let Some(guild) = args.msg.guild(&args.cx).await {
info!("Kicking user from guild");
guild.read().kick(&args.cx, UserId::from(user_id))?
guild.kick(&args.cx, UserId::from(user_id)).await?
}
}
Ok(())
}

pub(crate) fn kick_help(args: Args) -> Result<(), Error> {
pub async fn kick_help(args: Arc<Args>) -> Result<(), Error> {
let help_string = "
Kick a user from the guild
```
Expand All @@ -143,6 +150,6 @@ Kick a user from the guild
?kick @someuser
```
will kick a user from the guild.";
send_reply(&args, &help_string)?;
send_reply(args.clone(), &help_string).await?;
Ok(())
}
59 changes: 32 additions & 27 deletions src/ban.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ use crate::{
};
use diesel::prelude::*;
use serenity::{model::prelude::*, prelude::*, utils::parse_username};
use std::time::{Duration, SystemTime};
use std::{
sync::Arc,
time::{Duration, SystemTime},
};

pub(crate) fn save_ban(user_id: String, guild_id: String, hours: u64) -> Result<(), Error> {
pub fn save_ban(user_id: String, guild_id: String, hours: u64) -> Result<(), Error> {
info!("Recording ban for user {}", &user_id);
let conn = DB.get()?;
diesel::insert_into(bans::table)
Expand All @@ -22,7 +25,7 @@ pub(crate) fn save_ban(user_id: String, guild_id: String, hours: u64) -> Result<
Ok(())
}

pub(crate) fn save_unban(user_id: String, guild_id: String) -> Result<(), Error> {
pub fn save_unban(user_id: String, guild_id: String) -> Result<(), Error> {
info!("Recording unban for user {}", &user_id);
let conn = DB.get()?;
diesel::update(bans::table)
Expand All @@ -37,30 +40,34 @@ pub(crate) fn save_unban(user_id: String, guild_id: String) -> Result<(), Error>
Ok(())
}

pub(crate) fn unban_users(cx: &Context) -> Result<(), SendSyncError> {
pub async fn unban_users(cx: &Context) -> Result<(), SendSyncError> {
use std::str::FromStr;

let conn = DB.get()?;
let to_unban = bans::table
.filter(
bans::unbanned
.eq(false)
.and(bans::end_time.le(SystemTime::now())),
)
.load::<(i32, String, String, bool, SystemTime, SystemTime)>(&conn)?;

for row in &to_unban {
let to_unban = tokio::task::spawn_blocking(move || -> Result<Vec<(i32, String, String, bool, SystemTime, SystemTime)>, SendSyncError> {
let conn = DB.get()?;
Ok(bans::table
.filter(
bans::unbanned
.eq(false)
.and(bans::end_time.le(SystemTime::now())),
)
.load::<(i32, String, String, bool, SystemTime, SystemTime)>(&conn)?)
})
.await?;

for row in &to_unban? {
let guild_id = GuildId::from(u64::from_str(&row.2)?);
info!("Unbanning user {}", &row.1);
guild_id.unban(&cx, u64::from_str(&row.1)?)?;
guild_id.unban(&cx, u64::from_str(&row.1)?).await?;
}

Ok(())
}

/// Temporarily ban an user from the guild.
///
/// Requires the ban members permission
pub(crate) fn temp_ban(args: Args) -> Result<(), Error> {
pub async fn temp_ban(args: Arc<Args>) -> Result<(), Error> {
let user_id = parse_username(
&args
.params
Expand All @@ -82,25 +89,23 @@ pub(crate) fn temp_ban(args: Args) -> Result<(), Error> {
.get("reason")
.ok_or("unable to retrieve reason param")?;

if let Some(guild) = args.msg.guild(&args.cx) {
if let Some(guild) = args.msg.guild(&args.cx).await {
info!("Banning user from guild");
let user = UserId::from(user_id);

user.create_dm_channel(args.cx)?
.say(args.cx, ban_message(reason, hours))?;
user.create_dm_channel(&args.cx)
.await?
.say(&args.cx, ban_message(reason, hours))
.await?;

guild.read().ban(args.cx, &user, &"all")?;
guild.ban(&args.cx, &user, 7).await?;

save_ban(
format!("{}", user_id),
format!("{}", guild.read().id),
hours,
)?;
save_ban(format!("{}", user_id), format!("{}", guild.id), hours)?;
}
Ok(())
}

pub(crate) fn help(args: Args) -> Result<(), Error> {
pub async fn help(args: Arc<Args>) -> Result<(), Error> {
let hours = 24;
let reason = "violating the code of conduct";

Expand All @@ -125,6 +130,6 @@ will ban a user for {hours} hours and send them the following message:
reason = reason,
);

api::send_reply(&args, &help_string)?;
api::send_reply(args.clone(), &help_string).await?;
Ok(())
}
10 changes: 5 additions & 5 deletions src/command_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ use std::time::Duration;

const MESSAGE_AGE_MAX: Duration = Duration::from_secs(HOUR);

pub(crate) struct CommandHistory;
pub struct CommandHistory;

impl TypeMapKey for CommandHistory {
type Value = IndexMap<MessageId, MessageId>;
}

pub(crate) fn replay_message(
pub async fn replay_message(
cx: Context,
ev: MessageUpdateEvent,
cmds: &Commands,
Expand All @@ -37,15 +37,15 @@ pub(crate) fn replay_message(
"sending edited message - {:?} {:?}",
msg.content, msg.author
);
cmds.execute(cx, &msg);
cmds.execute(cx, msg);
}
}

Ok(())
}

pub(crate) fn clear_command_history(cx: &Context) -> Result<(), SendSyncError> {
let mut data = cx.data.write();
pub async fn clear_command_history(cx: &Context) -> Result<(), SendSyncError> {
let mut data = cx.data.write().await;
let history = data.get_mut::<CommandHistory>().unwrap();

// always keep the last command in history
Expand Down
Loading

0 comments on commit 220d6c5

Please sign in to comment.