Skip to content

Commit

Permalink
Fixed implementation of score function baselines
Browse files Browse the repository at this point in the history
  • Loading branch information
HEmile committed Oct 22, 2021
1 parent 9de4709 commit f9add03
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 24 deletions.
41 changes: 40 additions & 1 deletion docs/source/examples.discrete_vae.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,43 @@ Discrete Variational Autoencoder
# Go backward through both deterministic and stochastic nodes, and optimize
average_ELBO, _ = storch.backward()
optimizer.step()
optimizer.step()
.. code-block:: python
import torch
import storch
from vae import minibatches, encode, decode, KLD
method = ScoreFunctionLOO("z", 8, baseline="batch_average")
for data in minibatches():
optimizer.zero_grad()
# Denote the minibatch dimension as being independent
data = storch.denote_independent(data.view(-1, 784), 0, "data")
# Define variational distribution given data, and sample latent variables
q = torch.distributions.OneHotCategorical(logits=encode(data))
z = method(q)
# Compute and register the KL divergence and reconstruction losses to form the ELBO
reconstruction = decode(z)
storch.add_cost(KLD(q))
storch.add_cost(storch.nn.b_binary_cross_entropy(reconstruction, data))
# Backward pass through deterministic and stochastic nodes, and optimize
ELBO = storch.backward()
optimizer.step()
.. code-block:: python
class ScoreFunctionLOO(Method):
def proposal_dist(self, distr: Distribution, amt_samples: int,
) -> torch.Tensor:
return distr.sample((amt_samples,))

def estimator(self, tensor: StochasticTensor, cost: CostTensor
) -> Tuple[Optional[storch.Tensor], Optional[storch.Tensor]]:
# Compute the gradient function (
log_prob = tensor.distribution.log_prob(tensor)
sum_costs = storch.sum(costs.detach(), tensor.name)
baseline = (sum_costs - costs) / (tensor.n - 1)
return log_prob, (1.0 - log_prob) * baseline
19 changes: 10 additions & 9 deletions storch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def add_cost(cost: Tensor, name: str):
return cost


def surrogate_loss(debug: bool=False) -> storch.Tensor:
def surrogate_loss(debug: bool = False) -> storch.Tensor:
costs: [storch.Tensor] = storch.inference._cost_tensors
if not costs:
raise RuntimeError("No cost nodes registered for backward call.")
Expand Down Expand Up @@ -179,18 +179,18 @@ def surrogate_loss(debug: bool=False) -> storch.Tensor:
new_parent._children = parent._children

# Compute the estimator
(
gradient_function,
control_variate,
) = parent.method._estimator(new_parent, reduced_cost)
(gradient_function, control_variate,) = parent.method._estimator(
new_parent, reduced_cost
)

if gradient_function is not None:
L = L + gradient_function
# Compute control variate
if control_variate is not None:
final_A = magic_box(L) * control_variate
final_A = storch.reduce_plates(
final_A, detach_weights=False, # TODO: Should this boolean be false or true?
final_A,
detach_weights=False, # TODO: Should this boolean be false or true?
)
if final_A.ndim == 1:
final_A = final_A.squeeze(0)
Expand All @@ -202,6 +202,7 @@ def surrogate_loss(debug: bool=False) -> storch.Tensor:
SL = torch.sum(torch.stack(surrogate_losses))
return SL


