Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] compute_expectation wrt plate and markov variables #413

Closed
wants to merge 22 commits into from
Next Next commit
naive implementation
  • Loading branch information
Yerdos Ordabayev committed Dec 24, 2020
commit 62ecfd014d08ea10491e443fd721ecd130fa8407
219 changes: 219 additions & 0 deletions funsor/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from funsor.domains import Bint
from funsor.ops import UNITS, AssociativeOp
from funsor.terms import Cat, Funsor, FunsorMeta, Number, Slice, Stack, Subs, Variable, eager, substitute, to_funsor
from funsor.tensor import Tensor
from funsor.util import quote


Expand Down Expand Up @@ -323,6 +324,142 @@ def modified_partial_sum_product(sum_op, prod_op, factors,
return results


def sequential_integral(sum_op, prod_op, factors, integrand,
eliminate=frozenset(), plate_to_step=dict()):
"""
Generalization of the tensor variable elimination algorithm of
:func:`funsor.sum_product.partial_sum_product` to handle markov dimensions
in addition to plate dimensions. Markov dimensions in transition factors
are eliminated efficiently using the parallel-scan algorithm in
:func:`funsor.sum_product.sequential_sum_product`. The resulting factors are then
combined with the initial factors and final states are eliminated. Therefore,
when Markov dimension is eliminated ``factors`` has to contain a pairs of
initial factors and transition factors.

:param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation.
:param ~funsor.ops.AssociativeOp prod_op: A semiring product operation.
:param factors: A collection of funsors.
:type factors: tuple or list
:param frozenset eliminate: A set of free variables to eliminate,
including both sum variables and product variable.
:param dict plate_to_step: A dict mapping markov dimensions to
``step`` collections that contain ordered sequences of Markov variable names
(e.g., ``{"time": frozenset({("x_0", "x_prev", "x_curr")})}``).
Plates are passed with an empty ``step``.
:return: a list of partially contracted Funsors.
:rtype: list
"""
assert callable(sum_op)
assert callable(prod_op)
assert isinstance(factors, (tuple, list))
assert all(isinstance(f, Funsor) for f in factors)
assert isinstance(eliminate, frozenset)
assert isinstance(plate_to_step, dict)
# process plate_to_step
plate_to_step = plate_to_step.copy()
prev_to_init = {}
for key, step in plate_to_step.items():
# map prev to init; works for any history > 0
for chain in step:
init, prev = chain[:len(chain)//2], chain[len(chain)//2:-1]
prev_to_init.update(zip(prev, init))
# convert step to dict type required for MarkovProduct
plate_to_step[key] = {chain[1]: chain[2] for chain in step}

plates = frozenset(plate_to_step.keys())
sum_vars = eliminate - plates
prod_vars = eliminate.intersection(plates)
markov_sum_vars = frozenset()
for step in plate_to_step.values():
markov_sum_vars |= frozenset(step.keys()) | frozenset(step.values())
markov_sum_vars &= sum_vars
markov_prod_vars = frozenset(k for k, v in plate_to_step.items() if v and k in eliminate)
markov_sum_to_prod = defaultdict(set)
for markov_prod in markov_prod_vars:
for k, v in plate_to_step[markov_prod].items():
markov_sum_to_prod[k].add(markov_prod)
markov_sum_to_prod[v].add(markov_prod)

var_to_ordinal = {}
ordinal_to_factors = defaultdict(list)
for f in factors:
ordinal = plates.intersection(f.inputs)
ordinal_to_factors[ordinal].append(f)
for var in sum_vars.intersection(f.inputs):
var_to_ordinal[var] = var_to_ordinal.get(var, ordinal) & ordinal

ordinal_to_vars = defaultdict(set)
for var, ordinal in var_to_ordinal.items():
ordinal_to_vars[ordinal].add(var)

results = []
while ordinal_to_factors:
leaf = max(ordinal_to_factors, key=len)
leaf_factors = ordinal_to_factors.pop(leaf)
leaf_reduce_vars = ordinal_to_vars[leaf]
for (group_factors, group_vars) in _partition(leaf_factors, leaf_reduce_vars | markov_prod_vars):
if group_vars.intersection(integrand.inputs):
# eliminate non markov vars
nonmarkov_vars = group_vars - markov_sum_vars - markov_prod_vars
# eliminate markov vars
markov_vars = group_vars.intersection(markov_sum_vars)
nonmarkov_factors = [f for f in group_factors if nonmarkov_vars.intersection(f.inputs)]
markov_factors = [f for f in group_factors if not nonmarkov_vars.intersection(f.inputs)]
integrand = reduce(prod_op, nonmarkov_factors + [integrand]).reduce(sum_op, nonmarkov_vars)
if markov_vars:
markov_prod_var = [markov_sum_to_prod[var] for var in markov_vars]
assert all(p == markov_prod_var[0] for p in markov_prod_var)
if len(markov_prod_var[0]) != 1:
raise ValueError("intractable!")
time = next(iter(markov_prod_var[0]))
for v in sum_vars.intersection(f.inputs):
if time in var_to_ordinal[v] and var_to_ordinal[v] < leaf:
raise ValueError("intractable!")
time_var = Variable(time, f.inputs[time])
group_step = {k: v for (k, v) in plate_to_step[time].items() if v in markov_vars}
f = reduce(prod_op, markov_factors)
betas = naive_suffix_sum(sum_op, prod_op, f, time_var, group_step)
alphas = naive_prefix_sum(sum_op, prod_op, f, time_var, group_step)
integrand = reduce(prod_op, [integrand, f])
integrand = _helper(sum_op, prod_op, integrand, alphas, betas, time_var, group_step)
integrand = integrand.reduce(sum_op, frozenset(group_step.values()))
integrand = integrand(**prev_to_init)
else:
# eliminate non markov vars
nonmarkov_vars = group_vars - markov_sum_vars - markov_prod_vars
f = reduce(prod_op, group_factors).reduce(sum_op, nonmarkov_vars)
# eliminate markov vars
markov_vars = group_vars.intersection(markov_sum_vars)
if markov_vars:
markov_prod_var = [markov_sum_to_prod[var] for var in markov_vars]
assert all(p == markov_prod_var[0] for p in markov_prod_var)
if len(markov_prod_var[0]) != 1:
raise ValueError("intractable!")
time = next(iter(markov_prod_var[0]))
for v in sum_vars.intersection(f.inputs):
if time in var_to_ordinal[v] and var_to_ordinal[v] < leaf:
raise ValueError("intractable!")
time_var = Variable(time, f.inputs[time])
group_step = {k: v for (k, v) in plate_to_step[time].items() if v in markov_vars}
f = MarkovProduct(sum_op, prod_op, f, time_var, group_step)
f = f.reduce(sum_op, frozenset(group_step.values()))
f = f(**prev_to_init)

remaining_sum_vars = sum_vars.intersection(f.inputs)

if not remaining_sum_vars:
results.append(f.reduce(prod_op, leaf & prod_vars - markov_prod_vars))
else:
new_plates = frozenset().union(
*(var_to_ordinal[v] for v in remaining_sum_vars))
if new_plates == leaf:
raise ValueError("intractable!")
f = f.reduce(prod_op, leaf - new_plates - markov_prod_vars)
ordinal_to_factors[new_plates].append(f)

return integrand


def sum_product(sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset()):
"""
Performs sum-product contraction of a collection of factors.
Expand Down Expand Up @@ -360,6 +497,88 @@ def naive_sequential_sum_product(sum_op, prod_op, trans, time, step):
factors.append(xy)
return factors[0]

def _helper(sum_op, prod_op, integrand, alphas, betas, time, step):

step = OrderedDict(sorted(step.items()))
drop = tuple("_drop_{}".format(i) for i in range(len(step)))
prev_to_drop = dict(zip(step.keys(), drop))
curr_to_drop = dict(zip(step.values(), drop))
drop = frozenset(drop)

time, duration = time.name, time.output.size
first = integrand(**{time: Slice(time, 0, 1, 1, duration)})
integrand = integrand(**{time: Slice(time, 1, duration, 1, duration)}, **prev_to_drop)
alphas = alphas(**curr_to_drop)
integrand = Contraction(sum_op, prod_op, drop, alphas, integrand)
integrand = Cat(time, (first, integrand))
last = integrand(**{time: Slice(time, duration-1, duration, 1, duration)})
integrand = integrand(**{time: Slice(time, 0, duration-1, 1, duration)}, **curr_to_drop)
betas = betas(**prev_to_drop)
integrand = Contraction(sum_op, prod_op, drop, betas, integrand)
integrand = Cat(time, (integrand, last))
return integrand

def naive_suffix_sum(sum_op, prod_op, trans, time, step):
assert isinstance(sum_op, AssociativeOp)
assert isinstance(prod_op, AssociativeOp)
assert isinstance(trans, Funsor)
assert isinstance(time, Variable)
assert isinstance(step, dict)
assert all(isinstance(k, str) for k in step.keys())
assert all(isinstance(v, str) for v in step.values())
if time.name in trans.inputs:
assert time.output == trans.inputs[time.name]

step = OrderedDict(sorted(step.items()))
drop = tuple("_drop_{}".format(i) for i in range(len(step)))
prev_to_drop = dict(zip(step.keys(), drop))
curr_to_drop = dict(zip(step.values(), drop))
drop = frozenset(drop)

time, duration = time.name, time.output.size
factors = [trans(**{time: t}) for t in range(duration)]
betas = [factors[-1]]
while len(factors) > 1:
y = factors.pop()(**prev_to_drop)
x = factors.pop()(**curr_to_drop)
xy = prod_op(x, y).reduce(sum_op, drop)
factors.append(xy)
betas.append(xy)
betas.reverse()
beta_terms = Stack(time, tuple(betas[1:]))
return beta_terms


def naive_prefix_sum(sum_op, prod_op, trans, time, step):
assert isinstance(sum_op, AssociativeOp)
assert isinstance(prod_op, AssociativeOp)
assert isinstance(trans, Funsor)
assert isinstance(time, Variable)
assert isinstance(step, dict)
assert all(isinstance(k, str) for k in step.keys())
assert all(isinstance(v, str) for v in step.values())
if time.name in trans.inputs:
assert time.output == trans.inputs[time.name]

step = OrderedDict(sorted(step.items()))
drop = tuple("_drop_{}".format(i) for i in range(len(step)))
prev_to_drop = dict(zip(step.keys(), drop))
curr_to_drop = dict(zip(step.values(), drop))
drop = frozenset(drop)

time, duration = time.name, time.output.size
factors = [trans(**{time: t}) for t in range(duration)]
factors.reverse()
alphas = [factors[-1]]
while len(factors) > 1:
x = factors.pop()(**curr_to_drop)
y = factors.pop()(**prev_to_drop)
xy = prod_op(x, y).reduce(sum_op, drop)
factors.append(xy)
alphas.append(xy)
alpha_terms = Stack(time, tuple(alphas[:-1]))
return alpha_terms


def sequential_sum_product(sum_op, prod_op, trans, time, step):
"""
Expand Down