Skip to content

Scan by key #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 96 additions & 37 deletions src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ namespace GraphBLAS.FSharp.Backend.Common
open Brahma.FSharp
open FSharp.Quotations
open GraphBLAS.FSharp.Backend.Quotes
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
open GraphBLAS.FSharp.Backend.Objects.ClCell

module PrefixSum =
let private update (opAdd: Expr<'a -> 'a -> 'a>) (clContext: ClContext) workGroupSize =
Expand Down Expand Up @@ -38,7 +40,7 @@ module PrefixSum =
)

processor.Post(Msg.CreateRunMsg<_, _> kernel)
processor.Post(Msg.CreateFreeMsg(mirror))
mirror.Free processor

let private scanGeneral
beforeLocalSumClear
Expand All @@ -48,10 +50,8 @@ module PrefixSum =
workGroupSize
=

let subSum = SubSum.treeSum opAdd

let scan =
<@ fun (ndRange: Range1D) inputArrayLength verticesLength (resultBuffer: ClArray<'a>) (verticesBuffer: ClArray<'a>) (totalSumBuffer: ClCell<'a>) (zero: ClCell<'a>) (mirror: ClCell<bool>) ->
<@ fun (ndRange: Range1D) inputArrayLength verticesLength (inputArray: ClArray<'a>) (verticesBuffer: ClArray<'a>) (totalSumBuffer: ClCell<'a>) (zero: ClCell<'a>) (mirror: ClCell<bool>) ->

let mirror = mirror.Value

Expand All @@ -62,46 +62,34 @@ module PrefixSum =
if mirror then
i <- inputArrayLength - 1 - i

let localID = ndRange.LocalID0
let lid = ndRange.LocalID0

let zero = zero.Value

if gid < inputArrayLength then
resultLocalBuffer.[localID] <- resultBuffer.[i]
resultLocalBuffer.[lid] <- inputArray.[i]
else
resultLocalBuffer.[localID] <- zero
resultLocalBuffer.[lid] <- zero

barrierLocal ()

(%subSum) workGroupSize localID resultLocalBuffer

if localID = workGroupSize - 1 then
if verticesLength <= 1 && localID = gid then
totalSumBuffer.Value <- resultLocalBuffer.[localID]

verticesBuffer.[gid / workGroupSize] <- resultLocalBuffer.[localID]
(%beforeLocalSumClear) resultBuffer resultLocalBuffer.[localID] inputArrayLength gid i
resultLocalBuffer.[localID] <- zero
// Local tree reduce
(%SubSum.upSweep opAdd) workGroupSize lid resultLocalBuffer

let mutable step = workGroupSize
if lid = workGroupSize - 1 then
// if last iteration
if verticesLength <= 1 && lid = gid then
totalSumBuffer.Value <- resultLocalBuffer.[lid]

while step > 1 do
barrierLocal ()
verticesBuffer.[gid / workGroupSize] <- resultLocalBuffer.[lid]
(%beforeLocalSumClear) inputArray resultLocalBuffer.[lid] inputArrayLength gid i
resultLocalBuffer.[lid] <- zero

if localID < workGroupSize / step then
let i = step * (localID + 1) - 1
let j = i - (step >>> 1)

let tmp = resultLocalBuffer.[i]
let buff = (%opAdd) tmp resultLocalBuffer.[j]
resultLocalBuffer.[i] <- buff
resultLocalBuffer.[j] <- tmp

step <- step >>> 1
(%SubSum.downSweep opAdd) workGroupSize lid resultLocalBuffer

barrierLocal ()

(%writeData) resultBuffer resultLocalBuffer inputArrayLength workGroupSize gid i localID @>
(%writeData) inputArray resultLocalBuffer inputArrayLength workGroupSize gid i lid @>

let program = clContext.Compile(scan)

Expand Down Expand Up @@ -132,13 +120,14 @@ module PrefixSum =
)

processor.Post(Msg.CreateRunMsg<_, _> kernel)
processor.Post(Msg.CreateFreeMsg(zero))
processor.Post(Msg.CreateFreeMsg(mirror))

zero.Free processor
mirror.Free processor

let private scanExclusive<'a when 'a: struct> =
scanGeneral
<@ fun (_: ClArray<'a>) (_: 'a) (_: int) (_: int) (_: int) -> () @>
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (smth: int) (gid: int) (i: int) (localID: int) ->
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (_: int) (gid: int) (i: int) (localID: int) ->

