Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 103 additions & 32 deletions hugr-core/src/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ use crate::{extension::ExtensionRegistry, package::Package};
use header::EnvelopeHeader;
use std::io::BufRead;
use std::io::Write;
use std::str::FromStr;

#[allow(unused_imports)]
use itertools::Itertools as _;

use crate::import::ImportError;
use crate::{import::import_package, Extension};

/// Read a HUGR envelope from a reader.
///
Expand Down Expand Up @@ -219,6 +221,16 @@ pub enum EnvelopeError {
/// The source error.
source: hugr_model::v0::binary::WriteError,
},
/// Error reading a HUGR model payload.
ModelTextRead {
/// The source error.
source: hugr_model::v0::ast::ParseError,
},
/// Error reading a HUGR model payload.
ModelTextResolve {
/// The source error.
source: hugr_model::v0::ast::ResolveError,
},
}

/// Internal implementation of [`read_envelope`] to call with/without the zstd decompression wrapper.
Expand All @@ -233,6 +245,9 @@ fn read_impl(
EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
decode_model(payload, registry, header.format)
}
EnvelopeFormat::ModelText | EnvelopeFormat::ModelTextWithExtensions => {
decode_model_ast(payload, registry, header.format)
}
}
}

Expand All @@ -248,7 +263,6 @@ fn decode_model(
extension_registry: &ExtensionRegistry,
format: EnvelopeFormat,
) -> Result<Package, EnvelopeError> {
use crate::{import::import_package, Extension};
use hugr_model::v0::bumpalo::Bump;

if format.model_version() != Some(0) {
Expand All @@ -262,7 +276,7 @@ fn decode_model(
let model_package = hugr_model::v0::binary::read_from_reader(&mut stream, &bump)?;

let mut extension_registry = extension_registry.clone();
if format.append_extensions() {
if format == EnvelopeFormat::ModelWithExtensions {
let extra_extensions: Vec<Extension> =
serde_json::from_reader::<_, Vec<Extension>>(stream)?;
for ext in extra_extensions {
Expand All @@ -273,6 +287,54 @@ fn decode_model(
Ok(import_package(&model_package, &extension_registry)?)
}

/// Read a HUGR model text payload from a reader.
///
/// Parameters:
/// - `stream`: The reader to read the envelope from.
/// - `extension_registry`: An extension registry with additional extensions to use when
/// decoding the HUGR, if they are not already included in the package.
/// - `format`: The format of the payload.
fn decode_model_ast(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no test coverage of this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the tests only checked the roundtrip of the header.

I added a new envelope roundtrip test for al model kinds, for empty/single/multi-hugr packages.

And also fixed a bug with the decoding :)

mut stream: impl BufRead,
extension_registry: &ExtensionRegistry,
format: EnvelopeFormat,
) -> Result<Package, EnvelopeError> {
use crate::import::import_package;
use hugr_model::v0::bumpalo::Bump;

if format.model_version() != Some(0) {
return Err(EnvelopeError::FormatUnsupported {
format,
feature: None,
});
}

let mut extension_registry = extension_registry.clone();
if format == EnvelopeFormat::ModelTextWithExtensions {
let deserializer = serde_json::Deserializer::from_reader(&mut stream);
// Deserialize the first json object, leaving the rest of the reader unconsumed.
let extra_extensions = deserializer
.into_iter::<Vec<Extension>>()
.next()
.unwrap_or(Ok(vec![]))?;
for ext in extra_extensions {
extension_registry.register_updated(ext);
}
}

// Read the package into a string, then parse it.
//
// Due to how `to_string` works, we cannot append extensions after the package.
let mut buffer = String::new();
stream.read_to_string(&mut buffer)?;
let ast_package = hugr_model::v0::ast::Package::from_str(&buffer)?;

let bump = Bump::default();
let model_package = ast_package.resolve(&bump)?;

Ok(import_package(&model_package, &extension_registry)?)
}

/// Internal implementation of [`write_envelope`] to call with/without the zstd compression wrapper.
fn write_impl<'h>(
writer: impl Write,
Expand All @@ -283,7 +345,10 @@ fn write_impl<'h>(
match config.format {
#[allow(deprecated)]
EnvelopeFormat::PackageJson => package_json::to_json_writer(hugrs, extensions, writer)?,
EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
EnvelopeFormat::Model
| EnvelopeFormat::ModelWithExtensions
| EnvelopeFormat::ModelText
| EnvelopeFormat::ModelTextWithExtensions => {
encode_model(writer, hugrs, extensions, config.format)?
}
}
Expand All @@ -307,11 +372,27 @@ fn encode_model<'h>(
});
}

// Prepend extensions for binary model.
if format == EnvelopeFormat::ModelTextWithExtensions {
serde_json::to_writer(&mut writer, &extensions.iter().collect_vec())?;
}

let bump = Bump::default();
let model_package = export_package(hugrs, extensions, &bump);
write_to_writer(&model_package, &mut writer)?;

if format.append_extensions() {
match format {
EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
write_to_writer(&model_package, &mut writer)?;
}
EnvelopeFormat::ModelText | EnvelopeFormat::ModelTextWithExtensions => {
let model_package = model_package.as_ast().unwrap();
writeln!(writer, "{model_package}")?;
}
_ => unreachable!(),
}

// Apend extensions for binary model.
if format == EnvelopeFormat::ModelWithExtensions {
serde_json::to_writer(writer, &extensions.iter().collect_vec())?;
}

Expand Down Expand Up @@ -418,34 +499,24 @@ pub(crate) mod test {
}

#[rstest]
//#[case::empty(Package::default())] // Not currently supported
#[case::simple(simple_package())]
//#[case::multi(multi_module_package())] // Not currently supported
fn module_exts_roundtrip(#[case] package: Package) {
// Empty packages
#[case::empty_model(Package::default(), EnvelopeFormat::Model)]
#[case::empty_model_exts(Package::default(), EnvelopeFormat::ModelWithExtensions)]
#[case::empty_text(Package::default(), EnvelopeFormat::ModelText)]
#[case::empty_text_exts(Package::default(), EnvelopeFormat::ModelTextWithExtensions)]
// Single hugrs
#[case::simple_bin(simple_package(), EnvelopeFormat::Model)]
#[case::simple_bin_exts(simple_package(), EnvelopeFormat::ModelWithExtensions)]
#[case::simple_text(simple_package(), EnvelopeFormat::ModelText)]
#[case::simple_text_exts(simple_package(), EnvelopeFormat::ModelTextWithExtensions)]
// Multiple hugrs
#[case::multi_bin(multi_module_package(), EnvelopeFormat::Model)]
#[case::multi_bin_exts(multi_module_package(), EnvelopeFormat::ModelWithExtensions)]
#[case::multi_text(multi_module_package(), EnvelopeFormat::ModelText)]
#[case::multi_text_exts(multi_module_package(), EnvelopeFormat::ModelTextWithExtensions)]
fn model_roundtrip(#[case] package: Package, #[case] format: EnvelopeFormat) {
let mut buffer = Vec::new();
let config = EnvelopeConfig {
format: EnvelopeFormat::ModelWithExtensions,
zstd: None,
};
package.store(&mut buffer, config).unwrap();
let (decoded_config, new_package) =
read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap();

assert_eq!(config.format, decoded_config.format);
assert_eq!(config.zstd.is_some(), decoded_config.zstd.is_some());
assert_eq!(package, new_package);
}

#[rstest]
//#[case::empty(Package::default())] // Not currently supported
#[case::simple(simple_package())]
//#[case::multi(multi_module_package())] // Not currently supported
fn module_roundtrip(#[case] package: Package) {
let mut buffer = Vec::new();
let config = EnvelopeConfig {
format: EnvelopeFormat::Model,
zstd: None,
};
let config = EnvelopeConfig { format, zstd: None };
package.store(&mut buffer, config).unwrap();

let (decoded_config, new_package) =
Expand Down
50 changes: 39 additions & 11 deletions hugr-core/src/envelope/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,30 @@ pub(super) struct EnvelopeHeader {
pub enum EnvelopeFormat {
/// `hugr-model` v0 binary capnproto message.
Model = 1,
/// `hugr-model` v0 binary capnproto message followed by a json-encoded [crate::extension::ExtensionRegistry].
//
// This is a temporary format required until the model adds support for extensions.
/// `hugr-model` v0 binary capnproto message followed by a json-encoded
/// [crate::extension::ExtensionRegistry].
///
/// This is a temporary format required until the model adds support for
/// extensions.
ModelWithExtensions = 2,
/// Human-readable S-expression encoding using [`hugr_model::v0`].
///
/// Uses a printable ascii value as the discriminant so the envelope can be
/// read as text.
///
/// :caution: This format does not yet support extension encoding, so it should
/// be avoided.
//
// TODO: Update comment once extension encoding is supported.
ModelText = 40, // '(' in ascii
/// Human-readable S-expression encoding using [`hugr_model::v0`].
///
/// Uses a printable ascii value as the discriminant so the envelope can be
/// read as text.
///
/// This is a temporary format required until the model adds support for
/// extensions.
ModelTextWithExtensions = 41, // ')' in ascii
/// Json-encoded [crate::package::Package]
///
/// Uses a printable ascii value as the discriminant so the envelope can be
Expand All @@ -50,15 +70,13 @@ pub enum EnvelopeFormat {
static_assertions::assert_eq_size!(EnvelopeFormat, u8);

impl EnvelopeFormat {
/// Returns whether to encode the extensions as json after the hugr payload.
pub fn append_extensions(self) -> bool {
matches!(self, Self::ModelWithExtensions)
}

/// If the format is a model format, returns its version number.
pub fn model_version(self) -> Option<u32> {
match self {
Self::Model | Self::ModelWithExtensions => Some(0),
Self::Model
| Self::ModelWithExtensions
| Self::ModelText
| Self::ModelTextWithExtensions => Some(0),
_ => None,
}
}
Expand All @@ -67,7 +85,10 @@ impl EnvelopeFormat {
///
/// If true, the encoded envelope can be read as text.
pub fn ascii_printable(self) -> bool {
matches!(self, Self::PackageJson)
matches!(
self,
Self::PackageJson | Self::ModelText | Self::ModelTextWithExtensions
)
}
}

Expand Down Expand Up @@ -117,7 +138,7 @@ impl EnvelopeConfig {
pub const fn binary() -> Self {
Self {
format: EnvelopeFormat::ModelWithExtensions,
zstd: None,
zstd: Some(ZstdConfig::default_level()),
}
}
}
Expand All @@ -137,6 +158,11 @@ pub struct ZstdConfig {
}

impl ZstdConfig {
/// Create a new zstd configuration with default compression level.
pub const fn default_level() -> Self {
Self { level: None }
}

/// Returns the zstd compression level to pass to the zstd library.
///
/// Uses [zstd::DEFAULT_COMPRESSION_LEVEL] if the level is not set.
Expand Down Expand Up @@ -224,6 +250,8 @@ mod tests {
#[rstest]
#[case(EnvelopeFormat::Model)]
#[case(EnvelopeFormat::ModelWithExtensions)]
#[case(EnvelopeFormat::ModelText)]
#[case(EnvelopeFormat::ModelTextWithExtensions)]
#[case(EnvelopeFormat::PackageJson)]
fn header_round_trip(#[case] format: EnvelopeFormat) {
// With zstd compression
Expand Down
1 change: 1 addition & 0 deletions hugr-model/src/v0/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod python;
mod resolve;
mod view;

pub use parse::ParseError;
pub use resolve::ResolveError;

/// A package in the hugr AST.
Expand Down
Loading