Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] compute_expectation wrt plate and markov variables #413

Closed
wants to merge 22 commits into from
Closed
Prev Previous commit
Next Next commit
misc
  • Loading branch information
Yerdos Ordabayev committed Dec 28, 2020
commit 8cdd879a9cb14043a38fc8c662bbb8430ecf34fc
27 changes: 5 additions & 22 deletions funsor/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,29 +186,20 @@ def partial_sum_product(sum_op, prod_op, factors, eliminate=frozenset(), plates=
for f in factors:
ordinal = plates.intersection(f.inputs)
ordinal_to_factors[ordinal].append(f)
# for var in sum_vars.intersection(f.inputs):
for var in frozenset(f.inputs) - plates:
for var in sum_vars.intersection(f.inputs):
var_to_ordinal[var] = var_to_ordinal.get(var, ordinal) & ordinal

ordinal_to_vars = defaultdict(set)
ordinal_to_not_summed = defaultdict(set)
for var, ordinal in var_to_ordinal.items():
if var in sum_vars:
ordinal_to_vars[ordinal].add(var)
else:
ordinal_to_not_summed[ordinal].add(var)
ordinal_to_vars[ordinal].add(var)

results = []
while ordinal_to_factors:
leaf = max(ordinal_to_factors, key=len)
leaf_factors = ordinal_to_factors.pop(leaf)
leaf_reduce_vars = ordinal_to_vars[leaf]
for (group_factors, group_vars) in _partition(leaf_factors, leaf_reduce_vars):
f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & sum_vars)
# not_summed_vars = frozenset(f.inputs) & ordinal_to_not_summed[leaf]
# parent_vars = sum_vars & frozenset(f.inputs)
# if not_summed_vars and parent_vars:
# raise ValueError("intractable!")
f = reduce(prod_op, group_factors).reduce(sum_op, group_vars)
remaining_sum_vars = sum_vars.intersection(f.inputs)
if not remaining_sum_vars:
results.append(f.reduce(prod_op, leaf & eliminate))
Expand Down Expand Up @@ -284,30 +275,22 @@ def modified_partial_sum_product(sum_op, prod_op, factors,
for f in factors:
ordinal = plates.intersection(f.inputs)
ordinal_to_factors[ordinal].append(f)
# for var in sum_vars.intersection(f.inputs):
for var in frozenset(f.inputs) - plates:
for var in sum_vars.intersection(f.inputs):
var_to_ordinal[var] = var_to_ordinal.get(var, ordinal) & ordinal

ordinal_to_vars = defaultdict(set)
ordinal_to_not_summed = defaultdict(set)
for var, ordinal in var_to_ordinal.items():
if var in sum_vars:
ordinal_to_vars[ordinal].add(var)
else:
ordinal_to_not_summed[ordinal].add(var)

results = []
while ordinal_to_factors:
leaf = max(ordinal_to_factors, key=len)
leaf_factors = ordinal_to_factors.pop(leaf)
leaf_reduce_vars = ordinal_to_vars[leaf]
for (group_factors, group_vars) in _partition(leaf_factors, leaf_reduce_vars | markov_prod_vars):
# eliminate markov vars
nonmarkov_vars = group_vars - markov_sum_vars - markov_prod_vars
f = reduce(prod_op, group_factors).reduce(sum_op, nonmarkov_vars)
# not_summed_vars = frozenset(f.inputs) & ordinal_to_not_summed[leaf]
# if not_summed_vars and (sum_vars & frozenset(f.inputs)):
# raise ValueError("intractable!")
# eliminate markov vars
markov_vars = group_vars.intersection(markov_sum_vars)
# cond_vars |= frozenset(f.inputs) - plates
if markov_vars:
Expand Down
4 changes: 0 additions & 4 deletions test/test_sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ def test_partition(inputs, dims, expected_num_components):
@pytest.mark.parametrize('sum_op,prod_op', [(ops.add, ops.mul), (ops.logaddexp, ops.add)])
@pytest.mark.parametrize('inputs,plates', [('a,abi,bcij', 'ij')])
@pytest.mark.parametrize('vars1,vars2', [
# ('acj', 'bi'),
# ('a', 'bcij'),
('', 'abcij'),
('c', 'abij'),
('cj', 'abi'),
Expand Down Expand Up @@ -163,8 +161,6 @@ def test_modified_partial_sum_product_0(sum_op, prod_op, vars1, vars2,


@pytest.mark.parametrize('vars1,vars2', [
(frozenset({"time", "x_0", "x_prev", "x_curr"}),
frozenset({"y_0", "y_curr"})),
(frozenset(),
frozenset({"time", "x_0", "x_prev", "x_curr", "y_0", "y_curr"})),
(frozenset({"y_0", "y_curr"}),
Expand Down