Skip to content

Wasm runtime: support unmarhalling compressed data #1898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* Runtime/wasm: support jsoo_env and keep track of backtrace status (#1881)
* Runtime: less conversion during un-marshalling (#1889)
* Compiler: improve performance of Javascript linking
* Runtime/wasm: support unmarshaling compressed data (#1898)

## Bug fixes
* Runtime: fix path normalization (#1848)
Expand Down
13 changes: 13 additions & 0 deletions compiler/tests-jsoo/test_marshal_compressed.ml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,17 @@ let%expect_test _ =
else String.make 10000 'c'
in
Printf.printf "%s ... (%d)\n" (String.sub s 0 20) (String.length s);
[%expect {| cccccccccccccccccccc ... (10000) |}];
let tmp = Filename.temp_file "a" "txt" in
let ch = open_out tmp in
output_string ch data;
close_out ch;
let ch = open_in tmp in
let s =
if Compression.compression_supported
then Marshal.from_channel ch
else String.make 10000 'c'
in
close_in ch;
Printf.printf "%s ... (%d)\n" (String.sub s 0 20) (String.length s);
[%expect {| cccccccccccccccccccc ... (10000) |}]
166 changes: 139 additions & 27 deletions runtime/wasm/marshal.wat
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@
(call $parse_header (local.get $s) (global.get $input_val_from_string)))
(if (i32.gt_s
(i32.add (local.get $ofs)
(i32.add (struct.get $marshal_header $data_len (local.get $h))
(i32.const 20)))
(i32.add (struct.get $marshal_header $header_len (local.get $h))
(struct.get $marshal_header $data_len (local.get $h))))
(array.len (local.get $str)))
(then
(call $bad_length (global.get $input_val_from_string))))
(call $decompress_input (local.get $s) (local.get $h)
(global.get $input_val_from_string))
(return_call $intern_rec (local.get $s) (local.get $h)))

(@string $truncated_obj "input_value: truncated object")
Expand All @@ -80,16 +82,37 @@

(func (export "caml_input_value") (param $ch (ref eq)) (result (ref eq))
;; ZZZ check binary channel?
(local $r i32) (local $len i32)
(local $r i32) (local $magic i32) (local $len i32)
(local $header (ref $bytes)) (local $buf (ref $bytes))
(local $s (ref $intern_state)) (local $h (ref $marshal_header))
(local.set $header (array.new $bytes (i32.const 0) (i32.const 20)))
(local.set $header (array.new $bytes (i32.const 0) (i32.const 55)))
(local.set $r
(call $caml_really_getblock
(local.get $ch) (local.get $header) (i32.const 0) (i32.const 20)))
(local.get $ch) (local.get $header) (i32.const 0) (i32.const 5)))
(if (i32.eqz (local.get $r))
(then (call $caml_raise_end_of_file)))
(if (i32.lt_u (local.get $r) (i32.const 20))
(if (i32.lt_u (local.get $r) (i32.const 5))
(then (call $caml_failwith (global.get $truncated_obj))))
(local.set $s
(call $get_intern_state (local.get $header) (i32.const 0)))
(local.set $magic (call $read32 (local.get $s)))
(if (i32.eq (local.get $magic) (global.get $Intext_magic_number_big))
(then (call $too_large (global.get $input_value))))
(if (i32.eq (local.get $magic) (global.get $Intext_magic_number_small))
(then (local.set $len (i32.const 15))))
(if (i32.eq (local.get $magic)
(global.get $Intext_magic_number_compressed))
(then
(local.set $len
(i32.sub
(i32.and (call $read8u (local.get $s)) (i32.const 0x3F))
(i32.const 5)))))
(if (i32.eqz (local.get $len))
(then (call $bad_object (global.get $marshal_data_size))))
(if (i32.lt_u
(call $caml_really_getblock (local.get $ch)
(local.get $header) (i32.const 5) (local.get $len))
(local.get $len))
(then (call $caml_failwith (global.get $truncated_obj))))
(local.set $s
(call $get_intern_state (local.get $header) (i32.const 0)))
Expand All @@ -103,6 +126,8 @@
(local.get $len))
(then (call $caml_failwith (global.get $truncated_obj))))
(local.set $s (call $get_intern_state (local.get $buf) (i32.const 0)))
(call $decompress_input (local.get $s) (local.get $h)
(global.get $input_value))
(return_call $intern_rec (local.get $s) (local.get $h)))

(type $block (array (mut (ref eq))))
Expand All @@ -111,6 +136,9 @@
(type $float_array (array (mut f64)))
(type $js (struct (field anyref)))

(type $decompress
(func (param (ref $bytes) i32 i32 i32) (result (ref $bytes))))

(type $compare
(func (param (ref eq)) (param (ref eq)) (param i32) (result i32)))
(type $hash
Expand All @@ -134,6 +162,8 @@

(global $Intext_magic_number_small i32 (i32.const 0x8495A6BE))
(global $Intext_magic_number_big i32 (i32.const 0x8495A6BF))
(global $Intext_magic_number_compressed i32 (i32.const 0x8495A6BD))


(global $PREFIX_SMALL_BLOCK i32 (i32.const 0x80))
(global $PREFIX_SMALL_INT i32 (i32.const 0x40))
Expand Down Expand Up @@ -163,15 +193,17 @@

(type $intern_state
(struct
(field $src (ref $bytes))
(field $src (mut (ref $bytes)))
(field $pos (mut i32))
(field $obj_table (mut (ref null $block)))
(field $obj_counter (mut i32))))
(field $obj_counter (mut i32))
(field $overflow (mut i32))))

(func $get_intern_state
(param $src (ref $bytes)) (param $pos i32) (result (ref $intern_state))
(struct.new $intern_state
(local.get $src) (local.get $pos) (ref.null $block) (i32.const 0)))
(local.get $src) (local.get $pos) (ref.null $block)
(i32.const 0) (i32.const 0)))

(func $read8u (param $s (ref $intern_state)) (result i32)
(local $pos i32) (local $res i32)
Expand Down Expand Up @@ -425,6 +457,30 @@
(call $caml_failwith (global.get $unknown_custom))
(ref.i31 (i32.const 0)))

(global $caml_intern_decompress_input (export "caml_intern_decompress_input")
(mut (ref null $decompress))
(ref.null $decompress))

(func $decompress_input
(param $s (ref $intern_state)) (param $h (ref $marshal_header))
(param $prim (ref eq))
(if (i32.eqz (struct.get $marshal_header $compressed (local.get $h)))
(then (return)))
(block $cannot_decompress
(struct.set $intern_state $src (local.get $s)
(call_ref $decompress
(struct.get $intern_state $src (local.get $s))
(struct.get $intern_state $pos (local.get $s))
(struct.get $marshal_header $data_len (local.get $h))
(struct.get $marshal_header $uncompressed_data_len (local.get $h))
(br_on_null $cannot_decompress
(global.get $caml_intern_decompress_input))))
(struct.set $intern_state $pos (local.get $s) (i32.const 0))
(return))
(call $caml_failwith
(call $caml_string_concat (local.get $prim)
(@string ": compressed object, cannot decompress"))))

(func $intern_rec
(param $s (ref $intern_state)) (param $h (ref $marshal_header))
(result (ref eq))
Expand Down Expand Up @@ -586,10 +642,12 @@
(br $done))
))))
;; read_shared
(local.set $ofs
(i32.sub
(struct.get $intern_state $obj_counter (local.get $s))
(local.get $ofs)))
(if (i32.eqz (struct.get $marshal_header $compressed (local.get $h)))
(then
(local.set $ofs
(i32.sub
(struct.get $intern_state $obj_counter (local.get $s))
(local.get $ofs)))))
(local.set $v
(array.get $block
(ref.as_non_null
Expand Down Expand Up @@ -665,28 +723,71 @@

(type $marshal_header
(struct
(field $header_len i32)
(field $data_len i32)
(field $num_objects i32)))
(field $uncompressed_data_len i32)
(field $num_objects i32)
(field $compressed i32)))

(func $readvlq (param $s (ref $intern_state)) (result i32)
(local $c i32) (local $n i32) (local $n7 i32)
(local.set $c (call $read8u (local.get $s)))
(local.set $n (i32.and (local.get $c) (i32.const 0x7F)))
(loop $loop
(if (i32.and (local.get $c) (i32.const 0x80))
(then
(local.set $c (call $read8u (local.get $s)))
(local.set $n7 (i32.shl (local.get $n) (i32.const 7)))
(if (i32.ne (local.get $n)
(i32.shr_u (local.get $n7) (i32.const 7)))
(then
(struct.set $intern_state $overflow (local.get $s)
(i32.const 1))))
(local.set $n
(i32.or (local.get $n7)
(i32.and (local.get $c) (i32.const 0x7f))))
(br $loop))))
(local.get $n))

(func $parse_header
(param $s (ref $intern_state)) (param $prim (ref eq))
(result (ref $marshal_header))
(local $magic i32)
(local $data_len i32) (local $num_objects i32) (local $whsize i32)
(local $magic i32) (local $header_len i32)
(local $data_len i32) (local $uncompressed_data_len i32)
(local $num_objects i32) (local $whsize i32) (local $compressed i32)
(local.set $magic (call $read32 (local.get $s)))
(if (i32.eq (local.get $magic) (global.get $Intext_magic_number_big))
(then
(call $too_large (local.get $prim))))
(if (i32.ne (local.get $magic) (global.get $Intext_magic_number_small))
(if (i32.eq (local.get $magic) (global.get $Intext_magic_number_small))
(then
(call $bad_object (local.get $prim))))
(local.set $data_len (call $read32 (local.get $s)))
(local.set $num_objects (call $read32 (local.get $s)))
(drop (call $read32 (local.get $s)))
(drop (call $read32 (local.get $s)))
(local.set $header_len (i32.const 20))
(local.set $data_len (call $read32 (local.get $s)))
(local.set $uncompressed_data_len (local.get $data_len))
(local.set $num_objects (call $read32 (local.get $s)))
(drop (call $read32 (local.get $s)))
(drop (call $read32 (local.get $s))))
(else (if (i32.eq (local.get $magic)
(global.get $Intext_magic_number_compressed))
(then
(local.set $header_len
(i32.and (call $read8u (local.get $s)) (i32.const 0x3F)))
(local.set $data_len (call $readvlq (local.get $s)))
(local.set $uncompressed_data_len (call $readvlq (local.get $s)))
(local.set $num_objects (call $readvlq (local.get $s)))
(drop (call $readvlq (local.get $s)))
(if (struct.get $intern_state $overflow (local.get $s))
(then (call $too_large (local.get $prim))))
(drop (call $readvlq (local.get $s)))
(local.set $compressed (i32.const 1)))
(else
(call $bad_object (local.get $prim))))))
(struct.new $marshal_header
(local.get $header_len)
(local.get $data_len)
(local.get $num_objects)))
(local.get $uncompressed_data_len)
(local.get $num_objects)
(local.get $compressed)))

