1+ use crate :: types:: TYPE_SERIAL_REGISTRY ;
12use crate :: * ;
23use cobalt_utils:: misc:: new_lifetime_mut;
34use either:: Either :: { self , * } ;
45use hashbrown:: hash_map:: { Entry , HashMap } ;
56use hashbrown:: HashSet ;
67use inkwell:: { builder:: Builder , context:: Context , module:: Module } ;
78use owned_chars:: OwnedCharsExt ;
9+ use serde:: de:: DeserializeSeed ;
810use std:: cell:: { Cell , RefCell } ;
9- use std:: io:: { self , BufRead , Read , Write } ;
11+ use std:: fmt:: { self , Debug , Formatter } ;
12+ use std:: io:: { Read , Write } ;
1013use std:: mem:: MaybeUninit ;
1114use std:: pin:: Pin ;
12- use thiserror:: Error ;
1315
14- type HeaderVersionType = u16 ;
1516/// Simple number to check if a header is compatible for loading
1617/// Bump this whenever a breaking change is made to the format
17- const HEADER_FMT_VERSION : HeaderVersionType = 0 ;
18- #[ derive( Debug , Clone , Copy , PartialEq , Eq , Error ) ]
19- #[ error( "expected header version {HEADER_FMT_VERSION}, found version {0}" ) ]
20- pub struct HeaderVersionError ( pub HeaderVersionType ) ;
21- impl From < HeaderVersionError > for io:: Error {
22- fn from ( value : HeaderVersionError ) -> Self {
23- io:: Error :: new ( io:: ErrorKind :: Other , value)
24- }
25- }
18+ const HEADER_FMT_VERSION : u16 = 0 ;
2619
2720#[ derive( Clone , PartialEq , Eq , Debug ) ]
2821pub struct Flags {
@@ -33,6 +26,7 @@ pub struct Flags {
3326 pub all_move_metadata : bool ,
3427 pub private_syms : bool ,
3528 pub skip_header_version_check : bool ,
29+ pub add_type_map : bool ,
3630}
3731impl Default for Flags {
3832 fn default ( ) -> Self {
@@ -44,6 +38,7 @@ impl Default for Flags {
4438 all_move_metadata : false ,
4539 private_syms : true ,
4640 skip_header_version_check : false ,
41+ add_type_map : false ,
4742 }
4843 }
4944}
@@ -131,7 +126,7 @@ impl<'src, 'ctx> CompCtx<'src, 'ctx> {
131126 ]
132127 . into_iter ( )
133128 . map ( |( k, v) | ( k. into ( ) , v. into ( ) ) )
134- . collect :: < HashMap < _ , _ > > ( )
129+ . collect :: < std :: collections :: HashMap < _ , _ > > ( )
135130 . into ( ) ,
136131 ) ) ) ) ) ,
137132 name : Cell :: new ( MaybeUninit :: new ( "." . to_string ( ) ) ) ,
@@ -294,46 +289,11 @@ impl<'src, 'ctx> CompCtx<'src, 'ctx> {
294289 }
295290 Some ( v)
296291 }
297- pub fn save < W : Write > ( & self , out : & mut W ) -> io:: Result < ( ) > {
298- out. write_all ( & HEADER_FMT_VERSION . to_be_bytes ( ) ) ?;
299- for info in inventory:: iter :: < types:: TypeLoader > {
300- out. write_all ( & info. kind . get ( ) . to_be_bytes ( ) ) ?;
301- ( info. save_header ) ( out) ?;
302- }
303- out. write_all ( & [ 0 ] ) ?;
304- self . with_vars ( |v| v. save ( out) )
292+ pub fn save < W : Write > ( & self , buf : & mut W ) -> serde_json:: Result < ( ) > {
293+ serde_json:: to_writer ( buf, self )
305294 }
306- pub fn load < R : Read + BufRead > ( & self , buf : & mut R ) -> io:: Result < Vec < Cow < ' src , str > > > {
307- {
308- let mut arr = [ 0 ; std:: mem:: size_of :: < HeaderVersionType > ( ) ] ;
309- buf. read_exact ( & mut arr) ?;
310- let version = HeaderVersionType :: from_be_bytes ( arr) ;
311- if !( self . flags . skip_header_version_check || version == HEADER_FMT_VERSION ) {
312- Err ( HeaderVersionError ( version) ) ?;
313- }
314- }
315- let mut out = vec ! [ ] ;
316- while !buf. fill_buf ( ) ?. is_empty ( ) {
317- let mut bytes = [ 0u8 ; 8 ] ;
318- loop {
319- buf. read_exact ( & mut bytes) ?;
320- let Some ( kind) = std:: num:: NonZeroU64 :: new ( u64:: from_be_bytes ( bytes) ) else {
321- break ;
322- } ;
323- let guard = types:: TYPE_SERIAL_REGISTRY . guard ( ) ;
324- let info = types:: TYPE_SERIAL_REGISTRY
325- . get ( & kind, & guard)
326- . ok_or_else ( || {
327- io:: Error :: new (
328- io:: ErrorKind :: InvalidData ,
329- format ! ( "unknown type ID {kind:8>0X}" ) ,
330- )
331- } ) ?;
332- ( info. load_header ) ( buf) ?;
333- }
334- out. append ( & mut self . with_vars ( |v| v. load ( buf, self ) ) ?) ;
335- }
336- Ok ( out)
295+ pub fn load < R : Read > ( & self , buf : & mut R ) -> serde_json:: Result < Vec < String > > {
296+ self . deserialize ( & mut serde_json:: Deserializer :: from_reader ( buf) )
337297 }
338298}
339299impl Drop for CompCtx < ' _ , ' _ > {
@@ -346,3 +306,210 @@ impl Drop for CompCtx<'_, '_> {
346306 }
347307 }
348308}
309+ struct FnDeserializer <
310+ T ,
311+ F : FnOnce ( & mut dyn erased_serde:: Deserializer ) -> Result < T , erased_serde:: Error > ,
312+ > ( pub F ) ;
313+ impl < ' de , T , F : FnOnce ( & mut dyn erased_serde:: Deserializer ) -> Result < T , erased_serde:: Error > >
314+ DeserializeSeed < ' de > for FnDeserializer < T , F >
315+ {
316+ type Value = T ;
317+ fn deserialize < D > ( self , deserializer : D ) -> Result < Self :: Value , D :: Error >
318+ where
319+ D : Deserializer < ' de > ,
320+ {
321+ ( self . 0 ) ( & mut <dyn erased_serde:: Deserializer >:: erase ( deserializer) )
322+ . map_err ( de:: Error :: custom)
323+ }
324+ }
325+ #[ derive( Debug , Clone , Copy , Serialize , Deserialize ) ]
326+ #[ serde( transparent) ]
327+ struct HexArray ( #[ serde( with = "hex::serde" ) ] [ u8 ; 8 ] ) ;
328+ struct CtxTypeSerde < ' a , ' s , ' c > ( & ' a CompCtx < ' s , ' c > ) ;
329+ impl Serialize for CtxTypeSerde < ' _ , ' _ , ' _ > {
330+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
331+ where
332+ S : Serializer ,
333+ {
334+ use ser:: * ;
335+ let tsr = TYPE_SERIAL_REGISTRY . pin ( ) ;
336+ let mut map = serializer. serialize_map ( Some (
337+ tsr. iter ( ) . filter ( |( _, info) | ( info. has_header ) ( ) ) . count ( ) ,
338+ ) ) ?;
339+ for ( id, info) in & tsr {
340+ if ( info. has_header ) ( ) {
341+ map. serialize_entry ( & HexArray ( id. to_le_bytes ( ) ) , & ( info. erased_header ) ( ) ) ?;
342+ }
343+ }
344+ map. end ( )
345+ }
346+ }
347+ impl < ' de > de:: Visitor < ' de > for CtxTypeSerde < ' _ , ' _ , ' _ > {
348+ type Value = ( ) ;
349+ fn expecting ( & self , formatter : & mut std:: fmt:: Formatter ) -> std:: fmt:: Result {
350+ formatter. write_str ( "a map of type headers" )
351+ }
352+ fn visit_map < A > ( self , mut map : A ) -> Result < Self :: Value , A :: Error >
353+ where
354+ A : de:: MapAccess < ' de > ,
355+ {
356+ let tsr = TYPE_SERIAL_REGISTRY . pin ( ) ;
357+ while let Some ( id) = map. next_key :: < HexArray > ( ) ? {
358+ let Some ( loader) = tsr. get ( & u64:: from_le_bytes ( id. 0 ) ) else {
359+ return Err ( de:: Error :: custom ( "unknown type ID {:0>16x}" ) ) ;
360+ } ;
361+ map. next_value_seed ( FnDeserializer ( loader. load_header ) ) ?;
362+ }
363+ Ok ( ( ) )
364+ }
365+ }
366+ impl < ' de > DeserializeSeed < ' de > for CtxTypeSerde < ' _ , ' _ , ' _ > {
367+ type Value = ( ) ;
368+ fn deserialize < D > ( self , deserializer : D ) -> Result < Self :: Value , D :: Error >
369+ where
370+ D : Deserializer < ' de > ,
371+ {
372+ deserializer. deserialize_map ( self )
373+ }
374+ }
375+ impl Serialize for CompCtx < ' _ , ' _ > {
376+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
377+ where
378+ S : Serializer ,
379+ {
380+ use ser:: * ;
381+ SERIALIZATION_CONTEXT . with ( |c| {
382+ let p = unsafe {
383+ std:: mem:: transmute :: < * const CompCtx < ' _ , ' _ > , * const CompCtx < ' static , ' static > > (
384+ self as _ ,
385+ )
386+ } ;
387+ if let Some ( ptr) = c. replace ( Some ( ContextPointer :: new ( p) ) ) {
388+ if * ptr != p {
389+ panic ! ( "serialization context is already in use with an address of {ptr:#?}" ) ;
390+ }
391+ }
392+ } ) ;
393+ let mut map =
394+ serializer. serialize_struct ( "Context" , 3 + usize:: from ( self . flags . add_type_map ) ) ?;
395+ map. serialize_field ( "version" , & HEADER_FMT_VERSION ) ?;
396+ if self . flags . add_type_map {
397+ map. serialize_field (
398+ "names" ,
399+ & TYPE_SERIAL_REGISTRY
400+ . pin ( )
401+ . iter ( )
402+ . map ( |( k, v) | ( hex:: encode ( k. to_le_bytes ( ) ) , v. name ) )
403+ . collect :: < hashbrown:: HashMap < _ , _ > > ( ) ,
404+ ) ?;
405+ }
406+ map. serialize_field ( "types" , & CtxTypeSerde ( self ) ) ?;
407+ self . with_vars ( |v| map. serialize_field ( "vars" , v) ) ?;
408+ SERIALIZATION_CONTEXT . with ( |c| {
409+ c. replace ( None )
410+ . expect ( "serialization context is empty after serialization" )
411+ } ) ;
412+ map. end ( )
413+ }
414+ }
415+ #[ derive( Deserialize ) ]
416+ #[ serde( bound = "'a: 'de" ) ]
417+ struct ContextDeProxy < ' a > {
418+ version : u16 ,
419+ #[ serde( rename = "names" ) ]
420+ _names : Option < serde:: de:: IgnoredAny > , // if it gets into the serialization, ignore it for deserialization - it should be stable
421+ #[ serde( borrow = "'a" ) ]
422+ types : serde:: __private:: de:: Content < ' a > ,
423+ #[ serde( borrow = "'a" ) ]
424+ vars : serde:: __private:: de:: Content < ' a > ,
425+ }
426+ impl < ' de > DeserializeSeed < ' de > for & CompCtx < ' _ , ' _ > {
427+ type Value = Vec < String > ;
428+ fn deserialize < D > ( mut self , deserializer : D ) -> Result < Self :: Value , D :: Error >
429+ where
430+ D : Deserializer < ' de > ,
431+ {
432+ use de:: * ;
433+ SERIALIZATION_CONTEXT . with ( |c| {
434+ let p = unsafe {
435+ std:: mem:: transmute :: < * const CompCtx < ' _ , ' _ > , * const CompCtx < ' static , ' static > > (
436+ self as _ ,
437+ )
438+ } ;
439+ if let Some ( ptr) = c. replace ( Some ( ContextPointer :: new ( p) ) ) {
440+ if * ptr != p {
441+ panic ! ( "serialization context is already in use with an address of {ptr:#?}" ) ;
442+ }
443+ }
444+ } ) ;
445+ let proxy = ContextDeProxy :: deserialize ( deserializer) ?;
446+ if proxy. version != HEADER_FMT_VERSION {
447+ return Err ( D :: Error :: custom ( format ! ( "this header was saved with version {}, but version {HEADER_FMT_VERSION} is expected" , proxy. version) ) ) ;
448+ }
449+ CtxTypeSerde ( self )
450+ . deserialize ( serde:: __private:: de:: ContentDeserializer :: new ( proxy. types ) ) ?;
451+ let vars = VarMap :: deserialize_state (
452+ & mut self ,
453+ serde:: __private:: de:: ContentDeserializer :: new ( proxy. vars ) ,
454+ ) ?;
455+ SERIALIZATION_CONTEXT . with ( |c| {
456+ c. replace ( None )
457+ . expect ( "serialization context is empty after serialization" )
458+ } ) ;
459+ Ok ( self . with_vars ( |v| varmap:: merge ( & mut v. symbols , vars. symbols ) ) )
460+ }
461+ }
462+
463+ /// Wrapper around a context pointer, maybe with a backtace
464+ pub struct ContextPointer {
465+ ptr : * const CompCtx < ' static , ' static > ,
466+ #[ cfg( debug_assertions) ]
467+ trace : std:: backtrace:: Backtrace ,
468+ }
469+ impl ContextPointer {
470+ pub fn new ( ptr : * const CompCtx < ' static , ' static > ) -> Self {
471+ Self {
472+ trace : std:: backtrace:: Backtrace :: capture ( ) ,
473+ ptr,
474+ }
475+ }
476+ }
477+ impl std:: ops:: Deref for ContextPointer {
478+ type Target = * const CompCtx < ' static , ' static > ;
479+ fn deref ( & self ) -> & Self :: Target {
480+ & self . ptr
481+ }
482+ }
483+ impl Debug for ContextPointer {
484+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> fmt:: Result {
485+ write ! ( f, "{:p}" , self . ptr) ?;
486+ #[ cfg( debug_assertions) ]
487+ {
488+ if f. alternate ( ) {
489+ if self . trace . status ( ) == std:: backtrace:: BacktraceStatus :: Captured {
490+ write ! ( f, "at: \n {}" , self . trace) ?;
491+ } else {
492+ f. write_str ( "without backtrace" ) ?;
493+ }
494+ }
495+ }
496+ Ok ( ( ) )
497+ }
498+ }
499+ /// Get the context pointer from a cell
500+ /// Super unsafe lmao
501+ ///
502+ /// # Safety
503+ /// `SERIALIZATION_CONTEXT`` must be valid
504+ pub unsafe fn get_ctx_ptr < ' a , ' s , ' c > ( cell : & Cell < Option < ContextPointer > > ) -> & ' a CompCtx < ' s , ' c > {
505+ let opt = cell. replace ( None ) ;
506+ let cp = opt. expect ( "expected pointer in serialization context" ) ;
507+ let ptr = cp. ptr ;
508+ cell. set ( Some ( cp) ) ;
509+ #[ allow( clippy:: unnecessary_cast) ]
510+ & * std:: mem:: transmute :: < * const CompCtx < ' static , ' static > , * const CompCtx < ' s , ' c > > ( ptr)
511+ }
512+ thread_local ! {
513+ /// CompCtx, should only have a value during de/serialization
514+ pub static SERIALIZATION_CONTEXT : Cell <Option <ContextPointer >> = const { Cell :: new( None ) } ;
515+ }
0 commit comments