|
39 | 39 | x = rand(10_000) |
40 | 40 |
|
41 | 41 | @model function wthreads(x) |
| 42 | + global vi_ = _varinfo |
42 | 43 | x[1] ~ Normal(0, 1) |
43 | 44 | Threads.@threads for i in 2:length(x) |
44 | 45 | x[i] ~ Normal(x[i-1], 1) |
|
48 | 49 | vi = VarInfo() |
49 | 50 | wthreads(x)(vi) |
50 | 51 | lp_w_threads = getlogp(vi) |
| 52 | + if Threads.nthreads() == 1 |
| 53 | + @test vi_ isa VarInfo |
| 54 | + else |
| 55 | + @test vi_ isa DynamicPPL.ThreadSafeVarInfo |
| 56 | + end |
51 | 57 |
|
52 | 58 | println("With `@threads`:") |
53 | 59 | println(" default:") |
54 | 60 | @time wthreads(x)(vi) |
55 | 61 |
|
56 | | - # Ensure that we use `ThreadSafeVarInfo`. |
| 62 | + # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded assume statements. |
| 63 | + DynamicPPL.evaluate_threadsafe(Random.GLOBAL_RNG, wthreads(x), vi, |
| 64 | + SampleFromPrior(), DefaultContext()) |
57 | 65 | @test getlogp(vi) ≈ lp_w_threads |
58 | | - DynamicPPL.evaluate_multithreaded(Random.GLOBAL_RNG, wthreads(x), vi, |
59 | | - SampleFromPrior(), DefaultContext()) |
| 66 | + @test vi_ isa DynamicPPL.ThreadSafeVarInfo |
60 | 67 |
|
61 | | - println(" evaluate_multithreaded:") |
62 | | - @time DynamicPPL.evaluate_multithreaded(Random.GLOBAL_RNG, wthreads(x), vi, |
63 | | - SampleFromPrior(), DefaultContext()) |
| 68 | + println(" evaluate_threadsafe:") |
| 69 | + @time DynamicPPL.evaluate_threadsafe(Random.GLOBAL_RNG, wthreads(x), vi, |
| 70 | + SampleFromPrior(), DefaultContext()) |
64 | 71 |
|
65 | 72 | @model function wothreads(x) |
| 73 | + global vi_ = _varinfo |
66 | 74 | x[1] ~ Normal(0, 1) |
67 | 75 | for i in 2:length(x) |
68 | 76 | x[i] ~ Normal(x[i-1], 1) |
|
72 | 80 | vi = VarInfo() |
73 | 81 | wothreads(x)(vi) |
74 | 82 | lp_wo_threads = getlogp(vi) |
| 83 | + if Threads.nthreads() == 1 |
| 84 | + @test vi_ isa VarInfo |
| 85 | + else |
| 86 | + @test vi_ isa DynamicPPL.ThreadSafeVarInfo |
| 87 | + end |
75 | 88 |
|
76 | 89 | println("Without `@threads`:") |
77 | 90 | println(" default:") |
|
80 | 93 | @test lp_w_threads ≈ lp_wo_threads |
81 | 94 |
|
82 | 95 | # Ensure that we use `VarInfo`. |
83 | | - DynamicPPL.evaluate_singlethreaded(Random.GLOBAL_RNG, wothreads(x), vi, |
84 | | - SampleFromPrior(), DefaultContext()) |
| 96 | + DynamicPPL.evaluate_threadunsafe(Random.GLOBAL_RNG, wothreads(x), vi, |
| 97 | + SampleFromPrior(), DefaultContext()) |
85 | 98 | @test getlogp(vi) ≈ lp_w_threads |
| 99 | + @test vi_ isa VarInfo |
86 | 100 |
|
87 | | - println(" evaluate_singlethreaded:") |
88 | | - @time DynamicPPL.evaluate_singlethreaded(Random.GLOBAL_RNG, wothreads(x), vi, |
89 | | - SampleFromPrior(), DefaultContext()) |
| 101 | + println(" evaluate_threadunsafe:") |
| 102 | + @time DynamicPPL.evaluate_threadunsafe(Random.GLOBAL_RNG, wothreads(x), vi, |
| 103 | + SampleFromPrior(), DefaultContext()) |
90 | 104 | end |
91 | 105 | end |
0 commit comments