if gid < inputArrayLength then
resultBuffer.[i] <- resultLocalBuffer.[localID] @>
Expand Down Expand Up @@ -206,8 +195,8 @@ module PrefixSum =
verticesArrays <- swap verticesArrays
verticesLength <- (verticesLength - 1) / workGroupSize + 1

processor.Post(Msg.CreateFreeMsg(firstVertices))
processor.Post(Msg.CreateFreeMsg(secondVertices))
firstVertices.Free processor
secondVertices.Free processor

totalSum

Expand All @@ -226,7 +215,7 @@ module PrefixSum =
/// <code>
/// let arr = [| 1; 1; 1; 1 |]
/// let sum = [| 0 |]
/// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
/// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
/// |> ignore
/// ...
/// > val arr = [| 0; 1; 2; 3 |]
Expand All @@ -252,7 +241,7 @@ module PrefixSum =
/// <code>
/// let arr = [| 1; 1; 1; 1 |]
/// let sum = [| 0 |]
/// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
/// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
/// |> ignore
/// ...
/// > val arr = [| 1; 2; 3; 4 |]
Expand All @@ -270,3 +259,73 @@ module PrefixSum =
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<int>) ->

scan processor inputArray 0

module ByKey =
let private sequentialSegments opWrite (clContext: ClContext) workGroupSize opAdd zero =

