Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of endpoint_send_filter #58

Merged
merged 2 commits into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions src/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ impl Server {
for endpoint in endpoints.iter() {
if let Err(err) = Server::ensure_session(
&log,
chain.clone(),
sessions.clone(),
recv_addr,
&endpoint,
Expand Down Expand Up @@ -249,6 +250,7 @@ impl Server {
/// ensure_session makes sure there is a value session for the name in the sessions map
async fn ensure_session(
log: &Logger,
chain: Arc<FilterChain>,
sessions: SessionMap,
from: SocketAddr,
dest: &EndPoint,
Expand All @@ -260,7 +262,7 @@ impl Server {
return Ok(());
}
}
let s = Session::new(log, from, dest.clone(), sender).await?;
let s = Session::new(log, chain, from, dest.clone(), sender).await?;
{
let mut map = sessions.write().await;
map.insert(s.key(), Mutex::new(s));
Expand Down Expand Up @@ -523,7 +525,11 @@ mod tests {
.await;

assert_eq!(
format!("hello:lrf:127.0.0.1:{}", result.addr.port()),
format!(
"hello:lrf:127.0.0.1:{}:esf:address-0:127.0.0.1:{}",
result.addr.port(),
result.addr.port()
),
result.msg
);
}
Expand Down Expand Up @@ -575,9 +581,16 @@ mod tests {
{
assert!(map.read().await.is_empty());
}
Server::ensure_session(&log, map.clone(), from, &endpoint, sender)
.await
.unwrap();
Server::ensure_session(
&log,
Arc::new(FilterChain::new(vec![])),
map.clone(),
from,
&endpoint,
sender,
)
.await
.unwrap();

let rmap = map.read().await;
let key = (from, dest);
Expand Down Expand Up @@ -653,9 +666,16 @@ mod tests {
connection_ids: vec![],
};

Server::ensure_session(&log, sessions.clone(), from, &endpoint, send)
.await
.unwrap();
Server::ensure_session(
&log,
Arc::new(FilterChain::new(vec![])),
sessions.clone(),
from,
&endpoint,
send,
)
.await
.unwrap();

let key = (from, to);
// gate, to ensure valid state
Expand Down Expand Up @@ -706,9 +726,16 @@ mod tests {
};

server.run_prune_sessions(&sessions);
Server::ensure_session(&log, sessions.clone(), from, &endpoint, send)
.await
.unwrap();
Server::ensure_session(
&log,
Arc::new(FilterChain::new(vec![])),
sessions.clone(),
from,
&endpoint,
send,
)
.await
.unwrap();

// session map should be the same since, we haven't passed expiry
time::advance(Duration::new(SESSION_TIMEOUT_SECONDS / 2, 0)).await;
Expand Down
78 changes: 67 additions & 11 deletions src/server/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ use tokio::sync::{mpsc, watch, RwLock};
use tokio::time::{Duration, Instant};

use crate::config::EndPoint;
use crate::extensions::{Filter, FilterChain};

/// SESSION_TIMEOUT_SECONDS is the default session timeout - which is one minute.
pub const SESSION_TIMEOUT_SECONDS: u64 = 60;

/// Session encapsulates a UDP stream session
pub struct Session {
log: Logger,
chain: Arc<FilterChain>,
send: SendHalf,
/// dest is where to send data to
dest: EndPoint,
Expand Down Expand Up @@ -75,6 +77,7 @@ impl Session {
/// from its ephemeral port from endpoint(s)
pub async fn new(
base: &Logger,
chain: Arc<FilterChain>,
from: SocketAddr,
dest: EndPoint,
sender: mpsc::Sender<Packet>,
Expand All @@ -84,6 +87,7 @@ impl Session {
let (closer, closed) = watch::channel::<bool>(false);
let mut s = Session {
log: base.new(o!("source" => "server::Session", "from" => from, "dest_name" => dest.name.clone(), "dest_address" => dest.address.clone())),
chain,
send,
from,
dest,
Expand Down Expand Up @@ -186,7 +190,15 @@ impl Session {
/// Sends a packet to the Session's dest.
pub async fn send_to(&mut self, buf: &[u8]) -> Result<usize> {
debug!(self.log, "Sending packet"; "dest_name" => &self.dest.name, "dest_address" => &self.dest.address, "contents" => from_utf8(buf).unwrap());
return self.send.send_to(buf, &self.dest.address).await;

if let Some(data) = self
.chain
.endpoint_send_filter(&self.dest, self.from, buf.to_vec())
{
return self.send.send_to(data.as_slice(), &self.dest.address).await;
}

Ok(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we change the return type for this function to e.g Result<Option<size>> or similar enum that would allow the caller to differentiate between a dropped packet and a packet without payload (that could have size 0?)? In the drop case the returned value would be e.g Ok(None) so that we won't have to worry about cases where size 0 might be special.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strong agree on this one, indeed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent suggestion. Will jump on it 👍

}

/// is_closed returns if the Session is closed or not.
Expand All @@ -208,7 +220,7 @@ mod tests {
use tokio::time;
use tokio::time::delay_for;

use crate::test_utils::{ephemeral_socket, logger, recv_udp};
use crate::test_utils::{ephemeral_socket, logger, recv_udp, TestFilter};

use super::*;

Expand All @@ -226,9 +238,15 @@ mod tests {
};
let (send_packet, mut recv_packet) = mpsc::channel::<Packet>(5);

let mut sess = Session::new(&log, local_addr, endpoint, send_packet)
.await
.unwrap();
let mut sess = Session::new(
&log,
Arc::new(FilterChain::new(vec![])),
local_addr,
endpoint,
send_packet,
)
.await
.unwrap();

let initial_expiration: Instant;
{
Expand Down Expand Up @@ -272,6 +290,8 @@ mod tests {
async fn session_send_to() {
let log = logger();
let msg = "hello";

// without a filter
let (sender, _) = mpsc::channel::<Packet>(1);
let (local_addr, wait) = recv_udp().await;
let endpoint = EndPoint {
Expand All @@ -280,11 +300,41 @@ mod tests {
connection_ids: vec![],
};

let mut session = Session::new(&log, local_addr, endpoint.clone(), sender)
.await
.unwrap();
let mut session = Session::new(
&log,
Arc::new(FilterChain::new(vec![])),
local_addr,
endpoint.clone(),
sender,
)
.await
.unwrap();
session.send_to(msg.as_bytes()).await.unwrap();
assert_eq!(msg, wait.await.unwrap());

// with a filters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'filters' should be 'filter' ?

let (sender, _) = mpsc::channel::<Packet>(1);
let (local_addr, wait) = recv_udp().await;
let endpoint = EndPoint {
name: "endpoint".to_string(),
address: local_addr,
connection_ids: vec![],
};
let mut session = Session::new(
&log,
Arc::new(FilterChain::new(vec![Arc::new(TestFilter {})])),
local_addr,
endpoint.clone(),
sender,
)
.await
.unwrap();

session.send_to(msg.as_bytes()).await.unwrap();
assert_eq!(
format!("{}:esf:{}:{}", msg, endpoint.name, local_addr),
wait.await.unwrap()
);
}

#[tokio::test]
Expand All @@ -300,9 +350,15 @@ mod tests {
};

info!(log, ">> creating sessions");
let sess = Session::new(&log, local_addr, endpoint, send_packet)
.await
.unwrap();
let sess = Session::new(
&log,
Arc::new(FilterChain::new(vec![])),
local_addr,
endpoint,
send_packet,
)
.await
.unwrap();
info!(log, ">> session created and running");

assert!(!sess.is_closed(), "session should not be closed");
Expand Down