Skip to content

Commit

Permalink
Merge pull request #11 from uofgravity/add-scale-activation-to-masked…
Browse files Browse the repository at this point in the history
…-affine

ENH: add `scale_activation` to `MaskedAffineAutoregressiveTransform`
  • Loading branch information
mj-will authored Jun 7, 2024
2 parents df47ee5 + a2d9870 commit 4c1bb65
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
fail-fast: false
matrix:
os: [macOS, Ubuntu, Windows]
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
runs-on: ${{ matrix.os }}-latest

steps:
Expand Down
13 changes: 9 additions & 4 deletions nflows/transforms/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
activation=F.relu,
dropout_probability=0.0,
use_batch_norm=False,
scale_activation=None,
):
self.features = features
made = made_module.MADE(
Expand All @@ -87,6 +88,12 @@ def __init__(
use_batch_norm=use_batch_norm,
)
self._epsilon = 1e-3

if scale_activation is None:
self.scale_activation = lambda x: F.softplus(x) + self._epsilon
else:
self.scale_activation = scale_activation

super(MaskedAffineAutoregressiveTransform, self).__init__(made)

def _output_dim_multiplier(self):
Expand All @@ -96,8 +103,7 @@ def _elementwise_forward(self, inputs, autoregressive_params):
unconstrained_scale, shift = self._unconstrained_scale_and_shift(
autoregressive_params
)
# scale = torch.sigmoid(unconstrained_scale + 2.0) + self._epsilon
scale = F.softplus(unconstrained_scale) + self._epsilon
scale = self.scale_activation(unconstrained_scale)
log_scale = torch.log(scale)
outputs = scale * inputs + shift
logabsdet = torchutils.sum_except_batch(log_scale, num_batch_dims=1)
Expand All @@ -107,8 +113,7 @@ def _elementwise_inverse(self, inputs, autoregressive_params):
unconstrained_scale, shift = self._unconstrained_scale_and_shift(
autoregressive_params
)
# scale = torch.sigmoid(unconstrained_scale + 2.0) + self._epsilon
scale = F.softplus(unconstrained_scale) + self._epsilon
scale = self.scale_activation(unconstrained_scale)
log_scale = torch.log(scale)
outputs = (inputs - shift) / scale
logabsdet = -torchutils.sum_except_batch(log_scale, num_batch_dims=1)
Expand Down

0 comments on commit 4c1bb65

Please sign in to comment.