From a4f511f67e350fb4e4792416d42dbeaa9f1d544b Mon Sep 17 00:00:00 2001 From: Matthew Esposito Date: Sun, 24 Nov 2024 10:50:21 -0500 Subject: [PATCH] fix(client): update rate limit self-check (fix #335) --- src/client.rs | 30 ++++++++++++++++++++++++------ src/subreddit.rs | 3 +++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/client.rs b/src/client.rs index 0e2c3011..ba085312 100644 --- a/src/client.rs +++ b/src/client.rs @@ -19,8 +19,7 @@ use std::{io, result::Result}; use crate::dbg_msg; use crate::oauth::{force_refresh_token, token_daemon, Oauth}; use crate::server::RequestExt; -use crate::subreddit::community; -use crate::utils::format_url; +use crate::utils::{format_url, Post}; const REDDIT_URL_BASE: &str = "https://oauth.reddit.com"; const REDDIT_URL_BASE_HOST: &str = "oauth.reddit.com"; @@ -480,11 +479,10 @@ pub async fn json(path: String, quarantine: bool) -> Result { } async fn self_check(sub: &str) -> Result<(), String> { - let request = Request::get(format!("/r/{sub}/")).body(Body::empty()).unwrap(); + let query = format!("/r/{sub}/hot.json?&raw_json=1"); - match community(request).await { - Ok(sub) if sub.status().is_success() => Ok(()), - Ok(sub) => Err(sub.status().to_string()), + match Post::fetch(&query, true).await { + Ok(_) => Ok(()), Err(e) => Err(e), } } @@ -509,6 +507,26 @@ pub async fn rate_limit_check() -> Result<(), String> { Ok(()) } +#[cfg(test)] +use {crate::config::get_setting, sealed_test::prelude::*}; + +#[tokio::test(flavor = "multi_thread")] +async fn test_rate_limit_check() { + rate_limit_check().await.unwrap(); +} + +#[test] +#[sealed_test(env = [("REDLIB_DEFAULT_SUBSCRIPTIONS", "rust")])] +fn test_default_subscriptions() { + tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap().block_on(async { + let subscriptions = get_setting("REDLIB_DEFAULT_SUBSCRIPTIONS"); + assert!(subscriptions.is_some()); + + // check rate limit + rate_limit_check().await.unwrap(); + }); +} + #[cfg(test)] static POPULAR_URL: &str = "/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL"; diff --git a/src/subreddit.rs b/src/subreddit.rs index 3a07bdc7..88aa542e 100644 --- a/src/subreddit.rs +++ b/src/subreddit.rs @@ -8,6 +8,7 @@ use crate::utils::{ use crate::{client::json, server::RequestExt, server::ResponseExt}; use cookie::Cookie; use hyper::{Body, Request, Response}; +use log::{debug, trace}; use rinja::Template; use once_cell::sync::Lazy; @@ -62,6 +63,7 @@ pub async fn community(req: Request) -> Result, String> { // Build Reddit API path let root = req.uri().path() == "/"; let query = req.uri().query().unwrap_or_default().to_string(); + trace!("query: {}", query); let subscribed = setting(&req, "subscriptions"); let front_page = setting(&req, "front_page"); let post_sort = req.cookie("post_sort").map_or_else(|| "hot".to_string(), |c| c.value().to_string()); @@ -123,6 +125,7 @@ pub async fn community(req: Request) -> Result, String> { } let path = format!("/r/{}/{sort}.json?{}{params}", sub_name.replace('+', "%2B"), req.uri().query().unwrap_or_default()); + debug!("Path: {}", path); let url = String::from(req.uri().path_and_query().map_or("", |val| val.as_str())); let redirect_url = url[1..].replace('?', "%3F").replace('&', "%26").replace('+', "%2B"); let filters = get_filters(&req);