Skip to content

Commit

Permalink
Fix ONNX generation
Browse files Browse the repository at this point in the history
  • Loading branch information
w4123 committed Aug 24, 2022
1 parent 6698218 commit c17a8cf
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
9 changes: 3 additions & 6 deletions attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,12 @@ def _matmul_with_relative_keys(self, x, y):
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
pad_length = torch.clamp_min(length - (self.window_size + 1), 0)
slice_start_position = torch.clamp_min((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
return used_relative_embeddings

Expand Down
8 changes: 4 additions & 4 deletions transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def piecewise_rational_quadratic_transform(inputs,


def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
bin_locations[..., bin_locations.size(-1)-1] += eps
return torch.sum(
inputs[..., None] >= bin_locations,
dim=-1
Expand All @@ -72,7 +72,7 @@ def unconstrained_rational_quadratic_spline(inputs,
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
unnormalized_derivatives[..., unnormalized_derivatives.size(-1)-1] = constant

outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
Expand Down Expand Up @@ -118,7 +118,7 @@ def rational_quadratic_spline(inputs,
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
cumwidths[..., cumwidths.size(-1)-1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]

derivatives = min_derivative + F.softplus(unnormalized_derivatives)
Expand All @@ -129,7 +129,7 @@ def rational_quadratic_spline(inputs,
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
cumheights[..., cumheights.size(-1)-1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]

if inverse:
Expand Down

0 comments on commit c17a8cf

Please sign in to comment.