Skip to content

Commit 405c0c7

Browse files
committed
Refactor Random.jl
1 parent 8146597 commit 405c0c7

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

src/Random.jl

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,32 +52,31 @@ Sample a node from a tree according to the sampler `sampler`.
5252
function rand(rng::AbstractRNG, sampler::NodeSampler{N,F,Nothing}) where {N,F}
5353
n = count(sampler.filter, sampler.tree; sampler.break_sharing)
5454
idx = rand(rng, 1:n)
55-
i = Ref(0)
56-
out = Ref(sampler.tree)
57-
foreach(sampler.tree; sampler.break_sharing) do node
58-
if @inline(sampler.filter(node)) && (i[] += 1) == idx
59-
out[] = node
60-
end
61-
nothing
62-
end
63-
return out[]
55+
return _get_node(sampler.tree, sampler.filter, idx, sampler.break_sharing)
6456
end
6557
function rand(rng::AbstractRNG, sampler::NodeSampler{N,F,W}) where {N,F,W<:Function}
6658
weights = filter_map(
6759
sampler.filter, sampler.weighting, sampler.tree, Float64; sampler.break_sharing
6860
)
69-
idx = sample_idx(rng, weights)
61+
idx = _sample_idx(rng, weights)
62+
return _get_node(sampler.tree, sampler.filter, idx, sampler.break_sharing)
63+
end
64+
65+
function _get_node(
66+
tree, filter_f::F, idx::Int, ::Val{break_sharing}
67+
) where {F,break_sharing}
7068
i = Ref(0)
71-
out = Ref(sampler.tree)
72-
foreach(sampler.tree; sampler.break_sharing) do node
73-
if @inline(sampler.filter(node)) && (i[] += 1) == idx
69+
out = Ref(tree)
70+
foreach(tree; break_sharing=Val(break_sharing)) do node
71+
if @inline(filter_f(node)) && (i[] += 1) == idx
7472
out[] = node
7573
end
7674
nothing
7775
end
7876
return out[]
7977
end
80-
function sample_idx(rng::AbstractRNG, weights)
78+
79+
function _sample_idx(rng::AbstractRNG, weights)
8180
csum = cumsum(weights)
8281
r = rand(rng, eltype(weights)) * csum[end]
8382
return findfirst(ci -> ci > r, csum)::Int

0 commit comments

Comments
 (0)