(@string $marshal_data_size "Marshal.data_size")

Expand All @@ -703,19 +804,30 @@
(func (export "caml_marshal_data_size")
(param $buf (ref eq)) (param $ofs (ref eq)) (result (ref eq))
(local $s (ref $intern_state))
(local $magic i32)
(local $magic i32) (local $header_len i32)
(local.set $s
(call $get_intern_state
(ref.cast (ref $bytes) (local.get $buf))
(i31.get_u (ref.cast (ref i31) (local.get $ofs)))))
(local.set $magic (call $read32 (local.get $s)))
(if (i32.eq (local.get $magic) (global.get $Intext_magic_number_big))
(then (call $too_large (global.get $marshal_data_size))))
(if (i32.ne (local.get $magic) (global.get $Intext_magic_number_small))
(then (call $bad_object (global.get $marshal_data_size))))
(if (i32.eq (local.get $magic) (global.get $Intext_magic_number_small))
(then
(local.set $header_len (i32.const 20)))
(else (if (i32.eq (local.get $magic)
(global.get $Intext_magic_number_compressed))
(then
(local.set $header_len
(i32.and (call $read8u (local.get $s)) (i32.const 0x3F)))
(drop (call $readvlq (local.get $s)))
(if (struct.get $intern_state $overflow (local.get $s))
(then (call $too_large (global.get $marshal_data_size)))))
(else
(call $bad_object (global.get $marshal_data_size))))))
(ref.i31
(i32.add
(i32.sub (i32.const 20)
(i32.sub (local.get $header_len)
(global.get $caml_marshal_header_size))
(call $read32 (local.get $s)))))

Expand Down
42 changes: 41 additions & 1 deletion runtime/wasm/zstd.wat
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,47 @@
;; Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.

(module
(@if (>= ocaml_version (5 1 0))
(@then
(import "bindings" "ta_new" (func $ta_new (param i32) (result (ref extern))))
(import "bindings" "ta_blit_from_bytes"
(func $ta_blit_from_bytes
(param (ref $bytes)) (param i32) (param (ref extern)) (param i32)
(param i32)))
(import "bindings" "ta_blit_to_bytes"
(func $ta_blit_to_bytes
(param (ref extern)) (param i32) (param (ref $bytes)) (param i32)
(param i32)))
(import "marshal" "caml_intern_decompress_input"
(global $caml_intern_decompress_input (mut (ref null $decompress))))
(import "js" "zstd_decompress"
(func $zstd_decompress (param (ref extern)) (param (ref extern))))

(type $bytes (array (mut i8)))
(type $decompress
(func (param (ref $bytes) i32 i32 i32) (result (ref $bytes))))

(func $decompress
(param $input (ref $bytes)) (param $pos i32) (param $len i32)
(param $out_len i32) (result (ref $bytes))
(local $in_buf (ref extern)) (local $out_buf (ref extern))
(local $output (ref $bytes))
(local.set $in_buf (call $ta_new (local.get $len)))
(local.set $out_buf (call $ta_new (local.get $out_len)))
(call $ta_blit_from_bytes
(local.get $input) (local.get $pos)
(local.get $in_buf) (i32.const 0)
(local.get $len))
(call $zstd_decompress (local.get $in_buf) (local.get $out_buf))
(local.set $output (array.new $bytes (i32.const 0) (local.get $out_len)))
(call $ta_blit_to_bytes
(local.get $out_buf) (i32.const 0)
(local.get $output) (i32.const 0)
(array.len (local.get $output)))
(local.get $output))

(func (export "caml_zstd_initialize") (param (ref eq)) (result (ref eq))
(ref.i31 (i32.const 0)))
(global.set $caml_intern_decompress_input (ref.func $decompress))
(ref.i31 (i32.const 1)))
))
)
Loading