Skip to content

Commit 90c81b2

Browse files
authored
Add TensorPrimitives.Max/MinNumber (#101435)
* Add TensorPrimitives.Max/MinNumber * Address PR feedback plus some cleanup
1 parent e92b7d0 commit 90c81b2

File tree

7 files changed

+340
-90
lines changed

7 files changed

+340
-90
lines changed

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,18 @@ public static void MaxMagnitude<T>(System.ReadOnlySpan<T> x, T y, System.Span<T>
8989
public static T Max<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumber<T> { throw null; }
9090
public static void Max<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
9191
public static void Max<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
92+
public static T MaxNumber<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumber<T> { throw null; }
93+
public static void MaxNumber<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
94+
public static void MaxNumber<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
9295
public static T MinMagnitude<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumberBase<T> { throw null; }
9396
public static void MinMagnitude<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumberBase<T> { }
9497
public static void MinMagnitude<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.INumberBase<T> { }
9598
public static T Min<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumber<T> { throw null; }
9699
public static void Min<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
97100
public static void Min<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
101+
public static T MinNumber<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumber<T> { throw null; }
102+
public static void MinNumber<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
103+
public static void MinNumber<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
98104
public static void MultiplyAdd<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.ReadOnlySpan<T> addend, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
99105
public static void MultiplyAdd<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, T addend, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
100106
public static void MultiplyAdd<T>(System.ReadOnlySpan<T> x, T y, System.ReadOnlySpan<T> addend, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }

src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<Project Sdk="Microsoft.NET.Sdk">
1+
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
44
<TargetFrameworks>$(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.0;$(NetFrameworkMinimum)</TargetFrameworks>
@@ -82,8 +82,10 @@
8282
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.LogP1.cs" />
8383
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.Max.cs" />
8484
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.MaxMagnitude.cs" />
85+
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.MaxNumber.cs" />
8586
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.Min.cs" />
8687
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.MinMagnitude.cs" />
88+
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.MinNumber.cs" />
8789
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.Multiply.cs" />
8890
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.MultiplyAdd.cs" />
8991
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.MultiplyAddEstimate.cs" />

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.Max.cs

Lines changed: 23 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -96,35 +96,21 @@ public static T Invoke(T x, T y)
9696
[MethodImpl(MethodImplOptions.AggressiveInlining)]
9797
public static Vector128<T> Invoke(Vector128<T> x, Vector128<T> y)
9898
{
99-
if (AdvSimd.IsSupported)
100-
{
101-
if (typeof(T) == typeof(byte)) return AdvSimd.Max(x.AsByte(), y.AsByte()).As<byte, T>();
102-
if (typeof(T) == typeof(sbyte)) return AdvSimd.Max(x.AsSByte(), y.AsSByte()).As<sbyte, T>();
103-
if (typeof(T) == typeof(short)) return AdvSimd.Max(x.AsInt16(), y.AsInt16()).As<short, T>();
104-
if (typeof(T) == typeof(ushort)) return AdvSimd.Max(x.AsUInt16(), y.AsUInt16()).As<ushort, T>();
105-
if (typeof(T) == typeof(int)) return AdvSimd.Max(x.AsInt32(), y.AsInt32()).As<int, T>();
106-
if (typeof(T) == typeof(uint)) return AdvSimd.Max(x.AsUInt32(), y.AsUInt32()).As<uint, T>();
107-
if (typeof(T) == typeof(float)) return AdvSimd.Max(x.AsSingle(), y.AsSingle()).As<float, T>();
108-
}
109-
110-
if (AdvSimd.Arm64.IsSupported)
99+
if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
111100
{
112-
if (typeof(T) == typeof(double)) return AdvSimd.Arm64.Max(x.AsDouble(), y.AsDouble()).As<double, T>();
113-
}
101+
if (AdvSimd.IsSupported && typeof(T) == typeof(float))
102+
{
103+
return AdvSimd.Max(x.AsSingle(), y.AsSingle()).As<float, T>();
104+
}
114105

115-
if (typeof(T) == typeof(float))
116-
{
117-
return
118-
Vector128.ConditionalSelect(Vector128.Equals(x, y),
119-
Vector128.ConditionalSelect(IsNegative(x.AsSingle()).As<float, T>(), y, x),
120-
Vector128.Max(x, y));
121-
}
106+
if (AdvSimd.Arm64.IsSupported && typeof(T) == typeof(double))
107+
{
108+
return AdvSimd.Arm64.Max(x.AsDouble(), y.AsDouble()).As<double, T>();
109+
}
122110

123-
if (typeof(T) == typeof(double))
124-
{
125111
return
126112
Vector128.ConditionalSelect(Vector128.Equals(x, y),
127-
Vector128.ConditionalSelect(IsNegative(x.AsDouble()).As<double, T>(), y, x),
113+
Vector128.ConditionalSelect(IsNegative(x), y, x),
128114
Vector128.Max(x, y));
129115
}
130116

@@ -134,19 +120,11 @@ public static Vector128<T> Invoke(Vector128<T> x, Vector128<T> y)
134120
[MethodImpl(MethodImplOptions.AggressiveInlining)]
135121
public static Vector256<T> Invoke(Vector256<T> x, Vector256<T> y)
136122
{
137-
if (typeof(T) == typeof(float))
138-
{
139-
return
140-
Vector256.ConditionalSelect(Vector256.Equals(x, y),
141-
Vector256.ConditionalSelect(IsNegative(x.AsSingle()).As<float, T>(), y, x),
142-
Vector256.Max(x, y));
143-
}
144-
145-
if (typeof(T) == typeof(double))
123+
if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
146124
{
147125
return
148126
Vector256.ConditionalSelect(Vector256.Equals(x, y),
149-
Vector256.ConditionalSelect(IsNegative(x.AsDouble()).As<double, T>(), y, x),
127+
Vector256.ConditionalSelect(IsNegative(x), y, x),
150128
Vector256.Max(x, y));
151129
}
152130

@@ -156,19 +134,11 @@ public static Vector256<T> Invoke(Vector256<T> x, Vector256<T> y)
156134
[MethodImpl(MethodImplOptions.AggressiveInlining)]
157135
public static Vector512<T> Invoke(Vector512<T> x, Vector512<T> y)
158136
{
159-
if (typeof(T) == typeof(float))
160-
{
161-
return
162-
Vector512.ConditionalSelect(Vector512.Equals(x, y),
163-
Vector512.ConditionalSelect(IsNegative(x.AsSingle()).As<float, T>(), y, x),
164-
Vector512.Max(x, y));
165-
}
166-
167-
if (typeof(T) == typeof(double))
137+
if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
168138
{
169139
return
170140
Vector512.ConditionalSelect(Vector512.Equals(x, y),
171-
Vector512.ConditionalSelect(IsNegative(x.AsDouble()).As<double, T>(), y, x),
141+
Vector512.ConditionalSelect(IsNegative(x), y, x),
172142
Vector512.Max(x, y));
173143
}
174144

@@ -192,24 +162,18 @@ public static Vector512<T> Invoke(Vector512<T> x, Vector512<T> y)
192162
[MethodImpl(MethodImplOptions.AggressiveInlining)]
193163
public static Vector128<T> Invoke(Vector128<T> x, Vector128<T> y)
194164
{
195-
if (AdvSimd.IsSupported)
165+
if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
196166
{
197-
if (typeof(T) == typeof(byte)) return AdvSimd.Max(x.AsByte(), y.AsByte()).As<byte, T>();
198-
if (typeof(T) == typeof(sbyte)) return AdvSimd.Max(x.AsSByte(), y.AsSByte()).As<sbyte, T>();
199-
if (typeof(T) == typeof(ushort)) return AdvSimd.Max(x.AsUInt16(), y.AsUInt16()).As<ushort, T>();
200-
if (typeof(T) == typeof(short)) return AdvSimd.Max(x.AsInt16(), y.AsInt16()).As<short, T>();
201-
if (typeof(T) == typeof(uint)) return AdvSimd.Max(x.AsUInt32(), y.AsUInt32()).As<uint, T>();
202-
if (typeof(T) == typeof(int)) return AdvSimd.Max(x.AsInt32(), y.AsInt32()).As<int, T>();
203-
if (typeof(T) == typeof(float)) return AdvSimd.Max(x.AsSingle(), y.AsSingle()).As<float, T>();
204-
}
167+
if (AdvSimd.IsSupported && typeof(T) == typeof(float))
168+
{
169+
return AdvSimd.Max(x.AsSingle(), y.AsSingle()).As<float, T>();
170+
}
205171

206-
if (AdvSimd.Arm64.IsSupported)
207-
{
208-
if (typeof(T) == typeof(double)) return AdvSimd.Arm64.Max(x.AsDouble(), y.AsDouble()).As<double, T>();
209-
}
172+
if (AdvSimd.Arm64.IsSupported && typeof(T) == typeof(double))
173+
{
174+
return AdvSimd.Arm64.Max(x.AsDouble(), y.AsDouble()).As<double, T>();
175+
}
210176

211-
if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
212-
{
213177
return
214178
Vector128.ConditionalSelect(Vector128.Equals(x, x),
215179
Vector128.ConditionalSelect(Vector128.Equals(y, y),
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Runtime.CompilerServices;
5+
using System.Runtime.InteropServices;
6+
using System.Runtime.Intrinsics;
7+
using System.Runtime.Intrinsics.Arm;
8+
9+
namespace System.Numerics.Tensors
10+
{
11+
public static partial class TensorPrimitives
12+
{
13+
/// <summary>Searches for the largest number in the specified tensor.</summary>
14+
/// <param name="x">The tensor, represented as a span.</param>
15+
/// <returns>The maximum element in <paramref name="x"/>.</returns>
16+
/// <exception cref="ArgumentException">Length of <paramref name="x" /> must be greater than zero.</exception>
17+
/// <remarks>
18+
/// <para>
19+
/// The determination of the maximum element matches the IEEE 754:2019 `maximumNumber` function. Positive 0 is considered greater than negative 0.
20+
/// </para>
21+
/// <para>
22+
/// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
23+
/// operating systems or architectures.
24+
/// </para>
25+
/// </remarks>
26+
public static T MaxNumber<T>(ReadOnlySpan<T> x)
27+
where T : INumber<T> =>
28+
MinMaxCore<T, MaxNumberOperator<T>>(x);
29+
30+
/// <summary>Computes the element-wise maximum of the numbers in the specified tensors.</summary>
31+
/// <param name="x">The first tensor, represented as a span.</param>
32+
/// <param name="y">The second tensor, represented as a span.</param>
33+
/// <param name="destination">The destination tensor, represented as a span.</param>
34+
/// <exception cref="ArgumentException">Length of <paramref name="x" /> must be same as length of <paramref name="y" />.</exception>
35+
/// <exception cref="ArgumentException">Destination is too short.</exception>
36+
/// <exception cref="ArgumentException"><paramref name="x"/> and <paramref name="destination"/> reference overlapping memory locations and do not begin at the same location.</exception>
37+
/// <exception cref="ArgumentException"><paramref name="y"/> and <paramref name="destination"/> reference overlapping memory locations and do not begin at the same location.</exception>
38+
/// <remarks>
39+
/// <para>
40+
/// This method effectively computes <c><paramref name="destination" />[i] = <typeparamref name="T"/>.MaxNumber(<paramref name="x" />[i], <paramref name="y" />[i])</c>.
41+
/// </para>
42+
/// <para>
43+
/// The determination of the maximum element matches the IEEE 754:2019 `maximumNumber` function. If either value is <see cref="IFloatingPointIeee754{TSelf}.NaN"/>
44+
/// the other is returned. Positive 0 is considered greater than negative 0.
45+
/// </para>
46+
/// <para>
47+
/// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
48+
/// operating systems or architectures.
49+
/// </para>
50+
/// </remarks>
51+
public static void MaxNumber<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination)
52+
where T : INumber<T> =>
53+
InvokeSpanSpanIntoSpan<T, MaxNumberOperator<T>>(x, y, destination);
54+
55+
/// <summary>Computes the element-wise maximum of the numbers in the specified tensors.</summary>
56+
/// <param name="x">The first tensor, represented as a span.</param>
57+
/// <param name="y">The second tensor, represented as a scalar.</param>
58+
/// <param name="destination">The destination tensor, represented as a span.</param>
59+
/// <exception cref="ArgumentException">Destination is too short.</exception>
60+
/// <exception cref="ArgumentException"><paramref name="x"/> and <paramref name="destination"/> reference overlapping memory locations and do not begin at the same location.</exception>
61+
/// <remarks>
62+
/// <para>
63+
/// This method effectively computes <c><paramref name="destination" />[i] = <typeparamref name="T"/>.MaxNumber(<paramref name="x" />[i], <paramref name="y" />)</c>.
64+
/// </para>
65+
/// <para>
66+
/// The determination of the maximum element matches the IEEE 754:2019 `maximumNumber` function. If either value is <see cref="IFloatingPointIeee754{TSelf}.NaN"/>
67+
/// the other is returned. Positive 0 is considered greater than negative 0.
68+
/// </para>
69+
/// <para>
70+
/// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
71+
/// operating systems or architectures.
72+
/// </para>
73+
/// </remarks>
74+
public static void MaxNumber<T>(ReadOnlySpan<T> x, T y, Span<T> destination)
75+
where T : INumber<T> =>
76+
InvokeSpanScalarIntoSpan<T, MaxNumberOperator<T>>(x, y, destination);
77+
78+
/// <summary>T.MaxNumber(x, y)</summary>
79+
internal readonly struct MaxNumberOperator<T> : IAggregationOperator<T> where T : INumber<T>
80+
{
81+
public static bool Vectorizable => true;
82+
83+
public static T Invoke(T x, T y) => T.MaxNumber(x, y);
84+
85+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
86+
public static Vector128<T> Invoke(Vector128<T> x, Vector128<T> y)
87+
{
88+
if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
89+
{
90+
if (AdvSimd.IsSupported && typeof(T) == typeof(float))
91+
{
92+
return AdvSimd.MaxNumber(x.AsSingle(), y.AsSingle()).As<float, T>();
93+
}
94+
95+
if (AdvSimd.Arm64.IsSupported && typeof(T) == typeof(double))
96+
{
97+
return AdvSimd.Arm64.MaxNumber(x.AsDouble(), y.AsDouble()).As<double, T>();
98+
}
99+
100+
return
101+
Vector128.ConditionalSelect(Vector128.Equals(x, y),
102+
Vector128.ConditionalSelect(IsNegative(y), x, y),
103+
Vector128.ConditionalSelect(Vector128.Equals(y, y), Vector128.Max(x, y), x));
104+
}
105+
106+
return Vector128.Max(x, y);
107+
}
108+
109+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
110+
public static Vector256<T> Invoke(Vector256<T> x, Vector256<T> y)
111+
{
112+
if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
113+
{
114+
return
115+
Vector256.ConditionalSelect(Vector256.Equals(x, y),
116+
Vector256.ConditionalSelect(IsNegative(y), x, y),
117+
Vector256.ConditionalSelect(Vector256.Equals(y, y), Vector256.Max(x, y), x));
118+
}
119+
120+
return Vector256.Max(x, y);
121+
}
122+
123+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
124+
public static Vector512<T> Invoke(Vector512<T> x, Vector512<T> y)
125+
{
126+
if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
127+
{
128+
return
129+
Vector512.ConditionalSelect(Vector512.Equals(x, y),
130+
Vector512.ConditionalSelect(IsNegative(y), x, y),
131+
Vector512.ConditionalSelect(Vector512.Equals(y, y), Vector512.Max(x, y), x));
132+
}
133+
134+
return Vector512.Max(x, y);
135+
}
136+
137+
public static T Invoke(Vector128<T> x) => HorizontalAggregate<T, MaxNumberOperator<T>>(x);
138+
public static T Invoke(Vector256<T> x) => HorizontalAggregate<T, MaxNumberOperator<T>>(x);
139+
public static T Invoke(Vector512<T> x) => HorizontalAggregate<T, MaxNumberOperator<T>>(x);
140+
}
141+
}
142+
}

0 commit comments

Comments
 (0)