Skip to content

Add functional normalizations #707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Native/LibTorchSharp/THSNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ EXPORT_API(void) THSNN_GroupNorm_set_weight(const NNModule module, const Ten
EXPORT_API(NNModule) THSNN_LocalResponseNorm_ctor(const int64_t size, const double alpha, const double beta, const double k, NNAnyModule* outAsAnyModule);
EXPORT_API(Tensor) THSNN_LocalResponseNorm_forward(const NNModule module, const Tensor tensor);

EXPORT_API(Tensor) THSNN_batch_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool training, const double momentum, const double eps);
EXPORT_API(Tensor) THSNN_group_norm(const Tensor input, int64_t num_groups, const Tensor weight, const Tensor bias, const double eps);
EXPORT_API(Tensor) THSNN_instance_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool use_input_stats, const double momentum, const double eps);
EXPORT_API(Tensor) THSNN_layer_norm(const Tensor input, const int64_t* normalized_shape, const int64_t normalized_shape_len, const Tensor weight, const Tensor bias, const double eps);
EXPORT_API(Tensor) THSNN_local_response_norm(const Tensor input, const int64_t size, const double alpha, const double beta, const double k);

// Dropout

EXPORT_API(NNModule) THSNN_Dropout_ctor(double probability, bool inplace, NNAnyModule* outAsAnyModule);
Expand Down
52 changes: 52 additions & 0 deletions src/Native/LibTorchSharp/THSNormalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,55 @@ void THSNN_BatchNorm3d_set_weight(const NNModule module, const Tensor weight)
{
set_weight<torch::nn::BatchNorm3d>(module, weight);
}

Tensor THSNN_batch_norm(const Tensor input, Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool training, const double momentum, const double eps)
{
auto opts = torch::nn::functional::BatchNormFuncOptions()
.training(training)
.momentum(momentum)
.eps(eps);
if (weight != nullptr) opts.weight(*weight);
if (bias != nullptr) opts.bias(*bias);
CATCH_TENSOR(torch::nn::functional::batch_norm(*input, *running_mean, *running_var, opts));
}

Tensor THSNN_group_norm(const Tensor input, const int64_t num_groups, const Tensor weight, const Tensor bias, const double eps)
{
auto opts = torch::nn::functional::GroupNormFuncOptions(num_groups)
.eps(eps);
if (weight != nullptr) opts.weight(*weight);
if (bias != nullptr) opts.bias(*bias);
CATCH_TENSOR(torch::nn::functional::group_norm(*input, opts));
}

Tensor THSNN_instance_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool use_input_stats, const double momentum, const double eps)
{
auto opts = torch::nn::functional::InstanceNormFuncOptions()
.use_input_stats(use_input_stats)
.momentum(momentum)
.eps(eps);
if (running_mean != nullptr) opts.running_mean(*running_mean);
if (running_var != nullptr) opts.running_var(*running_var);
if (weight != nullptr) opts.weight(*weight);
if (bias != nullptr) opts.bias(*bias);
CATCH_TENSOR(torch::nn::functional::instance_norm(*input, opts));
}

Tensor THSNN_layer_norm(const Tensor input, const int64_t* normalized_shape, const int64_t normalized_shape_len, const Tensor weight, const Tensor bias, const double eps)
{
auto opts = torch::nn::functional::LayerNormFuncOptions(
std::vector<int64_t>(normalized_shape, normalized_shape + normalized_shape_len))
.eps(eps);
if (weight != nullptr) opts.weight(*weight);
if (bias != nullptr) opts.bias(*bias);
CATCH_TENSOR(torch::nn::functional::layer_norm(*input, opts));
}

Tensor THSNN_local_response_norm(const Tensor input, const int64_t size, const double alpha, const double beta, const double k)
{
auto opts = torch::nn::functional::LocalResponseNormFuncOptions(size)
.alpha(alpha)
.beta(beta)
.k(k);
CATCH_TENSOR(torch::nn::functional::local_response_norm(*input, opts));
}
115 changes: 115 additions & 0 deletions src/TorchSharp/NN/Normalization/Functional.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using System.Runtime.InteropServices;

