From 70dcb2e6cbe235b111d12b25cdabbff190b34701 Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Thu, 8 Aug 2024 00:35:32 +0100 Subject: [PATCH] First pass to implement socks5 support --- Cargo.lock | 14 ++++ crates/librqbit/Cargo.toml | 6 +- crates/librqbit/src/dht_utils.rs | 21 ++++- crates/librqbit/src/lib.rs | 1 + crates/librqbit/src/peer_connection.rs | 11 ++- crates/librqbit/src/peer_info_reader/mod.rs | 21 +++-- crates/librqbit/src/session.rs | 37 +++++++-- crates/librqbit/src/stream_connect.rs | 82 +++++++++++++++++++ crates/librqbit/src/torrent_state/live/mod.rs | 2 + crates/librqbit/src/torrent_state/mod.rs | 10 +++ crates/rqbit/src/main.rs | 13 ++- 11 files changed, 195 insertions(+), 23 deletions(-) create mode 100644 crates/librqbit/src/stream_connect.rs diff --git a/Cargo.lock b/Cargo.lock index 16673eb5..8d90173c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1346,6 +1346,7 @@ dependencies = [ "size_format", "tempfile", "tokio", + "tokio-socks", "tokio-stream", "tokio-test", "tokio-util", @@ -2145,6 +2146,7 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls", + "tokio-socks", "tower-service", "url", "wasm-bindgen", @@ -2700,6 +2702,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-socks" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d4770b8024672c1101b3f6733eab95b18007dbe0847a8afe341fcf79e06043f" +dependencies = [ + "either", + "futures-util", + "thiserror", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.15" diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index 1c600b10..4d4261ef 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -42,7 +42,10 @@ anyhow = "1" itertools = "0.12" http = "1" regex = "1" -reqwest = { version = "0.12", default-features = false, features = ["json"] } +reqwest = { version = "0.12", default-features = false, features = [ + "json", + "socks", +] } urlencoding = "2" byteorder = "1" bincode = "1" @@ -75,6 +78,7 @@ async-stream = "0.3.5" memmap2 = { version = "0.9.4" } lru = { version = "0.12.3", optional = true } mime_guess = { version = "2.0.5", default-features = false } +tokio-socks = "0.5.2" [dev-dependencies] futures = { version = "0.3" } diff --git a/crates/librqbit/src/dht_utils.rs b/crates/librqbit/src/dht_utils.rs index f0aa20f4..d348a7b2 100644 --- a/crates/librqbit/src/dht_utils.rs +++ b/crates/librqbit/src/dht_utils.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, net::SocketAddr}; +use std::{collections::HashSet, net::SocketAddr, sync::Arc}; use anyhow::Context; use buffers::ByteBufOwned; @@ -8,6 +8,7 @@ use tracing::{debug, error_span, Instrument}; use crate::{ peer_connection::PeerConnectionOptions, peer_info_reader, spawn_utils::BlockingSpawner, + stream_connect::StreamConnector, }; use librqbit_core::hash_id::Id20; @@ -30,6 +31,7 @@ pub async fn read_metainfo_from_peer_receiver + Unp initial_addrs: Vec, addrs_stream: A, peer_connection_options: Option, + connector: Arc, ) -> ReadMetainfoResult { let mut seen = HashSet::::new(); let mut addrs = addrs_stream; @@ -38,6 +40,7 @@ pub async fn read_metainfo_from_peer_receiver + Unp let read_info_guarded = |addr| { let semaphore = &semaphore; + let connector = connector.clone(); async move { let token = semaphore.acquire().await?; let ret = peer_info_reader::read_metainfo_from_peer( @@ -46,6 +49,7 @@ pub async fn read_metainfo_from_peer_receiver + Unp info_hash, peer_connection_options, BlockingSpawner::new(true), + connector, ) .instrument(error_span!("read_metainfo_from_peer", ?addr)) .await @@ -93,7 +97,10 @@ mod tests { use librqbit_core::peer_id::generate_peer_id; use super::*; - use std::{str::FromStr, sync::Once}; + use std::{ + str::FromStr, + sync::{Arc, Once}, + }; static LOG_INIT: Once = Once::new(); @@ -114,7 +121,15 @@ mod tests { let peer_rx = dht.get_peers(info_hash, None).unwrap(); let peer_id = generate_peer_id(); - match read_metainfo_from_peer_receiver(peer_id, info_hash, Vec::new(), peer_rx, None).await + match read_metainfo_from_peer_receiver( + peer_id, + info_hash, + Vec::new(), + peer_rx, + None, + Arc::new(Default::default()), + ) + .await { ReadMetainfoResult::Found { info, .. } => dbg!(info), ReadMetainfoResult::ChannelClosed { .. } => todo!("should not have happened"), diff --git a/crates/librqbit/src/lib.rs b/crates/librqbit/src/lib.rs index fb16b4a2..8990e0e0 100644 --- a/crates/librqbit/src/lib.rs +++ b/crates/librqbit/src/lib.rs @@ -41,6 +41,7 @@ mod read_buf; mod session; mod spawn_utils; pub mod storage; +mod stream_connect; mod torrent_state; pub mod tracing_subscriber_config_utils; mod type_aliases; diff --git a/crates/librqbit/src/peer_connection.rs b/crates/librqbit/src/peer_connection.rs index 65b4c303..3409a542 100644 --- a/crates/librqbit/src/peer_connection.rs +++ b/crates/librqbit/src/peer_connection.rs @@ -1,5 +1,6 @@ use std::{ net::SocketAddr, + sync::Arc, time::{Duration, Instant}, }; @@ -21,7 +22,7 @@ use serde_with::serde_as; use tokio::time::timeout; use tracing::{debug, trace}; -use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner}; +use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner, stream_connect::StreamConnector}; pub trait PeerConnectionHandler { fn on_connected(&self, _connection_time: Duration) {} @@ -65,6 +66,7 @@ pub(crate) struct PeerConnection { peer_id: Id20, options: PeerConnectionOptions, spawner: BlockingSpawner, + connector: Arc, } pub(crate) async fn with_timeout( @@ -88,6 +90,7 @@ impl PeerConnection { handler: H, options: Option, spawner: BlockingSpawner, + connector: Arc, ) -> Self { PeerConnection { handler, @@ -96,6 +99,7 @@ impl PeerConnection { peer_id, spawner, options: options.unwrap_or_default(), + connector, } } @@ -169,7 +173,8 @@ impl PeerConnection { .unwrap_or_else(|| Duration::from_secs(10)); let now = Instant::now(); - let mut conn = with_timeout(connect_timeout, tokio::net::TcpStream::connect(self.addr)) + let conn = self.connector.connect(self.addr); + let mut conn = with_timeout(connect_timeout, conn) .await .context("error connecting")?; self.handler.on_connected(now.elapsed()); @@ -218,7 +223,7 @@ impl PeerConnection { handshake_supports_extended: bool, mut read_buf: ReadBuf, mut write_buf: Vec, - mut conn: tokio::net::TcpStream, + mut conn: impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, mut have_broadcast: tokio::sync::broadcast::Receiver, ) -> anyhow::Result<()> { diff --git a/crates/librqbit/src/peer_info_reader/mod.rs b/crates/librqbit/src/peer_info_reader/mod.rs index 28e2232b..dc3e0b90 100644 --- a/crates/librqbit/src/peer_info_reader/mod.rs +++ b/crates/librqbit/src/peer_info_reader/mod.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use bencode::from_bytes; use buffers::{ByteBuf, ByteBufOwned}; @@ -22,6 +22,7 @@ use crate::{ PeerConnection, PeerConnectionHandler, PeerConnectionOptions, WriterRequest, }, spawn_utils::BlockingSpawner, + stream_connect::StreamConnector, }; pub(crate) async fn read_metainfo_from_peer( @@ -30,6 +31,7 @@ pub(crate) async fn read_metainfo_from_peer( info_hash: Id20, peer_connection_options: Option, spawner: BlockingSpawner, + connector: Arc, ) -> anyhow::Result> { let (result_tx, result_rx) = tokio::sync::oneshot::channel::>>(); @@ -48,6 +50,7 @@ pub(crate) async fn read_metainfo_from_peer( handler, peer_connection_options, spawner, + connector, ); let result_reader = async move { result_rx.await? }; @@ -234,6 +237,7 @@ impl PeerConnectionHandler for Handler { #[cfg(test)] mod tests { + use std::sync::Arc; use std::{net::SocketAddr, str::FromStr, sync::Once}; use librqbit_core::hash_id::Id20; @@ -260,10 +264,15 @@ mod tests { let addr = SocketAddr::from_str("127.0.0.1:27311").unwrap(); let peer_id = generate_peer_id(); let info_hash = Id20::from_str("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7").unwrap(); - dbg!( - read_metainfo_from_peer(addr, peer_id, info_hash, None, BlockingSpawner::new(true)) - .await - .unwrap() - ); + dbg!(read_metainfo_from_peer( + addr, + peer_id, + info_hash, + None, + BlockingSpawner::new(true), + Arc::new(Default::default()) + ) + .await + .unwrap()); } } diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index e61eda04..2470056d 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -19,6 +19,7 @@ use crate::{ storage::{ filesystem::FilesystemStorageFactory, BoxStorageFactory, StorageFactoryExt, TorrentStorage, }, + stream_connect::{SocksProxyConfig, StreamConnector}, torrent_state::{ ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive, }, @@ -197,6 +198,7 @@ pub struct Session { default_storage_factory: Option, reqwest_client: reqwest::Client, + connector: Arc, // This is stored for all tasks to stop when session is dropped. _cancellation_token_drop_guard: DropGuard, @@ -413,11 +415,6 @@ impl<'a> AddTorrent<'a> { } } -pub struct SocksProxyConfig { - // must start with socks5 - pub url: String, -} - #[derive(Default)] pub struct SessionOptions { /// Turn on to disable DHT. @@ -449,7 +446,8 @@ pub struct SessionOptions { pub default_storage_factory: Option, - pub socks_proxy: Option, + // socks5://[username:password@]host:port + pub socks_proxy_url: Option, } async fn create_tcp_listener( @@ -548,9 +546,27 @@ impl Session { }) .unwrap_or_default(); - let reqwest_client = reqwest::Client::builder() - .build() - .context("error building HTTP(S) client")?; + let proxy_config = match opts.socks_proxy_url.as_ref() { + Some(pu) => Some( + SocksProxyConfig::parse(pu) + .with_context(|| format!("error parsing proxy url {}", pu))?, + ), + None => None, + }; + + let reqwest_client = { + let builder = if let Some(proxy_url) = opts.socks_proxy_url.as_ref() { + let proxy = reqwest::Proxy::all(proxy_url) + .context("error creating socks5 proxy for HTTP")?; + reqwest::Client::builder().proxy(proxy) + } else { + reqwest::Client::builder() + }; + + builder.build().context("error building HTTP(S) client")? + }; + + let stream_connector = Arc::new(StreamConnector::from(proxy_config)); let session = Arc::new(Self { persistence_filename, @@ -566,6 +582,7 @@ impl Session { disk_write_tx, default_storage_factory: opts.default_storage_factory, reqwest_client, + connector: stream_connector, }); if let Some(mut disk_write_rx) = disk_write_rx { @@ -919,6 +936,7 @@ impl Session { opts.initial_peers.clone().unwrap_or_default(), peer_rx, Some(self.merge_peer_opts(opts.peer_opts)), + self.connector.clone(), ) .await { @@ -1088,6 +1106,7 @@ impl Session { .allow_overwrite(opts.overwrite) .spawner(self.spawner) .trackers(trackers) + .connector(self.connector.clone()) .peer_id(self.peer_id); if let Some(d) = self.disk_write_tx.clone() { diff --git a/crates/librqbit/src/stream_connect.rs b/crates/librqbit/src/stream_connect.rs new file mode 100644 index 00000000..11e91e0b --- /dev/null +++ b/crates/librqbit/src/stream_connect.rs @@ -0,0 +1,82 @@ +use std::net::SocketAddr; + +use anyhow::Context; + +#[derive(Debug, Clone)] +pub(crate) struct SocksProxyConfig { + pub host: String, + pub port: u16, + pub username_password: Option<(String, String)>, +} + +impl SocksProxyConfig { + pub fn parse(url: &str) -> anyhow::Result { + let url = ::url::Url::parse(url).context("invalid proxy URL")?; + if url.scheme() != "socks5" { + anyhow::bail!("proxy URL should have socks5 scheme"); + } + let host = url.host_str().context("missing host")?; + let port = url.port().context("missing port")?; + let up = url + .password() + .map(|p| (url.username().to_owned(), p.to_owned())); + Ok(Self { + host: host.to_owned(), + port, + username_password: up, + }) + } + + async fn connect( + &self, + addr: SocketAddr, + ) -> anyhow::Result { + let proxy_addr = (self.host.as_str(), self.port); + + if let Some((username, password)) = self.username_password.as_ref() { + tokio_socks::tcp::Socks5Stream::connect_with_password( + proxy_addr, + addr, + username.as_str(), + password.as_str(), + ) + .await + .context("error connecting to proxy") + } else { + tokio_socks::tcp::Socks5Stream::connect(proxy_addr, addr) + .await + .context("error connecting to proxy") + } + } +} + +#[derive(Debug, Default)] +pub(crate) struct StreamConnector { + proxy_config: Option, +} + +impl From> for StreamConnector { + fn from(proxy_config: Option) -> Self { + Self { proxy_config } + } +} + +pub(crate) trait AsyncReadWrite: + tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin +{ +} + +impl AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {} + +impl StreamConnector { + pub async fn connect(&self, addr: SocketAddr) -> anyhow::Result> { + if let Some(proxy) = self.proxy_config.as_ref() { + return Ok(Box::new(proxy.connect(addr).await?)); + } + Ok(Box::new( + tokio::net::TcpStream::connect(addr) + .await + .context("error connecting")?, + )) + } +} diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index 721109c4..bd235030 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -382,6 +382,7 @@ impl TorrentStateLive { &handler, Some(options), self.meta.spawner, + self.meta.connector.clone(), ); let requester = handler.task_peer_chunk_requester(); @@ -444,6 +445,7 @@ impl TorrentStateLive { &handler, Some(options), state.meta.spawner, + state.meta.connector.clone(), ); let requester = handler.task_peer_chunk_requester(); diff --git a/crates/librqbit/src/torrent_state/mod.rs b/crates/librqbit/src/torrent_state/mod.rs index 08a24358..6e544da1 100644 --- a/crates/librqbit/src/torrent_state/mod.rs +++ b/crates/librqbit/src/torrent_state/mod.rs @@ -37,6 +37,7 @@ use crate::chunk_tracker::ChunkTracker; use crate::file_info::FileInfo; use crate::spawn_utils::BlockingSpawner; use crate::storage::BoxStorageFactory; +use crate::stream_connect::StreamConnector; use crate::torrent_state::stats::LiveStats; use crate::type_aliases::DiskWorkQueueSender; use crate::type_aliases::FileInfos; @@ -106,6 +107,7 @@ pub struct ManagedTorrentInfo { pub file_infos: FileInfos, pub span: tracing::Span, pub(crate) options: ManagedTorrentOptions, + pub(crate) connector: Arc, } pub struct ManagedTorrent { @@ -509,6 +511,7 @@ pub(crate) struct ManagedTorrentBuilder { allow_overwrite: bool, storage_factory: BoxStorageFactory, disk_writer: Option, + connector: Arc, } impl ManagedTorrentBuilder { @@ -532,6 +535,7 @@ impl ManagedTorrentBuilder { output_folder, storage_factory, disk_writer: None, + connector: Arc::new(Default::default()), } } @@ -580,6 +584,11 @@ impl ManagedTorrentBuilder { self } + pub fn connector(&mut self, value: Arc) -> &mut Self { + self.connector = value; + self + } + pub fn build(self, span: tracing::Span) -> anyhow::Result { let lengths = Lengths::from_torrent(&self.info)?; let file_infos = self @@ -612,6 +621,7 @@ impl ManagedTorrentBuilder { output_folder: self.output_folder, disk_write_queue: self.disk_writer, }, + connector: self.connector, }); let initializing = Arc::new(TorrentStateInitializing::new( diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index 4046676f..782c370e 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -115,6 +115,13 @@ struct Opts { /// If you use it, you know what you are doing. #[arg(long)] experimental_mmap_storage: bool, + + /// Provide a socks5 URL. + /// The format is socks5://[username:password]@host:port + /// + /// Alternatively, set this as an environment variable RQBIT_SOCKS_PROXY_URL + #[arg(long)] + socks_url: Option, } #[derive(Parser)] @@ -281,6 +288,10 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> { Err(e) => warn!("failed increasing open file limit: {:#}", e), }; + let socks_url = opts + .socks_url + .or_else(|| std::env::var("RQBIT_SOCKS_PROXY_URL").ok()); + let mut sopts = SessionOptions { disable_dht: opts.disable_dht, disable_dht_persistence: opts.disable_dht_persistence, @@ -320,7 +331,7 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> { wrap(FilesystemStorageFactory::default()).boxed() } }), - socks_proxy: None, + socks_proxy_url: socks_url, }; let stats_printer = |session: Arc| async move {