From c17a8cf97cfc16a31b6957c0e342123acbba896b Mon Sep 17 00:00:00 2001 From: w4123 <1840686745@qq.com> Date: Wed, 24 Aug 2022 14:03:59 +0800 Subject: [PATCH] Fix ONNX generation --- attentions.py | 9 +++------ transforms.py | 8 ++++---- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/attentions.py b/attentions.py index 4e0b0c1f..0eeea464 100644 --- a/attentions.py +++ b/attentions.py @@ -199,15 +199,12 @@ def _matmul_with_relative_keys(self, x, y): def _get_relative_embeddings(self, relative_embeddings, length): max_relative_position = 2 * self.window_size + 1 # Pad first before slice to avoid using cond ops. - pad_length = max(length - (self.window_size + 1), 0) - slice_start_position = max((self.window_size + 1) - length, 0) + pad_length = torch.clamp_min(length - (self.window_size + 1), 0) + slice_start_position = torch.clamp_min((self.window_size + 1) - length, 0) slice_end_position = slice_start_position + 2 * length - 1 - if pad_length > 0: - padded_relative_embeddings = F.pad( + padded_relative_embeddings = F.pad( relative_embeddings, commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) - else: - padded_relative_embeddings = relative_embeddings used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] return used_relative_embeddings diff --git a/transforms.py b/transforms.py index 4793d67c..b42f5fb4 100644 --- a/transforms.py +++ b/transforms.py @@ -45,7 +45,7 @@ def piecewise_rational_quadratic_transform(inputs, def searchsorted(bin_locations, inputs, eps=1e-6): - bin_locations[..., -1] += eps + bin_locations[..., bin_locations.size(-1)-1] += eps return torch.sum( inputs[..., None] >= bin_locations, dim=-1 @@ -72,7 +72,7 @@ def unconstrained_rational_quadratic_spline(inputs, unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) constant = np.log(np.exp(1 - min_derivative) - 1) unnormalized_derivatives[..., 0] = constant - unnormalized_derivatives[..., -1] = constant + unnormalized_derivatives[..., unnormalized_derivatives.size(-1)-1] = constant outputs[outside_interval_mask] = inputs[outside_interval_mask] logabsdet[outside_interval_mask] = 0 @@ -118,7 +118,7 @@ def rational_quadratic_spline(inputs, cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) cumwidths = (right - left) * cumwidths + left cumwidths[..., 0] = left - cumwidths[..., -1] = right + cumwidths[..., cumwidths.size(-1)-1] = right widths = cumwidths[..., 1:] - cumwidths[..., :-1] derivatives = min_derivative + F.softplus(unnormalized_derivatives) @@ -129,7 +129,7 @@ def rational_quadratic_spline(inputs, cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) cumheights = (top - bottom) * cumheights + bottom cumheights[..., 0] = bottom - cumheights[..., -1] = top + cumheights[..., cumheights.size(-1)-1] = top heights = cumheights[..., 1:] - cumheights[..., :-1] if inverse: