forked from dotnet/TorchSharp
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathNormalize.cs
84 lines (74 loc) · 3.55 KB
/
Normalize.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using static TorchSharp.torch;
#nullable enable
namespace TorchSharp
{
public static partial class torchvision
{
internal class Normalize : ITransform, IDisposable
{
internal Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null)
{
if (means is null) throw new ArgumentNullException(nameof(means));
if (stdevs is null) throw new ArgumentNullException(nameof(stdevs));
if (means.Length != stdevs.Length)
throw new ArgumentException($"{nameof(means)} and {nameof(stdevs)} must be the same length in call to Normalize");
if (means.Length != 1 && means.Length != 3)
throw new ArgumentException($"Since they correspond to the number of channels in an image, {nameof(means)} and {nameof(stdevs)} must both be either 1 or 3 long");
this.means = means.ToTensor(new long[] { 1, means.Length, 1, 1 }); // Assumes NxCxHxW
this.stdevs = stdevs.ToTensor(new long[] { 1, stdevs.Length, 1, 1 }); // Assumes NxCxHxW
if (dtype != ScalarType.Float64) {
this.means = this.means.to_type(dtype);
this.stdevs = this.stdevs.to_type(dtype);
}
if (device != null && device.type != DeviceType.CPU) {
this.means = this.means.to(device);
this.stdevs = this.stdevs.to(device);
}
}
public Tensor call(Tensor input)
{
if (means.size(1) != input.size(1)) throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");
return (input - means) / stdevs;
}
private Tensor means;
private Tensor stdevs;
bool disposedValue;
protected virtual void Dispose(bool disposing)
{
if (!disposedValue) {
means?.Dispose();
stdevs?.Dispose();
disposedValue = true;
}
}
~Normalize()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: false);
}
public void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}
public static partial class transforms
{
/// <summary>
/// Normalize a float tensor image with mean and standard deviation.
/// </summary>
/// <param name="means">Sequence of means for each channel.</param>
/// <param name="stdevs">Sequence of standard deviations for each channel.</param>
/// <param name="dtype">Bool to make this operation inplace.</param>
/// <param name="device">The device to place the output tensor on.</param>
/// <returns></returns>
static public ITransform Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null)
{
return new Normalize(means, stdevs, dtype, device);
}
}
}
}