Skip to content

Commit

Permalink
Dispatch to Integrate(Delta, ...) in `normalize_integrate_contracti…
Browse files Browse the repository at this point in the history
…on` (#551)

* dispatch to Integrate(Delta) and dice_factor as log_density

* delta_fresh; revert dice_factor changes

* dice_factor changes
  • Loading branch information
ordabayevy authored Sep 24, 2021
1 parent 570d290 commit 7bfbd63
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
5 changes: 3 additions & 2 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
tuple(sample_inputs)
+ tuple(inp for inp in self.inputs if inp in funsor_value.inputs)
)
result = funsor.delta.Delta(value_name, funsor_value)
if not raw_dist.has_rsample:
# scaling of dice_factor by num samples should already be handled by Funsor.sample
raw_log_prob = raw_dist.log_prob(raw_value)
Expand All @@ -241,7 +240,9 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
output=self.output,
dim_to_name=dim_to_name,
)
result = result + dice_factor
result = funsor.delta.Delta(value_name, funsor_value, dice_factor)
else:
result = funsor.delta.Delta(value_name, funsor_value)
return result

def enumerate_support(self, expand=False):
Expand Down
10 changes: 3 additions & 7 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,9 @@ def normalize_integrate_contraction(log_measure, integrand, reduced_vars):
and t.fresh.intersection(reduced_names, integrand.inputs)
]
for delta in delta_terms:
integrand = integrand(
**{
name: point
for name, (point, log_density) in delta.terms
if name in reduced_names.intersection(integrand.inputs)
}
)
delta_fresh = frozenset(Variable(k, delta.inputs[k]) for k in delta.fresh)
args = delta, integrand, delta_fresh
integrand = eager.dispatch(Integrate, *args)(*args)
return normalize_integrate(log_measure, integrand, reduced_vars)


Expand Down
2 changes: 1 addition & 1 deletion test/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,7 +1427,7 @@ def test_categorical_event_dim_conversion(batch_shape, event_shape):

name_to_dim = {batch_dim: -1 - i for i, batch_dim in enumerate(batch_dims)}
rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32)
data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0].terms[0][1][0]
data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0][1][0]

actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim)
expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob(
Expand Down

0 comments on commit 7bfbd63

Please sign in to comment.