Skip to content

Commit 95d591f

Browse files
authored
Merge pull request #134 from cobalt-language/serde-libs
Use `serde` to serialize library headers
2 parents 66d0dd9 + c119149 commit 95d591f

File tree

23 files changed

+1048
-969
lines changed

23 files changed

+1048
-969
lines changed

cobalt-ast/Cargo.toml

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,28 @@ description.workspace = true
99
documentation.workspace = true
1010

1111
[dependencies]
12-
cobalt-errors = { path = "../cobalt-errors" }
12+
cobalt-errors = { path = "../cobalt-errors", features = ["serde"] }
1313
cobalt-llvm = { path = "../cobalt-llvm", default-features = false }
1414
cobalt-utils = { path = "../cobalt-utils" }
1515
bitvec = "1.0.1"
1616
either = "1.9.0"
1717
glob = "0.3.1"
1818
owned_chars = "0.3.2"
19-
bstr = { version = "1.7.0", default-features = false, features = ["std"] }
19+
bstr = { version = "1.7.0", default-features = false, features = ["std", "serde"] }
2020
derive_more = "0.99.17"
2121
ref-cast = "1.0.20"
2222
slab = "0.4.9"
2323
inventory = "0.3.12"
2424
once_cell = "1.18.0"
2525
hashbrown = "0.14.2"
26-
flurry = "0.4.0"
26+
flurry = { version = "0.4.0", features = ["serde"] }
2727
thiserror = "1.0.51"
28+
serde = {version = "1.0", features = ["derive", "rc"]}
29+
serde_state = "0.4.8"
30+
serde_derive_state = "0.4.10"
31+
const-identify = "0.1.1"
32+
erased-serde = "0.4.1"
33+
deranged = { version = "0.3.10", features = ["serde"] }
34+
serde_json = "1.0.111"
35+
hex = { version = "0.4.3", features = ["serde"] }
36+
serde_bytes = "0.11.14"

