Skip to content

Commit

Permalink
Optimized blitting between strings and bigstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
vouillon committed Oct 7, 2024
1 parent e005fbb commit 6f790a7
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 68 deletions.
48 changes: 29 additions & 19 deletions runtime/wasm/bigarray.wat
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@
(import "bindings" "ta_subarray"
(func $ta_subarray
(param (ref extern)) (param i32) (param i32) (result (ref extern))))
(import "bindings" "ta_blit_from_string"
(func $ta_blit_from_string
(param (ref $string)) (param i32) (param (ref extern)) (param i32)
(param i32)))
(import "bindings" "ta_blit_to_string"
(func $ta_blit_to_string
(param (ref extern)) (param i32) (param (ref $string)) (param i32)
(param i32)))
(import "fail" "caml_bound_error" (func $caml_bound_error))
(import "fail" "caml_raise_out_of_memory" (func $caml_raise_out_of_memory))
(import "fail" "caml_invalid_argument"
Expand Down Expand Up @@ -2016,43 +2024,33 @@
(func $caml_string_of_array (export "caml_string_of_array")
(param (ref eq)) (result (ref eq))
;; used to convert a typed array to a string
(local $a (ref extern)) (local $len i32) (local $i i32)
(local $a (ref extern)) (local $len i32)
(local $s (ref $string))
(local.set $a
(ref.as_non_null (extern.convert_any (call $unwrap (local.get 0)))))
(local.set $len (call $ta_length (local.get $a)))
(local.set $s (array.new $string (i32.const 0) (local.get $len)))
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
(then
(array.set $string (local.get $s) (local.get $i)
(call $ta_get_ui8 (local.get $a) (local.get $i)))
(local.set $i (i32.add (local.get $i) (i32.const 1)))
(br $loop))))
(call $ta_blit_to_string
(local.get $a) (i32.const 0) (local.get $s) (i32.const 0)
(local.get $len))
(local.get $s))

(export "caml_uint8_array_of_bytes" (func $caml_uint8_array_of_string))
(func $caml_uint8_array_of_string (export "caml_uint8_array_of_string")
(param (ref eq)) (result (ref eq))
;; Convert a string to a typed array
(local $ta (ref extern)) (local $len i32) (local $i i32)
(local $ta (ref extern)) (local $len i32)
(local $s (ref $string))
(local.set $s (ref.cast (ref $string) (local.get 0)))
(local.set $len (array.len (local.get $s)))
(local.set $ta
(call $ta_create
(i32.const 3) ;; Uint8Array
(local.get $len)))
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
(then
(call $ta_set_ui8
(local.get $ta)
(local.get $i)
(ref.i31 (array.get $string (local.get $s) (local.get $i))))
(local.set $i (i32.add (local.get $i) (i32.const 1)))
(br $loop))))
(call $wrap (extern.internalize (local.get $ta))))
(call $ta_blit_from_string
(local.get $s) (i32.const 0) (local.get $ta) (i32.const 0)
(local.get $len))
(call $wrap (any.convert_extern (local.get $ta))))

(func (export "caml_ba_get_kind") (param (ref eq)) (result i32)
(struct.get $bigarray $ba_kind (ref.cast (ref $bigarray) (local.get 0))))
Expand Down Expand Up @@ -2082,4 +2080,16 @@
(local.get $num_dims)
(local.get $kind)
(local.get $layout)))

(func (export "string_set")
(param $s externref) (param $i i32) (param $v i32)
(array.set $string
(ref.cast (ref null $string) (any.convert_extern (local.get $s)))
(local.get $i) (local.get $v)))

(func (export "string_get")
(param $s externref) (param $i i32) (result i32)
(array.get $string
(ref.cast (ref null $string) (any.convert_extern (local.get $s)))
(local.get $i)))
)
39 changes: 18 additions & 21 deletions runtime/wasm/bigstring.wat
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@
(func $ta_length (param (ref extern)) (result i32)))
(import "bindings" "ta_bytes"
(func $ta_bytes (param anyref) (result anyref)))
(import "bindings" "ta_blit_from_string"
(func $ta_blit_from_string
(param (ref $string)) (param i32) (param (ref extern)) (param i32)
(param i32)))
(import "bindings" "ta_blit_to_string"
(func $ta_blit_to_string
(param (ref extern)) (param i32) (param (ref $string)) (param i32)
(param i32)))
(import "hash" "caml_hash_mix_int"
(func $caml_hash_mix_int (param i32) (param i32) (result i32)))