namespace TorchSharp
{
public static partial class torch
{
public static partial class nn
{
public static partial class functional
{
[DllImport("LibTorchSharp")]
extern static IntPtr THSNN_batch_norm(IntPtr input, IntPtr running_mean, IntPtr running_var, IntPtr weight, IntPtr bias, bool training, double momentum, double eps);

/// <summary>
/// Applies Batch Normalization for each channel across a batch of data.
/// </summary>
public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor running_var, Tensor weight = null, Tensor bias = null, bool training = false, double momentum = 0.1, double eps = 1e-5)
{
var res = THSNN_batch_norm(
input.Handle,
running_mean.Handle,
running_var.Handle,
weight is not null ? weight.Handle : IntPtr.Zero,
bias is not null ? bias.Handle : IntPtr.Zero,
training,
momentum, eps);
if (res == IntPtr.Zero)
torch.CheckForErrors();
return new Tensor(res);
}

[DllImport("LibTorchSharp")]
extern static IntPtr THSNN_group_norm(IntPtr input, long num_groups, IntPtr weight, IntPtr bias, double eps);

/// <summary>
/// Applies Group Normalization for last certain number of dimensions.
/// </summary>
public static Tensor group_norm(Tensor input, long num_groups, Tensor weight = null, Tensor bias = null, double eps = 1e-5)
{
var res = THSNN_group_norm(
input.Handle,
num_groups,
weight is not null ? weight.Handle : IntPtr.Zero,
bias is not null ? bias.Handle : IntPtr.Zero,
eps);
if (res == IntPtr.Zero)
torch.CheckForErrors();
return new Tensor(res);
}

[DllImport("LibTorchSharp")]
extern static IntPtr THSNN_instance_norm(IntPtr input, IntPtr running_mean, IntPtr running_var, IntPtr weight, IntPtr bias, bool use_input_stats, double momentum, double eps);

/// <summary>
/// Applies Instance Normalization for each channel in each data sample in a batch.
/// </summary>
public static Tensor instance_norm(Tensor input, Tensor running_mean = null, Tensor running_var = null, Tensor weight = null, Tensor bias = null, bool use_input_stats = true, double momentum = 0.1, double eps = 1e-5)
{
var res = THSNN_instance_norm(
input.Handle,
running_mean is not null ? running_mean.Handle : IntPtr.Zero,
running_var is not null ? running_var.Handle : IntPtr.Zero,
weight is not null ? weight.Handle : IntPtr.Zero,
bias is not null ? bias.Handle : IntPtr.Zero,
use_input_stats,
momentum, eps);
if (res == IntPtr.Zero)
torch.CheckForErrors();
return new Tensor(res);
}

[DllImport("LibTorchSharp")]
extern static unsafe IntPtr THSNN_layer_norm(IntPtr input, long* normalized_shape, long normalized_shape_len, IntPtr weight, IntPtr bias, double eps);

/// <summary>
/// Applies Layer Normalization for last certain number of dimensions.
/// </summary>
public static Tensor layer_norm(Tensor input, long[] normalized_shape, Tensor weight = null, Tensor bias = null, double eps = 1e-5)
{
IntPtr res;
unsafe {
fixed (long* normalized_shape_ptr = normalized_shape) {
res = THSNN_layer_norm(
input.Handle,
normalized_shape_ptr,
normalized_shape.LongLength,
weight is not null ? weight.Handle : IntPtr.Zero,
bias is not null ? bias.Handle : IntPtr.Zero,
eps);
}
}
if (res == IntPtr.Zero)
torch.CheckForErrors();
return new Tensor(res);
}

[DllImport("LibTorchSharp")]
extern static IntPtr THSNN_local_response_norm(IntPtr input, long size, double alpha, double beta, double k);

/// <summary>
/// Applies Local Normalization.
/// </summary>
public static Tensor local_response_norm(Tensor input, long size, double alpha = 0.0001, double beta = 0.75, double k = 1.0)
{
var res = THSNN_local_response_norm(input.Handle, size, alpha, beta, k);
if (res == IntPtr.Zero)
torch.CheckForErrors();
return new Tensor(res);
}
}
}
}
}
2 changes: 1 addition & 1 deletion src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7091,7 +7091,6 @@ public Tensor atleast_3d()
public Tensor stft(long n_fft, long hop_length = -1, long win_length = -1, Tensor? window = null, bool center = true, PaddingModes pad_mode = PaddingModes.Reflect, bool normalized = false, bool? onesided = null, bool? return_complex = null)
{
IntPtr _input = Handle;
IntPtr _window = (window is null) ? IntPtr.Zero : window.Handle;

long _onesided = -1; // encoding of null
if (onesided.HasValue) {
Expand All @@ -7116,6 +7115,7 @@ public Tensor stft(long n_fft, long hop_length = -1, long win_length = -1, Tenso
}
}

IntPtr _window = (window is null) ? IntPtr.Zero : window.Handle;
var res = THSTensor_stft(_input, n_fft, hop_length, win_length, _window, normalized, _onesided, _return_complex);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
Expand Down
5 changes: 5 additions & 0 deletions src/TorchSharp/Tensor/Tensor.torch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ public static Tensor cat(IList<Tensor> tensors, long dimension)
}
}

/// <summary>
/// Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor.
/// </summary>
public static Tensor index_select(Tensor input, long dim, Tensor index) => input.index_select(dim, index);

/// <summary>
/// Roll the tensor along the given dimension(s).
/// Elements that are shifted beyond the last position are re-introduced at the first position.
Expand Down
94 changes: 94 additions & 0 deletions test/TorchSharpTest/NN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2836,6 +2836,100 @@ public void TestGroupNorm()
Assert.Throws<ArgumentException>(() => pool.forward(torch.ones(new long[] { 2, 2 })));
}
}

