Skip to content

Commit

Permalink
Updates based on review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
markmandel committed Sep 17, 2020
1 parent 75e232c commit 7bb5c5a
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 80 deletions.
30 changes: 7 additions & 23 deletions src/extensions/filter_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
* limitations under the License.
*/

use std::fmt::{self, Formatter};
use std::sync::Arc;

use prometheus::Registry;

use crate::config::{Config, ValidationError};
use crate::extensions::{
CreateFilterArgs, DownstreamContext, DownstreamResponse, Filter, FilterRegistry,
UpstreamContext, UpstreamResponse,
};
use prometheus::Registry;
use std::fmt::{self, Formatter};
use std::sync::Arc;

/// FilterChain implements a chain of Filters amd the implementation
/// of passing the information between Filters for each filter function
Expand Down Expand Up @@ -89,14 +91,7 @@ impl Filter for FilterChain {
for f in &self.filters {
match f.on_downstream_receive(ctx) {
None => return None,
Some(response) => {
ctx = DownstreamContext::new(
response.endpoints,
from,
response.contents,
response.values,
)
}
Some(response) => ctx = DownstreamContext::with_response(from, response),
}
}
Some(ctx.into())
Expand All @@ -110,13 +105,7 @@ impl Filter for FilterChain {
match f.on_upstream_receive(ctx) {
None => return None,
Some(response) => {
ctx = UpstreamContext::new(
endpoint,
from,
to,
response.contents,
response.values,
);
ctx = UpstreamContext::with_response(endpoint, from, to, response);
}
}
}
Expand All @@ -135,7 +124,6 @@ mod tests {
use crate::test_utils::{logger, noop_endpoint, TestFilter};

use super::*;
use std::collections::HashMap;

#[test]
fn from_config() {
Expand Down Expand Up @@ -203,7 +191,6 @@ mod tests {
endpoints_fixture.clone(),
"127.0.0.1:70".parse().unwrap(),
"hello".as_bytes().to_vec(),
HashMap::new(),
))
.unwrap();

Expand All @@ -227,7 +214,6 @@ mod tests {
endpoints_fixture[0].address,
"127.0.0.1:70".parse().unwrap(),
"hello".as_bytes().to_vec(),
HashMap::new(),
))
.unwrap();

Expand All @@ -254,7 +240,6 @@ mod tests {
endpoints_fixture.clone(),
"127.0.0.1:70".parse().unwrap(),
"hello".as_bytes().to_vec(),
HashMap::new(),
))
.unwrap();

Expand All @@ -279,7 +264,6 @@ mod tests {
endpoints_fixture[0].address,
"127.0.0.1:70".parse().unwrap(),
"hello".as_bytes().to_vec(),
HashMap::new(),
))
.unwrap();
assert_eq!(
Expand Down
50 changes: 33 additions & 17 deletions src/extensions/filter_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,23 @@ pub struct UpstreamResponse {

impl DownstreamContext {
/// Creates a new [`DownstreamContext`]
pub fn new(
endpoints: Vec<EndPoint>,
from: SocketAddr,
contents: Vec<u8>,
values: HashMap<String, Box<dyn Any + Send>>,
) -> Self {
pub fn new(endpoints: Vec<EndPoint>, from: SocketAddr, contents: Vec<u8>) -> Self {
Self {
endpoints,
from,
contents,
values,
values: HashMap::new(),
phantom: PhantomData,
}
}

/// Creates a new [`DownstreamContext`] from a [`DownstreamResponse`]
pub fn with_response(from: SocketAddr, response: DownstreamResponse) -> Self {
Self {
endpoints: response.endpoints,
from,
contents: response.contents,
values: response.values,
phantom: PhantomData,
}
}
Expand All @@ -130,14 +136,30 @@ impl UpstreamContext<'_> {
from: SocketAddr,
to: SocketAddr,
contents: Vec<u8>,
values: HashMap<String, Box<dyn Any + Send>>,
) -> UpstreamContext {
UpstreamContext {
endpoint,
from,
to,
contents,
values,
values: HashMap::new(),
phantom: PhantomData,
}
}

/// Creates a new [`UpstreamContext`] from a [`UpstreamResponse`]
pub fn with_response(
endpoint: &EndPoint,
from: SocketAddr,
to: SocketAddr,
response: UpstreamResponse,
) -> UpstreamContext {
UpstreamContext {
endpoint,
from,
to,
contents: response.contents,
values: response.values,
phantom: PhantomData,
}
}
Expand Down Expand Up @@ -335,16 +357,10 @@ mod tests {
};

assert!(filter
.on_downstream_receive(DownstreamContext::new(vec![], addr, vec![], HashMap::new()))
.on_downstream_receive(DownstreamContext::new(vec![], addr, vec![]))
.is_some());
assert!(filter
.on_upstream_receive(UpstreamContext::new(
&endpoint,
addr,
addr,
vec![],
HashMap::new()
))
.on_upstream_receive(UpstreamContext::new(&endpoint, addr, addr, vec![],))
.is_some());
}
}
3 changes: 0 additions & 3 deletions src/extensions/filters/local_rate_limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ impl Filter for RateLimitFilter {

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::time::Duration;

use prometheus::Registry;
Expand Down Expand Up @@ -288,7 +287,6 @@ mod tests {
vec![],
"127.0.0.1:8080".parse().unwrap(),
vec![9],
HashMap::new(),
))
.is_none(),);
}
Expand All @@ -305,7 +303,6 @@ mod tests {
vec![],
"127.0.0.1:8080".parse().unwrap(),
vec![9],
HashMap::new(),
))
.unwrap();
assert_eq!((result.endpoints, result.contents), (vec![], vec![9]));
Expand Down
1 change: 0 additions & 1 deletion src/proxy/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ impl Server {
lb_policy.choose_endpoints(),
recv_addr,
packet.to_vec(),
HashMap::new(),
));

