diff --git a/Cargo.lock b/Cargo.lock index 2e42223b..41c2d7c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -86,6 +86,12 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +[[package]] +name = "assert_matches" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" + [[package]] name = "autocfg" version = "1.1.0" @@ -264,6 +270,27 @@ version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +[[package]] +name = "directories" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a49173b84e034382284f27f1af4dcbbd231ffa358c0fe316541a7337f376a35" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + [[package]] name = "either" version = "1.10.0" @@ -334,6 +361,17 @@ dependencies = [ "libc", ] +[[package]] +name = "getrandom" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "glob" version = "0.3.1" @@ -497,16 +535,28 @@ dependencies = [ "libc", ] +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.4.2", + "libc", +] + [[package]] name = "libshpool" version = "0.6.2" dependencies = [ "anyhow", + "assert_matches", "bincode", "byteorder", "chrono", "clap", "crossbeam-channel", + "directories", "lazy_static", "libc", "libproc", @@ -698,6 +748,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -746,6 +802,17 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "redox_users" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + [[package]] name = "regex" version = "1.10.3" @@ -992,6 +1059,26 @@ dependencies = [ "home", ] +[[package]] +name = "thiserror" +version = "1.0.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "thread_local" version = "1.1.8" diff --git a/libshpool/Cargo.toml b/libshpool/Cargo.toml index 0efd2062..1437a94a 100644 --- a/libshpool/Cargo.toml +++ b/libshpool/Cargo.toml @@ -42,6 +42,7 @@ tempfile = "3" # RAII tmp files strip-ansi-escapes = "0.2.0" # cleaning up strings for pager display notify = "6" # watch config file for updates libproc = "0.14.8" # sniffing shells by examining the subprocess +directories = "5.0.1" # get paths for configurations # rusty wrapper for unix apis [dependencies.nix] @@ -55,3 +56,4 @@ features = ["std", "fmt", "tracing-log", "smallvec"] [dev-dependencies] ntest = "0.9" # test timeouts +assert_matches = "1.5" # assert_matches macro diff --git a/libshpool/src/config.rs b/libshpool/src/config.rs index e1450853..0467d48f 100644 --- a/libshpool/src/config.rs +++ b/libshpool/src/config.rs @@ -14,17 +14,18 @@ use std::{ collections::HashMap, - fs, - path::{Path, PathBuf}, + fs, io, + path::Path, sync::{Arc, RwLock, RwLockReadGuard}, }; use anyhow::Context; -use notify::Watcher; +use directories::ProjectDirs; use serde_derive::Deserialize; use tracing::{info, warn}; -use super::{daemon::keybindings, user}; +use super::config_watcher::ConfigWatcher; +use super::daemon::keybindings; /// Exposes the shpool config file, watching for file updates /// so that the user does not need to restart the daemon when @@ -33,90 +34,114 @@ use super::{daemon::keybindings, user}; /// Users should never cache the config value directly and always /// access the config through the manager. The config may change /// at any time, so any cached config value could become stale. +#[derive(Clone)] pub struct Manager { /// The config value. config: Arc>, - watcher: Option>, + #[allow(dead_code)] + watcher: Arc, + dirs: ProjectDirs, } impl Manager { - // Create a new config manager. - pub fn new(config_file: Option<&str>) -> anyhow::Result { - let user_info = user::info()?; - let mut default_config_path = PathBuf::from(user_info.home_dir); - - let (config, config_path) = if let Some(config_path) = config_file { - info!("parsing explicitly passed in config ({})", config_path); - let config_str = fs::read_to_string(config_path).context("reading config toml (1)")?; - let config = toml::from_str(&config_str).context("parsing config file (1)")?; - - (config, Some(String::from(config_path))) - } else { - default_config_path.push(".config"); - default_config_path.push("shpool"); - default_config_path.push("config.toml"); - if default_config_path.exists() { - let config_str = - fs::read_to_string(&default_config_path).context("reading config toml (2)")?; - let config = toml::from_str(&config_str).context("parsing config file (2)")?; - - (config, default_config_path.clone().to_str().map(String::from)) - } else { - (Config::default(), None) + /// Create a new config manager. + /// + /// Unless given as the `config_file` argument, config files are read from the following paths + /// in the reverse priority order: + /// - System level config: /etc/shpool/config.toml + /// - User level config: $XDG_CONFIG_HOME/shpool/config.toml or $HOME/.config/shpool/config.toml + /// if $XDG_CONFIG_HOME is not set + /// For each top level field, values read later will overrides those read eariler. The exact + /// merging strategy is as defined in `Config::or`. + pub fn new<'a>(config_file: Option<&'a str>) -> anyhow::Result { + let dirs = ProjectDirs::from("", "", "shpool") + .context("no valid home directory path could be retrieved from the operating system")?; + + let config_files: Vec<&Path> = match config_file { + None => { + vec![ + // TODO: this won't work on Windows. Consider app-dirs2 which supports + // system level paths, but are less maintained than directories-rs. + Path::new("/etc/shpool/config.toml"), + dirs.config_dir(), + ] + } + Some(config_file) => { + info!("parsing explicitly passed in config ({})", config_file); + vec![Path::new(config_file)] } }; - info!("starting with config: {:?}", config); - let mut manager = Manager { config: Arc::new(RwLock::new(config)), watcher: None }; - - if let Some(watch_path) = config_path { - let config_slot = Arc::clone(&manager.config); - let reload_path = watch_path.clone(); - let mut watcher = notify::recommended_watcher(move |res| match res { - Ok(event) => { - info!("config file modify event: {:?}", event); - - let config_str = match fs::read_to_string(&reload_path) { - Ok(s) => s, - Err(e) => { - warn!("error reading config file: {:?}", e); - return; - } - }; - - let config = match toml::from_str(&config_str) { - Ok(c) => c, - Err(e) => { - warn!("error parsing config file: {:?}", e); - return; - } - }; - info!("new config: {:?}", config); - - let mut manager_config = config_slot.write().unwrap(); - *manager_config = config; + let config = Self::load(&config_files).context("loading initial config")?; + info!("starting with config: {:?}", config); + let config = Arc::new(RwLock::new(config)); + + let watcher = { + let config = config.clone(); + // create a owned version of config_files to move to the watcher thread. + let config_files: Vec<_> = config_files.iter().map(|f| f.to_path_buf()).collect(); + ConfigWatcher::new(move || { + info!("reloading config"); + let mut config = config.write().unwrap(); + match Self::load(&config_files) { + Ok(c) => { + info!("new config: {:?}", c); + *config = c; + } + Err(err) => warn!("error loading config file: {:?}", err), } - Err(e) => warn!("config file watch err: {:?}", e), }) - .context("building watcher")?; - watcher - .watch(Path::new(&watch_path), notify::RecursiveMode::NonRecursive) - .context("registering config file for watching")?; - manager.watcher = Some(Arc::new(watcher)); + .context("building watcher")? + }; + for path in config_files { + watcher.watch(path).context("registering config file for watching")?; } + let manager = Manager { config, watcher: Arc::new(watcher), dirs }; Ok(manager) } - // Get the current config value. + /// Get the current config value. pub fn get(&self) -> RwLockReadGuard<'_, Config> { self.config.read().unwrap() } -} -impl std::clone::Clone for Manager { - fn clone(&self) -> Self { - Manager { config: Arc::clone(&self.config), watcher: self.watcher.as_ref().map(Arc::clone) } + /// Get the ProjectDirs instance created at startup. + #[allow(dead_code)] + pub fn dirs(&self) -> &ProjectDirs { + &self.dirs + } + + /// Load config by merging configurations from a list of Paths. + /// + /// Paths come later in the list takes higher priority. + /// Merge strategy is as defined in `Config::or`. + fn load(config_files: impl IntoIterator>) -> anyhow::Result { + let mut config = Config::default(); + for path in config_files { + let path = path.as_ref(); + let config_str = match fs::read_to_string(path) { + // It is okay if the file is not there. + Err(e) if e.kind() == io::ErrorKind::NotFound => continue, + Err(e) => { + warn!("error reading config file: {:?}", e); + return Err(e) + .with_context(|| format!("reading config toml {}", path.to_string_lossy())); + } + Ok(s) => s, + }; + let new_config: Config = match toml::from_str(&config_str) { + Err(e) => { + warn!("error parsing config file: {:?}", e); + return Err(e).with_context(|| { + format!("parsing config toml {}", path.to_string_lossy()) + }); + } + Ok(c) => c, + }; + config = new_config.or(config); + } + Ok(config) } } @@ -212,6 +237,37 @@ pub struct Config { pub motd_args: Option>, } +impl Config { + // Merge with `another` Config instance, with `self` taking higher priority. + // + // Top level options with simple value are directly taken from the higher priority instance. + // The merge strategy for list or map are handled case by case. Please refer to each options' + // documentation for details. + pub fn or(self, another: Config) -> Config { + Config { + norc: self.norc.or(another.norc), + noecho: self.noecho.or(another.noecho), + nosymlink_ssh_auth_sock: self + .nosymlink_ssh_auth_sock + .or(another.nosymlink_ssh_auth_sock), + noread_etc_environment: self.noread_etc_environment.or(another.noread_etc_environment), + shell: self.shell.or(another.shell), + // TODO: Check this + env: self.env.or(another.env), + forward_env: self.forward_env.or(another.forward_env), + initial_path: self.initial_path.or(another.initial_path), + session_restore_mode: self.session_restore_mode.or(another.session_restore_mode), + output_spool_lines: self.output_spool_lines.or(another.output_spool_lines), + // TODO: check this + keybinding: self.keybinding.or(another.keybinding), + prompt_prefix: self.prompt_prefix.or(another.prompt_prefix), + motd: self.motd.or(another.motd), + // TODO: check this + motd_args: self.motd_args.or(another.motd_args), + } + } +} + #[derive(Deserialize, Debug, Clone)] pub struct Keybinding { /// The keybinding to map to an action. The syntax for these keybindings @@ -296,4 +352,85 @@ mod test { Ok(()) } + + mod merge { + use super::*; + use assert_matches::assert_matches; + + #[test] + fn simple_value() -> anyhow::Result<()> { + // 4 values are chosen to cover all combinations of None and Some cases. + let higher = Config { + norc: None, + noecho: None, + shell: Some("abc".to_string()), + session_restore_mode: Some(SessionRestoreMode::Simple), + ..Default::default() + }; + let lower = Config { + norc: Some(true), + noecho: None, + shell: None, + session_restore_mode: Some(SessionRestoreMode::Lines(42)), + ..Default::default() + }; + + assert_matches!(higher.or(lower), Config { + norc: Some(true), + noecho: None, + shell: Some(shell), + session_restore_mode: Some(SessionRestoreMode::Simple), + .. + } if shell == "abc"); + Ok(()) + } + + #[test] + fn vec_value() -> anyhow::Result<()> { + let higher = Config { + forward_env: Some(vec!["abc".to_string(), "efg".to_string()]), + motd_args: None, + ..Default::default() + }; + let lower = Config { + forward_env: None, + motd_args: Some(vec!["hij".to_string(), "klm".to_string()]), + ..Default::default() + }; + + let actual = higher.or(lower); + assert_eq!(actual.forward_env, Some(vec!["abc".to_string(), "efg".to_string()])); + assert_eq!(actual.motd_args, Some(vec!["hij".to_string(), "klm".to_string()])); + Ok(()) + } + + #[test] + fn map_value() -> anyhow::Result<()> { + let higher = Config { + env: Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])), + ..Default::default() + }; + let lower = Config { + env: Some(HashMap::from([ + ("key3".to_string(), "value3".to_string()), + ("key4".to_string(), "value4".to_string()), + ])), + ..Default::default() + }; + + let actual = higher.or(lower); + + assert_eq!( + actual.env, + Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])) + ); + Ok(()) + } + } } diff --git a/libshpool/src/config_watcher.rs b/libshpool/src/config_watcher.rs new file mode 100644 index 00000000..80a62f9e --- /dev/null +++ b/libshpool/src/config_watcher.rs @@ -0,0 +1,676 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::Context; +use crossbeam_channel::{select, unbounded, Receiver, RecvError, Sender}; +use notify::{ + event::ModifyKind, recommended_watcher, Event, EventKind, INotifyWatcher, RecursiveMode, + Watcher as _, +}; +use std::{ + collections::{hash_map::Entry, HashMap}, + mem, + path::{Path, PathBuf}, + thread::{self, JoinHandle}, + time::{Duration, Instant}, +}; +use tracing::{debug, error}; + +/// Watches on `path`, returnes the watched path, which is the closest existing ancestor of `path`, and +/// the immediate child that is of interest. +pub fn best_effort_watch<'a>( + watcher: &mut INotifyWatcher, + path: &'a Path, +) -> anyhow::Result<(&'a Path, Option<&'a Path>)> { + let mut watched_path = Err(anyhow::anyhow!("empty path")); + // Ok or last Err + for watch_path in path.ancestors() { + match watcher.watch(watch_path, RecursiveMode::NonRecursive) { + Ok(_) => { + watched_path = Ok(watch_path); + break; + } + Err(err) => watched_path = Err(err.into()), + } + } + // watched path could be any ancestor of the original path + let watched_path = watched_path.context("add notify watch for config file")?; + let remaining_path = path + .strip_prefix(watched_path) + .expect("watched_path was obtained as an ancestor of path, yet it is not a prefix"); + let immediate_child = remaining_path.iter().next(); + debug!("Actually watching {}, ic {:?}", watched_path.display(), &immediate_child); + Ok((watched_path, immediate_child.map(Path::new))) +} + +/// Notify watcher to detect config file changes. +/// +/// Notable features: +/// - handles non-existing config files +/// - support watching multiple files +/// - configurable debounce time for reload +/// +/// For simplicity, reload doesn't distinguish which file was changed. It is expected that all +/// config files need to be reload regardless which one changed. +/// +/// # Examples +/// ``` +/// let watcher = ConfigWatcher::new(|| println!("RELOAD CONFIG")).unwrap(); +/// watcher.watch("/some/path/config.toml"); +/// ```` +pub struct ConfigWatcher { + // for sending watch requests + tx: Sender, + + worker: JoinHandle<()>, + + // for receiving debug info from worker thread, test only + #[cfg(test)] + debug_rx: Receiver<()>, +} + +impl ConfigWatcher { + /// Creates a new [`ConfigWatcher`] with default debounce time 100ms. + /// + /// Event processing happens in another thread, so the passed in `handler` is expected to properly handle synchronization and locking. + /// + /// # Errors + /// Returns error if the creation of underlying `notify` watcher or worker thread failed. + pub fn new(handler: impl FnMut() + Send + 'static) -> anyhow::Result { + Self::with_debounce(handler, Duration::from_millis(100)) + } + + /// Creates a new [`ConfigWatcher`] with default debounce time `reload_debounce`. + /// + /// Event processing happens in another thread, so the passed in `handler` is expected to properly handle synchronization and locking. + /// + /// # Arguments + /// * `handler` - The handler is called when the watcher determines there is a need to reload config files + /// * `reload_debounce` - Reloads happen within `reload_debounce` time will only trigger the handler once + /// + /// # Errors + /// Returns error if the creation of underlying `notify` watcher or worker thread failed. + pub fn with_debounce( + handler: impl FnMut() + Send + 'static, + reload_debounce: Duration, + ) -> anyhow::Result { + let (notify_tx, notify_rx) = unbounded(); + let (req_tx, req_rx) = unbounded(); + + #[cfg(test)] + let (debug_tx, debug_rx) = unbounded(); + + let watcher = recommended_watcher(notify_tx).context("create notify watcher")?; + + let mut inner = ConfigWatcherInner { + reload_debounce, + reload_deadline: None, + handler, + watcher, + notify_rx, + req_rx, + #[cfg(test)] + debug_tx, + paths: Default::default(), + }; + let worker = thread::Builder::new() + .name("config-reload".to_string()) + .spawn(move || { + if let Err(err) = inner.run() { + error!("config reload thread returned error: {:?}", err); + } + }) + .context("create config reload thread")?; + + Ok(Self { + tx: req_tx, + worker, + #[cfg(test)] + debug_rx, + }) + } + + /// Adds a watch on `path`. + /// + /// # Errors + /// Returns error if the underlying thread is gone, e.g. the worker thread encountered fatal + /// error and stopped its event loop. + pub fn watch(&self, path: impl AsRef) -> anyhow::Result<()> { + self.tx.send(path.as_ref().to_owned()).context("send AddWatch to ConfigWatcherInnfer")?; + Ok(()) + } + + /// Stops watching, shutting down the worker thread. + #[allow(dead_code)] + pub fn stop(self) { + mem::drop(self.tx); + self.worker.join().unwrap(); + } + + /// Worker is idle and ready for the next event. Debug/test only. + #[cfg(test)] + fn worker_ready(&self) { + self.debug_rx.recv().unwrap(); + debug!("worker ready"); + } +} + +struct ConfigWatcherInner { + // time to wait before actual reloading + reload_debounce: Duration, + // deadline to do a reload + reload_deadline: Option, + + handler: Handler, + + watcher: INotifyWatcher, + // receiving notify events + notify_rx: Receiver>, + // receiving watch requests + req_rx: Receiver, + // from target_path to (watched_path, immediate_child_path) + paths: HashMap, + + // for sending out debug info + #[cfg(test)] + debug_tx: Sender<()>, +} + +// Outcomes of selecting channels in the worker thread +enum Outcome { + // A notify event occurred + Event(notify::Result), + // A control command from outside + AddWatch(PathBuf), + // Timeout on notify event, trigger reload + Timeout, + // Any channel was disconnected + Disconnected, +} + +impl From, RecvError>> for Outcome { + fn from(value: Result, RecvError>) -> Self { + match value { + Ok(v) => Self::Event(v), + Err(RecvError) => Self::Disconnected, + } + } +} + +impl From> for Outcome { + fn from(value: Result) -> Self { + match value { + Ok(v) => Self::AddWatch(v), + Err(RecvError) => Self::Disconnected, + } + } +} + +impl ConfigWatcherInner { + // get next event to work on + fn select(&self) -> Outcome { + debug!("now {:?} select with ddl {:?}", Instant::now(), &self.reload_deadline); + + // only impose a deadline if there is pending reload + let timeout = self + .reload_deadline + .map(crossbeam_channel::at) + .unwrap_or_else(crossbeam_channel::never); + // first try poll + if let Ok(res) = self.notify_rx.try_recv() { + return Outcome::Event(res); + } + if let Ok(res) = self.req_rx.try_recv() { + return Outcome::AddWatch(res); + } + if timeout.try_recv().is_ok() { + return Outcome::Timeout; + } + + // nothing ready to act immediately, notify debug_tx + #[cfg(test)] + self.debug_tx.send(()).unwrap(); + + // finally blocking wait + let res = select! { + recv(self.notify_rx) -> res => Outcome::from(res), + recv(self.req_rx) -> res => Outcome::from(res), + recv(timeout) -> _ => Outcome::Timeout, + }; + res + } + + fn trigger_reload(&mut self) { + self.reload_deadline = + self.reload_deadline.or_else(|| Some(Instant::now() + self.reload_debounce)); + debug!("defer config reloading to {:?}!", &self.reload_deadline); + } +} + +impl ConfigWatcherInner +where + Handler: FnMut(), +{ + // run forever loop to reload config, only return when there is error to create any watches. + fn run(&mut self) -> anyhow::Result<()> { + loop { + match self.select() { + Outcome::Event(res) => { + let (rewatch, mut reload) = match res { + Err(error) => { + error!("Error: {error:?}"); + (ReWatch::All, false) + } + Ok(event) => handle_event(event, &self.paths), + }; + debug!("rewatch = {rewatch:?}, reload = {reload}"); + match rewatch { + ReWatch::Some(rewatch_paths) => { + for (path, watched_path) in rewatch_paths { + if let Err(err) = self.watcher.unwatch(&watched_path) { + // error sometimes is expected if the watched_path is simply removed, in that case notify will automatically remove the watch. + error!("error unwatch {:?}", err); + } else { + debug!("unwatched {}", watched_path.display()); + } + reload = watch_and_add(&mut self.watcher, self.paths.entry(path)) + || reload; + } + } + ReWatch::All => { + // drain paths and collect into vec first, to avoid keeping a mutable borrow on + // self.paths + let paths: Vec<_> = self.paths.drain().collect(); + for (path, _) in paths { + watch_and_add(&mut self.watcher, self.paths.entry(path)); + } + } + }; + + if reload { + self.trigger_reload(); + } + } + Outcome::AddWatch(path) => { + match self.paths.entry(path) { + Entry::Occupied(e) => { + error!("{} is already being watched", e.key().display()); + } + e @ Entry::Vacant(_) => { + if watch_and_add(&mut self.watcher, e) { + self.trigger_reload(); + } + } + }; + } + Outcome::Timeout => { + debug!("reload config!"); + self.reload_deadline = None; + (self.handler)(); + } + Outcome::Disconnected => { + debug!("notify watcher or main handle dropped"); + break; + } + } + } + Ok(()) + } +} + +#[derive(Debug, PartialEq, Eq)] +enum ReWatch { + // rewatch a few (target path, watched path) + Some(Vec<(PathBuf, PathBuf)>), + // rewatch all paths + All, +} + +// return wether need to rewatch, and whether need to reload +fn handle_event(event: Event, paths: &HashMap) -> (ReWatch, bool) { + if event.need_rescan() { + debug!("need rescan"); + return (ReWatch::All, true); + } + + // this event is about one of the watched target + let is_original = event.paths.iter().any(|p| paths.contains_key(p)); + + match event.kind { + // create/remove in any segment in path + EventKind::Remove(_) | EventKind::Create(_) | EventKind::Modify(ModifyKind::Name(_)) => { + debug!("create/remove: {:?}", event); + // find all path entries about this event + let rewatch = paths + .iter() + .filter(|(_, (watched_path, immediate_child_path))| { + event.paths.iter().any(|p| p == watched_path || p == immediate_child_path) + }) + .map(|(path, (watched_path, _))| (path.to_owned(), watched_path.to_owned())) + .collect(); + (ReWatch::Some(rewatch), is_original) + } + // modification in any segment in path + EventKind::Modify(_) => { + debug!("modify: {:?}", event); + (ReWatch::Some(vec![]), is_original) + } + _ => { + debug!("ignore {:?}", event); + + (ReWatch::Some(vec![]), false) + } + } +} + +// Add a watch at `path`, update paths `entry` if success, or remove `entry` if failed. +// Note that this will overwrite any existing state. +// Return whether reload is needed. +fn watch_and_add(watcher: &mut INotifyWatcher, entry: Entry) -> bool { + // make a version of watch path that doesn't retain a borrow in its return value + let watch_path_owned = |watcher: &mut INotifyWatcher, path: &Path| { + best_effort_watch(watcher, path) + .map(|(w, ic)| (w.to_owned(), w.join(ic.unwrap_or_else(|| Path::new(""))))) + }; + match watch_path_owned(watcher, entry.key()) { + Ok((watched_path, immediate_child_path)) => { + let reload = &watched_path == entry.key(); + // update entry after `match watch_a_path(...)`, as that takes a borrow on entry (entry.key()) + match entry { + Entry::Occupied(mut entry) => { + entry.insert((watched_path, immediate_child_path)); + } + Entry::Vacant(entry) => { + entry.insert((watched_path, immediate_child_path)); + } + } + if reload { + debug!("Force reload since now watching on target file"); + } + reload + } + Err(err) => { + error!("Failed to watch {}: {:?}", entry.key().display(), err); + if let Entry::Occupied(entry) = entry { + entry.remove(); + } + true + } + } +} + +#[cfg(test)] +#[rustfmt::skip::attributes(test_case)] +mod test { + use super::*; + use std::fs; + use tempfile::TempDir; + + mod watch { + use super::*; + use std::fs; + + #[test] + fn all_non_existing() { + let mut watcher = recommended_watcher(|_| {}).unwrap(); + + let (watched_path, immediate_child) = + best_effort_watch(&mut watcher, Path::new("/non_existing/subdir")).unwrap(); + + assert_eq!(watched_path, Path::new("/")); + assert_eq!(immediate_child, Some(Path::new("non_existing"))); + } + + #[test] + fn non_existing_parent() { + let tmpdir = tempfile::tempdir().unwrap(); + let target_path = tmpdir.path().join(Path::new("sub1/sub2/c.txt")); + + let parent_path = target_path.parent().unwrap().parent().unwrap(); + + fs::create_dir_all(parent_path).unwrap(); + + let mut watcher = recommended_watcher(|_| {}).unwrap(); + let (watched_path, immediate_child) = + best_effort_watch(&mut watcher, &target_path).unwrap(); + + assert_eq!(watched_path, parent_path); + assert_eq!(immediate_child, Some(Path::new("sub2"))); + } + + #[test] + fn existing_file() { + let tmpdir = tempfile::tempdir().unwrap(); + let target_path = tmpdir.path().join(Path::new("sub1/sub2/c.txt")); + + let parent_path = target_path.parent().unwrap(); + + fs::create_dir_all(parent_path).unwrap(); + fs::write(&target_path, "test").unwrap(); + + let mut watcher = recommended_watcher(|_| {}).unwrap(); + let (watched_path, immediate_child) = + best_effort_watch(&mut watcher, &target_path).unwrap(); + + assert_eq!(watched_path, target_path); + assert_eq!(immediate_child, None); + } + } + + mod handle_event { + use super::*; + use assert_matches::assert_matches; + use notify::{ + event::{CreateKind, ModifyKind, RemoveKind, RenameMode}, + Event, EventKind, + }; + use ntest::test_case; + + fn paths_entry(target: &str, watched: &str) -> (PathBuf, (PathBuf, PathBuf)) { + let target = PathBuf::from(target); + let base = PathBuf::from(watched); + let immediate = + base.join(target.strip_prefix(&base).unwrap().iter().next().unwrap_or_default()); + (target, (base, immediate)) + } + + // create event from spec: + // path1 [path2] + // `base` is prepended to all paths + fn event_from_spec(base: &str, evt: &str) -> notify::Event { + let base = Path::new(base); + let (evt, path) = evt.split_once(' ').unwrap_or((evt, "")); + match evt { + "create" => { + Event::new(EventKind::Create(CreateKind::Any)).add_path(base.join(path)) + } + "mv" => { + let (src, dst) = path.split_once(' ').unwrap(); + Event::new(EventKind::Modify(ModifyKind::Name(RenameMode::Both))) + .add_path(base.join(src)) + .add_path(base.join(dst)) + } + "mvselfother" => Event::new(EventKind::Modify(ModifyKind::Name(RenameMode::Both))) + .add_path(base.to_owned()) + .add_path(PathBuf::from("/some/other/path")), + "modify" => { + Event::new(EventKind::Modify(ModifyKind::Any)).add_path(base.join(path)) + } + "modifyself" => { + Event::new(EventKind::Modify(ModifyKind::Any)).add_path(base.to_owned()) + } + "rm" => Event::new(EventKind::Remove(RemoveKind::Any)).add_path(base.join(path)), + "rmself" => { + Event::new(EventKind::Remove(RemoveKind::Any)).add_path(base.to_owned()) + } + _ => panic!("malformatted event spec"), + } + } + + #[test] + fn need_rescan() { + let event = notify::Event::default().set_flag(notify::event::Flag::Rescan); + let paths = Default::default(); + let (rewatch, reload) = handle_event(event, &paths); + assert_eq!(rewatch, ReWatch::All); + assert!(reload); + } + + const TARGET: &str = "/base/sub/config.toml"; + + #[test_case(TARGET, "/base", "create sub", true, false, name = "base_create_sub")] + #[test_case(TARGET, "/base", "create other", false, false, name = "base_create_other")] + #[test_case(TARGET, "/base", "mv other sub", true, false, name = "base_other_to_sub")] + #[test_case(TARGET, "/base", "mv other another", false, false, name = "base_other_to_another")] + #[test_case(TARGET, "/base", "mv sub other", true, false, name = "base_sub_to_other")] + #[test_case(TARGET, "/base", "rm sub", true, false, name = "base_rm_sub")] + #[test_case(TARGET, "/base", "rm other", false, false, name = "base_rm_other")] + #[test_case(TARGET, "/base", "modify other.toml", false, false, name = "base_modify_other")] + #[test_case(TARGET, "/base/sub", "create config.toml", true, true, name = "sub_create_cfg")] + #[test_case(TARGET, "/base/sub", "mv other.toml config.toml", true, true, name = "sub_other_to_cfg")] + #[test_case(TARGET, "/base/sub", "mv other.toml another.toml", false, false, name = "sub_other_to_another")] + #[test_case(TARGET, "/base/sub", "modify config.toml", false, true, name = "sub_modify_cfg")] + #[test_case(TARGET, "/base/sub", "modify other.toml", false, false, name = "sub_modify_other")] + #[test_case(TARGET, "/base/sub", "rmself", true, false, name = "sub_rm_self")] + #[test_case(TARGET, "/base/sub/config.toml", "rmself", true, true, name = "cfg_rm_self")] + #[test_case(TARGET, "/base/sub/config.toml", "mvselfother", true, true, name = "cfg_self_to_other")] + #[test_case(TARGET, "/base/sub/config.toml", "modifyself", false, true, name = "cfg_modify_self")] + fn single_path( + target: &str, + watched: &str, + evt: &str, + expected_rewatch: bool, + expected_reload: bool, + ) { + let paths = HashMap::from([paths_entry(target, watched)]); + let event = event_from_spec(watched, evt); + + let (rewatch, reload) = handle_event(event, &paths); + + let expected_rewatch = if expected_rewatch { + ReWatch::Some(vec![(PathBuf::from(target), PathBuf::from(watched))]) + } else { + ReWatch::Some(vec![]) + }; + assert_eq!(rewatch, expected_rewatch); + assert_eq!(reload, expected_reload); + } + + #[test] + fn both_paths_are_updated() { + let paths = HashMap::from([ + paths_entry("/base/sub/config.toml", "/base"), + paths_entry("/base/other/another.toml", "/base"), + ]); + let event = event_from_spec("/base", "rm /base"); + + let (rewatch, reload) = handle_event(event, &paths); + + assert_matches!(rewatch, ReWatch::Some(p) if p.len() == 2); + assert!(!reload); + } + } + + // Smaller debounce time for faster testing + const DEBOUNCE_TIME: Duration = Duration::from_millis(50); + + struct WatcherState { + #[allow(dead_code)] + tmpdir: TempDir, + base_path: PathBuf, + target_path: PathBuf, + rx: Receiver<()>, + watcher: ConfigWatcher, + } + + // Setup file structure at /`base`, configure watcher to watch /`base`/`target` + fn setup(base: &str, target: &str) -> anyhow::Result { + let tmpdir = tempfile::tempdir()?; + let base_path = tmpdir.path().join(base); + let target_path = base_path.join(target); + assert!(target_path.strip_prefix(&base_path).is_ok()); + + fs::create_dir_all(&base_path)?; + + let (tx, rx) = unbounded(); + let watcher = ConfigWatcher::with_debounce(move || tx.send(()).unwrap(), DEBOUNCE_TIME)?; + watcher.watch(&target_path)?; + + // wait for watcher thread to be ready + watcher.worker_ready(); + + Ok(WatcherState { tmpdir, base_path, target_path, rx, watcher }) + } + + // Wait for watcher to do its work and drop the watcher to close the channel + fn drop_watcher(watcher: ConfigWatcher) { + // sleep time larger than 1 debounce time + thread::sleep(DEBOUNCE_TIME * 2); + watcher.worker_ready(); + } + + #[test] + fn debounce() { + let state = setup("base", "sub/config.toml").unwrap(); + + fs::create_dir_all(state.target_path.parent().unwrap()).unwrap(); + + state.watcher.worker_ready(); + fs::write(&state.target_path, "test").unwrap(); + + state.watcher.worker_ready(); + fs::write(&state.target_path, "another").unwrap(); + + drop_watcher(state.watcher); + + let reloads: Vec<_> = state.rx.into_iter().collect(); + assert_eq!(reloads.len(), 1); + } + + #[test] + fn writes_larger_than_debounce() { + let state = setup("base", "sub/config.toml").unwrap(); + + fs::create_dir_all(state.target_path.parent().unwrap()).unwrap(); + state.watcher.worker_ready(); + fs::write(&state.target_path, "test").unwrap(); + + thread::sleep(DEBOUNCE_TIME * 2); + state.watcher.worker_ready(); + fs::write(&state.target_path, "another").unwrap(); + + drop_watcher(state.watcher); + + let reloads: Vec<_> = state.rx.into_iter().collect(); + assert_eq!(reloads.len(), 2); + } + + // /base, mv /base/other (with config.toml) /base/sub (with config.toml) => rewatch, reload + #[test] + fn move_multiple_levels_in_place() { + let state = setup("base", "sub/config.toml").unwrap(); + + // create /base/other/config.toml + fs::create_dir_all(state.base_path.join("other")).unwrap(); + fs::write(state.base_path.join("other/config.toml"), "test").unwrap(); + + // mv /base/other /base/sub + fs::rename(state.base_path.join("other"), state.base_path.join("sub")).unwrap(); + + drop_watcher(state.watcher); + + let reloads: Vec<_> = state.rx.into_iter().collect(); + assert_eq!(reloads.len(), 1); + } +} diff --git a/libshpool/src/lib.rs b/libshpool/src/lib.rs index c0e4401b..1f219806 100644 --- a/libshpool/src/lib.rs +++ b/libshpool/src/lib.rs @@ -30,6 +30,7 @@ use tracing_subscriber::fmt::format::FmtSpan; mod attach; mod common; mod config; +mod config_watcher; mod consts; mod daemon; mod detach; diff --git a/shpool/tests/attach.rs b/shpool/tests/attach.rs index 6b7a3336..0e393eed 100644 --- a/shpool/tests/attach.rs +++ b/shpool/tests/attach.rs @@ -1287,6 +1287,9 @@ fn dynamic_config_change() -> anyhow::Result<()> { let config_contents = config_tmpl.replace("REPLACE_ME", "NEW_VALUE"); fs::write(&config_file, config_contents)?; + // Wait for longer than debounce time + thread::sleep(time::Duration::from_secs(2)); + // When we spawn a new session, it should pick up the new value let mut attach_proc = daemon_proc.attach("sh2", Default::default()).context("starting attach proc")?;