Skip to content

Commit

Permalink
Pass context objects in filter APIs
Browse files Browse the repository at this point in the history
Since the objects are exposed to the filter API we would
like to be able to control how they're created so that we don't
break code whenever we e.g add new fields - for this we use
phantoms, we might also need to make the `Context::new()` constructors
private from the filter functions since a filter shouldn't need
to create them to begin with.

Fixes #94
  • Loading branch information
iffyio committed Sep 9, 2020
1 parent c23ceac commit 83d6909
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 179 deletions.
89 changes: 40 additions & 49 deletions src/extensions/filter_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
*/

use std::io::{Error, ErrorKind, Result};
use std::net::SocketAddr;
use std::sync::Arc;

use crate::config::{Config, EndPoint};
use crate::extensions::{CreateFilterArgs, Filter, FilterRegistry};
use crate::config::Config;
use crate::extensions::{
CreateFilterArgs, DownstreamContext, DownstreamResponse, Filter, FilterRegistry,
UpstreamContext, UpstreamResponse,
};
use prometheus::Registry;

/// FilterChain implements a chain of Filters amd the implementation
Expand Down Expand Up @@ -66,43 +68,32 @@ impl FilterChain {
}

impl Filter for FilterChain {
fn on_downstream_receive(
&self,
endpoints: &[EndPoint],
from: SocketAddr,
contents: Vec<u8>,
) -> Option<(Vec<EndPoint>, Vec<u8>)> {
let mut e = endpoints.to_vec();
let mut c = contents;
fn on_downstream_receive(&self, mut ctx: DownstreamContext) -> Option<DownstreamResponse> {
let from = ctx.from;
for f in &self.filters {
match f.on_downstream_receive(&e, from, c) {
match f.on_downstream_receive(ctx) {
None => return None,
Some((endpoints, contents)) => {
e = endpoints;
c = contents;
Some(response) => {
ctx = DownstreamContext::new(response.endpoints, from, response.contents)
}
}
}
Some((e, c))
Some(ctx.into())
}

fn on_upstream_receive(
&self,
endpoint: &EndPoint,
from: SocketAddr,
to: SocketAddr,
contents: Vec<u8>,
) -> Option<Vec<u8>> {
let mut c = contents;
fn on_upstream_receive(&self, mut ctx: UpstreamContext) -> Option<UpstreamResponse> {
let endpoint = ctx.endpoint;
let from = ctx.from;
let to = ctx.to;
for f in &self.filters {
match f.on_upstream_receive(&endpoint, from, to, c) {
match f.on_upstream_receive(ctx) {
None => return None,
Some(contents) => {
c = contents;
Some(response) => {
ctx = UpstreamContext::new(endpoint, from, to, response.contents);
}
}
}
Some(c)
Some(ctx.into())
}
}

Expand All @@ -111,7 +102,7 @@ mod tests {
use std::str::from_utf8;

use crate::config;
use crate::config::{ConnectionConfig, Local};
use crate::config::{ConnectionConfig, EndPoint, Local};
use crate::extensions::filters::DebugFilterFactory;
use crate::extensions::{default_registry, FilterFactory};
use crate::test_utils::{logger, noop_endpoint, TestFilter};
Expand Down Expand Up @@ -179,33 +170,33 @@ mod tests {

let endpoints_fixture = endpoints();

let (eps, content) = chain
.on_downstream_receive(
&endpoints_fixture,
let response = chain
.on_downstream_receive(DownstreamContext::new(
endpoints_fixture.clone(),
"127.0.0.1:70".parse().unwrap(),
"hello".as_bytes().to_vec(),
)
))
.unwrap();

let mut expected = endpoints_fixture.clone();
expected.push(noop_endpoint());
assert_eq!(expected, eps);
assert_eq!(expected, response.endpoints);
assert_eq!(
"hello:odr:127.0.0.1:70",
from_utf8(content.as_slice()).unwrap()
from_utf8(response.contents.as_slice()).unwrap()
);

let content = chain
.on_upstream_receive(
let response = chain
.on_upstream_receive(UpstreamContext::new(
&endpoints_fixture[0],
endpoints_fixture[0].address,
"127.0.0.1:70".parse().unwrap(),
"hello".as_bytes().to_vec(),
)
))
.unwrap();
assert_eq!(
"hello:our:one:127.0.0.1:80:127.0.0.1:70",
from_utf8(content.as_slice()).unwrap()
from_utf8(response.contents.as_slice()).unwrap()
);
}

Expand All @@ -215,34 +206,34 @@ mod tests {

let endpoints_fixture = endpoints();

let (eps, content) = chain
.on_downstream_receive(
&endpoints_fixture,
let response = chain
.on_downstream_receive(DownstreamContext::new(
endpoints_fixture.clone(),
"127.0.0.1:70".parse().unwrap(),
"hello".as_bytes().to_vec(),
)
))
.unwrap();

let mut expected = endpoints_fixture.clone();
expected.push(noop_endpoint());
expected.push(noop_endpoint());
assert_eq!(expected, eps);
assert_eq!(expected, response.endpoints);
assert_eq!(
"hello:odr:127.0.0.1:70:odr:127.0.0.1:70",
from_utf8(content.as_slice()).unwrap()
from_utf8(response.contents.as_slice()).unwrap()
);

let content = chain
.on_upstream_receive(
let response = chain
.on_upstream_receive(UpstreamContext::new(
&endpoints_fixture[0],
endpoints_fixture[0].address,
"127.0.0.1:70".parse().unwrap(),
"hello".as_bytes().to_vec(),
)
))
.unwrap();
assert_eq!(
"hello:our:one:127.0.0.1:80:127.0.0.1:70:our:one:127.0.0.1:80:127.0.0.1:70",
from_utf8(content.as_slice()).unwrap()
from_utf8(response.contents.as_slice()).unwrap()
);
}
}
161 changes: 126 additions & 35 deletions src/extensions/filter_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,137 @@ use std::fmt;
use std::net::SocketAddr;

use prometheus::{Error as MetricsError, Registry};
use serde::export::Formatter;

use crate::config::{ConnectionConfig, EndPoint};
use std::marker::PhantomData;

/// Contains the input arguments to [on_downstream_receive](crate::extensions::filter_registry::Filter::on_downstream_receive)
pub struct DownstreamContext {
/// The upstream endpoints that the packet will be forwarded to.
pub endpoints: Vec<EndPoint>,
/// The source of the received packet.
pub from: SocketAddr,
/// Contents of the received packet.
pub contents: Vec<u8>,
// Enforce using constructor to create this struct.
phantom: PhantomData<()>,
}

/// Contains the output of [on_downstream_receive](crate::extensions::filter_registry::Filter::on_downstream_receive)
///
/// New instances are created from a [`DownstreamContext`]
///
/// ```rust
/// # use quilkin::extensions::{DownstreamContext, DownstreamResponse};
/// fn on_downstream_receive(ctx: DownstreamContext) -> Option<DownstreamContext> {
/// Some(ctx.into())
/// }
/// ```
pub struct DownstreamResponse {
/// The upstream endpoints that the packet should be forwarded to.
pub endpoints: Vec<EndPoint>,
/// Contents of the packet to be forwarded.
pub contents: Vec<u8>,
// Enforce using constructor to create this struct.
phantom: PhantomData<()>,
}

/// Contains the input arguments to [on_upstream_receive](crate::extensions::filter_registry::Filter::on_upstream_receive)
pub struct UpstreamContext<'a> {
/// The upstream endpoint that we're expecting packets from.
pub endpoint: &'a EndPoint,
/// The source of the received packet.
pub from: SocketAddr,
/// The destination of the received packet.
pub to: SocketAddr,
/// Contents of the received packet.
pub contents: Vec<u8>,
// Enforce using constructor to create this struct.
phantom: PhantomData<()>,
}

/// Contains the output of [on_upstream_receive](crate::extensions::filter_registry::Filter::on_upstream_receive)
///
/// New instances are created from an [`UpstreamContext`]
///
/// ```rust
/// # use quilkin::extensions::{UpstreamContext, UpstreamResponse};
/// fn on_upstream_receive(ctx: UpstreamContext) -> Option<UpstreamContext> {
/// Some(ctx.into())
/// }
/// ```
pub struct UpstreamResponse {
/// Contents of the packet to be sent back to the original sender.
pub contents: Vec<u8>,
// Enforce using constructor to create this struct.
phantom: PhantomData<()>,
}

impl DownstreamContext {
/// Creates a new [`DownstreamContext`]
pub fn new(endpoints: Vec<EndPoint>, from: SocketAddr, contents: Vec<u8>) -> Self {
Self {
endpoints,
from,
contents,
phantom: PhantomData,
}
}
}

impl From<DownstreamContext> for DownstreamResponse {
fn from(ctx: DownstreamContext) -> Self {
Self {
endpoints: ctx.endpoints,
contents: ctx.contents,
phantom: ctx.phantom,
}
}
}

impl UpstreamContext<'_> {
/// Creates a new [`UpstreamContext`]
pub fn new(
endpoint: &EndPoint,
from: SocketAddr,
to: SocketAddr,
contents: Vec<u8>,
) -> UpstreamContext {
UpstreamContext {
endpoint,
from,
to,
contents,
phantom: PhantomData,
}
}
}

impl From<UpstreamContext<'_>> for UpstreamResponse {
fn from(ctx: UpstreamContext) -> Self {
Self {
contents: ctx.contents,
phantom: ctx.phantom,
}
}
}

/// Filter is a trait for routing and manipulating packets.
pub trait Filter: Send + Sync {
/// on_downstream_receive filters packets received from the local port, and potentially sends them
/// to configured endpoints.
/// This function should return the array of endpoints that the packet should be sent to,
/// and the packet that should be sent (which may be manipulated) as well.
/// This function should return a [`DownstreamResponse`] containing the array of
/// endpoints that the packet should be sent to and the packet that should be
/// sent (which may be manipulated) as well.
/// If the packet should be rejected, return None.
fn on_downstream_receive(
&self,
endpoints: &[EndPoint],
from: SocketAddr,
contents: Vec<u8>,
) -> Option<(Vec<EndPoint>, Vec<u8>)>;
fn on_downstream_receive(&self, ctx: DownstreamContext) -> Option<DownstreamResponse>;

/// on_upstream_receive filters packets received from `from`, to a given endpoint,
/// that are going back to the original sender.
/// This function should return the packet to be sent (which may be manipulated).
/// on_upstream_receive filters packets received upstream and destined
/// for a given endpoint, that are going back to the original sender.
/// This function should return an [`UpstreamResponse`] containing the packet to
/// be sent (which may be manipulated).
/// If the packet should be rejected, return None.
fn on_upstream_receive(
&self,
endpoint: &EndPoint,
from: SocketAddr,
to: SocketAddr,
contents: Vec<u8>,
) -> Option<Vec<u8>>;
fn on_upstream_receive(&self, ctx: UpstreamContext) -> Option<UpstreamResponse>;
}

#[derive(Debug, PartialEq)]
Expand All @@ -60,7 +162,7 @@ pub enum Error {
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::NotFound(key) => write!(f, "filter {} is not found", key),
Error::FieldInvalid { field, reason } => {
Expand Down Expand Up @@ -163,22 +265,11 @@ mod tests {
struct TestFilter {}

impl Filter for TestFilter {
fn on_downstream_receive(
&self,
_: &[EndPoint],
_: SocketAddr,
_: Vec<u8>,
) -> Option<(Vec<EndPoint>, Vec<u8>)> {
fn on_downstream_receive(&self, _: DownstreamContext) -> Option<DownstreamResponse> {
None
}

fn on_upstream_receive(
&self,
_: &EndPoint,
_: SocketAddr,
_: SocketAddr,
_: Vec<u8>,
) -> Option<Vec<u8>> {
fn on_upstream_receive(&self, _: UpstreamContext) -> Option<UpstreamResponse> {
None
}
}
Expand Down Expand Up @@ -219,10 +310,10 @@ mod tests {
};

assert!(filter
.on_downstream_receive(&vec![], addr, vec![])
.on_downstream_receive(DownstreamContext::new(vec![], addr, vec![]))
.is_some());
assert!(filter
.on_upstream_receive(&endpoint, addr, addr, vec![])
.on_upstream_receive(UpstreamContext::new(&endpoint, addr, addr, vec![]))
.is_some());
}
}
Loading

0 comments on commit 83d6909

Please sign in to comment.