diff --git a/Cargo.toml b/Cargo.toml index 6f2ee538..742217c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ slog-term = { version = "2.9.0", optional = true } thiserror = "1.0.30" tokio = { version = "1.17.0", optional = true, default-features = false } tokio-stream = { version = "0.1.8", optional = true, default-features = false } +serde_json = "1.0.127" [features] default = ["runtime-tokio", "incremental"] diff --git a/src/frontend.rs b/src/frontend.rs new file mode 100644 index 00000000..2c41d096 --- /dev/null +++ b/src/frontend.rs @@ -0,0 +1,102 @@ +use crate::{CoreApi, Enforcer}; +use std::collections::HashMap; + +pub fn casbin_js_get_permission_for_user( + e: &Enforcer, + _user: &str, +) -> Result> { + let model = e.get_model(); + let mut m = HashMap::new(); + + m.insert("m", serde_json::Value::from(model.to_text())); + + let mut p_rules = Vec::new(); + if let Some(assertions) = model.get_model().get("p") { + for (ptype, _assertion) in assertions { + let policies = model.get_policy("p", ptype); + for rules in policies { + let mut rule = vec![ptype.to_string()]; + rule.extend(rules); + p_rules.push(rule); + } + } + } + m.insert("p", serde_json::Value::from(p_rules)); + + let mut g_rules = Vec::new(); + if let Some(assertions) = model.get_model().get("g") { + for (ptype, _assertion) in assertions { + let policies = model.get_policy("g", ptype); + for rules in policies { + let mut rule = vec![ptype.to_string()]; + rule.extend(rules); + g_rules.push(rule); + } + } + } + m.insert("g", serde_json::Value::from(g_rules)); + + let result = serde_json::to_string(&m)?; + Ok(result) +} + +#[cfg(test)] +mod tests { + use crate::frontend::casbin_js_get_permission_for_user; + use crate::prelude::*; + + #[cfg(not(target_arch = "wasm32"))] + #[cfg_attr( + all(feature = "runtime-async-std", not(target_arch = "wasm32")), + async_std::test + )] + #[cfg_attr( + all(feature = "runtime-tokio", not(target_arch = "wasm32")), + tokio::test + )] + async fn test_casbin_js_get_permission_for_user() { + use serde_json::Value; + use std::fs; + use std::io::Read; + + let model_path = "examples/rbac_model.conf"; + let policy_path = "examples/rbac_with_hierarchy_policy.csv"; + let e = Enforcer::new(model_path, policy_path).await.unwrap(); + + let received_string = + casbin_js_get_permission_for_user(&e, "alice").unwrap(); + let received: Value = serde_json::from_str(&received_string).unwrap(); + + let mut expected_model = String::new(); + fs::File::open(model_path) + .unwrap() + .read_to_string(&mut expected_model) + .unwrap(); + let expected_model_str = + expected_model.replace("\r\n", "\n").replace("\n\n", "\n"); + + assert_eq!( + received["m"].as_str().unwrap().trim(), + expected_model_str.trim() + ); + + let mut expected_policies = String::new(); + fs::File::open(policy_path) + .unwrap() + .read_to_string(&mut expected_policies) + .unwrap(); + let expected_policies_items: Vec<&str> = + expected_policies.split(&[',', '\n'][..]).collect(); + + let mut i = 0; + for s_arr in received["p"].as_array().unwrap() { + for s in s_arr.as_array().unwrap() { + assert_eq!( + s.as_str().unwrap().trim(), + expected_policies_items[i].trim() + ); + i += 1; + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 01006207..0ed2bc93 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,7 @@ mod util; mod watcher; pub mod error; +pub mod frontend; pub mod prelude; #[cfg(not(target_arch = "wasm32"))] diff --git a/src/model/default_model.rs b/src/model/default_model.rs index b11a7a72..4f6cfeac 100644 --- a/src/model/default_model.rs +++ b/src/model/default_model.rs @@ -384,6 +384,66 @@ impl Model for DefaultModel { } (res, rules_removed) } + + fn to_text(&self) -> String { + let mut token_patterns = HashMap::new(); + let p_pattern = regex::Regex::new(r"^p_").unwrap(); + let r_pattern = regex::Regex::new(r"^r_").unwrap(); + + for ptype in ["r", "p"] { + if let Some(assertion) = self.model.get(ptype) { + for token in &assertion[ptype].tokens { + let new_token = p_pattern.replace_all(token, "p."); + let new_token = r_pattern.replace_all(&new_token, "r."); + token_patterns.insert(token.clone(), new_token.to_string()); + } + } + } + + if let Some(assertions) = self.model.get("e") { + if let Some(assertion) = assertions.get("e") { + if assertion.value.contains("p_eft") { + token_patterns + .insert("p_eft".to_string(), "p.eft".to_string()); + } + } + } + + let mut s = String::new(); + + let write_string = |sec: &str, s: &mut String| { + if let Some(assertions) = self.model.get(sec) { + for (_ptype, assertion) in assertions { + let mut value = assertion.value.clone(); + for (token_pattern, new_token) in &token_patterns { + value = value.replace(token_pattern, new_token); + } + s.push_str(&format!("{} = {}\n", sec, value)); + } + } + }; + + s.push_str("[request_definition]\n"); + write_string("r", &mut s); + s.push_str("[policy_definition]\n"); + write_string("p", &mut s); + + if self.model.contains_key("g") { + s.push_str("[role_definition]\n"); + if let Some(assertions) = self.model.get("g") { + for (ptype, assertion) in assertions { + s.push_str(&format!("{} = {}\n", ptype, assertion.value)); + } + } + } + + s.push_str("[policy_effect]\n"); + write_string("e", &mut s); + s.push_str("[matchers]\n"); + write_string("m", &mut s); + + s + } } #[cfg(test)] diff --git a/src/model/mod.rs b/src/model/mod.rs index 970d91cb..30a2540d 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -72,4 +72,5 @@ pub trait Model: Send + Sync { field_index: usize, field_values: Vec, ) -> (bool, Vec>); + fn to_text(&self) -> String; }