Expand Down Expand Up @@ -202,47 +210,36 @@
(param $str1 (ref eq)) (param $vpos1 (ref eq))
(param $ba2 (ref eq)) (param $vpos2 (ref eq))
(param $vlen (ref eq)) (result (ref eq))
(local $i i32) (local $pos1 i32) (local $pos2 i32) (local $len i32)
(local $pos1 i32) (local $pos2 i32) (local $len i32)
(local $s1 (ref $string))
(local $d2 (ref extern))
(local.set $s1 (ref.cast (ref $string) (local.get $str1)))
(local.set $pos1 (i31.get_s (ref.cast (ref i31) (local.get $vpos1))))
(local.set $d2 (call $caml_ba_get_data (local.get $ba2)))
(local.set $pos2 (i31.get_s (ref.cast (ref i31) (local.get $vpos2))))
(local.set $len (i31.get_s (ref.cast (ref i31) (local.get $vlen))))
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
(then
(call $ta_set_ui8 (local.get $d2)
(i32.add (local.get $pos2) (local.get $i))
(ref.i31
(array.get_u $string (local.get $s1)
(i32.add (local.get $pos1) (local.get $i)))))
(local.set $i (i32.add (local.get $i) (i32.const 1)))
(br $loop))))
(call $ta_blit_from_string
(local.get $s1) (local.get $pos1)
(local.get $d2) (local.get $pos2)
(local.get $len))
(ref.i31 (i32.const 0)))

(func (export "caml_bigstring_blit_ba_to_bytes")
(param $ba1 (ref eq)) (param $vpos1 (ref eq))
(param $str2 (ref eq)) (param $vpos2 (ref eq))
(param $vlen (ref eq)) (result (ref eq))
(local $i i32) (local $pos1 i32) (local $pos2 i32) (local $len i32)
(local $pos1 i32) (local $pos2 i32) (local $len i32)
(local $d1 (ref extern))
(local $s2 (ref $string))
(local.set $d1 (call $caml_ba_get_data (local.get $ba1)))
(local.set $pos1 (i31.get_s (ref.cast (ref i31) (local.get $vpos1))))
(local.set $s2 (ref.cast (ref $string) (local.get $str2)))
(local.set $pos2 (i31.get_s (ref.cast (ref i31) (local.get $vpos2))))
(local.set $len (i31.get_s (ref.cast (ref i31) (local.get $vlen))))
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
(then
(array.set $string (local.get $s2)
(i32.add (local.get $pos2) (local.get $i))
(call $ta_get_ui8 (local.get $d1)
(i32.add (local.get $pos1) (local.get $i))))
(local.set $i (i32.add (local.get $i) (i32.const 1)))
(br $loop))))
(call $ta_blit_to_string
(local.get $d1) (local.get $pos1)
(local.get $s2) (local.get $pos2)
(local.get $len))
(ref.i31 (i32.const 0)))

(func (export "caml_bigstring_blit_ba_to_ba")
Expand Down
10 changes: 9 additions & 1 deletion runtime/wasm/deps.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"name": "root",
"reaches": ["init", "exn", "mem", "strings"],
"reaches": ["init", "exn", "mem", "strings", "string_get", "string_set"],
"root": true
},
{
Expand All @@ -20,6 +20,14 @@
"name": "strings",
"export": "caml_extract_string"
},
{
"name": "string_get",
"export": "string_get"
},
{
"name": "string_set",
"export": "string_set"
},
{
"name": "callback",
"export": "caml_callback"
Expand Down
42 changes: 15 additions & 27 deletions runtime/wasm/io.wat
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@
(func $ta_set_ui8 (param (ref extern)) (param i32) (param i32))) ;; ZZZ ??
(import "bindings" "ta_get_ui8"
(func $ta_get_ui8 (param (ref extern)) (param i32) (result i32)))
(import "bindings" "ta_blit_from_string"
(func $ta_blit_from_string
(param (ref $string)) (param i32) (param (ref extern)) (param i32)
(param i32)))
(import "bindings" "ta_blit_to_string"
(func $ta_blit_to_string
(param (ref extern)) (param i32) (param (ref $string)) (param i32)
(param i32)))
(import "custom" "custom_compare_id"
(func $custom_compare_id
(param (ref eq)) (param (ref eq)) (param i32) (result i32)))
Expand Down Expand Up @@ -330,20 +338,6 @@
(i64.add (local.get $offset) (i64.extend_i32_u (local.get $n))))
(local.get $n))

