Skip to content

Commit

Permalink
redo web download logic to have configured instances
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacherr committed Oct 21, 2024
1 parent e36b374 commit aec47c9
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 155 deletions.
9 changes: 8 additions & 1 deletion assyst-common/src/config/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@ pub struct Entitlements {
pub premium_server_sku_id: u64,
}

#[derive(Deserialize)]
#[derive(Deserialize, Clone)]
pub struct CobaltApiInstance {
pub url: String,
pub key: String,
}

#[derive(Deserialize, Clone)]
pub struct Urls {
pub proxy: Vec<String>,
pub filer: String,
pub eval: String,
pub bad_translation: String,
pub cobalt_api: Vec<CobaltApiInstance>,
}

#[derive(Deserialize)]
Expand Down
4 changes: 2 additions & 2 deletions assyst-core/src/command/services/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};

use anyhow::Context;
use assyst_common::config::CONFIG;
use assyst_common::util::{filetype, format_duration, sanitise_filename};
use assyst_proc_macro::command;
use assyst_string_fmt::Markdown;
Expand Down Expand Up @@ -117,8 +118,7 @@ impl ParseArgument for DownloadFlags {
]
)]
pub async fn download(ctxt: CommandCtxt<'_>, url: Word, options: DownloadFlags) -> anyhow::Result<()> {
let mut opts =
WebDownloadOpts::from_download_flags(options, ctxt.assyst().rest_cache_handler.get_web_download_urls());
let mut opts = WebDownloadOpts::from_download_flags(options, CONFIG.urls.clone().cobalt_api);

if url.0.to_ascii_lowercase().contains("youtube.com/playlist") {
let videos = get_youtube_playlist_entries(&url.0).await?;
Expand Down
2 changes: 2 additions & 0 deletions assyst-core/src/downloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async fn download_with_proxy(
) -> Result<impl Stream<Item = Result<Bytes, reqwest::Error>>, DownloadError> {
let resp = client
.get(format!("{}/proxy", get_next_proxy()))
.header("User-Agent", "Assyst Discord Bot (https://github.com/jacherr/assyst2)")
.query(&[("url", url), ("limit", &limit.to_string())])
.timeout(Duration::from_secs(10))
.send()
Expand All @@ -70,6 +71,7 @@ async fn download_no_proxy(
) -> Result<impl Stream<Item = Result<Bytes, reqwest::Error>>, DownloadError> {
Ok(client
.get(url)
.header("User-Agent", "Assyst Discord Bot (https://github.com/jacherr/assyst2)")
.send()
.await
.map_err(DownloadError::Reqwest)?
Expand Down
18 changes: 1 addition & 17 deletions assyst-core/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ use command::registry::register_interaction_commands;
use gateway_handler::handle_raw_event;
use gateway_handler::incoming_event::IncomingEvent;
use rest::patreon::init_patreon_refresh;
use rest::web_media_download::get_web_download_api_urls;
use task::tasks::refresh_entitlements::refresh_entitlements;
use task::tasks::refresh_web_download_urls::refresh_web_download_urls;
use task::tasks::reminders::handle_reminders;
use tokio::spawn;
use tracing::{debug, info /* trace */};
use tracing::{info /* trace */};
use twilight_gateway::EventTypeFlags;
use twilight_model::id::marker::WebhookMarker;
use twilight_model::id::Id;
Expand Down Expand Up @@ -143,14 +141,6 @@ async fn main() {
info!("Reminder processing disabled in config.dev.disable_reminder_check: not registering task");
}

assyst.register_task(Task::new_delayed(
assyst.clone(),
Duration::from_secs(60 * 10),
Duration::from_secs(60 * 10),
function_task_callback!(refresh_web_download_urls),
));
info!("Registered web download url refreshing task");

assyst.register_task(Task::new(
assyst.clone(),
Duration::from_secs(60 * 10),
Expand Down Expand Up @@ -239,12 +229,6 @@ async fn main() {
}
});

info!("Caching web download API URLs");
let web_download_urls = get_web_download_api_urls(&a.reqwest_client).await.unwrap_or(vec![]);
info!("Got {} URLs to cache", web_download_urls.len());
debug!(?web_download_urls);
a.rest_cache_handler.set_web_download_urls(web_download_urls);

loop {
std::thread::sleep(Duration::from_secs(1));
}
Expand Down
17 changes: 0 additions & 17 deletions assyst-core/src/rest/rest_cache_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ pub struct RestCacheHandler {
channel_nsfw_status: Cache<u64, bool>,
/// Guild ID -> User ID
guild_owners: Cache<u64, u64>,
/// List of all download URLs.
web_download_urls: Cache<String, ()>,
}
impl RestCacheHandler {
pub fn new(client: Arc<HttpClient>) -> RestCacheHandler {
Expand All @@ -43,7 +41,6 @@ impl RestCacheHandler {
guild_upload_limits: default_cache(),
channel_nsfw_status: default_cache(),
guild_owners: default_cache(),
web_download_urls: Cache::builder().build(),
}
}

Expand All @@ -56,10 +53,6 @@ impl RestCacheHandler {
size += self.guild_upload_limits.entry_count() * size_of::<(u64, u64)>() as u64;
size += self.channel_nsfw_status.entry_count() * size_of::<(u64, bool)>() as u64;
size += self.guild_owners.entry_count() * size_of::<(u64, u64)>() as u64;
size += self
.web_download_urls
.iter()
.fold(0, |acc, x| acc + x.0.as_bytes().len()) as u64;
size
}

Expand Down Expand Up @@ -184,14 +177,4 @@ impl RestCacheHandler {

Ok(owner == user_id || member_is_manager)
}

pub fn set_web_download_urls(&self, urls: Vec<String>) {
for url in urls {
self.web_download_urls.insert(url, ());
}
}

pub fn get_web_download_urls(&self) -> Vec<Arc<String>> {
self.web_download_urls.iter().map(|x| x.0).collect::<Vec<_>>()
}
}
114 changes: 15 additions & 99 deletions assyst-core/src/rest/web_media_download.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::time::Duration;

use anyhow::{bail, Context};
use assyst_common::util::{format_duration, string_from_likely_utf8};
use futures_util::future::join_all;
use assyst_common::config::config::CobaltApiInstance;
use assyst_common::util::string_from_likely_utf8;
use rand::seq::SliceRandom;
use rand::thread_rng;
use reqwest::{Client, StatusCode};
Expand All @@ -27,11 +26,11 @@ pub static TEST_URL_TIMEOUT: Duration = Duration::from_secs(15);
pub struct WebDownloadOpts {
pub audio_only: Option<bool>,
pub quality: Option<String>,
pub urls: Vec<Arc<String>>,
pub urls: Vec<CobaltApiInstance>,
pub verbose: bool,
}
impl WebDownloadOpts {
pub fn from_download_flags(flags: DownloadFlags, urls: Vec<Arc<String>>) -> Self {
pub fn from_download_flags(flags: DownloadFlags, urls: Vec<CobaltApiInstance>) -> Self {
Self {
audio_only: Some(flags.audio),
quality: if flags.quality != 0 {
Expand Down Expand Up @@ -76,99 +75,13 @@ pub struct InstancesQueryResult {
pub protocol: String,
}

/// Tests a web download route to see if it meets requirements.
/// Requirement is that the entire request finishes in less than 15 seconds on this URL, with a
/// successful download.
/// Returns true if the route is valid, false otherwise.
async fn test_route(client: &Client, url: &str) -> bool {
let start = Instant::now();
let opts = WebDownloadOpts {
audio_only: Some(false),
quality: Some("144".to_owned()),
urls: vec![Arc::new(url.to_owned())],
verbose: false,
};

let res = download_web_media(client, TEST_URL, opts).await;
let success = res.is_ok();

let elapsed = start.elapsed();

if success && elapsed < TEST_URL_TIMEOUT {
debug!(
"Web download URL {url} took {} to download test media",
format_duration(&elapsed)
);
} else if elapsed < TEST_URL_TIMEOUT {
let err = res.unwrap_err();
debug!(
"Web download URL {url} failed to download test media ({})",
err.to_string()
);
}

success && (elapsed < TEST_URL_TIMEOUT)
}

/// URLs must be a score of at least 90 (i.e., most sites supported), must support YouTube,
/// and must have a domain over https.
pub async fn get_web_download_api_urls(client: &Client) -> anyhow::Result<Vec<String>> {
/*
let res = client
.get(INSTANCES_ROUTE)
.header("accept", "application/json")
.header("User-Agent", "Assyst Discord Bot (https://github.com/jacherr/assyst2)")
.send()
.await?;
let json = res.json::<Vec<InstancesQueryResult>>().await?;
let test_urls = json
.iter()
.filter_map(|entry: &InstancesQueryResult| {
if entry.protocol == "https" && entry.score >= TEST_SCORE_THRESHOLD {
Some(format!("https://{}/api/json", entry.api))
} else {
None
}
})
.map(|url| {
debug!("Testing web download API URL {}", url);
let c = client.clone();
timeout(
TEST_URL_TIMEOUT,
tokio::spawn(async move {
let res = test_route(&c, &url).await;
(url, res)
}),
)
})
.collect::<Vec<_>>();
let valid_urls = join_all(test_urls)
.await
.into_iter()
.filter_map(|res| res.ok())
.map(|res| res.unwrap())
.filter(|res| res.1)
.map(|res| res.0)
.collect::<Vec<_>>();
Ok(valid_urls)*/

Ok(vec!["https://api.cobalt.tools/api/json".to_owned()])
}

/// Attempts to download web media. Will try all APIs until one succeeds, unless
/// `opts.api_url_override` is set.
pub async fn download_web_media(client: &Client, url: &str, opts: WebDownloadOpts) -> anyhow::Result<Vec<u8>> {
let encoded_url = urlencoding::encode(url).to_string();

let urls = {
let mut urls = opts.urls;
if urls.is_empty() {
bail!("The download command is temporarily disabled due to abuse. Please try again later.");
bail!("No available instances are defined.");
}
urls.shuffle(&mut thread_rng());
urls
Expand All @@ -178,19 +91,22 @@ pub async fn download_web_media(client: &Client, url: &str, opts: WebDownloadOpt
let mut err: String = String::new();

for route in urls {
let key = route.key.clone();
let route = route.url.clone();

debug!("trying url: {route} for web media {url}");

let res = client
.post((*route).clone())
.post(route)
.header("accept", "application/json")
.header("content-type", "application/json")
.header("User-Agent", "Assyst Discord Bot (https://github.com/jacherr/assyst2)")
.header("Authorization", key)
.json(&json!({
"url": encoded_url,
"isAudioOnly": opts.audio_only.unwrap_or(false),
"aFormat": "mp3",
"isNoTTWatermark": true,
"vQuality": opts.quality.clone().unwrap_or("720".to_owned())
"url": url,
"downloadMode": if opts.audio_only.unwrap_or(false) { "audio" } else { "auto" },
"audioFormat": "mp3",
"videoQuality": opts.quality.clone().unwrap_or("720".to_owned()),
}))
.timeout(Duration::from_secs(60))
.send()
Expand Down
1 change: 0 additions & 1 deletion assyst-core/src/task/tasks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pub mod get_premium_users;
pub mod refresh_entitlements;
pub mod refresh_web_download_urls;
pub mod reminders;
pub mod top_gg_stats;
18 changes: 0 additions & 18 deletions assyst-core/src/task/tasks/refresh_web_download_urls.rs

This file was deleted.

2 changes: 2 additions & 0 deletions config.template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ proxy = []
filer = ""
# Bad translation URL.
bad_translation = ""
# Cobalt API instances
cobalt_api = [{ url = "", key = "" }]

[authentication]
# Token to authenticate with Discord.
Expand Down

0 comments on commit aec47c9

Please sign in to comment.