Skip to content

Commit f310ad8

Browse files
committed
Set up DNS cache
1 parent 51efaa0 commit f310ad8

File tree

7 files changed

+129
-55
lines changed

7 files changed

+129
-55
lines changed

src/client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ pub struct Client<S, T> {
5959
/// to connect and cancel a query.
6060
client_server_map: ClientServerMap,
6161

62-
/// Client parameters, e.g. user, client_encoding, etc.
62+
/// Client parameters, e.g. user, client_encoding, etc1
6363
#[allow(dead_code)]
6464
parameters: HashMap<String, String>,
6565

6666
/// Statistics
6767
stats: Reporter,
6868

69-
/// Clients want to talk to admin database.
69+
/// Clients want to talk to admin database0
7070
admin: bool,
7171

7272
/// Last address the client talked to.

src/config.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ impl General {
208208
}
209209

210210
pub fn default_dns_max_ttl() -> u64 {
211-
30
211+
30
212212
}
213213

214214
pub fn default_healthcheck_timeout() -> u64 {
@@ -238,8 +238,8 @@ impl Default for General {
238238
ban_time: Self::default_ban_time(),
239239
log_client_connections: false,
240240
log_client_disconnections: false,
241-
dns_cache_enabled: false,
242-
dns_max_ttl: Self::default_dns_max_ttl(),
241+
dns_cache_enabled: false,
242+
dns_max_ttl: Self::default_dns_max_ttl(),
243243
autoreload: false,
244244
tls_certificate: None,
245245
tls_private_key: None,

src/dns_cache.rs

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use log::{debug, error};
22
use std::collections::{HashMap, HashSet};
3+
use std::io;
34
use std::net::IpAddr;
45
use std::sync::Arc;
56
use std::sync::RwLock;
6-
use std::io;
77
use tokio::time::{sleep, Duration};
88
use trust_dns_resolver::lookup_ip::LookupIp;
99
use trust_dns_resolver::TokioAsyncResolver;
@@ -55,7 +55,7 @@ impl From<LookupIp> for AddrSet {
5555
// // calls will be returned from cache.
5656
// resolver.has_changed("www.example.com.", addrset)
5757
// ```
58-
#[derive(Clone)]
58+
#[derive(Clone, Debug)]
5959
pub struct CachedResolver {
6060
// The configuration of the cached_resolver.
6161
pub config: CachedResolverConfig,
@@ -91,7 +91,7 @@ impl CachedResolver {
9191
//
9292
pub fn new(config: CachedResolverConfig) -> io::Result<Self> {
9393
// Construct a new Resolver with default configuration options
94-
let resolver = Arc::new(TokioAsyncResolver::tokio_from_system_conf()?);
94+
let resolver = Arc::new(TokioAsyncResolver::tokio_from_system_conf()?);
9595
let data = Arc::new(RwLock::new(HashMap::new()));
9696

9797
Ok(Self {
@@ -103,46 +103,53 @@ impl CachedResolver {
103103

104104
// Schedules the refresher
105105
pub async fn refresh_dns_entries_loop(&mut self) {
106-
let data = self.data.clone();
107-
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
106+
let data = self.data.clone();
107+
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
108108
let interval = Duration::from_secs(self.config.dns_max_ttl);
109109
loop {
110-
debug!("Begin refreshing cached DNS addresses.");
111-
// To minimize the time we hold the lock, we first create
112-
// an array with keys.
113-
let mut hostnames: Vec<String> = Vec::new();
114-
{
110+
debug!("Begin refreshing cached DNS addresses.");
111+
// To minimize the time we hold the lock, we first create
112+
// an array with keys.
113+
let mut hostnames: Vec<String> = Vec::new();
114+
{
115115
for hostname in data.read().unwrap().keys() {
116-
hostnames.push(hostname.clone());
116+
hostnames.push(hostname.clone());
117117
}
118-
}
118+
}
119119

120-
for hostname in hostnames.iter() {
120+
for hostname in hostnames.iter() {
121121
match CachedResolver::fetch_from_data_cache(data.clone(), hostname.as_str()) {
122-
Some(addrset) => {
123-
match resolver.lookup_ip(hostname).await {
124-
Ok(lookup_ip) => {
125-
let new_addrset = AddrSet::from(lookup_ip);
126-
debug!("Obtained address for host ({}) -> ({:?})", hostname, new_addrset);
122+
Some(addrset) => match resolver.lookup_ip(hostname).await {
123+
Ok(lookup_ip) => {
124+
let new_addrset = AddrSet::from(lookup_ip);
125+
debug!(
126+
"Obtained address for host ({}) -> ({:?})",
127+
hostname, new_addrset
128+
);
127129

128-
if addrset != new_addrset {
129-
debug!("Addr changed from {:?} to {:?} updating cache.", addrset, new_addrset);
130-
CachedResolver::store_in_cache(data.clone(), hostname, new_addrset);
131-
}
132-
},
133-
Err(err) => {
134-
error!("There was an error trying to resolv {}: ({}).", hostname, err);
135-
}
130+
if addrset != new_addrset {
131+
debug!(
132+
"Addr changed from {:?} to {:?} updating cache.",
133+
addrset, new_addrset
134+
);
135+
CachedResolver::store_in_cache(data.clone(), hostname, new_addrset);
136+
}
137+
}
138+
Err(err) => {
139+
error!(
140+
"There was an error trying to resolv {}: ({}).",
141+
hostname, err
142+
);
136143
}
137-
}
138-
None => {
144+
},
145+
None => {
139146
error!("Could not obtain expected address from cache, this should not happen, \
140147
as this cache does not allow deleting addresses.");
141-
}
148+
}
142149
}
143-
}
144-
debug!("Finished refreshing cached DNS addresses.");
145-
sleep(interval).await;
150+
}
151+
debug!("Finished refreshing cached DNS addresses.");
152+
sleep(interval).await;
146153
}
147154
}
148155

@@ -166,16 +173,16 @@ impl CachedResolver {
166173
// ```
167174
//
168175
pub async fn lookup_ip(&mut self, host: &str) -> ResolveResult<AddrSet> {
169-
debug!("Lookup up {} in cache", host);
176+
debug!("Lookup up {} in cache", host);
170177
match self.fetch_from_cache(host) {
171178
Some(addr_set) => {
172-
debug!("Cache hit!");
173-
Ok(addr_set)
174-
},
179+
debug!("Cache hit!");
180+
Ok(addr_set)
181+
}
175182
None => {
176-
debug!("Not found, executing a dns query!");
183+
debug!("Not found, executing a dns query!");
177184
let addr_set = AddrSet::from(self.resolver.clone().lookup_ip(host).await?);
178-
debug!("Obtained: {:?}", addr_set);
185+
debug!("Obtained: {:?}", addr_set);
179186
CachedResolver::store_in_cache(self.data.clone(), host, addr_set.clone());
180187
Ok(addr_set)
181188
}
@@ -247,7 +254,7 @@ mod tests {
247254
let mut resolver = CachedResolver::new(config).unwrap();
248255
let hostname = "www.idontexists.";
249256
let response = resolver.lookup_ip(hostname).await;
250-
assert!(matches!(response, Err(ResolveError{ .. })));
257+
assert!(matches!(response, Err(ResolveError { .. })));
251258
}
252259

253260
#[tokio::test]
@@ -256,26 +263,26 @@ mod tests {
256263
let mut resolver = CachedResolver::new(config).unwrap();
257264
let hostname = "w ww.idontexists.";
258265
let response = resolver.lookup_ip(hostname).await;
259-
assert!(matches!(response, Err(ResolveError{ .. })));
260-
assert!(!resolver.has_changed(hostname, &AddrSet::new()));
266+
assert!(matches!(response, Err(ResolveError { .. })));
267+
assert!(!resolver.has_changed(hostname, &AddrSet::new()));
261268
}
262269

263270
#[tokio::test]
264271
// Ok, this test is based on the fact that google does DNS RR
265272
// and does not responds with every available ip everytime, so
266273
// if I cache here, it will miss after one cache iteration or two.
267274
async fn thread() {
268-
env_logger::init();
275+
env_logger::init();
269276
let config = CachedResolverConfig { dns_max_ttl: 10 };
270277
let mut resolver = CachedResolver::new(config).unwrap();
271278
let hostname = "www.google.com.";
272279
let response = resolver.lookup_ip(hostname).await;
273280
let addr_set = response.unwrap();
274281
assert!(!resolver.has_changed(hostname, &addr_set));
275-
let mut resolver_for_refresher = resolver.clone();
276-
let _thread_handle = tokio::task::spawn(async move {
277-
resolver_for_refresher.refresh_dns_entries_loop().await;
278-
});
279-
assert!(!resolver.has_changed(hostname, &addr_set));
282+
let mut resolver_for_refresher = resolver.clone();
283+
let _thread_handle = tokio::task::spawn(async move {
284+
resolver_for_refresher.refresh_dns_entries_loop().await;
285+
});
286+
assert!(!resolver.has_changed(hostname, &addr_set));
280287
}
281288
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pub mod config;
22
pub mod constants;
3+
pub mod dns_cache;
34
pub mod errors;
45
pub mod messages;
56
pub mod pool;

src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ mod admin;
6464
mod client;
6565
mod config;
6666
mod constants;
67+
mod dns_cache;
6768
mod errors;
6869
mod messages;
6970
mod pool;
@@ -74,7 +75,6 @@ mod server;
7475
mod sharding;
7576
mod stats;
7677
mod tls;
77-
mod dns_cache;
7878

7979
use crate::config::{get_config, reload_config, VERSION};
8080
use crate::pool::{ClientServerMap, ConnectionPool};

src/pool.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use std::time::Instant;
1515
use crate::config::{get_config, Address, General, PoolMode, Role, User};
1616
use crate::errors::Error;
1717

18+
use crate::dns_cache::{CachedResolver, CachedResolverConfig};
1819
use crate::server::Server;
1920
use crate::sharding::ShardingFunction;
2021
use crate::stats::{get_reporter, Reporter};
@@ -132,13 +133,40 @@ pub struct ConnectionPool {
132133

133134
/// Pool configuration.
134135
pub settings: PoolSettings,
136+
137+
/// CachedResolver to use when dns_cache is enabled
138+
pub cached_resolver: Arc<Option<CachedResolver>>,
135139
}
136140

137141
impl ConnectionPool {
138142
/// Construct the connection pool from the configuration.
139143
pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> {
140144
let config = get_config();
141145

146+
// Configure dns_cache if enabled
147+
let mut cached_resolver: Arc<Option<CachedResolver>> = Arc::new(None);
148+
if config.general.dns_cache_enabled {
149+
info!("Starting Dns cache");
150+
let cached_resolver_config = CachedResolverConfig {
151+
dns_max_ttl: config.general.dns_max_ttl,
152+
};
153+
cached_resolver = match CachedResolver::new(cached_resolver_config) {
154+
Ok(ok) => {
155+
let mut refresher_cached_resolver = ok.clone();
156+
info!("Scheduling DNS refresh loop");
157+
tokio::task::spawn(async move {
158+
refresher_cached_resolver.refresh_dns_entries_loop().await;
159+
});
160+
Arc::new(Some(ok))
161+
}
162+
Err(err) => {
163+
error!("Error Starting cached_resolver error: {:?}", err);
164+
error!("Will continue without this feature");
165+
Arc::new(None)
166+
}
167+
};
168+
};
169+
142170
let mut new_pools = HashMap::new();
143171
let mut address_id = 0;
144172

@@ -246,6 +274,7 @@ impl ConnectionPool {
246274
let mut pool = ConnectionPool {
247275
databases: shards,
248276
addresses,
277+
cached_resolver: cached_resolver.clone(),
249278
banlist: Arc::new(RwLock::new(banlist)),
250279
stats: get_reporter(),
251280
server_info: BytesMut::new(),
@@ -596,6 +625,12 @@ impl ManageConnection for ServerPool {
596625
/// Attempts to create a new connection.
597626
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
598627
info!("Creating a new server connection {:?}", self.address);
628+
629+
let cached_resolver = match get_pool(&self.database, &self.user.username) {
630+
Some(pool) => pool.cached_resolver,
631+
None => Arc::new(None),
632+
};
633+
599634
let server_id = rand::random::<i32>();
600635

601636
self.stats.server_register(
@@ -614,6 +649,7 @@ impl ManageConnection for ServerPool {
614649
&self.user,
615650
&self.database,
616651
self.client_server_map.clone(),
652+
cached_resolver,
617653
self.stats.clone(),
618654
)
619655
.await

src/server.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
/// Implementation of the PostgreSQL server (database) protocol.
2-
/// Here we are pretending to the a Postgres client.
31
use bytes::{Buf, BufMut, BytesMut};
42
use log::{debug, error, info, trace, warn};
53
use std::io::Read;
4+
/// Implementation of the PostgreSQL server (database) protocol.
5+
/// Here we are pretending to the a Postgres client.
6+
use std::sync::Arc;
67
use std::time::SystemTime;
78
use tokio::io::{AsyncReadExt, BufReader};
89
use tokio::net::{
@@ -12,6 +13,7 @@ use tokio::net::{
1213

1314
use crate::config::{Address, User};
1415
use crate::constants::*;
16+
use crate::dns_cache::{AddrSet, CachedResolver};
1517
use crate::errors::Error;
1618
use crate::messages::*;
1719
use crate::pool::ClientServerMap;
@@ -68,6 +70,12 @@ pub struct Server {
6870

6971
// Last time that a successful server send or response happened
7072
last_activity: SystemTime,
73+
74+
// Cached Resolver to be used to resolve addresses with hostnames
75+
cached_resolver: Arc<Option<CachedResolver>>,
76+
77+
// Associated addresses used
78+
addr_set: Option<AddrSet>,
7179
}
7280

7381
impl Server {
@@ -79,8 +87,22 @@ impl Server {
7987
user: &User,
8088
database: &str,
8189
client_server_map: ClientServerMap,
90+
cached_resolver: Arc<Option<CachedResolver>>,
8291
stats: Reporter,
8392
) -> Result<Server, Error> {
93+
let addr_set = match *cached_resolver.clone() {
94+
Some(ref cached_resolver) => {
95+
match cached_resolver.clone().lookup_ip(&address.host).await {
96+
Ok(ok) => Some(ok),
97+
Err(err) => {
98+
warn!("Error trying to resolve {}, ({:?})", &address.host, err);
99+
None
100+
}
101+
}
102+
}
103+
None => None,
104+
};
105+
84106
let mut stream =
85107
match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await {
86108
Ok(stream) => stream,
@@ -329,6 +351,8 @@ impl Server {
329351
bad: false,
330352
needs_cleanup: false,
331353
client_server_map,
354+
cached_resolver,
355+
addr_set,
332356
connected_at: chrono::offset::Utc::now().naive_utc(),
333357
stats,
334358
application_name: String::new(),
@@ -554,6 +578,12 @@ impl Server {
554578
/// Server & client are out of sync, we must discard this connection.
555579
/// This happens with clients that misbehave.
556580
pub fn is_bad(&self) -> bool {
581+
if let Some(cached_resolver) = &(*(self.cached_resolver.clone())) {
582+
if let Some(addr_set) = &self.addr_set {
583+
cached_resolver.has_changed(self.address.host.as_str(), addr_set);
584+
return true;
585+
}
586+
}
557587
self.bad
558588
}
559589

0 commit comments

Comments
 (0)