|
125 | 125 | offset = :(0) |
126 | 126 | for f in names |
127 | 127 | mdf = :(metadata.$f) |
128 | | - if f in space || length(space) == 0 |
| 128 | + if inspace(f, space) || length(space) == 0 |
129 | 129 | len = :(length($mdf.vals)) |
130 | 130 | push!(exprs, :($f = Metadata($mdf.idcs, |
131 | 131 | $mdf.vns, |
@@ -330,13 +330,6 @@ setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) |
330 | 330 | return expr |
331 | 331 | end |
332 | 332 |
|
333 | | -""" |
334 | | - getsym(vn::VarName) |
335 | | -
|
336 | | -Return the symbol of the Julia variable used to generate `vn`. |
337 | | -""" |
338 | | -getsym(vn::VarName{sym}) where sym = sym |
339 | | - |
340 | 333 | """ |
341 | 334 | getgid(vi::VarInfo, vn::VarName) |
342 | 335 |
|
|
407 | 400 | # If the varname is in the sampler space |
408 | 401 | # or the sample space is empty (all variables) |
409 | 402 | # then return the indices for that variable. |
410 | | - if f in space || length(space) == 0 |
| 403 | + if inspace(f, space) || length(space) == 0 |
411 | 404 | push!(exprs, :($f = findinds(metadata.$f, s, Val($space)))) |
412 | 405 | end |
413 | 406 | end |
|
418 | 411 | # Get all the idcs of the vns in `space` and that belong to the selector `s` |
419 | 412 | return filter((i) -> |
420 | 413 | (s in f_meta.gids[i] || isempty(f_meta.gids[i])) && |
421 | | - (isempty(space) || in(f_meta.vns[i], space)), 1:length(f_meta.gids)) |
| 414 | + (isempty(space) || inspace(f_meta.vns[i], space)), 1:length(f_meta.gids)) |
422 | 415 | end |
423 | 416 | @inline function findinds(f_meta) |
424 | 417 | # Get all the idcs of the vns |
|
488 | 481 | #### APIs for typed and untyped VarInfo |
489 | 482 | #### |
490 | 483 |
|
491 | | -# VarName |
492 | | - |
493 | | -""" |
494 | | - VarName(sym, indexing) |
495 | | - VarName{sym}(indexing::String) |
496 | | -
|
497 | | -Construct a new instance of `VarName{sym}` |
498 | | -""" |
499 | | -VarName(sym, indexing) = VarName{sym}(indexing) |
500 | | - |
501 | | -""" |
502 | | - VarName(vn::VarName, indexing::String) |
503 | | -
|
504 | | -Return a copy of `vn` with a new index `indexing`. |
505 | | -""" |
506 | | -function VarName(vn::VarName, indexing::String) |
507 | | - return VarName{getsym(vn)}(indexing) |
508 | | -end |
509 | | - |
510 | | -""" |
511 | | - uid(vn::VarName) |
512 | | -
|
513 | | -Return a unique tuple identifier for `vn`. |
514 | | -""" |
515 | | -uid(vn::VarName) = (getsym(vn), vn.indexing) |
516 | | - |
517 | | -hash(vn::VarName) = hash(uid(vn)) |
518 | | - |
519 | | -==(x::VarName, y::VarName) = hash(uid(x)) == hash(uid(y)) |
520 | | - |
521 | | -function string(vn::VarName) |
522 | | - return "$(getsym(vn))$(vn.indexing)" |
523 | | -end |
524 | | -function string(vns::Vector{<:VarName}) |
525 | | - return replace(string(map(string, vns)), "String" => "") |
526 | | -end |
527 | | - |
528 | | -""" |
529 | | - Symbol(vn::VarName) |
530 | | -
|
531 | | -Return a `Symbol` represenation of the variable identifier `VarName`. |
532 | | -""" |
533 | | -Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol |
534 | | - |
535 | | -""" |
536 | | - in(vn::VarName, space::Set) |
537 | | -
|
538 | | -Check whether `vn`'s symbol is in `space`. |
539 | | -""" |
540 | | -in(::VarName, ::Tuple{}) = true |
541 | | -in(vn::VarName, space::Tuple)::Bool = getsym(vn) in space || _in(string(vn), space) |
542 | | - |
543 | | -_in(::String, ::Tuple{}) = false |
544 | | -_in(vn_str::String, space::Tuple)::Bool = _in(vn_str, Base.tail(space)) |
545 | | -function _in(vn_str::String, space::Tuple{Expr,Vararg})::Bool |
546 | | - # Collect expressions from space |
547 | | - expr = first(space) |
548 | | - # Filter `(` and `)` out and get a string representation of `exprs` |
549 | | - expr_str = replace(string(expr), r"\(|\)" => "") |
550 | | - # Check if `vn_str` is in `expr_strs` |
551 | | - valid = occursin(expr_str, vn_str) |
552 | | - return valid || _in(vn_str, Base.tail(space)) |
553 | | -end |
554 | 484 |
|
555 | 485 | # VarInfo |
556 | 486 |
|
@@ -602,8 +532,7 @@ function TypedVarInfo(vi::UntypedVarInfo) |
602 | 532 | sym_vals = foldl(vcat, _vals) |
603 | 533 |
|
604 | 534 | push!(new_metas, Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, |
605 | | - sym_dists, sym_gids, sym_orders, sym_flags) |
606 | | - ) |
| 535 | + sym_dists, sym_gids, sym_orders, sym_flags)) |
607 | 536 | end |
608 | 537 | logp = getlogp(vi) |
609 | 538 | num_produce = get_num_produce(vi) |
|
764 | 693 | @generated function _link!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} |
765 | 694 | expr = Expr(:block) |
766 | 695 | for f in names |
767 | | - if f in space || length(space) == 0 |
| 696 | + if inspace(f, space) || length(space) == 0 |
768 | 697 | push!(expr.args, quote |
769 | 698 | f_vns = vi.metadata.$f.vns |
770 | 699 | if ~istrans(vi, f_vns[1]) |
|
810 | 739 | @generated function _invlink!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} |
811 | 740 | expr = Expr(:block) |
812 | 741 | for f in names |
813 | | - if f in space || length(space) == 0 |
| 742 | + if inspace(f, space) || length(space) == 0 |
814 | 743 | push!(expr.args, quote |
815 | 744 | f_vns = vi.metadata.$f.vns |
816 | 745 | if istrans(vi, f_vns[1]) |
@@ -1173,7 +1102,7 @@ Set `vn`'s `gid` to `Set([spl.selector])`, if `vn` does not have a sampler selec |
1173 | 1102 | and `vn`'s symbol is in the space of `spl`. |
1174 | 1103 | """ |
1175 | 1104 | function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) |
1176 | | - if vn in getspace(spl) |
| 1105 | + if inspace(vn, getspace(spl)) |
1177 | 1106 | setgid!(vi, spl.selector, vn) |
1178 | 1107 | end |
1179 | 1108 | end |
0 commit comments