Skip to content

Commit 0381921

Browse files
committed
feat: include generator in package reading errors
1 parent 99ced9b commit 0381921

File tree

2 files changed

+83
-9
lines changed

2 files changed

+83
-9
lines changed

hugr-core/src/envelope.rs

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,13 @@ pub mod serde_with;
4747
pub use header::{EnvelopeConfig, EnvelopeFormat, MAGIC_NUMBERS, ZstdConfig};
4848
pub use package_json::PackageEncodingError;
4949

50-
use crate::Hugr;
50+
use crate::{Hugr, HugrView};
5151
use crate::{
5252
extension::{ExtensionRegistry, Version},
5353
package::Package,
5454
};
5555
use header::EnvelopeHeader;
56+
use std::collections::HashSet;
5657
use std::io::BufRead;
5758
use std::io::Write;
5859
use std::str::FromStr;
@@ -64,6 +65,49 @@ use itertools::Itertools as _;
6465
use crate::import::ImportError;
6566
use crate::{Extension, import::import_package};
6667

68+
/// Key used to store the name of the generator that produced the envelope.
69+
pub const GENERATOR_KEY: &str = "__generator";
70+
71+
/// Get the name of the generator from the metadata of the HUGR modules.
72+
/// If multiple modules have different generators, only the first one is returned.
73+
fn get_generator<H: HugrView>(modules: &[H]) -> Option<String> {
74+
let generators: HashSet<String> = modules
75+
.iter()
76+
.filter_map(|hugr| hugr.get_metadata(hugr.module_root(), GENERATOR_KEY))
77+
.map(|v| v.to_string())
78+
.collect();
79+
debug_assert!(
80+
generators.len() <= 1,
81+
"Multiple generators found in the package metadata: {generators:?}"
82+
);
83+
generators.into_iter().next()
84+
}
85+
86+
fn gen_str(generator: &Option<String>) -> String {
87+
match generator {
88+
Some(g) => format!("\ngenerated by {g}"),
89+
None => String::new(),
90+
}
91+
}
92+
93+
/// Wrap an error with a generator string.
94+
#[derive(Error, Debug)]
95+
#[error("{inner}{}", gen_str(&self.generator))]
96+
pub struct WithGenerator<E: std::fmt::Display> {
97+
inner: E,
98+
/// The name of the generator that produced the envelope, if any.
99+
generator: Option<String>,
100+
}
101+
102+
impl<E: std::fmt::Display> WithGenerator<E> {
103+
fn new(err: E, modules: &[impl HugrView]) -> Self {
104+
Self {
105+
inner: err,
106+
generator: get_generator(modules),
107+
}
108+
}
109+
}
110+
67111
/// Read a HUGR envelope from a reader.
68112
///
69113
/// Returns the deserialized package and the configuration used to encode it.
@@ -216,6 +260,7 @@ pub enum EnvelopeError {
216260
/// The source error.
217261
#[from]
218262
source: ImportError,
263+
// TODO add generator to model import errors
219264
},
220265
/// Error reading a HUGR model payload.
221266
#[error(transparent)]
@@ -454,7 +499,7 @@ fn check_breaking_extensions(
454499
hugr: impl crate::HugrView,
455500
registry: &ExtensionRegistry,
456501
) -> Result<(), ExtensionBreakingError> {
457-
let Some(exts) = hugr.get_metadata(hugr.module_root(), &USED_EXTENSIONS_KEY) else {
502+
let Some(exts) = hugr.get_metadata(hugr.module_root(), USED_EXTENSIONS_KEY) else {
458503
return Ok(()); // No used extensions metadata, nothing to check
459504
};
460505
let used_exts: Vec<UsedExtension> = serde_json::from_value(exts.clone())?; // TODO handle errors properly
@@ -690,4 +735,28 @@ pub(crate) mod test {
690735
Err(ExtensionBreakingError::Deserialization(_))
691736
);
692737
}
738+
739+
#[test]
740+
fn test_with_generator_error_message() {
741+
let test_ext = Extension::new(ExtensionId::new_unchecked("test"), Version::new(1, 0, 0));
742+
let registry = ExtensionRegistry::new([Arc::new(test_ext)]);
743+
744+
let mut hugr = simple_package().modules.remove(0);
745+
746+
// Set a generator name in the metadata
747+
let generator_name = json!({ "name": "TestGenerator", "version": "1.2.3" });
748+
hugr.set_metadata(hugr.module_root(), GENERATOR_KEY, generator_name.clone());
749+
750+
// Set incompatible extension version in metadata
751+
let used_exts = json!([{ "name": "test", "version": "2.0.0" }]);
752+
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
753+
754+
// Create the error and wrap it with WithGenerator
755+
let err = check_breaking_extensions(&hugr, &registry).unwrap_err();
756+
let with_gen = WithGenerator::new(err, &[&hugr]);
757+
758+
let err_msg = with_gen.to_string();
759+
assert!(err_msg.contains("Extension 'test' version mismatch"));
760+
assert!(err_msg.contains(generator_name.to_string().as_str()));
761+
}
693762
}

hugr-core/src/envelope/package_json.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use derive_more::{Display, Error, From};
33
use itertools::Itertools;
44
use std::io;
55

6+
use super::{ExtensionBreakingError, WithGenerator, check_breaking_extensions};
67
use crate::extension::ExtensionRegistry;
78
use crate::extension::resolution::ExtensionResolutionError;
89
use crate::hugr::ExtensionError;
@@ -21,20 +22,24 @@ pub(super) fn from_json_reader(
2122
extensions: pkg_extensions,
2223
} = serde_json::from_value::<PackageDeser>(val.clone())?;
2324
let mut modules = modules.into_iter().map(|h| h.0).collect_vec();
24-
2525
let pkg_extensions = ExtensionRegistry::new_with_extension_resolution(
2626
pkg_extensions,
2727
&extension_registry.into(),
28-
)?;
28+
)
29+
.map_err(|err| WithGenerator::new(err, &modules))?;
2930

3031
// Resolve the operations in the modules using the defined registries.
3132
let mut combined_registry = extension_registry.clone();
3233
combined_registry.extend(&pkg_extensions);
3334

34-
for module in &mut modules {
35-
super::check_breaking_extensions(&module, &combined_registry)?;
36-
module.resolve_extension_defs(&combined_registry)?;
35+
for module in &modules {
36+
check_breaking_extensions(module, &combined_registry)
37+
.map_err(|err| WithGenerator::new(err, &modules))?;
3738
}
39+
modules
40+
.iter_mut()
41+
.try_for_each(|module| module.resolve_extension_defs(&combined_registry))
42+
.map_err(|err| WithGenerator::new(err, &modules))?;
3843

3944
Ok(Package {
4045
modules,
@@ -65,9 +70,9 @@ pub enum PackageEncodingError {
6570
/// Error raised while reading from a file.
6671
IOError(io::Error),
6772
/// Could not resolve the extension needed to encode the hugr.
68-
ExtensionResolution(ExtensionResolutionError),
73+
ExtensionResolution(WithGenerator<ExtensionResolutionError>),
6974
/// Error raised while checking for breaking extension version mismatch.
70-
ExtensionVersion(super::ExtensionBreakingError),
75+
ExtensionVersion(WithGenerator<ExtensionBreakingError>),
7176
/// Could not resolve the runtime extensions for the hugr.
7277
RuntimeExtensionResolution(ExtensionError),
7378
}

0 commit comments

Comments
 (0)