private Tensor NormalizeTensor(Tensor x, long[] dim, double eps = 1e-5)
{
var x_mean = torch.mean(x, dimensions: dim, keepDimension: true);
var x_var = torch.var(x, unbiased: false, dimensions: dim, keepDimension: true);
return NormalizeTensor(x, x_mean, x_var, eps);
}

private Tensor NormalizeTensor(Tensor x, Tensor x_mean, Tensor x_var, double eps = 1e-5)
{
return (x - x_mean) / torch.sqrt(eps + x_var);
}

[Fact]
public void TestBatchNormFunc()
{
var x = torch.randn(3, 2, 4);
var running_mean = torch.randn(2);
var running_var = torch.square(torch.randn(2));
var y = torch.nn.functional.batch_norm(x, running_mean, running_var);
var z = NormalizeTensor(x, torch.unsqueeze(running_mean, 1), torch.unsqueeze(running_var, 1));
Assert.InRange(torch.mean(torch.square(z - y)).item<float>(), 0, 1e-5);

var weight = torch.randn(2);
var bias = torch.randn(2);
y = torch.nn.functional.batch_norm(x, running_mean, running_var, weight, bias);
z = torch.unsqueeze(weight, 1) * z + torch.unsqueeze(bias, 1);
Assert.InRange(torch.mean(torch.square(z - y)).item<float>(), 0, 1e-5);

y = torch.nn.functional.batch_norm(x, running_mean, running_var, weight, bias, training: true);
Assert.Equal(x.shape, y.shape);
}

[Fact]
public void TestGroupNormFunc()
{
var x = torch.randn(3, 12, 5);
var y = torch.nn.functional.group_norm(x, 4);
y = y[TensorIndex.Colon, TensorIndex.Slice(3, 6)];
var z = NormalizeTensor(x[TensorIndex.Colon, TensorIndex.Slice(3, 6)], new long[] { 1, 2 });
Assert.InRange(torch.mean(torch.square(z - y)).item<float>(), 0, 1e-5);

var weight = torch.randn(12);
var bias = torch.randn(12);
y = torch.nn.functional.group_norm(x, 4, weight, bias);
y = y[TensorIndex.Colon, TensorIndex.Slice(3, 6)];
z = weight[TensorIndex.Slice(3, 6), TensorIndex.None] * z + bias[TensorIndex.Slice(3, 6), TensorIndex.None];
Assert.InRange(torch.mean(torch.square(z - y)).item<float>(), 0, 1e-5);
}

[Fact]
public void TestInstanceNormFunc()
{
var x = torch.randn(3, 2, 5);
var y = torch.nn.functional.instance_norm(x);
var z = NormalizeTensor(x, new long[] { 2 });
Assert.InRange(torch.mean(torch.square(z - y)).item<float>(), 0, 1e-5);

var running_mean = torch.randn(2);
var running_var = torch.square(torch.randn(2));
y = torch.nn.functional.instance_norm(x, running_mean, running_var);
Assert.InRange(torch.mean(torch.square(z - y)).item<float>(), 0, 1e-5);

var weight = torch.randn(2);
var bias = torch.randn(2);
y = torch.nn.functional.instance_norm(x, running_mean, running_var, weight, bias);
z = torch.unsqueeze(weight, 1) * z + torch.unsqueeze(bias, 1);
Assert.InRange(torch.mean(torch.square(z - y)).item<float>(), 0, 1e-5);
}

[Fact]
public void TestLayerNormFunc()
{
var x = torch.randn(3, 5, 12);
var y = torch.nn.functional.layer_norm(x, new long[] { 12 });
var z = NormalizeTensor(x, new long[] { 2 });
Assert.InRange(torch.mean(torch.square(z - y)).item<float>(), 0, 1e-5);

var weight = torch.randn(12);
var bias = torch.randn(12);
y = torch.nn.functional.layer_norm(x, new long[] { 12 }, weight, bias);
z = weight * z + bias;
Assert.InRange(torch.mean(torch.square(z - y)).item<float>(), 0, 1e-5);
}

[Fact]
public void TestLocalResponseNormFunc()
{
var x = torch.randn(3, 6, 4);
var y = torch.nn.functional.local_response_norm(x, 5, alpha: 0.5);
y = y[TensorIndex.Colon, 3];
var z = x[TensorIndex.Colon, 3] * torch.pow(torch.square(x[TensorIndex.Colon, TensorIndex.Slice(1, 6)]).sum(dim: 1) * 0.5 / 5 + 1, torch.tensor(-0.75f));
Assert.InRange(torch.mean(torch.square(z - y)).item<float>(), 0, 1e-5);
}
#endregion

#region Embedding, Encoding, Transformer
Expand Down