Skip to content

Commit

Permalink
[ new ] atomic modification of arrays using CAS loops (#51)
Browse files Browse the repository at this point in the history
* [ new ] atomic modification of arrays using CAS loops

* [ test ] preparing CAS tests

* [ test ] single threaded CAS updates

* [ test ] CAS updates under contention
  • Loading branch information
stefan-hoeck authored Jan 6, 2025
1 parent 8a5346c commit 3f14963
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 10 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci-lib.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ jobs:
run: pack --cg node test array -n 1000
- name: Build docs
run: pack typecheck array-docs
- name: Test concurrent counter
run: pack exec test/src/Concurrent.idr
- name: Test concurrent queue
run: pack exec test/src/ConcurrentQueue.idr
64 changes: 64 additions & 0 deletions src/Data/Array/Core.idr
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ prim__emptyArray : Bits32 -> PrimIO AnyPtr
%extern prim__arrayGet : forall a . AnyPtr -> Bits32 -> PrimIO a
%extern prim__arraySet : forall a . AnyPtr -> Bits32 -> a -> PrimIO ()

%foreign "scheme:(lambda (a x i v w) (if (vector-cas! x i v w) 1 0))"
"javascript:lambda:(a,x,i,v,w) => {if (x[i] === v) {x[i] = w; return 1;} else {return 0;}}"
prim__casSet : AnyPtr -> Bits32 -> (prev,val : a) -> Bits8

--------------------------------------------------------------------------------
-- Immutable Arrays
--------------------------------------------------------------------------------
Expand Down Expand Up @@ -133,6 +137,66 @@ export %inline
modify : (r : MArray' t n a) -> (0 p : Res r rs) => Fin n -> (a -> a) -> F1' rs
modify r ix f t = let v # t1 := get r ix t in set r ix (f v) t1

||| Atomically writes `val` at the given position of the mutable array
||| if its current value is equal to `pre`.
|||
||| This is supported and has been tested on the Chez and Racket backends.
||| It trivially works on the JavaScript backends, which are single-threaded
||| anyway.
export %inline
casset :
(r : MArray' t n a)
-> {auto 0 p : Res r rs}
-> Fin n
-> (pre,val : a)
-> F1 rs Bool
casset (MA arr) x pre val t =
case prim__casSet arr (cast $ finToNat x) pre val of
0 => False # t
_ => True # t

||| Atomic modification of an array position using a CAS-loop internally.
|||
||| This is supported and has been tested on the Chez and Racket backends.
||| It trivially works on the JavaScript backends, which are single-threaded
||| anyway.
export
casupdate :
(r : MArray' t n a)
-> Fin n
-> (a -> (a,b))
-> {auto 0 p : Res r rs}
-> F1 rs b
casupdate r x f t = assert_total (loop t)
where
covering loop : F1 rs b
loop t =
let cur # t := get r x t
(new,v) := f cur
True # t := casset r x cur new t | _ # t => loop t
in v # t

||| Atomic modification of an array position reference using a CAS-loop
||| internally.
|||
||| This is supported and has been tested on the Chez and Racket backends.
||| It trivially works on the JavaScript backends, which are single-threaded
||| anyway.
export
casmodify :
(r : MArray' t n a)
-> Fin n
-> (a -> a)
-> {auto 0 p : Res r rs}
-> F1' rs
casmodify r x f t = assert_total (loop t)
where
covering loop : F1' rs
loop t =
let cur # t := get r x t
True # t := casset r x cur (f cur) t | _ # t => loop t
in () # t

||| Wraps a mutable array in a shorter one.
export %inline
mtake :
Expand Down
59 changes: 49 additions & 10 deletions test/src/Array.idr
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ module Array

import Control.Monad.Identity
import Data.Array
import Data.SOP
import Data.Array.Mutable
import Data.List
import Data.List.Quantifiers
import Data.SnocList
import Data.Vect
import Hedgehog
Expand All @@ -22,17 +24,17 @@ prop_eq_refl = property $ do

prop_eq_sym : Property
prop_eq_sym = property $ do
[vs,ws] <- forAll $ np [arrBits,arrBits]
[vs,ws] <- forAll $ hlist [arrBits,arrBits]
(vs == ws) === (ws == vs)

prop_eq_trans : Property
prop_eq_trans = property $ do
[us,vs,ws] <- forAll $ np [arrBits,arrBits,arrBits]
[us,vs,ws] <- forAll $ hlist [arrBits,arrBits,arrBits]
when (us == vs && vs == ws) (us === ws)

prop_eq_eq : Property
prop_eq_eq = property $ do
[vs,ws] <- forAll $ np [arrBits,arrBits]
[vs,ws] <- forAll $ hlist [arrBits,arrBits]
when (vs == ws) $ do
assert (vs <= ws)
assert (vs >= ws)
Expand All @@ -42,13 +44,13 @@ prop_eq_eq = property $ do

prop_eq_neq : Property
prop_eq_neq = property $ do
[vs,ws] <- forAll $ np [arrBits,arrBits]
[vs,ws] <- forAll $ hlist [arrBits,arrBits]
when (vs /= ws) $ do
assert (vs < ws || ws < vs)

prop_lt : Property
prop_lt = property $ do
[vs,ws] <- forAll $ np [arrBits,arrBits]
[vs,ws] <- forAll $ hlist [arrBits,arrBits]
((vs < ws) === (ws > vs))
when (vs < ws) $ do
assert (vs /= ws)
Expand All @@ -57,7 +59,7 @@ prop_lt = property $ do

prop_lte : Property
prop_lte = property $ do
[vs,ws] <- forAll $ np [arrBits,arrBits]
[vs,ws] <- forAll $ hlist [arrBits,arrBits]
((vs <= ws) === (ws >= vs))

prop_map_id : Property
Expand Down Expand Up @@ -88,7 +90,7 @@ prop_foldl = property $ do
prop_foldr : Property
prop_foldr = property $ do
vs <- forAll arrBits
foldr (::) [] vs === foldr (::) [] (toList vs)
foldr (::) Prelude.Nil vs === foldr (::) [] (toList vs)

prop_null : Property
prop_null = property $ do
Expand Down Expand Up @@ -138,12 +140,12 @@ prop_traverse_id = property $ do

prop_append : Property
prop_append = property $ do
[x,y] <- forAll $ np [arrBits,arrBits]
[x,y] <- forAll $ hlist [arrBits,arrBits]
toList (x <+> y) === (toList x ++ toList y)

prop_semigroup_assoc : Property
prop_semigroup_assoc = property $ do
[x,y,z] <- forAll $ np [arrBits,arrBits,arrBits]
[x,y,z] <- forAll $ hlist [arrBits,arrBits,arrBits]
(x <+> (y <+> z)) === ((x <+> y) <+> z)

prop_monoid_left_neutral : Property
Expand All @@ -156,6 +158,39 @@ prop_monoid_right_neutral = property $ do
x <- forAll arrBits
(x <+> empty) === x

casWriteGet :
(r : MArray' t 3 a)
-> (pre,new : a)
-> F1 [r] (Bool,a)
casWriteGet r pre new t =
let b # t := casset r 2 pre new t
v # t := Core.get r 2 t
in (b,v) # t

prop_casset : Property
prop_casset =
property $ do
[x,y] <- forAll $ hlist [anyBits8, anyBits8]
(True,y) === alloc 3 x (\r => casWriteGet r x y)

prop_casset_diff : Property
prop_casset_diff =
property $ do
[x,y] <- forAll $ hlist [anyBits8, anyBits8]
(False,x) === alloc 3 x (\r => casWriteGet r (x+1) y)

prop_casupdate : Property
prop_casupdate =
property $ do
[x,y] <- forAll $ hlist [anyBits8, anyBits8]
x === alloc 3 x (\r => casupdate r 2 (\v => (v+y,v)))

prop_casmodify : Property
prop_casmodify =
property $ do
[x,y] <- forAll $ hlist [anyBits8, anyBits8]
(x+y) === alloc 3 x (\r,t => let _ # t := casmodify r 2 (+y) t in get r 2 t)

export
props : Group
props = MkGroup "Array"
Expand Down Expand Up @@ -185,5 +220,9 @@ props = MkGroup "Array"
, ("prop_semigroup_assoc", prop_semigroup_assoc)
, ("prop_monoid_left_neutral", prop_monoid_left_neutral)
, ("prop_monoid_right_neutral", prop_monoid_right_neutral)
, ("prop_casset", prop_casset)
, ("prop_casset_diff", prop_casset_diff)
, ("prop_casupdate", prop_casupdate)
, ("prop_casmodify", prop_casmodify)
]

50 changes: 50 additions & 0 deletions test/src/Concurrent.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
module Concurrent

import Data.Vect as V
import Data.Array
import Data.Array.Mutable
import System.Concurrency
import System

%default total

public export
ITER : Nat
ITER = 1_000_000

data Prog = Unsafe | CAS | Mut

inc : (r : IOArray 1 Nat) -> F1' [World]
inc r = modify r 0 S

casinc : (r : IOArray 1 Nat) -> F1' [World]
casinc r = casmodify r 0 S

mutinc : Mutex -> IOArray 1 Nat -> Nat -> IO ()
mutinc m r 0 = pure ()
mutinc m r (S k) = do
mutexAcquire m
runIO (inc r)
mutexRelease m
mutinc m r k

prog : Prog -> Mutex -> IOArray 1 Nat -> IO ()
prog Unsafe m ref = runIO (forN ITER $ inc ref)
prog CAS m ref = runIO (forN ITER $ casinc ref)
prog Mut m ref = mutinc m ref ITER

runProg : Prog -> Nat -> IO Nat
runProg prg n = do
mut <- makeMutex
ref <- newIOArray 1 Z
ts <- sequence $ V.replicate n (fork $ prog prg mut ref)
traverse_ (\t => threadWait t) ts
runIO (get ref 0)

main : IO ()
main = do
u <- runProg Unsafe 4
c <- runProg CAS 4
when (u >= c) (die "no race condition")
when (c /= 4 * ITER) (die "CAS synchronization failed")
putStrLn "Concurrent counter succeeded!"
64 changes: 64 additions & 0 deletions test/src/ConcurrentQueue.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
module ConcurrentQueue

import Data.Queue
import Data.Vect as V
import Data.Array
import Data.Array.Mutable
import System.Concurrency
import System

%default total

record State where
constructor ST
cur : Nat
queue : Queue Nat

next : State -> State
next (ST n q) = ST (S n) (enqueue q n)

ITER : Nat
ITER = 10_000

DELAY : Nat
DELAY = 100_000

data Prog = Unsafe | CAS | Mut

inc : (r : IOArray 1 State) -> Nat -> F1' [World]
inc r 0 = modify r 0 next
inc r (S k) = inc r k

casinc : (r : IOArray 1 State) -> Nat -> F1' [World]
casinc r 0 = casmodify r 0 next
casinc r (S k) = casinc r k

mutinc : Mutex -> IOArray 1 State -> Nat -> Nat -> IO ()
mutinc m r n (S k) = mutinc m r n k
mutinc m r 0 0 = pure ()
mutinc m r (S k) 0 = do
mutexAcquire m
runIO (inc r 0)
mutexRelease m
mutinc m r k DELAY

prog : Prog -> Mutex -> IOArray 1 State -> IO ()
prog Unsafe m ref = runIO (forN ITER $ inc ref DELAY)
prog CAS m ref = runIO (forN ITER $ casinc ref DELAY)
prog Mut m ref = mutinc m ref ITER DELAY

runProg : Prog -> Nat -> IO (List Nat)
runProg prg n = do
mut <- makeMutex
ref <- newIOArray 1 (ST 0 empty)
ts <- sequence $ V.replicate n (fork $ prog prg mut ref)
traverse_ (\t => threadWait t) ts
toList . queue <$> runIO (get ref 0)

main : IO ()
main = do
us <- runProg Unsafe 4
cs <- runProg CAS 4
when (length us >= length cs) (die "no race condition")
when (cs /= [0 .. pred (4 * ITER)]) (die "CAS synchronization failed")
putStrLn "Concurrent queue succeeded!"
1 change: 1 addition & 0 deletions test/test.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ executable = "array-test"
sourcedir = "src"

depends = array
, containers
, hedgehog

0 comments on commit 3f14963

Please sign in to comment.