@@ -52,32 +52,31 @@ Sample a node from a tree according to the sampler `sampler`.
5252function rand (rng:: AbstractRNG , sampler:: NodeSampler{N,F,Nothing} ) where {N,F}
5353 n = count (sampler. filter, sampler. tree; sampler. break_sharing)
5454 idx = rand (rng, 1 : n)
55- i = Ref (0 )
56- out = Ref (sampler. tree)
57- foreach (sampler. tree; sampler. break_sharing) do node
58- if @inline (sampler. filter (node)) && (i[] += 1 ) == idx
59- out[] = node
60- end
61- nothing
62- end
63- return out[]
55+ return _get_node (sampler. tree, sampler. filter, idx, sampler. break_sharing)
6456end
6557function rand (rng:: AbstractRNG , sampler:: NodeSampler{N,F,W} ) where {N,F,W<: Function }
6658 weights = filter_map (
6759 sampler. filter, sampler. weighting, sampler. tree, Float64; sampler. break_sharing
6860 )
69- idx = sample_idx (rng, weights)
61+ idx = _sample_idx (rng, weights)
62+ return _get_node (sampler. tree, sampler. filter, idx, sampler. break_sharing)
63+ end
64+
65+ function _get_node (
66+ tree, filter_f:: F , idx:: Int , :: Val{break_sharing}
67+ ) where {F,break_sharing}
7068 i = Ref (0 )
71- out = Ref (sampler . tree)
72- foreach (sampler . tree; sampler . break_sharing) do node
73- if @inline (sampler . filter (node)) && (i[] += 1 ) == idx
69+ out = Ref (tree)
70+ foreach (tree; break_sharing= Val (break_sharing) ) do node
71+ if @inline (filter_f (node)) && (i[] += 1 ) == idx
7472 out[] = node
7573 end
7674 nothing
7775 end
7876 return out[]
7977end
80- function sample_idx (rng:: AbstractRNG , weights)
78+
79+ function _sample_idx (rng:: AbstractRNG , weights)
8180 csum = cumsum (weights)
8281 r = rand (rng, eltype (weights)) * csum[end ]
8382 return findfirst (ci -> ci > r, csum):: Int
0 commit comments