let kernel =
<@ fun (ndRange: Range1D) lenght uniqueKeysCount (values: ClArray<'a>) (keys: ClArray<int>) (offsets: ClArray<int>) ->
let gid = ndRange.GlobalID0

if gid < uniqueKeysCount then
let sourcePosition = offsets.[gid]
let sourceKey = keys.[sourcePosition]

let mutable currentSum = zero
let mutable previousSum = zero

let mutable currentPosition = sourcePosition

while currentPosition < lenght
&& keys.[currentPosition] = sourceKey do

previousSum <- currentSum
currentSum <- (%opAdd) currentSum values.[currentPosition]

values.[currentPosition] <- (%opWrite) previousSum currentSum

currentPosition <- currentPosition + 1 @>

let kernel = clContext.Compile kernel

fun (processor: MailboxProcessor<_>) uniqueKeysCount (values: ClArray<'a>) (keys: ClArray<int>) (offsets: ClArray<int>) ->

let kernel = kernel.GetKernel()

let ndRange =
Range1D.CreateValid(values.Length, workGroupSize)

processor.Post(
Msg.MsgSetArguments
(fun () -> kernel.KernelFunc ndRange values.Length uniqueKeysCount values keys offsets)
)

processor.Post(Msg.CreateRunMsg<_, _> kernel)

/// <summary>
/// Exclude scan by key.
/// </summary>
/// <example>
/// <code>
/// let arr = [| 1; 1; 1; 1; 1; 1|]
/// let keys = [| 1; 2; 2; 2; 3; 3 |]
/// ...
/// > val result = [| 0; 0; 1; 2; 0; 1 |]
/// </code>
/// </example>
let sequentialExclude clContext =
sequentialSegments (Map.fst ()) clContext

/// <summary>
/// Include scan by key.
/// </summary>
/// <example>
/// <code>
/// let arr = [| 1; 1; 1; 1; 1; 1|]
/// let keys = [| 1; 2; 2; 2; 3; 3 |]
/// ...
/// > val result = [| 1; 1; 2; 3; 1; 2 |]
/// </code>
/// </example>
let sequentialInclude clContext =
sequentialSegments (Map.snd ()) clContext
4 changes: 4 additions & 0 deletions src/GraphBLAS-sharp.Backend/Quotes/Map.fs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ module Map =
match (%map) item with
| Some _ -> 1
| None -> 0 @>

let fst () = <@ fun fst _ -> fst @>

let snd () = <@ fun _ snd -> snd @>
28 changes: 25 additions & 3 deletions src/GraphBLAS-sharp.Backend/Quotes/SubSum.fs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,30 @@ module SubSum =

barrierLocal () @>

let sequentialSum<'a> opAdd =
sumGeneral<'a> <| sequentialAccess<'a> opAdd
let sequentialSum<'a> = sumGeneral<'a> << sequentialAccess<'a>

let treeSum<'a> opAdd = sumGeneral<'a> <| treeAccess<'a> opAdd
let upSweep<'a> = sumGeneral<'a> << treeAccess<'a>

let downSweep opAdd =
<@ fun wgSize lid (localBuffer: 'a []) ->
let mutable step = wgSize

while step > 1 do
barrierLocal ()

if lid < wgSize / step then
let i = step * (lid + 1) - 1
let j = i - (step >>> 1)

let tmp = localBuffer.[i]

let operand = localBuffer.[j] // brahma error
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it reported?

let buff = (%opAdd) tmp operand

localBuffer.[i] <- buff
localBuffer.[j] <- tmp

step <- step >>> 1 @>

let localPrefixSum opAdd =
<@ fun (lid: int) (workGroupSize: int) (array: 'a []) ->
Expand All @@ -52,4 +72,6 @@ module SubSum =
barrierLocal ()
array.[lid] <- value @>



let localIntPrefixSum = localPrefixSum <@ (+) @>
111 changes: 111 additions & 0 deletions tests/GraphBLAS-sharp.Tests/Common/Scan/ByKey.fs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
module GraphBLAS.FSharp.Tests.Backend.Common.Scan.ByKey

open GraphBLAS.FSharp.Backend.Common
open GraphBLAS.FSharp.Backend.Objects.ClContext
open Expecto
open GraphBLAS.FSharp.Tests
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions

let context = Context.defaultContext.ClContext

let processor = Context.defaultContext.Queue

let checkResult isEqual keysAndValues actual hostScan =

let expected =
HostPrimitives.scanByKey hostScan keysAndValues

"Results must be the same"
|> Utils.compareArrays isEqual actual expected

let makeTestSequentialSegments isEqual scanHost scanDevice (keysAndValues: (int * 'a) []) =
if keysAndValues.Length > 0 then
let keys, values =
Array.sortBy fst keysAndValues |> Array.unzip

let offsets =
HostPrimitives.getUniqueBitmapFirstOccurrence keys
|> HostPrimitives.getBitPositions

let uniqueKeysCount = Array.distinct keys |> Array.length

let clKeys =
context.CreateClArrayWithSpecificAllocationMode(HostInterop, keys)

let clValues =
context.CreateClArrayWithSpecificAllocationMode(HostInterop, values)

let clOffsets =
context.CreateClArrayWithSpecificAllocationMode(HostInterop, offsets)

scanDevice processor uniqueKeysCount clValues clKeys clOffsets

let actual = clValues.ToHostAndFree processor
clKeys.Free processor
clOffsets.Free processor

let keysAndValues = Array.zip keys values

checkResult isEqual keysAndValues actual scanHost

let createTest (zero: 'a) opAddQ opAdd isEqual deviceScan hostScan =

let hostScan = hostScan zero opAdd

let deviceScan =
deviceScan context Utils.defaultWorkGroupSize opAddQ zero

makeTestSequentialSegments isEqual hostScan deviceScan
|> testPropertyWithConfig Utils.defaultConfig $"test on {typeof<'a>}"

let sequentialSegmentsTests =
let excludeTests =
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude

if Utils.isFloat64Available context.ClDevice then
createTest
0.0
<@ (+) @>
(+)
Utils.floatIsEqual
PrefixSum.ByKey.sequentialExclude
HostPrimitives.prefixSumExclude

createTest
0.0f
<@ (+) @>
(+)
Utils.float32IsEqual
PrefixSum.ByKey.sequentialExclude
HostPrimitives.prefixSumExclude

createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude ]
|> testList "exclude"

let includeTests =
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude

if Utils.isFloat64Available context.ClDevice then
createTest
0.0
<@ (+) @>
(+)
Utils.floatIsEqual
PrefixSum.ByKey.sequentialInclude
HostPrimitives.prefixSumInclude

createTest
0.0f
<@ (+) @>
(+)
Utils.float32IsEqual
PrefixSum.ByKey.sequentialInclude
HostPrimitives.prefixSumInclude

createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude ]

|> testList "include"

testList "Sequential segments" [ excludeTests; includeTests ]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module GraphBLAS.FSharp.Tests.Backend.Common.ClArray.PrefixSum
module GraphBLAS.FSharp.Tests.Backend.Common.Scan.PrefixSum

open Expecto
open Expecto.Logging
Expand Down Expand Up @@ -62,7 +62,7 @@ let makeTest plus zero isEqual scan (array: 'a []) =
let testFixtures plus plusQ zero isEqual name =
PrefixSum.runIncludeInplace plusQ context wgSize
|> makeTest plus zero isEqual
|> testPropertyWithConfig config (sprintf "Correctness on %s" name)
|> testPropertyWithConfig config $"Correctness on %s{name}"

let tests =
q.Error.Add(fun e -> failwithf "%A" e)
Expand Down
Loading