def backward(
debug: bool = False,
create_graph: bool = False,
Expand Down Expand Up @@ -232,10 +233,10 @@ def backward(
if isinstance(parent, StochasticTensor):
stochastic_nodes.add(parent)
if parent.requires_grad and parent.method:
create_higher_order_graph = parent.method.should_create_higher_order_graph()
_create_graph = (
create_higher_order_graph or _create_graph
create_higher_order_graph = (
parent.method.should_create_higher_order_graph()
)
_create_graph = create_higher_order_graph or _create_graph
if isinstance(SL, storch.Tensor) and SL._tensor.requires_grad:
SL._tensor.backward(create_graph=_create_graph)

Expand Down
23 changes: 9 additions & 14 deletions storch/method/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_distr_parameters,
rsample_gumbel_softmax,
rsample_gumbel,
magic_box,
)
import storch
from storch.method.baseline import MovingAverageBaseline, BatchAverageBaseline, Baseline
Expand All @@ -21,7 +22,6 @@
import entmax



class Method(ABC, torch.nn.Module):
"""
Base class of gradient estimation methods.
Expand All @@ -34,6 +34,7 @@ class Method(ABC, torch.nn.Module):
plate_name (str): The name of the :class:`.Plate` that samples of this method will use.
sampling_method (storch.sampling.SamplingMethod): The method to sample tensors with given an input distribution.
"""

def __init__(self, plate_name: str, sampling_method: SamplingMethod):
super().__init__()
self._estimation_pairs = []
Expand All @@ -58,6 +59,7 @@ def forward(self, distr: Distribution) -> StochasticTensor:
def _create_hook(sample: StochasticTensor, name: str, plates: List[Plate]):
accum_grads = sample.param_grads
del sample # Remove from hook closure for GC reasons

def hook(*args: Tuple[any]):
# For some reason, this args unpacking is required for compatbility with registring on a .grad_fn...?
# TODO: I'm sure there could be something wrong here
Expand Down Expand Up @@ -85,6 +87,7 @@ def clean_graph() -> None:
del accum_grads
except (UnboundLocalError, NameError):
pass

return hook, clean_graph

def sample(self, distr: Distribution) -> storch.tensor.StochasticTensor:
Expand Down Expand Up @@ -188,17 +191,13 @@ def sample(self, distr: Distribution) -> storch.tensor.StochasticTensor:

def _estimator(
self, tensor: StochasticTensor, cost_node: CostTensor
) -> Tuple[
Optional[storch.Tensor], Optional[storch.Tensor]
]:
) -> Tuple[Optional[storch.Tensor], Optional[storch.Tensor]]:
self._estimation_pairs.append((tensor, cost_node))
return self.estimator(tensor, cost_node)

def estimator(
self, tensor: StochasticTensor, cost_node: CostTensor
) -> Tuple[
Optional[storch.Tensor], Optional[storch.Tensor]
]:
) -> Tuple[Optional[storch.Tensor], Optional[storch.Tensor]]:
"""
Returns two terms that will be used for inferring higher-order gradient estimates.
- The first return is the gradient function. It will be multiplied with the cost function.
Expand Down Expand Up @@ -286,9 +285,7 @@ def __init__(

def estimator(
self, tensor: StochasticTensor, cost_node: CostTensor
) -> Tuple[
Optional[storch.Tensor], Optional[storch.Tensor]
]:
) -> Tuple[Optional[storch.Tensor], Optional[storch.Tensor]]:
return self._score_method.estimator(tensor, cost_node)

def update_parameters(
Expand Down Expand Up @@ -553,9 +550,7 @@ def __init__(

def estimator(
self, tensor: StochasticTensor, cost: CostTensor
) -> Tuple[
Optional[storch.Tensor], Optional[storch.Tensor]
]:
) -> Tuple[Optional[storch.Tensor], Optional[storch.Tensor]]:
log_prob = tensor.distribution.log_prob(tensor)
if len(log_prob.shape) > tensor.plate_dims:
# Sum out over the event shape
Expand All @@ -568,7 +563,7 @@ def estimator(
setattr(self, baseline_name, self.baseline_factory(tensor, cost))
baseline = getattr(self, baseline_name)
baseline = baseline.compute_baseline(tensor, cost)
return log_prob, (1.0 - log_prob) * baseline
return log_prob, (1.0 - magic_box(log_prob)) * baseline
return log_prob, None


Expand Down
1 change: 1 addition & 0 deletions storch/sampling/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def mc_sample(
plates: [Plate],
amt_samples: int,
) -> torch.Tensor:
# TODO: Why does this ignore amt_samples?
return distr.sample((self.n_samples,))

def set_mc_sample(
Expand Down

0 comments on commit f9add03

Please sign in to comment.