cobalt-ast/src/ast/funcs.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,7 @@ impl<'src> AST<'src> for FnDefAST<'src> {
11881188
std::mem::drop(graph);
11891189
if dtor.is_some() {
11901190
if let Some(ty) = params
1191-
.get(0)
1191+
.first()
11921192
.and_then(|ty| ty.0.downcast::<types::Reference>())
11931193
.and_then(|ty| {
11941194
ty.base()
@@ -1475,7 +1475,7 @@ impl<'src> AST<'src> for FnDefAST<'src> {
14751475
std::mem::drop(graph);
14761476
if dtor.is_some() {
14771477
if let Some(ty) = params
1478-
.get(0)
1478+
.first()
14791479
.and_then(|ty| ty.0.downcast::<types::Reference>())
14801480
.and_then(|ty| {
14811481
ty.base()

cobalt-ast/src/ast/vars.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ impl<'src> AST<'src> for VarDefAST<'src> {
819819
val.name = self
820820
.name
821821
.ids
822-
.get(0)
822+
.first()
823823
.map(|x| (x.0.clone(), ctx.lex_scope.get()));
824824
val.frozen = (!self.is_mut).then_some(self.loc);
825825
match if ctx.is_const.get() || !self.is_mut {
@@ -851,7 +851,7 @@ impl<'src> AST<'src> for VarDefAST<'src> {
851851
val.name = self
852852
.name
853853
.ids
854-
.get(0)
854+
.first()
855855
.map(|x| (x.0.clone(), ctx.lex_scope.get()));
856856
ctx.with_vars(|v| {
857857
v.insert(

cobalt-ast/src/context.rs

Lines changed: 219 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,21 @@
1+
use crate::types::TYPE_SERIAL_REGISTRY;
12
use crate::*;
23
use cobalt_utils::misc::new_lifetime_mut;
34
use either::Either::{self, *};
45
use hashbrown::hash_map::{Entry, HashMap};
56
use hashbrown::HashSet;
67
use inkwell::{builder::Builder, context::Context, module::Module};
78
use owned_chars::OwnedCharsExt;
9+
use serde::de::DeserializeSeed;
810
use std::cell::{Cell, RefCell};
9-
use std::io::{self, BufRead, Read, Write};
11+
use std::fmt::{self, Debug, Formatter};
12+
use std::io::{Read, Write};
1013
use std::mem::MaybeUninit;
1114
use std::pin::Pin;
12-
use thiserror::Error;
1315

14-
type HeaderVersionType = u16;
1516
/// Simple number to check if a header is compatible for loading
1617
/// Bump this whenever a breaking change is made to the format
17-
const HEADER_FMT_VERSION: HeaderVersionType = 0;
18-
#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
19-
#[error("expected header version {HEADER_FMT_VERSION}, found version {0}")]
20-
pub struct HeaderVersionError(pub HeaderVersionType);
21-
impl From<HeaderVersionError> for io::Error {
22-
fn from(value: HeaderVersionError) -> Self {
23-
io::Error::new(io::ErrorKind::Other, value)
24-
}
25-
}
18+
const HEADER_FMT_VERSION: u16 = 0;
2619

2720
#[derive(Clone, PartialEq, Eq, Debug)]
2821
pub struct Flags {
@@ -33,6 +26,7 @@ pub struct Flags {
3326
pub all_move_metadata: bool,
3427
pub private_syms: bool,
3528
pub skip_header_version_check: bool,
29+
pub add_type_map: bool,
3630
}
3731
impl Default for Flags {
3832
fn default() -> Self {
@@ -44,6 +38,7 @@ impl Default for Flags {
4438
all_move_metadata: false,
4539
private_syms: true,
4640
skip_header_version_check: false,
41+
add_type_map: false,
4742
}
4843
}
4944
}
@@ -131,7 +126,7 @@ impl<'src, 'ctx> CompCtx<'src, 'ctx> {
131126
]
132127
.into_iter()
133128
.map(|(k, v)| (k.into(), v.into()))
134-
.collect::<HashMap<_, _>>()
129+
.collect::<std::collections::HashMap<_, _>>()
135130
.into(),
136131
))))),
137132
name: Cell::new(MaybeUninit::new(".".to_string())),
@@ -294,46 +289,11 @@ impl<'src, 'ctx> CompCtx<'src, 'ctx> {
294289
}
295290
Some(v)
296291
}
297-
pub fn save<W: Write>(&self, out: &mut W) -> io::Result<()> {
298-
out.write_all(&HEADER_FMT_VERSION.to_be_bytes())?;
299-
for info in inventory::iter::<types::TypeLoader> {
300-
out.write_all(&info.kind.get().to_be_bytes())?;
301-
(info.save_header)(out)?;
302-
}
303-
out.write_all(&[0])?;
304-
self.with_vars(|v| v.save(out))
292+
pub fn save<W: Write>(&self, buf: &mut W) -> serde_json::Result<()> {
293+
serde_json::to_writer(buf, self)
305294
}
306-
pub fn load<R: Read + BufRead>(&self, buf: &mut R) -> io::Result<Vec<Cow<'src, str>>> {
307-
{
308-
let mut arr = [0; std::mem::size_of::<HeaderVersionType>()];
309-
buf.read_exact(&mut arr)?;
310-
let version = HeaderVersionType::from_be_bytes(arr);
311-
if !(self.flags.skip_header_version_check || version == HEADER_FMT_VERSION) {
312-
Err(HeaderVersionError(version))?;
313-
}
314-
}
315-
let mut out = vec![];
316-
while !buf.fill_buf()?.is_empty() {
317-
let mut bytes = [0u8; 8];
318-
loop {
319-
buf.read_exact(&mut bytes)?;
320-
let Some(kind) = std::num::NonZeroU64::new(u64::from_be_bytes(bytes)) else {
321-
break;
322-
};
323-
let guard = types::TYPE_SERIAL_REGISTRY.guard();
324-
let info = types::TYPE_SERIAL_REGISTRY
325-
.get(&kind, &guard)
326-
.ok_or_else(|| {
327-
io::Error::new(
328-
io::ErrorKind::InvalidData,
329-
format!("unknown type ID {kind:8>0X}"),
330-
)
331-
})?;
332-
(info.load_header)(buf)?;
333-
}
334-
out.append(&mut self.with_vars(|v| v.load(buf, self))?);
335-
}
336-
Ok(out)
295+
pub fn load<R: Read>(&self, buf: &mut R) -> serde_json::Result<Vec<String>> {
296+
self.deserialize(&mut serde_json::Deserializer::from_reader(buf))
337297
}
338298
}
339299
impl Drop for CompCtx<'_, '_> {
@@ -346,3 +306,210 @@ impl Drop for CompCtx<'_, '_> {
346306
}
347307
}
348308
}
309+
struct FnDeserializer<
310+
T,
311+
F: FnOnce(&mut dyn erased_serde::Deserializer) -> Result<T, erased_serde::Error>,
312+
>(pub F);
313+
impl<'de, T, F: FnOnce(&mut dyn erased_serde::Deserializer) -> Result<T, erased_serde::Error>>
314+
DeserializeSeed<'de> for FnDeserializer<T, F>
315+
{
316+
type Value = T;
317+
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
318+
where
319+
D: Deserializer<'de>,
320+
{
321+
(self.0)(&mut <dyn erased_serde::Deserializer>::erase(deserializer))
322+
.map_err(de::Error::custom)
323+
}
324+
}
325+
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
326+
#[serde(transparent)]
327+
struct HexArray(#[serde(with = "hex::serde")] [u8; 8]);
328+
struct CtxTypeSerde<'a, 's, 'c>(&'a CompCtx<'s, 'c>);
329+
impl Serialize for CtxTypeSerde<'_, '_, '_> {
330+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
331+
where
332+
S: Serializer,
333+
{
334+
use ser::*;
335+
let tsr = TYPE_SERIAL_REGISTRY.pin();
336+
let mut map = serializer.serialize_map(Some(
337+
tsr.iter().filter(|(_, info)| (info.has_header)()).count(),
338+
))?;
339+
for (id, info) in &tsr {
340+
if (info.has_header)() {
341+
map.serialize_entry(&HexArray(id.to_le_bytes()), &(info.erased_header)())?;
342+
}
343+
}
344+
map.end()
345+
}
346+
}
347+
impl<'de> de::Visitor<'de> for CtxTypeSerde<'_, '_, '_> {
348+
type Value = ();
349+
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
350+
formatter.write_str("a map of type headers")
351+
}
352+
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
353+
where
354+
A: de::MapAccess<'de>,
355+
{
356+
let tsr = TYPE_SERIAL_REGISTRY.pin();
357+
while let Some(id) = map.next_key::<HexArray>()? {
358+
let Some(loader) = tsr.get(&u64::from_le_bytes(id.0)) else {
359+
return Err(de::Error::custom("unknown type ID {:0>16x}"));
360+
};
361+
map.next_value_seed(FnDeserializer(loader.load_header))?;
362+
}
363+
Ok(())
364+
}
365+
}
366+
impl<'de> DeserializeSeed<'de> for CtxTypeSerde<'_, '_, '_> {
367+
type Value = ();
368+
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
369+
where
370+
D: Deserializer<'de>,
371+
{
372+
deserializer.deserialize_map(self)
373+
}
374+
}
375+
impl Serialize for CompCtx<'_, '_> {
376+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
377+
where
378+
S: Serializer,
379+
{
380+
use ser::*;
381+
SERIALIZATION_CONTEXT.with(|c| {
382+
let p = unsafe {
383+
std::mem::transmute::<*const CompCtx<'_, '_>, *const CompCtx<'static, 'static>>(
384+
self as _,
385+
)
386+
};
387+
if let Some(ptr) = c.replace(Some(ContextPointer::new(p))) {
388+
if *ptr != p {
389+
panic!("serialization context is already in use with an address of {ptr:#?}");
390+
}
391+
}
392+
});
393+
let mut map =
394+
serializer.serialize_struct("Context", 3 + usize::from(self.flags.add_type_map))?;
395+
map.serialize_field("version", &HEADER_FMT_VERSION)?;
396+
if self.flags.add_type_map {
397+
map.serialize_field(
398+
"names",
399+
&TYPE_SERIAL_REGISTRY
400+
.pin()
401+
.iter()
402+
.map(|(k, v)| (hex::encode(k.to_le_bytes()), v.name))
403+
.collect::<hashbrown::HashMap<_, _>>(),
404+
)?;
405+
}
406+
map.serialize_field("types", &CtxTypeSerde(self))?;
407+
self.with_vars(|v| map.serialize_field("vars", v))?;
408+
SERIALIZATION_CONTEXT.with(|c| {
409+
c.replace(None)
410+
.expect("serialization context is empty after serialization")
411+
});
412+
map.end()
413+
}
414+
}
415+
#[derive(Deserialize)]
416+
#[serde(bound = "'a: 'de")]
417+
struct ContextDeProxy<'a> {
418+
version: u16,
419+
#[serde(rename = "names")]
420+
_names: Option<serde::de::IgnoredAny>, // if it gets into the serialization, ignore it for deserialization - it should be stable
421+
#[serde(borrow = "'a")]
422+
types: serde::__private::de::Content<'a>,
423+
#[serde(borrow = "'a")]
424+
vars: serde::__private::de::Content<'a>,
425+
}
426+
impl<'de> DeserializeSeed<'de> for &CompCtx<'_, '_> {
427+
type Value = Vec<String>;
428+
fn deserialize<D>(mut self, deserializer: D) -> Result<Self::Value, D::Error>
429+
where
430+
D: Deserializer<'de>,
431+
{
432+
use de::*;
433+
SERIALIZATION_CONTEXT.with(|c| {
434+
let p = unsafe {
435+
std::mem::transmute::<*const CompCtx<'_, '_>, *const CompCtx<'static, 'static>>(
436+
self as _,
437+
)
438+
};
439+
if let Some(ptr) = c.replace(Some(ContextPointer::new(p))) {
440+
if *ptr != p {
441+
panic!("serialization context is already in use with an address of {ptr:#?}");
442+
}
443+
}
444+
});
445+
let proxy = ContextDeProxy::deserialize(deserializer)?;
446+
if proxy.version != HEADER_FMT_VERSION {
447+
return Err(D::Error::custom(format!("this header was saved with version {}, but version {HEADER_FMT_VERSION} is expected", proxy.version)));
448+
}
449+
CtxTypeSerde(self)
450+
.deserialize(serde::__private::de::ContentDeserializer::new(proxy.types))?;
451+
let vars = VarMap::deserialize_state(
452+
&mut self,
453+
serde::__private::de::ContentDeserializer::new(proxy.vars),
454+
)?;
455+
SERIALIZATION_CONTEXT.with(|c| {
456+
c.replace(None)
457+
.expect("serialization context is empty after serialization")
458+
});
459+
Ok(self.with_vars(|v| varmap::merge(&mut v.symbols, vars.symbols)))
460+
}
461+
}
462+
463+
/// Wrapper around a context pointer, maybe with a backtace
464+
pub struct ContextPointer {
465+
ptr: *const CompCtx<'static, 'static>,
466+
#[cfg(debug_assertions)]
467+
trace: std::backtrace::Backtrace,
468+
}
469+
impl ContextPointer {
470+
pub fn new(ptr: *const CompCtx<'static, 'static>) -> Self {
471+
Self {
472+
trace: std::backtrace::Backtrace::capture(),
473+
ptr,
474+
}
475+
}
476+
}
477+
impl std::ops::Deref for ContextPointer {
478+
type Target = *const CompCtx<'static, 'static>;
479+
fn deref(&self) -> &Self::Target {
480+
&self.ptr
481+
}
482+
}
483+
impl Debug for ContextPointer {
484+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
485+
write!(f, "{:p}", self.ptr)?;
486+
#[cfg(debug_assertions)]
487+
{
488+
if f.alternate() {
489+
if self.trace.status() == std::backtrace::BacktraceStatus::Captured {
490+
write!(f, "at: \n{}", self.trace)?;
491+
} else {
492+
f.write_str("without backtrace")?;
493+
}
494+
}
495+
}
496+
Ok(())
497+
}
498+
}
499+
/// Get the context pointer from a cell
500+
/// Super unsafe lmao
501+
///
502+
/// # Safety
503+
/// `SERIALIZATION_CONTEXT`` must be valid
504+
pub unsafe fn get_ctx_ptr<'a, 's, 'c>(cell: &Cell<Option<ContextPointer>>) -> &'a CompCtx<'s, 'c> {
505+
let opt = cell.replace(None);
506+
let cp = opt.expect("expected pointer in serialization context");
507+
let ptr = cp.ptr;
508+
cell.set(Some(cp));
509+
#[allow(clippy::unnecessary_cast)]
510+
&*std::mem::transmute::<*const CompCtx<'static, 'static>, *const CompCtx<'s, 'c>>(ptr)
511+
}
512+
thread_local! {
513+
/// CompCtx, should only have a value during de/serialization
514+
pub static SERIALIZATION_CONTEXT: Cell<Option<ContextPointer>> = const {Cell::new(None)};
515+
}

0 commit comments

Comments
 (0)