From 1412d4444abe4e94eafffebb56dab116bfeedd03 Mon Sep 17 00:00:00 2001 From: Axel Jacobsen Date: Sat, 30 Dec 2023 18:00:41 -0800 Subject: [PATCH] refactoring! --- src/bots/arb_bot.rs | 17 ++-- src/bots/ewma_bot.rs | 21 ++--- src/coms.rs | 202 ++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 2 +- src/market_handler.rs | 160 ++++++--------------------------- src/utils.rs | 88 ------------------ 6 files changed, 252 insertions(+), 238 deletions(-) create mode 100644 src/coms.rs delete mode 100644 src/utils.rs diff --git a/src/bots/arb_bot.rs b/src/bots/arb_bot.rs index 39d29ad..67a0399 100644 --- a/src/bots/arb_bot.rs +++ b/src/bots/arb_bot.rs @@ -6,22 +6,23 @@ use tokio::sync::{broadcast, mpsc}; use crate::bots::Bot; use crate::manifold_types; -use crate::market_handler; + +use crate::coms::{InternalPacket,Method}; pub struct ArbitrageBot { id: String, market: manifold_types::FullMarket, answers: HashMap, - bot_to_mh_tx: mpsc::Sender, - mh_to_bot_rx: broadcast::Receiver, + bot_to_mh_tx: mpsc::Sender, + mh_to_bot_rx: broadcast::Receiver, } impl ArbitrageBot { pub fn new( id: String, market: manifold_types::FullMarket, - bot_to_mh_tx: mpsc::Sender, - mh_to_bot_rx: broadcast::Receiver, + bot_to_mh_tx: mpsc::Sender, + mh_to_bot_rx: broadcast::Receiver, ) -> Self { let mut id_to_answers = HashMap::new(); @@ -92,10 +93,10 @@ impl ArbitrageBot { fn botbet_to_internal_coms_packet( &self, bet: manifold_types::BotBet, - ) -> market_handler::InternalPacket { - market_handler::InternalPacket::new( + ) -> InternalPacket { + InternalPacket::new( self.get_id(), - market_handler::Method::Post, + Method::Post, "bet".to_string(), vec![], Some(serde_json::json!({ diff --git a/src/bots/ewma_bot.rs b/src/bots/ewma_bot.rs index 02942d8..9f62d5c 100644 --- a/src/bots/ewma_bot.rs +++ b/src/bots/ewma_bot.rs @@ -4,7 +4,8 @@ use tokio::sync::{broadcast, mpsc}; use crate::bots::Bot; use crate::manifold_types; -use crate::market_handler; + +use crate::coms::{Method,InternalPacket}; struct EWMA { s0: f64, @@ -30,8 +31,8 @@ pub struct EWMABot { id: String, market: manifold_types::FullMarket, - bot_to_mh_tx: mpsc::Sender, - mh_to_bot_rx: broadcast::Receiver, + bot_to_mh_tx: mpsc::Sender, + mh_to_bot_rx: broadcast::Receiver, ewma_1: EWMA, ewma_2: EWMA, @@ -45,8 +46,8 @@ impl EWMABot { pub fn new( id: String, market: manifold_types::FullMarket, - bot_to_mh_tx: mpsc::Sender, - mh_to_bot_rx: broadcast::Receiver, + bot_to_mh_tx: mpsc::Sender, + mh_to_bot_rx: broadcast::Receiver, alpha_1: f64, alpha_2: f64, ) -> Self { @@ -65,7 +66,7 @@ impl EWMABot { } } - async fn make_trades(&mut self, trades: Vec) { + async fn make_trades(&mut self, trades: Vec) { for trade in trades { self.bot_to_mh_tx.send(trade).await.unwrap(); @@ -124,9 +125,9 @@ impl Bot for EWMABot { match self.update_prob(&bet) { manifold_types::Side::Buy => { - let buy_bet = market_handler::InternalPacket::new( + let buy_bet = InternalPacket::new( self.get_id(), - market_handler::Method::Post, + Method::Post, "bet".to_string(), vec![], Some(serde_json::json!({ @@ -139,9 +140,9 @@ impl Bot for EWMABot { self.make_trades(vec![buy_bet]).await; } manifold_types::Side::Sell => { - let sell_bet = market_handler::InternalPacket::new( + let sell_bet = InternalPacket::new( self.get_id(), - market_handler::Method::Post, + Method::Post, format!("market/{}/sell", bet.contract_id), vec![], Some(serde_json::json!({ diff --git a/src/coms.rs b/src/coms.rs new file mode 100644 index 0000000..a901e0b --- /dev/null +++ b/src/coms.rs @@ -0,0 +1,202 @@ +use std::env; + + + + +use log::{debug, error}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use tokio::time::{Duration}; + + +use crate::rate_limiter; + + +use crate::errors; + +const MANIFOLD_API_URL: &str = "https://api.manifold.markets/v0"; + +fn get_env_key(key: &str) -> Result { + match env::var(key) { + Ok(key) => Ok(format!("Key {key}")), + Err(e) => Err(format!("couldn't find Manifold API key: {e}")), + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum Method { + Get, + Post, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct InternalPacket { + pub bot_id: String, + method: Method, + endpoint: String, + query_params: Vec<(String, String)>, + data: Option, + response: Option, +} + +impl InternalPacket { + pub fn new( + bot_id: String, + method: Method, + endpoint: String, + query_params: Vec<(String, String)>, + data: Option, + ) -> Self { + Self { + bot_id, + method, + endpoint, + query_params, + data, + response: None, + } + } + + pub fn response_from_existing(packet: &InternalPacket, response: String) -> Self { + Self { + bot_id: packet.bot_id.clone(), + method: packet.method.clone(), + endpoint: packet.endpoint.clone(), + query_params: packet.query_params.clone(), + data: packet.data.clone(), + response: Some(response), + } + } +} + +pub async fn get_endpoint( + endpoint: String, + query_params: &[(String, String)], +) -> Result { + debug!( + "get endpoint; endpoint '{endpoint}'; query params '{:?}'", + query_params, + ); + + let client = reqwest::Client::new(); + + let req = client + .get(format!("{MANIFOLD_API_URL}/{endpoint}")) + .query(&query_params) + .header("Authorization", get_env_key("MANIFOLD_KEY").unwrap()); + + let resp = req.send().await?; + + if resp.status().is_success() { + Ok(resp) + } else { + error!("api error (bad status code) {resp:?} {query_params:?}"); + Err(resp.error_for_status().unwrap_err()) + } +} + +pub async fn post_endpoint( + endpoint: String, + query_params: &[(String, String)], + data: Option, +) -> Result { + debug!( + "post endpoint; endpoint '{endpoint}'; query params '{:?}'; data '{:?}'", + query_params, data + ); + + let client = reqwest::Client::new(); + let req = client + .post(format!("{MANIFOLD_API_URL}/{endpoint}")) + .query(&query_params) + .header("Authorization", get_env_key("MANIFOLD_KEY").unwrap()); + + let data_clone = data.clone(); + + let resp = if let Some(data) = data { + let reqq = req.json(&data); + reqq.send().await? + } else { + req.send().await? + }; + + if resp.status().is_success() { + Ok(resp) + } else { + error!("api error (bad status code) {resp:?} {query_params:?} {data_clone:?}"); + Err(resp.error_for_status().unwrap_err()) + } +} + +pub async fn response_into( + resp: reqwest::Response, +) -> Result { + let body = resp.text().await?; + let from_json = serde_json::from_str::(&body); + match from_json { + Ok(t) => Ok(t), + Err(e) => { + error!("Couldn't parse response {body}"); + Err(errors::ReqwestResponseParsing::SerdeError(e)) + } + } +} + + +pub async fn rate_limited_post_endpoint( + mut write_rate_limiter: rate_limiter::RateLimiter, + endpoint: String, + query_params: &[(String, String)], + data: Option, +) -> Result { + if write_rate_limiter.block_for_average_pace_then_commit(Duration::from_secs(60)) { + post_endpoint(endpoint, query_params, data).await + } else { + panic!( + "rate limiter timed out; this shouldn't be possible, \ + most likely rate limit is set wrong" + ); + } +} + +pub async fn rate_limited_get_endpoint( + mut read_rate_limiter: rate_limiter::RateLimiter, + endpoint: String, + query_params: &[(String, String)], +) -> Result { + if read_rate_limiter.block_for_average_pace_then_commit(Duration::from_secs(1)) { + get_endpoint(endpoint, query_params).await + } else { + panic!( + "rate limiter timed out; this shouldn't be possible, \ + most likely rate limit is set wrong" + ); + } +} + +pub async fn send_internal_packet( + read_rate_limiter: &rate_limiter::RateLimiter, + write_rate_limiter: &rate_limiter::RateLimiter, + internal_coms_packet: &InternalPacket, +) -> Result { + match internal_coms_packet.method { + Method::Get => { + rate_limited_get_endpoint( + read_rate_limiter.clone(), + internal_coms_packet.endpoint.clone(), + &internal_coms_packet.query_params, + ) + .await + } + Method::Post => { + rate_limited_post_endpoint( + write_rate_limiter.clone(), + internal_coms_packet.endpoint.clone(), + &internal_coms_packet.query_params, + internal_coms_packet.data.clone(), + ) + .await + } + } +} diff --git a/src/main.rs b/src/main.rs index 65c6a9f..1c4440a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ mod errors; mod manifold_types; mod market_handler; mod rate_limiter; -mod utils; +mod coms; use crate::bots::arb_bot::ArbitrageBot; use crate::bots::ewma_bot::EWMABot; diff --git a/src/market_handler.rs b/src/market_handler.rs index 7d19505..75bc781 100644 --- a/src/market_handler.rs +++ b/src/market_handler.rs @@ -5,126 +5,24 @@ use std::sync::{ }; use log::{debug, error, info, warn}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; + + use tokio::sync::{broadcast, mpsc}; use tokio::time::{sleep, Duration}; use crate::errors; use crate::manifold_types; use crate::rate_limiter; -use crate::utils; - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub enum Method { - Get, - Post, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct InternalPacket { - bot_id: String, - method: Method, - endpoint: String, - query_params: Vec<(String, String)>, - data: Option, - response: Option, -} - -impl InternalPacket { - pub fn new( - bot_id: String, - method: Method, - endpoint: String, - query_params: Vec<(String, String)>, - data: Option, - ) -> Self { - Self { - bot_id, - method, - endpoint, - query_params, - data, - response: None, - } - } - - pub fn response_from_existing(packet: &InternalPacket, response: String) -> Self { - Self { - bot_id: packet.bot_id.clone(), - method: packet.method.clone(), - endpoint: packet.endpoint.clone(), - query_params: packet.query_params.clone(), - data: packet.data.clone(), - response: Some(response), - } - } -} +use crate::coms; -async fn rate_limited_post_endpoint( - mut write_rate_limiter: rate_limiter::RateLimiter, - endpoint: String, - query_params: &[(String, String)], - data: Option, -) -> Result { - if write_rate_limiter.block_for_average_pace_then_commit(Duration::from_secs(60)) { - utils::post_endpoint(endpoint, query_params, data).await - } else { - panic!( - "rate limiter timed out; this shouldn't be possible, \ - most likely rate limit is set wrong" - ); - } -} - -async fn rate_limited_get_endpoint( - mut read_rate_limiter: rate_limiter::RateLimiter, - endpoint: String, - query_params: &[(String, String)], -) -> Result { - if read_rate_limiter.block_for_average_pace_then_commit(Duration::from_secs(1)) { - utils::get_endpoint(endpoint, query_params).await - } else { - panic!( - "rate limiter timed out; this shouldn't be possible, \ - most likely rate limit is set wrong" - ); - } -} - -pub async fn send_internal_packet( - read_rate_limiter: &rate_limiter::RateLimiter, - write_rate_limiter: &rate_limiter::RateLimiter, - internal_coms_packet: &InternalPacket, -) -> Result { - match internal_coms_packet.method { - Method::Get => { - rate_limited_get_endpoint( - read_rate_limiter.clone(), - internal_coms_packet.endpoint.clone(), - &internal_coms_packet.query_params, - ) - .await - } - Method::Post => { - rate_limited_post_endpoint( - write_rate_limiter.clone(), - internal_coms_packet.endpoint.clone(), - &internal_coms_packet.query_params, - internal_coms_packet.data.clone(), - ) - .await - } - } -} #[allow(dead_code)] #[derive(Debug)] pub struct MarketHandler { halt_flag: Arc, - bots_to_mh_tx: mpsc::Sender, - bot_out_channel: Arc>>>, + bots_to_mh_tx: mpsc::Sender, + bot_out_channel: Arc>>>, read_rate_limiter: rate_limiter::RateLimiter, write_rate_limiter: rate_limiter::RateLimiter, @@ -137,8 +35,8 @@ impl MarketHandler { pub fn new() -> Self { let halt_flag = Arc::new(AtomicBool::new(false)); - let (bots_to_mh_tx, bots_to_mh_rx) = mpsc::channel::(256); - let bot_out_channel: Arc>>> = + let (bots_to_mh_tx, bots_to_mh_rx) = mpsc::channel::(256); + let bot_out_channel: Arc>>> = Arc::new(Mutex::new(HashMap::new())); let halt_flag_clone = halt_flag.clone(); @@ -167,9 +65,9 @@ impl MarketHandler { } pub fn send_to_bots( - bot_out_channel: &Arc>>>, + bot_out_channel: &Arc>>>, bot_id: &String, - packet: InternalPacket, + packet: coms::InternalPacket, ) { bot_out_channel .lock() @@ -184,8 +82,8 @@ impl MarketHandler { write_rate_limiter: rate_limiter::RateLimiter, read_rate_limiter: rate_limiter::RateLimiter, halt_flag: Arc, - mut bots_to_mh_rx: mpsc::Receiver, - bot_out_channel: Arc>>>, + mut bots_to_mh_rx: mpsc::Receiver, + bot_out_channel: Arc>>>, ) { while !halt_flag.load(Ordering::SeqCst) { let internal_coms_packet = match bots_to_mh_rx.recv().await { @@ -195,7 +93,7 @@ impl MarketHandler { debug!("got internal_coms packet {:?}", internal_coms_packet); - let maybe_res = send_internal_packet( + let maybe_res = coms::send_internal_packet( &read_rate_limiter, &write_rate_limiter, &internal_coms_packet, @@ -206,7 +104,7 @@ impl MarketHandler { Ok(res) => res, Err(e) => { error!("api error {e}"); - let packet = InternalPacket::response_from_existing( + let packet = coms::InternalPacket::response_from_existing( &internal_coms_packet, format!("api error {e}"), ); @@ -220,7 +118,7 @@ impl MarketHandler { .await .unwrap(); - let packet = InternalPacket::response_from_existing(&internal_coms_packet, res); + let packet = coms::InternalPacket::response_from_existing(&internal_coms_packet, res); Self::send_to_bots(&bot_out_channel, &internal_coms_packet.bot_id, packet); } } @@ -230,7 +128,7 @@ impl MarketHandler { } pub async fn check_alive(&self) -> bool { - let resp = rate_limited_get_endpoint(self.read_rate_limiter.clone(), "me".to_string(), &[]) + let resp = coms::rate_limited_get_endpoint(self.read_rate_limiter.clone(), "me".to_string(), &[]) .await .unwrap(); @@ -238,7 +136,7 @@ impl MarketHandler { } pub async fn whoami(&self) -> Result { - let resp = rate_limited_get_endpoint(self.read_rate_limiter.clone(), "me".to_string(), &[]) + let resp = coms::rate_limited_get_endpoint(self.read_rate_limiter.clone(), "me".to_string(), &[]) .await .unwrap(); @@ -266,7 +164,7 @@ impl MarketHandler { ("limit".to_string(), "1000".to_string()), ]; - let bets_response = rate_limited_get_endpoint( + let bets_response = coms::rate_limited_get_endpoint( self.read_rate_limiter.clone(), "bets".to_string(), ¶ms, @@ -335,7 +233,7 @@ impl MarketHandler { "answerId": pos.answer_id, })); - let sell_response = rate_limited_post_endpoint( + let sell_response = coms::rate_limited_post_endpoint( self.write_rate_limiter.clone(), format!("market/{}/sell", pos.contract_id), &[], @@ -365,7 +263,7 @@ impl MarketHandler { &self, term: String, ) -> Result { - let resp = rate_limited_get_endpoint( + let resp = coms::rate_limited_get_endpoint( self.read_rate_limiter.clone(), "search-markets".to_string(), &[ @@ -376,7 +274,7 @@ impl MarketHandler { .await .unwrap(); - let lite_market_req = utils::response_into::>(resp).await; + let lite_market_req = coms::response_into::>(resp).await; let lite_market = match lite_market_req { Ok(mut markets) => { if markets.len() == 1 { @@ -392,7 +290,7 @@ impl MarketHandler { Err(e) => Err(e), }?; - let full_market = rate_limited_get_endpoint( + let full_market = coms::rate_limited_get_endpoint( self.read_rate_limiter.clone(), format!("market/{}", lite_market.as_ref().unwrap().id), &[], @@ -400,7 +298,7 @@ impl MarketHandler { .await .unwrap(); - utils::response_into::(full_market).await + coms::response_into::(full_market).await } /// Initializes a tx, rx pair for the bot. The tx channel is used by the @@ -412,8 +310,8 @@ impl MarketHandler { bot_id: String, ) -> Result< ( - mpsc::Sender, - broadcast::Receiver, + mpsc::Sender, + broadcast::Receiver, ), String, > { @@ -424,7 +322,7 @@ impl MarketHandler { let bot_to_mh_tx = self.bots_to_mh_tx.clone(); - let (tx_bot, rx_bot) = broadcast::channel::(4); + let (tx_bot, rx_bot) = broadcast::channel::(4); self.bot_out_channel.lock().unwrap().insert(bot_id, tx_bot); Ok((bot_to_mh_tx, rx_bot)) @@ -464,7 +362,7 @@ impl MarketHandler { let mut base_query = query_params.to_vec(); base_query.push(("limit".to_string(), "1".to_string())); - let response = rate_limited_get_endpoint( + let response = coms::rate_limited_get_endpoint( self.read_rate_limiter.clone(), "bets".to_string(), &base_query, @@ -472,7 +370,7 @@ impl MarketHandler { .await .expect("Couldn't get most recent bet from api"); - let mut most_recent_id = utils::response_into::>(response) + let mut most_recent_id = coms::response_into::>(response) .await .expect("Couldn't convert json into Bet") .pop() @@ -498,7 +396,7 @@ impl MarketHandler { continue; } - let maybe_resp = rate_limited_get_endpoint( + let maybe_resp = coms::rate_limited_get_endpoint( read_rate_limiter_clone.clone(), "bets".to_string(), ¶ms, @@ -513,7 +411,7 @@ impl MarketHandler { } }; - let bets = utils::response_into::>(resp) + let bets = coms::response_into::>(resp) .await .expect("Couldn't convert json into Bet"); diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index 3a1484c..0000000 --- a/src/utils.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::env; - -use log::{debug, error}; -use serde_json::Value; - -use crate::errors; - -const MANIFOLD_API_URL: &str = "https://api.manifold.markets/v0"; - -fn get_env_key(key: &str) -> Result { - match env::var(key) { - Ok(key) => Ok(format!("Key {key}")), - Err(e) => Err(format!("couldn't find Manifold API key: {e}")), - } -} - -pub async fn get_endpoint( - endpoint: String, - query_params: &[(String, String)], -) -> Result { - debug!( - "get endpoint; endpoint '{endpoint}'; query params '{:?}'", - query_params, - ); - - let client = reqwest::Client::new(); - - let req = client - .get(format!("{MANIFOLD_API_URL}/{endpoint}")) - .query(&query_params) - .header("Authorization", get_env_key("MANIFOLD_KEY").unwrap()); - - let resp = req.send().await?; - - if resp.status().is_success() { - Ok(resp) - } else { - error!("api error (bad status code) {resp:?} {query_params:?}"); - Err(resp.error_for_status().unwrap_err()) - } -} - -pub async fn post_endpoint( - endpoint: String, - query_params: &[(String, String)], - data: Option, -) -> Result { - debug!( - "post endpoint; endpoint '{endpoint}'; query params '{:?}'; data '{:?}'", - query_params, data - ); - - let client = reqwest::Client::new(); - let req = client - .post(format!("{MANIFOLD_API_URL}/{endpoint}")) - .query(&query_params) - .header("Authorization", get_env_key("MANIFOLD_KEY").unwrap()); - - let data_clone = data.clone(); - - let resp = if let Some(data) = data { - let reqq = req.json(&data); - reqq.send().await? - } else { - req.send().await? - }; - - if resp.status().is_success() { - Ok(resp) - } else { - error!("api error (bad status code) {resp:?} {query_params:?} {data_clone:?}"); - Err(resp.error_for_status().unwrap_err()) - } -} - -pub async fn response_into( - resp: reqwest::Response, -) -> Result { - let body = resp.text().await?; - let from_json = serde_json::from_str::(&body); - match from_json { - Ok(t) => Ok(t), - Err(e) => { - error!("Couldn't parse response {body}"); - Err(errors::ReqwestResponseParsing::SerdeError(e)) - } - } -}