if let Some(response) = result {
Expand Down
14 changes: 5 additions & 9 deletions src/proxy/sessions/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ use tokio::select;
use tokio::sync::{mpsc, watch, RwLock};
use tokio::time::{Duration, Instant};

use super::metrics::Metrics;
use crate::config::EndPoint;
use crate::extensions::{Filter, FilterChain, UpstreamContext};
use std::collections::HashMap;

use super::metrics::Metrics;

/// SESSION_TIMEOUT_SECONDS is the default session timeout - which is one minute.
pub const SESSION_TIMEOUT_SECONDS: u64 = 60;
Expand Down Expand Up @@ -211,13 +211,9 @@ impl Session {
debug!(log, "Received packet"; "from" => from, "endpoint_name" => &endpoint.name, "endpoint_addr" => &endpoint.address, "contents" => from_utf8(packet).unwrap());
Session::inc_expiration(expiration).await;

if let Some(response) = chain.on_upstream_receive(UpstreamContext::new(
endpoint,
from,
to,
packet.to_vec(),
HashMap::new(),
)) {
if let Some(response) =
chain.on_upstream_receive(UpstreamContext::new(endpoint, from, to, packet.to_vec()))
{
if let Err(err) = sender.send(Packet::new(to, response.contents)).await {
metrics.errors_total.inc();
error!(log, "Error sending packet to channel"; "error" => %err);
Expand Down
36 changes: 9 additions & 27 deletions src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* limitations under the License.
*/

use std::collections::HashMap;
/// Common utilities for testing
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::str::from_utf8;
Expand Down Expand Up @@ -64,18 +63,10 @@ impl Filter for TestFilter {
ctx.endpoints.push(noop_endpoint());

// append values on each run
let key = "downstream";
match ctx.values.get(key) {
None => {
ctx.values
.insert(key.into(), Box::new("receive".to_string()));
}
Some(value) => {
let mut value = value.downcast_ref::<String>().unwrap().clone();
value.push_str(":receive");
ctx.values.insert(key.into(), Box::new(value));
}
}
ctx.values
.entry("downstream".into())
.and_modify(|e| e.downcast_mut::<String>().unwrap().push_str(":receive"))
.or_insert_with(|| Box::new("receive".to_string()));

ctx.contents
.append(&mut format!(":odr:{}", ctx.from).into_bytes());
Expand All @@ -84,18 +75,11 @@ impl Filter for TestFilter {

fn on_upstream_receive(&self, mut ctx: UpstreamContext) -> Option<UpstreamResponse> {
// append values on each run
let key = "upstream";
match ctx.values.get(key) {
None => {
ctx.values
.insert(key.into(), Box::new("receive".to_string()));
}
Some(value) => {
let mut value = value.downcast_ref::<String>().unwrap().clone();
value.push_str(":receive");
ctx.values.insert(key.into(), Box::new(value));
}
}
ctx.values
.entry("upstream".into())
.and_modify(|e| e.downcast_mut::<String>().unwrap().push_str(":receive"))
.or_insert_with(|| Box::new("receive".to_string()));

ctx.contents.append(
&mut format!(":our:{}:{}:{}", ctx.endpoint.name, ctx.from, ctx.to).into_bytes(),
);
Expand Down Expand Up @@ -341,7 +325,6 @@ where
endpoints.clone(),
from,
contents.clone(),
HashMap::new(),
)) {
None => unreachable!("should return a result"),
Some(response) => {
Expand All @@ -368,7 +351,6 @@ where
endpoint.address,
"127.0.0.1:70".parse().unwrap(),
contents.clone(),
HashMap::new(),
)) {
None => unreachable!("should return a result"),
Some(response) => assert_eq!(contents, response.contents),
Expand Down

0 comments on commit 7bb5c5a

Please sign in to comment.