Skip to content

Commit

Permalink
runtimes/core: fix validation of nested objects in unions (#1529)
Browse files Browse the repository at this point in the history
  • Loading branch information
eandre authored Oct 28, 2024
1 parent 4e98856 commit 47fae67
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 18 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions runtimes/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ rsa = { version = "0.9.6", features = ["pem"] }
flate2 = "1.0.30"
urlencoding = "2.1.3"
tower-http = { version = "0.5.2", features = ["fs"] }
serde_path_to_error = "0.1.16"

[build-dependencies]
prost-build = "0.12.3"
Expand Down
4 changes: 2 additions & 2 deletions runtimes/core/src/api/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ impl Error {
{
Self {
code: ErrCode::InvalidArgument,
message: public_msg.into(),
internal_message: Some(format!("{:?}", cause.into())),
message: format!("{}: {:?}", public_msg.into(), cause.into()),
internal_message: None,
stack: None,
details: None,
}
Expand Down
140 changes: 128 additions & 12 deletions runtimes/core/src/api/jsonschema/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ impl Literal {
Literal::Float(lit) => format!("{:#?}", lit),
}
}

pub fn expecting_type(&self) -> &'static str {
match self {
Literal::Str(_) => "string",
Literal::Bool(_) => "boolean",
Literal::Int(_) => "integer",
Literal::Float(_) => "number",
}
}
}

#[derive(Debug, Copy, Clone)]
Expand Down Expand Up @@ -1159,20 +1168,127 @@ impl<'a> DecodeValue<'a> {
JVal::Null => PValue::Null,
JVal::Bool(val) => PValue::Bool(val),
JVal::Number(num) => PValue::Number(num),
JVal::Array(vals) => {
let mut new_vals = Vec::with_capacity(vals.len());
for val in vals {
new_vals.push(self.transform(val)?);
JVal::Array(vals) => match self.value {
Value::Ref(idx) => return recurse_ref!(self, idx, transform, JVal::Array(vals)),
Value::Option(bov) => return recurse!(self, bov, transform, JVal::Array(vals)),
Value::Basic(Basic::Any) => {
let mut new_vals = Vec::with_capacity(vals.len());
for val in vals {
new_vals.push(self.transform(val)?);
}
PValue::Array(new_vals)
}
PValue::Array(new_vals)
}
JVal::Object(obj) => {
let mut new_obj = BTreeMap::new();
for (key, val) in obj {
new_obj.insert(key, self.transform(val)?);
Value::Array(bov) => {
let mut new_vals = Vec::with_capacity(vals.len());
for val in vals {
let val = recurse!(self, bov, transform, val)?;
new_vals.push(val);
}
PValue::Array(new_vals)
}
PValue::Object(new_obj)
}
Value::Union(candidates) => {
// First transform with the first candidate.
// If it fails validation afterwards, we need to start over.
'CandidateLoop: for c in candidates {
let mut new_vals = Vec::with_capacity(vals.len());
let vals = vals.clone();
for val in vals {
let res: Result<_, E> = recurse!(self, c, transform, val);
match res {
Ok(val) => new_vals.push(val),
Err(_) => continue 'CandidateLoop,
}
}
return Ok(PValue::Array(new_vals));
}

return Err(serde::de::Error::invalid_type(Unexpected::Seq, self));
}
Value::Basic(basic) => {
return Err(serde::de::Error::invalid_type(
Unexpected::Other(basic.expecting()),
self,
))
}
Value::Literal(lit) => {
return Err(serde::de::Error::invalid_type(
Unexpected::Other(lit.expecting_type()),
self,
))
}
Value::Map(_) | Value::Struct(_) => {
return Err(serde::de::Error::invalid_type(Unexpected::Map, self))
}
},
JVal::Object(obj) => match self.value {
Value::Ref(idx) => return recurse_ref!(self, idx, transform, JVal::Object(obj)),
Value::Option(bov) => return recurse!(self, bov, transform, JVal::Object(obj)),
Value::Basic(Basic::Any) => {
let mut new_obj = BTreeMap::new();
for (key, val) in obj {
new_obj.insert(key, self.transform(val)?);
}
PValue::Object(new_obj)
}
Value::Map(bov) => {
let mut new_obj = BTreeMap::new();
for (key, val) in obj {
let val = recurse!(self, bov, transform, val)?;
new_obj.insert(key, val);
}
PValue::Object(new_obj)
}
Value::Struct(Struct { fields }) => {
let mut new_obj = BTreeMap::new();
for (key, value) in obj {
match fields.get(key.as_str()) {
Some(entry) => {
let val = recurse!(self, &entry.value, transform, value)?;
new_obj.insert(key, val);
}
None => {
// Unknown field; ignore it.
}
}
}
PValue::Object(new_obj)
}
Value::Union(candidates) => {
// First transform with the first candidate.
// If it fails validation afterwards, we need to start over.
'CandidateLoop: for c in candidates {
let mut new_obj = BTreeMap::new();
let obj = obj.clone();
for (key, val) in obj {
let res: Result<_, E> = recurse!(self, c, transform, val);
match res {
Ok(val) => {
new_obj.insert(key, val);
}
Err(_) => continue 'CandidateLoop,
}
}
return Ok(PValue::Object(new_obj));
}

return Err(serde::de::Error::invalid_type(Unexpected::Map, self));
}
Value::Basic(basic) => {
return Err(serde::de::Error::invalid_type(
Unexpected::Other(basic.expecting()),
self,
))
}
Value::Literal(lit) => {
return Err(serde::de::Error::invalid_type(
Unexpected::Other(lit.expecting_type()),
self,
))
}
Value::Array(_) => {
return Err(serde::de::Error::invalid_type(Unexpected::Seq, self))
}
},
JVal::String(str) => match self.value {
Value::Ref(idx) => return recurse_ref!(self, idx, transform, JVal::String(str)),
Value::Option(bov) => return recurse!(self, bov, transform, JVal::String(str)),
Expand Down
13 changes: 11 additions & 2 deletions runtimes/core/src/api/jsonschema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,21 @@ impl JSONSchema {
SchemaDeserializer { cfg, schema: self }
}

pub fn deserialize<'de, T>(&self, de: T, cfg: DecodeConfig) -> Result<PValues, T::Error>
pub fn deserialize<'de, T>(
&self,
de: T,
cfg: DecodeConfig,
) -> Result<PValues, serde_path_to_error::Error<T::Error>>
where
T: Deserializer<'de>,
{
let seed = SchemaDeserializer { cfg, schema: self };
seed.deserialize(de)
let mut track = serde_path_to_error::Track::new();
let de = serde_path_to_error::Deserializer::new(de, &mut track);
match seed.deserialize(de) {
Ok(t) => Ok(t),
Err(err) => Err(serde_path_to_error::Error::new(track.path(), err)),
}
}

pub fn null() -> Self {
Expand Down

0 comments on commit 47fae67

Please sign in to comment.