diff --git a/Cargo.lock b/Cargo.lock index 7e9ff79121e0..cc99b0ff7c2f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1494,6 +1494,7 @@ dependencies = [ "device_tree", "displaydoc", "event-manager", + "itertools 0.12.0", "kvm-bindings", "kvm-ioctls", "lazy_static", diff --git a/src/vmm/Cargo.toml b/src/vmm/Cargo.toml index e637e18e14fe..a0d4b535fa64 100644 --- a/src/vmm/Cargo.toml +++ b/src/vmm/Cargo.toml @@ -50,6 +50,7 @@ vm-fdt = "0.2.0" criterion = { version = "0.5.0", default-features = false } device_tree = "1.1.0" proptest = { version = "1.0.0", default-features = false, features = ["std"] } +itertools = "0.12.0" [features] tracing = ["log-instrument"] diff --git a/src/vmm/src/logger/logging.rs b/src/vmm/src/logger/logging.rs index d111994363a4..31d0e278375f 100644 --- a/src/vmm/src/logger/logging.rs +++ b/src/vmm/src/logger/logging.rs @@ -10,7 +10,7 @@ use std::sync::{Mutex, OnceLock}; use std::thread; use log::{Log, Metadata, Record}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use utils::time::LocalTime; use super::metrics::{IncMetric, METRICS}; @@ -200,7 +200,7 @@ pub struct LoggerConfig { /// the log level filter. It would be a breaking change to no longer support this. In the next /// breaking release this should be removed (replaced with `log::LevelFilter` and only supporting /// its default deserialization). -#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize)] pub enum LevelFilter { /// [`log::LevelFilter:Off`] #[serde(alias = "OFF")] @@ -233,6 +233,25 @@ impl From for log::LevelFilter { } } } +impl<'de> Deserialize<'de> for LevelFilter { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use serde::de::Error; + let key = String::deserialize(deserializer)?; + let level = match key.to_lowercase().as_str() { + "off" => Ok(LevelFilter::Off), + "trace" => Ok(LevelFilter::Trace), + "debug" => Ok(LevelFilter::Debug), + "info" => Ok(LevelFilter::Info), + "warn" | "warning" => Ok(LevelFilter::Warn), + "error" => Ok(LevelFilter::Error), + _ => Err(D::Error::custom("Invalid LevelFilter")), + }; + level + } +} /// Error type for [`::from_str`]. #[derive(Debug, PartialEq, Eq, thiserror::Error)] @@ -288,6 +307,35 @@ mod tests { ); } #[test] + fn levelfilter_from_str_all_variants() { + use itertools::Itertools; + + #[derive(Debug, Deserialize)] + struct Foo { + #[allow(dead_code)] + level: LevelFilter, + } + + for level in ["off", "trace", "debug", "info", "warn", "warning", "error"] { + let multi = level.chars().map(|_| 0..=1).multi_cartesian_product(); + for combination in multi { + let variant = level + .chars() + .zip_eq(combination) + .map(|(c, v)| match v { + 0 => c.to_ascii_lowercase(), + 1 => c.to_ascii_uppercase(), + _ => unreachable!(), + }) + .collect::(); + + let ex = format!("{} \"level\": \"{}\" {}", "{", variant, "}"); + assert!(LevelFilter::from_str(&variant).is_ok(), "{variant}"); + assert!(serde_json::from_str::(&ex).is_ok(), "{ex}"); + } + } + } + #[test] fn levelfilter_from_str() { assert_eq!( LevelFilter::from_str("bad"),