@@ -251,3 +251,176 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
251251 VarName {Symbol(Prefix, PREFIX_SEPARATOR, Sym)} (vn. indexing)
252252 end
253253end
254+
255+ struct ConditionContext{Names,Values,Ctx<: AbstractContext } <: AbstractContext
256+ values:: Values
257+ context:: Ctx
258+
259+ function ConditionContext {Values} (
260+ values:: Values , context:: AbstractContext
261+ ) where {names,Values<: NamedTuple{names} }
262+ return new {names,typeof(values),typeof(context)} (values, context)
263+ end
264+ end
265+
266+ function ConditionContext (values:: NamedTuple )
267+ return ConditionContext (values, DefaultContext ())
268+ end
269+ function ConditionContext (values:: NamedTuple , context:: AbstractContext )
270+ return ConditionContext {typeof(values)} (values, context)
271+ end
272+
273+ # Try to avoid nested `ConditionContext`.
274+ function ConditionContext (
275+ values:: NamedTuple{Names} , context:: ConditionContext
276+ ) where {Names}
277+ # Note that this potentially overrides values from `context`, thus giving
278+ # precedence to the outmost `ConditionContext`.
279+ return ConditionContext (merge (context. values, values), childcontext (context))
280+ end
281+
282+ function Base. show (io:: IO , context:: ConditionContext )
283+ return print (io, " ConditionContext($(context. values) , $(childcontext (context)) )" )
284+ end
285+
286+ NodeTrait (context:: ConditionContext ) = IsParent ()
287+ childcontext (context:: ConditionContext ) = context. context
288+ setchildcontext (parent:: ConditionContext , child) = ConditionContext (parent. values, child)
289+
290+ """
291+ hasvalue(context, vn)
292+
293+ Return `true` if `vn` is found in `context`.
294+ """
295+ hasvalue (context, vn) = false
296+
297+ function hasvalue (context:: ConditionContext{vars} , vn:: VarName{sym} ) where {vars,sym}
298+ return sym in vars
299+ end
300+ function hasvalue (
301+ context:: ConditionContext{vars} , vn:: AbstractArray{<:VarName{sym}}
302+ ) where {vars,sym}
303+ return sym in vars
304+ end
305+
306+ """
307+ getvalue(context, vn)
308+
309+ Return value of `vn` in `context`.
310+ """
311+ function getvalue (context:: AbstractContext , vn)
312+ return error (" context $(context) does not contain value for $vn " )
313+ end
314+ getvalue (context:: ConditionContext , vn) = _getvalue (context. values, vn)
315+
316+ """
317+ hasvalue_nested(context, vn)
318+
319+ Return `true` if `vn` is found in `context` or any of its descendants.
320+
321+ This is contrast to [`hasvalue`](@ref) which only checks for `vn` in `context`,
322+ not recursively checking if `vn` is in any of its descendants.
323+ """
324+ function hasvalue_nested (context:: AbstractContext , vn)
325+ return hasvalue_nested (NodeTrait (hasvalue_nested, context), context, vn)
326+ end
327+ hasvalue_nested (:: IsLeaf , context, vn) = hasvalue (context, vn)
328+ function hasvalue_nested (:: IsParent , context, vn)
329+ return hasvalue (context, vn) || hasvalue_nested (childcontext (context), vn)
330+ end
331+ function hasvalue_nested (context:: PrefixContext , vn)
332+ return hasvalue_nested (childcontext (context), prefix (context, vn))
333+ end
334+
335+ """
336+ getvalue_nested(context, vn)
337+
338+ Return the value of the parameter corresponding to `vn` from `context` or its descendants.
339+
340+ This is contrast to [`getvalue`](@ref) which only returns the value `vn` in `context`,
341+ not recursively looking into its descendants.
342+ """
343+ function getvalue_nested (context:: AbstractContext , vn)
344+ return getvalue_nested (NodeTrait (getvalue_nested, context), context, vn)
345+ end
346+ function getvalue_nested (:: IsLeaf , context, vn)
347+ return error (" context $(context) does not contain value for $vn " )
348+ end
349+ function getvalue_nested (context:: PrefixContext , vn)
350+ return getvalue_nested (childcontext (context), prefix (context, vn))
351+ end
352+ function getvalue_nested (:: IsParent , context, vn)
353+ return if hasvalue (context, vn)
354+ getvalue (context, vn)
355+ else
356+ getvalue_nested (childcontext (context), vn)
357+ end
358+ end
359+
360+ """
361+ condition([context::AbstractContext,] values::NamedTuple)
362+ condition([context::AbstractContext]; values...)
363+
364+ Return `ConditionContext` with `values` and `context` if `values` is non-empty,
365+ otherwise return `context` which is [`DefaultContext`](@ref) by default.
366+
367+ See also: [`decondition`](@ref)
368+ """
369+ condition (; values... ) = condition (DefaultContext (), NamedTuple (values))
370+ condition (values:: NamedTuple ) = condition (DefaultContext (), values)
371+ condition (context:: AbstractContext , values:: NamedTuple{()} ) = context
372+ condition (context:: AbstractContext , values:: NamedTuple ) = ConditionContext (values, context)
373+ condition (context:: AbstractContext ; values... ) = condition (context, NamedTuple (values))
374+
375+ """
376+ decondition(context::AbstractContext, syms...)
377+
378+ Return `context` but with `syms` no longer conditioned on.
379+
380+ Note that this recursively traverses contexts, deconditioning all along the way.
381+
382+ See also: [`condition`](@ref)
383+ """
384+ decondition (:: IsLeaf , context, args... ) = context
385+ function decondition (:: IsParent , context, args... )
386+ return setchildcontext (context, decondition (childcontext (context), args... ))
387+ end
388+ decondition (context, args... ) = decondition (NodeTrait (context), context, args... )
389+ function decondition (context:: ConditionContext )
390+ return decondition (childcontext (context))
391+ end
392+ function decondition (context:: ConditionContext , sym)
393+ return condition (
394+ decondition (childcontext (context), sym), BangBang. delete!! (context. values, sym)
395+ )
396+ end
397+ function decondition (context:: ConditionContext , sym, syms... )
398+ return decondition (
399+ condition (
400+ decondition (childcontext (context), syms... ),
401+ BangBang. delete!! (context. values, sym),
402+ ),
403+ syms... ,
404+ )
405+ end
406+
407+ """
408+ conditioned(context::AbstractContext)
409+
410+ Return `NamedTuple` of values that are conditioned on under context`.
411+
412+ Note that this will recursively traverse the context stack and return
413+ a merged version of the condition values.
414+ """
415+ function conditioned (context:: AbstractContext )
416+ return conditioned (NodeTrait (conditioned, context), context)
417+ end
418+ conditioned (:: IsLeaf , context) = ()
419+ conditioned (:: IsParent , context) = conditioned (childcontext (context))
420+ function conditioned (context:: ConditionContext )
421+ # Note the order of arguments to `merge`. The behavior of the rest of DPPL
422+ # is that the outermost `context` takes precendence, hence when resolving
423+ # the `conditioned` variables we need to ensure that `context.values` takes
424+ # precedence over decendants of `context`.
425+ return merge (context. values, conditioned (childcontext (context)))
426+ end
0 commit comments