Skip to content

Commit 3688b59

Browse files
Use IRootFunctions in Tensor.StdDev (#110641)
* Use IRootFunctions in Tensor.StdDev * Try fix api compat issue * Replace Pow+Sum with SumOfSquares * Drop IPowerFunctions constraint * Try fix compatability supression * Fix StdDev stride issue * Add regression test * fix test * Use FlattenedLength in StdDev * Try add byref to api compat supressions * Try fix type constraint --------- Co-authored-by: Tanner Gooding <tagoo@outlook.com>
1 parent 7de730a commit 3688b59

File tree

4 files changed

+44
-19
lines changed

4 files changed

+44
-19
lines changed

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ public static void ResizeTo<T>(scoped in System.Numerics.Tensors.Tensor<T> tenso
487487
public static ref readonly System.Numerics.Tensors.TensorSpan<T> StackAlongDimension<T>(scoped System.ReadOnlySpan<System.Numerics.Tensors.Tensor<T>> tensors, in System.Numerics.Tensors.TensorSpan<T> destination, int dimension) { throw null; }
488488
public static System.Numerics.Tensors.Tensor<T> Stack<T>(params scoped System.ReadOnlySpan<System.Numerics.Tensors.Tensor<T>> tensors) { throw null; }
489489
public static ref readonly System.Numerics.Tensors.TensorSpan<T> Stack<T>(scoped in System.ReadOnlySpan<System.Numerics.Tensors.Tensor<T>> tensors, in System.Numerics.Tensors.TensorSpan<T> destination) { throw null; }
490-
public static T StdDev<T>(in System.Numerics.Tensors.ReadOnlyTensorSpan<T> x) where T : System.Numerics.IFloatingPoint<T>, System.Numerics.IPowerFunctions<T>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
490+
public static T StdDev<T>(in System.Numerics.Tensors.ReadOnlyTensorSpan<T> x) where T : System.Numerics.IFloatingPoint<T>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T>, System.Numerics.IRootFunctions<T> { throw null; }
491491
public static System.Numerics.Tensors.Tensor<T> Subtract<T>(in System.Numerics.Tensors.ReadOnlyTensorSpan<T> x, in System.Numerics.Tensors.ReadOnlyTensorSpan<T> y) where T : System.Numerics.ISubtractionOperators<T, T, T> { throw null; }
492492
public static ref readonly System.Numerics.Tensors.TensorSpan<T> Subtract<T>(scoped in System.Numerics.Tensors.ReadOnlyTensorSpan<T> x, scoped in System.Numerics.Tensors.ReadOnlyTensorSpan<T> y, in System.Numerics.Tensors.TensorSpan<T> destination) where T : System.Numerics.ISubtractionOperators<T, T, T> { throw null; }
493493
public static System.Numerics.Tensors.Tensor<T> Subtract<T>(in System.Numerics.Tensors.ReadOnlyTensorSpan<T> x, T y) where T : System.Numerics.ISubtractionOperators<T, T, T> { throw null; }

src/libraries/System.Numerics.Tensors/src/CompatibilitySuppressions.xml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,4 +253,18 @@
253253
<Right>lib/net9.0/System.Numerics.Tensors.dll</Right>
254254
<IsBaselineSuppression>true</IsBaselineSuppression>
255255
</Suppression>
256+
<Suppression>
257+
<DiagnosticId>CP0021</DiagnosticId>
258+
<Target>M:System.Numerics.Tensors.Tensor.StdDev``1(System.Numerics.Tensors.ReadOnlyTensorSpan{``0}@)``0:T:System.Numerics.IRootFunctions{``0}</Target>
259+
<Left>lib/net8.0/System.Numerics.Tensors.dll</Left>
260+
<Right>lib/net8.0/System.Numerics.Tensors.dll</Right>
261+
<IsBaselineSuppression>true</IsBaselineSuppression>
262+
</Suppression>
263+
<Suppression>
264+
<DiagnosticId>CP0021</DiagnosticId>
265+
<Target>M:System.Numerics.Tensors.Tensor.StdDev``1(System.Numerics.Tensors.ReadOnlyTensorSpan{``0}@)``0:T:System.Numerics.IRootFunctions{``0}</Target>
266+
<Left>lib/net9.0/System.Numerics.Tensors.dll</Left>
267+
<Right>lib/net9.0/System.Numerics.Tensors.dll</Right>
268+
<IsBaselineSuppression>true</IsBaselineSuppression>
269+
</Suppression>
256270
</Suppressions>

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3512,26 +3512,15 @@ public static ref readonly TensorSpan<T> StackAlongDimension<T>(scoped ReadOnlyS
35123512
/// <param name="x">The <see cref="TensorSpan{T}"/> to take the standard deviation of.</param>
35133513
/// <returns><typeparamref name="T"/> representing the standard deviation.</returns>
35143514
public static T StdDev<T>(in ReadOnlyTensorSpan<T> x)
3515-
where T : IFloatingPoint<T>, IPowerFunctions<T>, IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>
3515+
where T : IFloatingPoint<T>, IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IRootFunctions<T>
35163516
{
35173517
T mean = Average(x);
3518-
Span<T> span = MemoryMarshal.CreateSpan(ref x._reference, (int)x._shape._memoryLength);
3519-
Span<T> output = new T[x.FlattenedLength];
3520-
TensorPrimitives.Subtract(span, mean, output);
3521-
TensorPrimitives.Abs(output, output);
3522-
TensorPrimitives.Pow((ReadOnlySpan<T>)output, T.CreateChecked(2), output);
3523-
T sum = TensorPrimitives.Sum((ReadOnlySpan<T>)output);
3524-
T variance = sum / T.CreateChecked(x._shape._memoryLength);
3525-
3526-
if (typeof(T) == typeof(float))
3527-
{
3528-
return T.CreateChecked(MathF.Sqrt(float.CreateChecked(variance)));
3529-
}
3530-
if (typeof(T) == typeof(double))
3531-
{
3532-
return T.CreateChecked(Math.Sqrt(double.CreateChecked(variance)));
3533-
}
3534-
return T.Pow(variance, T.CreateChecked(0.5));
3518+
Tensor<T> temp = CreateUninitialized<T>(x.Lengths);
3519+
Subtract(x, mean, temp);
3520+
Abs<T>(temp, temp);
3521+
T sum = SumOfSquares<T>(temp);
3522+
T variance = sum / T.CreateChecked(x.FlattenedLength);
3523+
return T.Sqrt(variance);
35353524
}
35363525
#endregion
35373526

@@ -6664,6 +6653,19 @@ public static T Sum<T>(scoped in ReadOnlyTensorSpan<T> x)
66646653
}
66656654
#endregion
66666655

6656+
#region SumOfSquares
6657+
/// <summary>
6658+
/// Sums the squared elements of the specified tensor.
6659+
/// </summary>
6660+
/// <param name="x">Tensor to sum squares of</param>
6661+
/// <returns></returns>
6662+
internal static T SumOfSquares<T>(scoped in ReadOnlyTensorSpan<T> x)
6663+
where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>
6664+
{
6665+
return TensorPrimitivesHelperSpanInTOut(x, TensorPrimitives.SumOfSquares);
6666+
}
6667+
#endregion
6668+
66676669
#region Tan
66686670
/// <summary>Computes the element-wise tangent of the value in the specified tensor.</summary>
66696671
/// <param name="x">The <see cref="ReadOnlyTensorSpan{T}"/> to take the sin of.</param>

src/libraries/System.Numerics.Tensors/tests/TensorTests.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,15 @@ public static void TensorStdDevTests()
11241124
Tensor<float> t0 = Tensor.Create<float>((Enumerable.Range(0, 4).Select(i => (float)i)), [2, 2]);
11251125

11261126
Assert.Equal(StdDev([0, 1, 2, 3]), Tensor.StdDev<float>(t0), .1);
1127+
1128+
// Test that non-contiguous calculations work
1129+
Tensor<float> fourByFour = Tensor.Create<float>([4, 4]);
1130+
fourByFour[[0, 0]] = 1f;
1131+
fourByFour[[0, 1]] = 1f;
1132+
fourByFour[[1, 0]] = 1f;
1133+
fourByFour[[1, 1]] = 1f;
1134+
ReadOnlyTensorSpan<float> upperLeft = fourByFour.AsReadOnlyTensorSpan().Slice([0..2, 0..2]);
1135+
Assert.Equal(0f, Tensor.StdDev(upperLeft));
11271136
}
11281137

11291138
public static float StdDev(float[] values)

0 commit comments

Comments
 (0)