(func $copy_from_buffer
(param $buf (ref extern)) (param $curr i32)
(param $s (ref $string)) (param $pos i32) (param $len i32)
(local $i i32)
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
(then
(array.set $string (local.get $s)
(i32.add (local.get $pos) (local.get $i))
(call $ta_get_ui8 (local.get $buf)
(i32.add (local.get $curr) (local.get $i))))
(local.set $i (i32.add (local.get $i) (i32.const 1)))
(br $loop)))))

(func $caml_refill (param $ch (ref $channel)) (result i32)
(local $n i32)
(local $buf (ref extern))
Expand Down Expand Up @@ -374,7 +368,7 @@
(then
(if (i32.gt_u (local.get $len) (local.get $avail))
(then (local.set $len (local.get $avail))))
(call $copy_from_buffer
(call $ta_blit_to_string
(struct.get $channel $buffer (local.get $ch))
(struct.get $channel $curr (local.get $ch))
(local.get $s) (local.get $pos)
Expand All @@ -389,7 +383,7 @@
(struct.set $channel $max (local.get $ch) (local.get $nread))
(if (i32.gt_u (local.get $len) (local.get $nread))
(then (local.set $len (local.get $nread))))
(call $copy_from_buffer
(call $ta_blit_to_string
(struct.get $channel $buffer (local.get $ch))
(i32.const 0)
(local.get $s) (local.get $pos)
Expand Down Expand Up @@ -445,7 +439,7 @@
(local.set $curr (i32.const 0))
(if (i32.gt_u (local.get $len) (local.get $nread))
(then (local.set $len (local.get $nread))))))))
(call $copy_from_buffer
(call $ta_blit_to_string
(local.get $buf) (local.get $curr)
(local.get $s) (local.get $pos) (local.get $len))
(struct.set $channel $curr (local.get $ch)
Expand Down Expand Up @@ -730,23 +724,17 @@
(func $caml_putblock
(param $ch (ref $channel)) (param $s (ref $string)) (param $pos i32)
(param $len i32) (result i32)
(local $free i32) (local $curr i32) (local $i i32)
(local $free i32) (local $curr i32)
(local $buf (ref extern))
(local.set $curr (struct.get $channel $curr (local.get $ch)))
(local.set $free
(i32.sub (struct.get $channel $size (local.get $ch)) (local.get $curr)))
(if (i32.ge_u (local.get $len) (local.get $free))
(then (local.set $len (local.get $free))))
(local.set $buf (struct.get $channel $buffer (local.get $ch)))
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
(then
(call $ta_set_ui8 (local.get $buf)
(i32.add (local.get $curr) (local.get $i))
(array.get_u $string (local.get $s)
(i32.add (local.get $pos) (local.get $i))))
(local.set $i (i32.add (local.get $i) (i32.const 1)))
(br $loop))))
(call $ta_blit_from_string
(local.get $s) (local.get $pos)
(local.get $buf) (local.get $curr) (local.get $len))
(struct.set $channel $curr (local.get $ch)
(i32.add (local.get $curr) (local.get $len)))
(if (i32.ge_u (local.get $len) (local.get $free))
Expand Down
8 changes: 8 additions & 0 deletions runtime/wasm/runtime.js
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@
ta_copy: (ta, t, s, n) => ta.copyWithin(t, s, n),
ta_bytes: (a) =>
new Uint8Array(a.buffer, a.byteOffset, a.length * a.BYTES_PER_ELEMENT),
ta_blit_from_string: (s, p1, a, p2, l) => {
for (let i = 0; i < l; i++) a[p2 + i] = string_get(s, p1 + i);
},
ta_blit_to_string: (a, p1, s, p2, l) => {
for (let i = 0; i < l; i++) string_set(s, p2 + i, a[p1 + i]);
},
wrap_callback: (f) =>
function () {
var n = arguments.length;
Expand Down Expand Up @@ -537,6 +543,8 @@
caml_handle_uncaught_exception,
caml_buffer,
caml_extract_string,
string_get,
string_set,
_initialize,
} = wasmModule.instance.exports;

Expand Down

0 comments on commit 6f790a7

Please sign in to comment.