Skip to content

Feat/tls #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 92 additions & 73 deletions pkarr/src/extra/endpoints/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,103 +3,115 @@ use crate::{
rdata::{RData, SVCB},
ResourceRecord,
},
SignedPacket,
PublicKey, SignedPacket,
};
use std::{
collections::HashSet,
net::{IpAddr, SocketAddr, ToSocketAddrs},
};
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};

use pubky_timestamp::Timestamp;
use rand::{seq::SliceRandom, thread_rng};

#[derive(Debug, Clone)]
/// An alternative Endpoint for a `qname`, from either [RData::SVCB] or [RData::HTTPS] dns records
pub struct Endpoint {
target: String,
// public_key: PublicKey,
port: Option<u16>,
public_key: PublicKey,
port: u16,
/// SocketAddrs from the [SignedPacket]
addrs: Vec<IpAddr>,
}

impl Endpoint {
/// 1. Find the SVCB or HTTPS records with the lowest priority
/// 2. Choose a random one of the list of the above
/// 3. If the target is `.`, check A and AAAA records see [rfc9460](https://www.rfc-editor.org/rfc/rfc9460#name-special-handling-of-in-targ)
pub(crate) fn find(
/// Returns a stack of endpoints from a SignedPacket
///
/// 1. Find the SVCB or HTTPS records
/// 2. Sort them by priority (reverse)
/// 3. Shuffle records within each priority
/// 3. If the target is `.`, keep track of A and AAAA records see [rfc9460](https://www.rfc-editor.org/rfc/rfc9460#name-special-handling-of-in-targ)
pub(crate) fn parse(
signed_packet: &SignedPacket,
target: &str,
is_svcb: bool,
) -> Option<Endpoint> {
let mut lowest_priority = u16::MAX;
let mut lowest_priority_index = 0;
let mut records = vec![];

for record in signed_packet.resource_records(target) {
if let Some(svcb) = get_svcb(record, is_svcb) {
match svcb.priority.cmp(&lowest_priority) {
std::cmp::Ordering::Equal => records.push(svcb),
std::cmp::Ordering::Less => {
lowest_priority_index = records.len();
lowest_priority = svcb.priority;
records.push(svcb)
}
_ => {}
) -> Vec<Endpoint> {
let mut records = signed_packet
.resource_records(target)
.filter_map(|record| get_svcb(record, is_svcb))
.collect::<Vec<_>>();

// TODO: support wildcard?

// Shuffle the vector first
let mut rng = thread_rng();
records.shuffle(&mut rng);
// Sort by priority
records.sort_by(|a, b| b.priority.cmp(&a.priority));

let mut addrs = HashSet::new();
for record in signed_packet.resource_records("@") {
match &record.rdata {
RData::A(ip) => {
addrs.insert(IpAddr::V4(ip.address.into()));
}
RData::AAAA(ip) => {
addrs.insert(IpAddr::V6(ip.address.into()));
}
_ => {}
}
}
let addrs = addrs.into_iter().collect::<Vec<_>>();

// Good enough random selection
let now = Timestamp::now().as_u64();
let slice = &records[lowest_priority_index..];
let index = if slice.is_empty() {
0
} else {
(now as usize) % slice.len()
};
records
.into_iter()
.map(|s| {
let target = s.target.to_string();

slice.get(index).map(|s| {
let target = s.target.to_string();
let port = s
.get_param(SVCB::PORT)
.map(|bytes| {
let mut arr = [0_u8; 2];
arr[0] = bytes[0];
arr[1] = bytes[1];

let mut addrs: Vec<IpAddr> = vec![];
u16::from_be_bytes(arr)
})
.unwrap_or_default();

if &target == "." {
for record in signed_packet.resource_records("@") {
match &record.rdata {
RData::A(ip) => addrs.push(IpAddr::V4(ip.address.into())),
RData::AAAA(ip) => addrs.push(IpAddr::V6(ip.address.into())),
_ => {}
}
Endpoint {
target,
port,
public_key: signed_packet.public_key(),
addrs: if s.target.to_string() == "." {
addrs.clone()
} else {
Vec::with_capacity(0)
},
}
}

Endpoint {
target,
// public_key: signed_packet.public_key(),
port: s.get_param(SVCB::PORT).map(|bytes| {
let mut arr = [0_u8; 2];
arr[0] = bytes[0];
arr[1] = bytes[1];

u16::from_be_bytes(arr)
}),
addrs,
}
})
})
.collect::<Vec<_>>()
}

/// Return the endpoint target, i.e the domain it points to
/// "." means this endpoint points to its own [Endpoint::public_key]
pub fn target(&self) -> &str {
/// Returns the [SVCB] record's `target` value.
///
/// Useful in web browsers where we can't use [Self::to_socket_addrs]
pub fn domain(&self) -> &str {
&self.target
}

pub fn port(&self) -> Option<u16> {
self.port
/// Return the [PublicKey] of the [SignedPacket] this endpoint was found at.
///
/// This is useful as the [PublicKey] of the endpoint (server), and could be
/// used for TLS.
pub fn public_key(&self) -> &PublicKey {
&self.public_key
}

/// Return an iterator of [SocketAddr], either by resolving the [Endpoint::target] using normal DNS,
/// Return an iterator of [SocketAddr], either by resolving the [Endpoint::domain] using normal DNS,
/// or, if the target is ".", return the [RData::A] or [RData::AAAA] records
/// from the endpoint's [SignedPacket], if available.
pub fn to_socket_addrs(&self) -> Vec<SocketAddr> {
if self.target == "." {
let port = self.port.unwrap_or(0);
let port = self.port;

return self
.addrs
Expand All @@ -111,7 +123,7 @@ impl Endpoint {
if cfg!(target_arch = "wasm32") {
vec![]
} else {
format!("{}:{}", self.target, self.port.unwrap_or(0))
format!("{}:{}", self.target, self.port)
.to_socket_addrs()
.map_or(vec![], |v| v.collect::<Vec<_>>())
}
Expand Down Expand Up @@ -149,7 +161,7 @@ mod tests {
use crate::{dns, Keypair};

#[tokio::test]
async fn endpoint_target() {
async fn endpoint_domain() {
let mut packet = dns::Packet::new_reply(0);
packet.answers.push(dns::ResourceRecord::new(
dns::Name::new("foo").unwrap(),
Expand Down Expand Up @@ -183,12 +195,16 @@ mod tests {
let tld = keypair.public_key();

// Follow foo.tld HTTPS records
let endpoint = Endpoint::find(&signed_packet, &format!("foo.{tld}"), false).unwrap();
assert_eq!(endpoint.target, "https.example.com");
let endpoint = Endpoint::parse(&signed_packet, &format!("foo.{tld}"), false)
.pop()
.unwrap();
assert_eq!(endpoint.domain(), "https.example.com");

// Follow _foo.tld SVCB records
let endpoint = Endpoint::find(&signed_packet, &format!("_foo.{tld}"), true).unwrap();
assert_eq!(endpoint.target, "protocol.example.com");
let endpoint = Endpoint::parse(&signed_packet, &format!("_foo.{tld}"), true)
.pop()
.unwrap();
assert_eq!(endpoint.domain(), "protocol.example.com");
}

#[test]
Expand Down Expand Up @@ -220,16 +236,19 @@ mod tests {
let signed_packet = SignedPacket::from_packet(&keypair, &packet).unwrap();

// Follow foo.tld HTTPS records
let endpoint = Endpoint::find(
let endpoint = Endpoint::parse(
&signed_packet,
&signed_packet.public_key().to_string(),
false,
)
.pop()
.unwrap();

assert_eq!(endpoint.target, ".");
assert_eq!(endpoint.domain(), ".");

let mut addrs = endpoint.to_socket_addrs();
addrs.sort();

let addrs = endpoint.to_socket_addrs();
assert_eq!(
addrs.into_iter().map(|s| s.to_string()).collect::<Vec<_>>(),
vec!["209.151.148.15:6881", "[2a05:d014:275:6201::64]:6881"]
Expand Down
Loading