Skip to content

Adding torch.nn.functional.normalize #1405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ The argument defaults for `torch.diagonal()` and `Tensor.diagonal()` arguments h

__Bug Fixes__:

#1400 There may be an error in torchvision.transforms.GaussianBlur
#1402 diagonal() has incorrect default
#1400 There may be an error in torchvision.transforms.GaussianBlur<br/>
#1402 diagonal() has incorrect default<br/>

__API Changes__:

#1382: Add support for torch.nn.functional.normalize<br/>

# NuGet Version 0.103.1

Expand Down
27 changes: 13 additions & 14 deletions TorchSharp.sln
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "TorchSharp", "TorchSharp",
pkg\TorchSharp\TorchSharp.symbols.nupkgproj = pkg\TorchSharp\TorchSharp.symbols.nupkgproj
EndProjectSection
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibTorchSharp", "bin\obj\x64.Debug\Native\LibTorchSharp\LibTorchSharp.vcxproj", "{2B359162-062E-3C52-91D3-027A8542A58C}"
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibTorchSharp", "bin\obj\x64.Debug\Native\LibTorchSharp\LibTorchSharp.vcxproj", "{CAD9DB7F-3223-3324-884D-FA2381593DA7}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibTorchSharp", "bin\obj\x64.Release\Native\LibTorchSharp\LibTorchSharp.vcxproj", "{E4C0DBEE-0815-311B-9065-137BB50BD793}"
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibTorchSharp", "bin\obj\x64.Release\Native\LibTorchSharp\LibTorchSharp.vcxproj", "{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Native-Debug", "Native-Debug", "{CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D}"
ProjectSection(SolutionItems) = preProject
Expand Down Expand Up @@ -66,9 +66,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution
azure-pipelines.yml = azure-pipelines.yml
build\BranchInfo.props = build\BranchInfo.props
DEVGUIDE.md = DEVGUIDE.md
global.json = global.json
README.md = README.md
RELEASENOTES.md = RELEASENOTES.md
global.json = global.json
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TorchVision", "src\TorchVision\TorchVision.csproj", "{DCF01EE5-6431-4115-85E0-1FC4C3DE86A2}"
Expand Down Expand Up @@ -107,14 +107,14 @@ Global
{42B45168-476D-4BFA-87B8-81A34E6295CD}.Release|Any CPU.Build.0 = Release|Any CPU
{42B45168-476D-4BFA-87B8-81A34E6295CD}.Release|x64.ActiveCfg = Release|Any CPU
{42B45168-476D-4BFA-87B8-81A34E6295CD}.Release|x64.Build.0 = Release|Any CPU
{2B359162-062E-3C52-91D3-027A8542A58C}.Debug|Any CPU.ActiveCfg = Debug|x64
{2B359162-062E-3C52-91D3-027A8542A58C}.Debug|x64.ActiveCfg = Debug|x64
{2B359162-062E-3C52-91D3-027A8542A58C}.Release|Any CPU.ActiveCfg = Release|x64
{2B359162-062E-3C52-91D3-027A8542A58C}.Release|x64.ActiveCfg = Release|x64
{E4C0DBEE-0815-311B-9065-137BB50BD793}.Debug|Any CPU.ActiveCfg = Debug|x64
{E4C0DBEE-0815-311B-9065-137BB50BD793}.Debug|x64.ActiveCfg = Debug|x64
{E4C0DBEE-0815-311B-9065-137BB50BD793}.Release|Any CPU.ActiveCfg = Release|x64
{E4C0DBEE-0815-311B-9065-137BB50BD793}.Release|x64.ActiveCfg = Release|x64
{CAD9DB7F-3223-3324-884D-FA2381593DA7}.Debug|Any CPU.ActiveCfg = Debug|x64
{CAD9DB7F-3223-3324-884D-FA2381593DA7}.Debug|x64.ActiveCfg = Debug|x64
{CAD9DB7F-3223-3324-884D-FA2381593DA7}.Release|Any CPU.ActiveCfg = Release|x64
{CAD9DB7F-3223-3324-884D-FA2381593DA7}.Release|x64.ActiveCfg = Release|x64
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Debug|Any CPU.ActiveCfg = Debug|x64
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Debug|x64.ActiveCfg = Debug|x64
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Release|Any CPU.ActiveCfg = Release|x64
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Release|x64.ActiveCfg = Release|x64
{DD652544-711E-4029-83FF-DA4A9600E6E7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{DD652544-711E-4029-83FF-DA4A9600E6E7}.Debug|Any CPU.Build.0 = Debug|Any CPU
{DD652544-711E-4029-83FF-DA4A9600E6E7}.Debug|x64.ActiveCfg = Debug|Any CPU
Expand Down Expand Up @@ -148,7 +148,6 @@ Global
{95493944-D1AE-414E-964B-B58AEAE672E5}.Release|x64.ActiveCfg = Release|Any CPU
{95493944-D1AE-414E-964B-B58AEAE672E5}.Release|x64.Build.0 = Release|Any CPU
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|x64.ActiveCfg = Debug|Any CPU
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|x64.Build.0 = Debug|Any CPU
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Release|Any CPU.ActiveCfg = Release|Any CPU
Expand Down Expand Up @@ -181,8 +180,8 @@ Global
{6C323B05-9028-4B09-911C-3C03AE058BEE} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{42B45168-476D-4BFA-87B8-81A34E6295CD} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{567456AD-B026-4CB6-B98D-4FC930C90223} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{2B359162-062E-3C52-91D3-027A8542A58C} = {CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D}
{E4C0DBEE-0815-311B-9065-137BB50BD793} = {4DB9E84D-324C-408F-87A6-246E86205540}
{CAD9DB7F-3223-3324-884D-FA2381593DA7} = {CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D}
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB} = {4DB9E84D-324C-408F-87A6-246E86205540}
{CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{D8C60CD8-8429-45F2-A755-47B6CD10FDF8} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{4DB9E84D-324C-408F-87A6-246E86205540} = {CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D}
Expand Down
1 change: 1 addition & 0 deletions src/Native/LibTorchSharp/THSNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ EXPORT_API(void) THSNN_GroupNorm_set_weight(const NNModule module, const Ten
EXPORT_API(NNModule) THSNN_LocalResponseNorm_ctor(const int64_t size, const double alpha, const double beta, const double k, NNAnyModule* outAsAnyModule);
EXPORT_API(Tensor) THSNN_LocalResponseNorm_forward(const NNModule module, const Tensor tensor);

EXPORT_API(Tensor) THSNN_normalize(const Tensor input, const double p, const int64_t dim, const double eps);
EXPORT_API(Tensor) THSNN_batch_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool training, const double momentum, const double eps);
EXPORT_API(Tensor) THSNN_group_norm(const Tensor input, int64_t num_groups, const Tensor weight, const Tensor bias, const double eps);
EXPORT_API(Tensor) THSNN_instance_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool use_input_stats, const double momentum, const double eps);
Expand Down
9 changes: 9 additions & 0 deletions src/Native/LibTorchSharp/THSNormalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,15 @@ Tensor THSNN_batch_norm(const Tensor input, Tensor running_mean, const Tensor ru
CATCH_TENSOR(torch::nn::functional::batch_norm(*input, *running_mean, *running_var, opts));
}

Tensor THSNN_normalize(const Tensor input, const double p, const int64_t dim, const double eps)
{
auto opts = torch::nn::functional::NormalizeFuncOptions()
.p(p)
.dim(dim)
.eps(eps);
CATCH_TENSOR(torch::nn::functional::normalize(*input, opts));
}

Tensor THSNN_group_norm(const Tensor input, const int64_t num_groups, const Tensor weight, const Tensor bias, const double eps)
{
auto opts = torch::nn::functional::GroupNormFuncOptions(num_groups)
Expand Down
26 changes: 22 additions & 4 deletions src/TorchSharp/NN/Normalization/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System;
using static TorchSharp.PInvoke.NativeMethods;

#nullable enable
namespace TorchSharp
{
public static partial class torch
Expand All @@ -10,10 +11,27 @@ public static partial class nn
{
public static partial class functional
{
/// <summary>
/// Perform normalization of inputs over specified dimension.
/// </summary>
/// <param name="input">Input tensor of any shape.</param>
/// <param name="p">the exponent value in the norm formulation</param>
/// <param name="dim">the dimension to reduce</param>
/// <param name="eps">small value to avoid division by zero</param>
public static Tensor normalize(Tensor input, double p = 2.0, long dim = 1L, double eps = 1e-12)
{
var res = THSNN_normalize(
input.Handle,
p, dim, eps);
if (res == IntPtr.Zero)
torch.CheckForErrors();
return new Tensor(res);
}

/// <summary>
/// Applies Batch Normalization for each channel across a batch of data.
/// </summary>
public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor running_var, Tensor weight = null, Tensor bias = null, bool training = false, double momentum = 0.1, double eps = 1e-5)
public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor running_var, Tensor? weight = null, Tensor? bias = null, bool training = false, double momentum = 0.1, double eps = 1e-5)
{
var res = THSNN_batch_norm(
input.Handle,
Expand All @@ -31,7 +49,7 @@ public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor runnin
/// <summary>
/// Applies Group Normalization for last certain number of dimensions.
/// </summary>
public static Tensor group_norm(Tensor input, long num_groups, Tensor weight = null, Tensor bias = null, double eps = 1e-5)
public static Tensor group_norm(Tensor input, long num_groups, Tensor? weight = null, Tensor? bias = null, double eps = 1e-5)
{
var res = THSNN_group_norm(
input.Handle,
Expand All @@ -47,7 +65,7 @@ public static Tensor group_norm(Tensor input, long num_groups, Tensor weight = n
/// <summary>
/// Applies Instance Normalization for each channel in each data sample in a batch.
/// </summary>
public static Tensor instance_norm(Tensor input, Tensor running_mean = null, Tensor running_var = null, Tensor weight = null, Tensor bias = null, bool use_input_stats = true, double momentum = 0.1, double eps = 1e-5)
public static Tensor instance_norm(Tensor input, Tensor? running_mean = null, Tensor? running_var = null, Tensor? weight = null, Tensor? bias = null, bool use_input_stats = true, double momentum = 0.1, double eps = 1e-5)
{
var res = THSNN_instance_norm(
input.Handle,
Expand All @@ -65,7 +83,7 @@ public static Tensor instance_norm(Tensor input, Tensor running_mean = null, Ten
/// <summary>
/// Applies Layer Normalization for last certain number of dimensions.
/// </summary>
public static Tensor layer_norm(Tensor input, long[] normalized_shape, Tensor weight = null, Tensor bias = null, double eps = 1e-5)
public static Tensor layer_norm(Tensor input, long[] normalized_shape, Tensor? weight = null, Tensor? bias = null, double eps = 1e-5)
{
IntPtr res;
unsafe {
Expand Down
3 changes: 3 additions & 0 deletions src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,9 @@ internal static extern IntPtr THSNN_custom_module(
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_Unflatten_ctor(long dim, IntPtr shape, long shape_len, out IntPtr pBoxedModule);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_normalize(IntPtr input, double p, long dim, double eps);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_batch_norm(IntPtr input, IntPtr running_mean, IntPtr running_var, IntPtr weight, IntPtr bias, [MarshalAs(UnmanagedType.U1)] bool training, double momentum, double eps);

Expand Down
32 changes: 31 additions & 1 deletion test/TorchSharpTest/NN.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using System.Linq;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -4817,6 +4817,36 @@ private Tensor NormalizeTensor(Tensor x, Tensor x_mean, Tensor x_var, double eps
return (x - x_mean) / torch.sqrt(eps + x_var);
}

[Fact]
public void TestNormalizeFunc()
{
foreach (var device in TestUtils.AvailableDevices()) {
var x = torch.from_array(new double[]
{ -1.0786, 0.3455, 1.2929, 0.5030,
-0.2930, 1.0420, -0.1082, -0.2943,
-0.3989, -0.8311, 0.7103, -1.5878,
0.6331, 1.0106, 0.5128, -2.2565,
1.2044, -0.6916, -0.1242, 0.6808,
0.1672, 0.1105, -1.7364, 0.0669
}).reshape(3,2,4);
var y = torch.nn.functional.normalize(x);
Assert.Equal(x.shape, y.shape);
Assert.Equal(x.device_type, y.device_type);

var expected = torch.from_array(new double[]
{ -0.9650, 0.3147, 0.9965, 0.8631,
-0.2621, 0.9492, -0.0834, -0.5050,
-0.5331, -0.6352, 0.8108, -0.5755,
0.8460, 0.7724, 0.5853, -0.8178,
0.9905, -0.9875, -0.0713, 0.9952,
0.1375, 0.1577, -0.9975, 0.0978
}).reshape(3, 2, 4);


Assert.True(y.allclose(expected, rtol: 0.005, atol: 0.005));
}
}

[Fact]
public void TestBatchNormFunc()
{
Expand Down