From b88d3b9cba8467e3f50b7af9591f43c70053b5c9 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 29 Sep 2023 20:05:53 -0400 Subject: [PATCH] Throw exception in TensorPrimitives for unsupported span overlaps (#92838) --- .../src/Resources/Strings.resx | 5 +- .../Numerics/Tensors/TensorPrimitives.cs | 203 +++++--------- .../Tensors/TensorPrimitives.netcore.cs | 17 ++ .../Tensors/TensorPrimitives.netstandard.cs | 17 ++ .../src/System/ThrowHelper.cs | 4 + .../tests/TensorPrimitivesTests.cs | 258 ++++++++++++++++++ 6 files changed, 374 insertions(+), 130 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx index 45f0d8fa17893..86b9f4d82b1f6 100644 --- a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx +++ b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx @@ -126,4 +126,7 @@ Input span arguments must all have the same length. - \ No newline at end of file + + The destination span may only overlap with an input span if the two spans start at the same memory location. + + diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs index 41fe81416b27a..cd4a33f8d60a9 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -1,6 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + namespace System.Numerics.Tensors { /// Performs primitive tensor operations over spans of memory. @@ -10,6 +13,7 @@ public static partial class TensorPrimitives /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = MathF.Abs([i]). @@ -21,10 +25,6 @@ public static partial class TensorPrimitives /// If a value is equal to or , the result stored into the corresponding destination location is set to . /// If a value is equal to , the result stored into the corresponding destination location is the original NaN value with the sign bit removed. /// - /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// /// public static void Abs(ReadOnlySpan x, Span destination) => InvokeSpanIntoSpan(x, destination); @@ -35,16 +35,13 @@ public static void Abs(ReadOnlySpan x, Span destination) => /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = [i] + [i]. /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, - /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -56,16 +53,12 @@ public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span /// The second tensor, represented as a scalar. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = [i] + . /// /// - /// and may overlap, but only if they start at the same memory location; - /// otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters, such as to perform - /// an in-place operation. - /// - /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -79,16 +72,14 @@ public static void Add(ReadOnlySpan x, float y, Span destination) /// The destination tensor, represented as a span. /// Length of must be same as length of and the length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = ([i] + [i]) * [i]. /// /// - /// , , and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. - /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -102,16 +93,13 @@ public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, Rea /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = ([i] + [i]) * . /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. - /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -125,16 +113,13 @@ public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, flo /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = ([i] + ) * [i]. /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. - /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -145,15 +130,12 @@ public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpanThe tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = .Cosh([i]). /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// If a value is equal to or , the result stored into the corresponding destination location is set to . /// If a value is equal to , the result stored into the corresponding destination location is also NaN. /// @@ -172,6 +154,8 @@ public static void Cosh(ReadOnlySpan x, Span destination) ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + for (int i = 0; i < x.Length; i++) { destination[i] = MathF.Cosh(x[i]); @@ -257,16 +241,13 @@ public static float Distance(ReadOnlySpan x, ReadOnlySpan y) /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = [i] / [i]. /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, - /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -278,15 +259,12 @@ public static void Divide(ReadOnlySpan x, ReadOnlySpan y, SpanThe second tensor, represented as a scalar. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = [i] / . /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -330,15 +308,12 @@ public static float Dot(ReadOnlySpan x, ReadOnlySpan y) /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = .Exp([i]). /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// If a value equals or , the result stored into the corresponding destination location is set to NaN. /// If a value equals , the result stored into the corresponding destination location is set to 0. /// @@ -354,6 +329,8 @@ public static void Exp(ReadOnlySpan x, Span destination) ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + for (int i = 0; i < x.Length; i++) { destination[i] = MathF.Exp(x[i]); @@ -570,15 +547,12 @@ public static unsafe int IndexOfMinMagnitude(ReadOnlySpan x) /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = .Log([i]). /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// If a value equals 0, the result stored into the corresponding destination location is set to . /// If a value is negative or equal to , the result stored into the corresponding destination location is set to NaN. /// If a value is positive infinity, the result stored into the corresponding destination location is set to . @@ -596,6 +570,8 @@ public static void Log(ReadOnlySpan x, Span destination) ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + for (int i = 0; i < x.Length; i++) { destination[i] = MathF.Log(x[i]); @@ -606,15 +582,12 @@ public static void Log(ReadOnlySpan x, Span destination) /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = .Log2([i]). /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// If a value equals 0, the result stored into the corresponding destination location is set to . /// If a value is negative or equal to , the result stored into the corresponding destination location is set to NaN. /// If a value is positive infinity, the result stored into the corresponding destination location is set to . @@ -632,6 +605,8 @@ public static void Log2(ReadOnlySpan x, Span destination) ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + for (int i = 0; i < x.Length; i++) { destination[i] = Log2(x[i]); @@ -661,16 +636,13 @@ public static float Max(ReadOnlySpan x) => /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = MathF.Max([i], [i]). /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, - /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If either value is equal to , /// that value is stored as the result. Positive 0 is considered greater than negative 0. /// @@ -706,14 +678,11 @@ public static float MaxMagnitude(ReadOnlySpan x) => /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// This method effectively computes [i] = MathF.MaxMagnitude([i], [i]). /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, - /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different /// operating systems or architectures. /// @@ -744,6 +713,8 @@ public static float Min(ReadOnlySpan x) => /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = MathF.Max([i], [i]). @@ -753,11 +724,6 @@ public static float Min(ReadOnlySpan x) => /// that value is stored as the result. Positive 0 is considered greater than negative 0. /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, - /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different /// operating systems or architectures. /// @@ -789,6 +755,8 @@ public static float MinMagnitude(ReadOnlySpan x) => /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// This method effectively computes [i] = MathF.MinMagnitude([i], [i]). /// /// @@ -797,11 +765,6 @@ public static float MinMagnitude(ReadOnlySpan x) => /// the negative value is considered to have the smaller magnitude. /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, - /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different /// operating systems or architectures. /// @@ -815,16 +778,13 @@ public static void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Sp /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = [i] * [i]. /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, - /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -836,16 +796,13 @@ public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, SpanThe second tensor, represented as a scalar. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = [i] * . /// It corresponds to the scal method defined by BLAS1. /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -859,21 +816,17 @@ public static void Multiply(ReadOnlySpan x, float y, Span destinat /// The destination tensor, represented as a span. /// Length of must be same as length of and length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = ([i] * [i]) + [i]. /// /// - /// , , and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. - /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// - - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); @@ -884,17 +837,14 @@ public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, Rea /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = ([i] * [i]) + . /// It corresponds to the axpy method defined by BLAS1. /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. - /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -908,16 +858,13 @@ public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, flo /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = ([i] * ) + [i]. /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. - /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -928,15 +875,12 @@ public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpanThe tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = -[i]. /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -1063,15 +1007,12 @@ public static float ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) /// The destination tensor. /// Destination is too short. /// must not be empty. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = 1f / (1f + .Exp(-[i])). /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different /// operating systems or architectures. /// @@ -1088,6 +1029,8 @@ public static void Sigmoid(ReadOnlySpan x, Span destination) ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + for (int i = 0; i < x.Length; i++) { destination[i] = 1f / (1f + MathF.Exp(-x[i])); @@ -1098,15 +1041,12 @@ public static void Sigmoid(ReadOnlySpan x, Span destination) /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = .Sinh([i]). /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// If a value is equal to , , or , /// the corresponding destination location is set to that value. /// @@ -1125,6 +1065,8 @@ public static void Sinh(ReadOnlySpan x, Span destination) ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + for (int i = 0; i < x.Length; i++) { destination[i] = MathF.Sinh(x[i]); @@ -1136,16 +1078,13 @@ public static void Sinh(ReadOnlySpan x, Span destination) /// The destination tensor. /// Destination is too short. /// must not be empty. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes a sum of MathF.Exp(x[i]) for all elements in . /// It then effectively computes [i] = MathF.Exp([i]) / sum. /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different /// operating systems or architectures. /// @@ -1162,6 +1101,8 @@ public static void SoftMax(ReadOnlySpan x, Span destination) ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + float expSum = 0f; for (int i = 0; i < x.Length; i++) @@ -1181,16 +1122,13 @@ public static void SoftMax(ReadOnlySpan x, Span destination) /// The destination tensor, represented as a span. /// Length of must be same as length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = [i] - [i]. /// /// - /// and may overlap arbitrarily, but they may only overlap with - /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. - /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. - /// - /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -1202,15 +1140,12 @@ public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, SpanThe second tensor, represented as a scalar. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = [i] - . /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// @@ -1278,15 +1213,12 @@ public static float SumOfSquares(ReadOnlySpan x) => /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// /// /// This method effectively computes [i] = .Tanh([i]). /// /// - /// may overlap with , but only if the input and the output span begin at the same memory - /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. - /// - /// /// If a value is equal to , the corresponding destination location is set to -1. /// If a value is equal to , the corresponding destination location is set to 1. /// If a value is equal to , the corresponding destination location is set to NaN. @@ -1306,12 +1238,25 @@ public static void Tanh(ReadOnlySpan x, Span destination) ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + for (int i = 0; i < x.Length; i++) { destination[i] = MathF.Tanh(x[i]); } } + /// Throws an exception if the and spans overlap and don't begin at the same memory location. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ValidateInputOutputSpanNonOverlapping(ReadOnlySpan input, Span output) + { + if (!Unsafe.AreSame(ref MemoryMarshal.GetReference(input), ref MemoryMarshal.GetReference(output)) && + input.Overlaps(output)) + { + ThrowHelper.ThrowArgument_InputAndDestinationSpanMustNotOverlap(); + } + } + /// Mask used to handle remaining elements after vectorized handling of the input. /// /// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 4cc29c70ce0bd..6773771602212 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -1201,6 +1201,8 @@ private static unsafe void InvokeSpanIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float dRef = ref MemoryMarshal.GetReference(destination); int i = 0, oneVectorFromEnd; @@ -1313,6 +1315,9 @@ private static unsafe void InvokeSpanSpanIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float dRef = ref MemoryMarshal.GetReference(destination); @@ -1428,6 +1433,8 @@ private static unsafe void InvokeSpanScalarIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float dRef = ref MemoryMarshal.GetReference(destination); int i = 0, oneVectorFromEnd; @@ -1553,6 +1560,10 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float zRef = ref MemoryMarshal.GetReference(z); @@ -1681,6 +1692,9 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float dRef = ref MemoryMarshal.GetReference(destination); @@ -1814,6 +1828,9 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float zRef = ref MemoryMarshal.GetReference(z); ref float dRef = ref MemoryMarshal.GetReference(destination); diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index 70207a5c8995b..ac4ea2dfe9bef 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -298,6 +298,8 @@ private static void InvokeSpanIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float dRef = ref MemoryMarshal.GetReference(destination); int i = 0, oneVectorFromEnd; @@ -354,6 +356,9 @@ private static void InvokeSpanSpanIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float dRef = ref MemoryMarshal.GetReference(destination); @@ -408,6 +413,8 @@ private static void InvokeSpanScalarIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float dRef = ref MemoryMarshal.GetReference(destination); int i = 0, oneVectorFromEnd; @@ -467,6 +474,10 @@ private static void InvokeSpanSpanSpanIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float zRef = ref MemoryMarshal.GetReference(z); @@ -531,6 +542,9 @@ private static void InvokeSpanSpanScalarIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float dRef = ref MemoryMarshal.GetReference(destination); @@ -596,6 +610,9 @@ private static void InvokeSpanScalarSpanIntoSpan( ThrowHelper.ThrowArgument_DestinationTooShort(); } + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + ref float xRef = ref MemoryMarshal.GetReference(x); ref float zRef = ref MemoryMarshal.GetReference(z); ref float dRef = ref MemoryMarshal.GetReference(destination); diff --git a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs index 902b27787e856..272991aed44ab 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs @@ -18,5 +18,9 @@ public static void ThrowArgument_SpansMustHaveSameLength() => [DoesNotReturn] public static void ThrowArgument_SpansMustBeNonEmpty() => throw new ArgumentException(SR.Argument_SpansMustBeNonEmpty); + + [DoesNotReturn] + public static void ThrowArgument_InputAndDestinationSpanMustNotOverlap() => + throw new ArgumentException(SR.Argument_InputAndDestinationSpanMustNotOverlap, "destination"); } } diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index 751e352dd1da5..137f2fd9070de 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -99,6 +99,14 @@ public static void Abs_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(x, destination)); } + + [Fact] + public static void Abs_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(array.AsSpan(1, 5), array.AsSpan(0, 5))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(array.AsSpan(1, 5), array.AsSpan(2, 5))); + } #endregion #region Add @@ -164,6 +172,16 @@ public static void Add_TwoTensors_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); } + [Fact] + public static void Add_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(4, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(6, 2))); + } + [Theory] [MemberData(nameof(TensorLengthsIncluding0))] public static void Add_TensorScalar(int tensorLength) @@ -206,6 +224,14 @@ public static void Add_TensorScalar_ThrowsForTooShortDestination(int tensorLengt AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); } + + [Fact] + public static void Add_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + } #endregion #region AddMultiply @@ -267,6 +293,18 @@ public static void AddMultiply_ThreeTensors_ThrowsForTooShortDestination(int ten AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } + [Fact] + public static void AddMultiply_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); + } + [Theory] [MemberData(nameof(TensorLengthsIncluding0))] public static void AddMultiply_TensorTensorScalar(int tensorLength) @@ -325,6 +363,16 @@ public static void AddMultiply_TensorTensorScalar_ThrowsForTooShortDestination(i AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } + [Fact] + public static void AddMultiply_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(5, 2))); + } + [Theory] [MemberData(nameof(TensorLengthsIncluding0))] public static void AddMultiply_TensorScalarTensor(int tensorLength) @@ -382,6 +430,16 @@ public static void AddMultiply_TensorScalarTensor_ThrowsForTooShortDestination(i AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } + + [Fact] + public static void AddMultiply_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(5, 2))); + } #endregion #region Cosh @@ -424,6 +482,14 @@ public static void Cosh_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(x, destination)); } + + [Fact] + public static void Cosh_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } #endregion #region CosineSimilarity @@ -575,6 +641,16 @@ public static void Divide_TwoTensors_ThrowsForTooShortDestination(int tensorLeng AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); } + [Fact] + public static void Divide_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + [Theory] [MemberData(nameof(TensorLengthsIncluding0))] public static void Divide_TensorScalar(int tensorLength) @@ -617,6 +693,16 @@ public static void Divide_TensorScalar_ThrowsForTooShortDestination(int tensorLe AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); } + + [Fact] + public static void Divide_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } #endregion #region Dot @@ -698,6 +784,14 @@ public static void Exp_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(x, destination)); } + + [Fact] + public static void Exp_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } #endregion #region IndexOfMax @@ -921,6 +1015,14 @@ public static void Log_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.Log(x, destination)); } + + [Fact] + public static void Log_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } #endregion #region Log2 @@ -963,6 +1065,14 @@ public static void Log2_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(x, destination)); } + + [Fact] + public static void Log2_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } #endregion #region Max @@ -1109,6 +1219,16 @@ public static void Max_TwoTensors_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.Max(x, y, destination)); } + + [Fact] + public static void Max_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } #endregion #region MaxMagnitude @@ -1256,6 +1376,16 @@ public static void MaxMagnitude_TwoTensors_ThrowsForTooShortDestination(int tens AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(x, y, destination)); } + + [Fact] + public static void MaxMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } #endregion #region Min @@ -1402,6 +1532,16 @@ public static void Min_TwoTensors_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.Min(x, y, destination)); } + + [Fact] + public static void Min_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } #endregion #region MinMagnitude @@ -1547,6 +1687,16 @@ public static void MinMagnitude_TwoTensors_ThrowsForTooShortDestination(int tens AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(x, y, destination)); } + + [Fact] + public static void MinMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } #endregion #region Multiply @@ -1604,6 +1754,16 @@ public static void Multiply_TwoTensors_ThrowsForTooShortDestination(int tensorLe AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); } + [Fact] + public static void Multiply_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + [Theory] [MemberData(nameof(TensorLengthsIncluding0))] public static void Multiply_TensorScalar(int tensorLength) @@ -1646,6 +1806,14 @@ public static void Multiply_TensorScalar_ThrowsForTooShortDestination(int tensor AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); } + + [Fact] + public static void Multiply_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + } #endregion #region MultiplyAdd @@ -1707,6 +1875,18 @@ public static void MultiplyAdd_ThreeTensors_ThrowsForTooShortDestination(int ten AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } + [Fact] + public static void MultiplyAdd_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); + } + [Theory] [MemberData(nameof(TensorLengthsIncluding0))] public static void MultiplyAdd_TensorTensorScalar(int tensorLength) @@ -1752,6 +1932,16 @@ public static void MultiplyAdd_TensorTensorScalar_ThrowsForTooShortDestination(i AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } + [Fact] + public static void MultiplyAdd_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(5, 2))); + } + [Theory] [MemberData(nameof(TensorLengthsIncluding0))] public static void MultiplyAdd_TensorScalarTensor(int tensorLength) @@ -1796,6 +1986,16 @@ public static void MultiplyAdd_TensorScalarTensor_ThrowsForTooShortDestination(i AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } + + [Fact] + public static void MultiplyAdd_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(5, 2))); + } #endregion #region Negate @@ -1838,6 +2038,14 @@ public static void Negate_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(x, destination)); } + + [Fact] + public static void Negate_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } #endregion #region Norm @@ -2064,6 +2272,14 @@ public static void Sigmoid_ThrowsForEmptyInput() { AssertExtensions.Throws(() => TensorPrimitives.Sigmoid(ReadOnlySpan.Empty, CreateTensor(1))); } + + [Fact] + public static void Sigmoid_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } #endregion #region Sinh @@ -2106,6 +2322,14 @@ public static void Sinh_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(x, destination)); } + + [Fact] + public static void Sinh_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } #endregion #region SoftMax @@ -2186,6 +2410,14 @@ public static void SoftMax_ThrowsForEmptyInput() { AssertExtensions.Throws(() => TensorPrimitives.SoftMax(ReadOnlySpan.Empty, CreateTensor(1))); } + + [Fact] + public static void SoftMax_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } #endregion #region Subtract @@ -2243,6 +2475,16 @@ public static void Subtract_TwoTensors_ThrowsForTooShortDestination(int tensorLe AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); } + [Fact] + public static void Subtract_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + [Theory] [MemberData(nameof(TensorLengthsIncluding0))] public static void Subtract_TensorScalar(int tensorLength) @@ -2285,6 +2527,14 @@ public static void Subtract_TensorScalar_ThrowsForTooShortDestination(int tensor AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); } + + [Fact] + public static void Subtract_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + } #endregion #region Sum @@ -2411,6 +2661,14 @@ public static void Tanh_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(x, destination)); } + + [Fact] + public static void Tanh_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } #endregion } }