Description
Summary
For each Vector
API, we introduce a corresponding VectorMask
, which abstracts away low-level bit-masking and instead allows to express conditional SIMD processing as boolean logic over Vector APIs. In particular, VectorMask<T>
allows to perform masking operations and conditional SIMD processing on the variable length Vector<T>
API, which allows for Vector<T>
to be used more performantly and closer to SIMD processing done with Vector64
/Vector128
/Vector256
.
Please see dotnet/designs#268 and https://github.com/anthonycanino/designs/blob/main/accepted/2022/enable-512-vectors.md#vectormask-usage for detailed discussion behind the rationale for VectorMask
, though the APIs that are posted here reflect the most recent discussion on the proposal at dotnet/designs#268.
The API Proposal focuses on Vector128
and Vector
with associated VectorMask128
and VectorMask
APIs, but we propose a correponding VectorMaskX
for each VectorX
API, e.g., Vector64
, Vector256
etc.
API Proposal
namespace System.Runtime.Intrinsics
{
public static partial class Vector64
{
public VectorMask64<T> ExtractMask<T>(this Vector64<T> vector);
}
public static partial class Vector128
{
public VectorMask128<T> ExtractMask<T>(this Vector128<T> vector);
}
public static partial class Vector256
{
public VectorMask128<T> ExtractMask<T>(this Vector256<T> vector);
}
public static partial class Vector512
{
public VectorMask512<T> ExtractMask<T>(this Vector512<T> vector);
}
public static class VectorMask64
{
public bool IsHardwareAccelerated { get; }
public static VectorMask64<T> Create(ushort mask);
public static VectorMask64<TTo> As<TFrom, TTo>(this VectorMask64<TFrom> vector) where TFrom : struct where TTo : struct;
public static VectorMask64<byte> AsByte <T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<double> AsDouble<T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<short> AsInt16 <T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<int> AsInt32 <T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<long> AsInt64 <T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<nint> AsNInt <T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<nuint> AsNUInt <T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<sbyte> AsSByte <T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<float> AsSingle<T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<ushort> AsUInt16<T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<uint> AsUInt32<T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<ulong> AsUInt64<T>(this VectorMask64<T> vector) where T : struct;
public static VectorMask64<T> BitwiseAnd<T>(VectorMask64<T> left, VectorMask64<T> right);
public static VectorMask64<T> BitwiseOr<T>(VectorMask64<T> left, VectorMask64<T> right);
public static VectorMask64<T> AndNot<T>(VectorMask64<T> left, VectorMask64<T> right);
public static VectorMask64<T> OnesComplement<T>(VectorMask64<T> value);
public static VectorMask64<T> Xor<T>(VectorMask64<T> left, VectorMask64<T> right);
public static VectorMask64<T> Xnor<T>(VectorMask64<T> left, VectorMask64<T> right);
public static VectorMask64<T> ShiftLeft<T>(VectorMask64<T> value, int count);
public static VectorMask64<T> ShiftRight<T>(VectorMask64<T> value, int count);
public static bool Equals<T>(VectorMask64<T> left, VectorMask64<T> right);
public static int LeadingZeroCount(VectorMask64<T> mask);
public static int TrailingZeroCount(VectorMask64<T> mask);
public static int PopCount(VectorMask64<T> mask);
public static bool GetElement(this Vector64<T> vector, int index) where T : struct;
public static Vector64Mask<T> WithElement(this Vector64<T> vector, int index, bool value) where T : struct;
}
public readonly struct VectorMask64<T> where T : struct
{
private readonly byte _value;
public static bool IsSupported { get; }
public static int Count { get; }
public static VectorMask64<T> AllBitsSet { get; }
public static VectorMask64<T> Zero { get; }
public static bool this[int index] { get; }
public static VectorMask64<T> operator &(VectorMask64<T> left, VectorMask64<T> right);
public static VectorMask64<T> operator |(VectorMask64<T> left, VectorMask64<T> right);
public static VectorMask64<T> operator ~(VectorMask64<T> value);
public static VectorMask64<T> operator ^(VectorMask64<T> left, VectorMask64<T> right);
public static VectorMask64<T> operator <<(VectorMask64<T> value, int count);
public static VectorMask64<T> operator >>(VectorMask64<T> value, int count);
public static bool operator ==(VectorMask64<T> left, VectorMask64<T> right);
public static bool operator !=(VectorMask64<T> left, VectorMask64<T> right);
}
public static class VectorMask128
{
public bool IsHardwareAccelerated { get; }
public static VectorMask128<T> Create(ushort mask);
public static VectorMask128<TTo> As<TFrom, TTo>(this VectorMask128<TFrom> vector) where TFrom : struct where TTo : struct;
public static VectorMask128<byte> AsByte <T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<double> AsDouble<T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<short> AsInt16 <T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<int> AsInt32 <T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<long> AsInt64 <T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<nint> AsNInt <T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<nuint> AsNUInt <T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<sbyte> AsSByte <T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<float> AsSingle<T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<ushort> AsUInt16<T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<uint> AsUInt32<T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<ulong> AsUInt64<T>(this VectorMask128<T> vector) where T : struct;
public static VectorMask128<T> BitwiseAnd<T>(VectorMask128<T> left, VectorMask128<T> right);
public static VectorMask128<T> BitwiseOr<T>(VectorMask128<T> left, VectorMask128<T> right);
public static VectorMask128<T> AndNot<T>(VectorMask128<T> left, VectorMask128<T> right);
public static VectorMask128<T> OnesComplement<T>(VectorMask128<T> value);
public static VectorMask128<T> Xor<T>(VectorMask128<T> left, VectorMask128<T> right);
public static VectorMask128<T> Xnor<T>(VectorMask128<T> left, VectorMask128<T> right);
public static VectorMask128<T> ShiftLeft<T>(VectorMask128<T> value, int count);
public static VectorMask128<T> ShiftRight<T>(VectorMask128<T> value, int count);
public static bool Equals<T>(VectorMask128<T> left, VectorMask128<T> right);
public static int LeadingZeroCount(VectorMask128<T> mask);
public static int TrailingZeroCount(VectorMask128<T> mask);
public static int PopCount(VectorMask128<T> mask);
public static bool GetElement(this Vector128<T> vector, int index) where T : struct;
public static Vector128Mask<T> WithElement(this Vector128<T> vector, int index, bool value) where T : struct;
}
public readonly struct VectorMask128<T> where T : struct
{
private readonly ushort _value;
public static bool IsSupported { get; }
public static int Count { get; }
public static VectorMask128<T> AllBitsSet { get; }
public static VectorMask128<T> Zero { get; }
public static bool this[int index] { get; }
public static VectorMask128<T> operator &(VectorMask128<T> left, VectorMask128<T> right);
public static VectorMask128<T> operator |(VectorMask128<T> left, VectorMask128<T> right);
public static VectorMask128<T> operator ~(VectorMask128<T> value);
public static VectorMask128<T> operator ^(VectorMask128<T> left, VectorMask128<T> right);
public static VectorMask128<T> operator <<(VectorMask128<T> value, int count);
public static VectorMask128<T> operator >>(VectorMask128<T> value, int count);
public static bool operator ==(VectorMask128<T> left, VectorMask128<T> right);
public static bool operator !=(VectorMask128<T> left, VectorMask128<T> right);
}
public static class VectorMask256
{
public bool IsHardwareAccelerated { get; }
public static VectorMask256<T> Create(ushort mask);
public static VectorMask256<TTo> As<TFrom, TTo>(this VectorMask256<TFrom> vector) where TFrom : struct where TTo : struct;
public static VectorMask256<byte> AsByte <T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<double> AsDouble<T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<short> AsInt16 <T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<int> AsInt32 <T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<long> AsInt64 <T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<nint> AsNInt <T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<nuint> AsNUInt <T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<sbyte> AsSByte <T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<float> AsSingle<T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<ushort> AsUInt16<T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<uint> AsUInt32<T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<ulong> AsUInt64<T>(this VectorMask256<T> vector) where T : struct;
public static VectorMask256<T> BitwiseAnd<T>(VectorMask256<T> left, VectorMask256<T> right);
public static VectorMask256<T> BitwiseOr<T>(VectorMask256<T> left, VectorMask256<T> right);
public static VectorMask256<T> AndNot<T>(VectorMask256<T> left, VectorMask256<T> right);
public static VectorMask256<T> OnesComplement<T>(VectorMask256<T> value);
public static VectorMask256<T> Xor<T>(VectorMask256<T> left, VectorMask256<T> right);
public static VectorMask256<T> Xnor<T>(VectorMask256<T> left, VectorMask256<T> right);
public static VectorMask256<T> ShiftLeft<T>(VectorMask256<T> value, int count);
public static VectorMask256<T> ShiftRight<T>(VectorMask256<T> value, int count);
public static bool Equals<T>(VectorMask256<T> left, VectorMask256<T> right);
public static int LeadingZeroCount(VectorMask256<T> mask);
public static int TrailingZeroCount(VectorMask256<T> mask);
public static int PopCount(VectorMask256<T> mask);
public static bool GetElement(this Vector256<T> vector, int index) where T : struct;
public static Vector256Mask<T> WithElement(this Vector256<T> vector, int index, bool value) where T : struct;
}
public readonly struct VectorMask256<T> where T : struct
{
private readonly uint _value;
public static bool IsSupported { get; }
public static int Count { get; }
public static VectorMask256<T> AllBitsSet { get; }
public static VectorMask256<T> Zero { get; }
public static bool this[int index] { get; }
public static VectorMask256<T> operator &(VectorMask256<T> left, VectorMask256<T> right);
public static VectorMask256<T> operator |(VectorMask256<T> left, VectorMask256<T> right);
public static VectorMask256<T> operator ~(VectorMask256<T> value);
public static VectorMask256<T> operator ^(VectorMask256<T> left, VectorMask256<T> right);
public static VectorMask256<T> operator <<(VectorMask256<T> value, int count);
public static VectorMask256<T> operator >>(VectorMask256<T> value, int count);
public static bool operator ==(VectorMask256<T> left, VectorMask256<T> right);
public static bool operator !=(VectorMask256<T> left, VectorMask256<T> right);
}
public static class VectorMask512
{
public bool IsHardwareAccelerated { get; }
public static VectorMask512<T> Create(ushort mask);
public static VectorMask512<TTo> As<TFrom, TTo>(this VectorMask512<TFrom> vector) where TFrom : struct where TTo : struct;
public static VectorMask512<byte> AsByte <T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<double> AsDouble<T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<short> AsInt16 <T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<int> AsInt32 <T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<long> AsInt64 <T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<nint> AsNInt <T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<nuint> AsNUInt <T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<sbyte> AsSByte <T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<float> AsSingle<T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<ushort> AsUInt16<T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<uint> AsUInt32<T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<ulong> AsUInt64<T>(this VectorMask512<T> vector) where T : struct;
public static VectorMask512<T> BitwiseAnd<T>(VectorMask512<T> left, VectorMask512<T> right);
public static VectorMask512<T> BitwiseOr<T>(VectorMask512<T> left, VectorMask512<T> right);
public static VectorMask512<T> AndNot<T>(VectorMask512<T> left, VectorMask512<T> right);
public static VectorMask512<T> OnesComplement<T>(VectorMask512<T> value);
public static VectorMask512<T> Xor<T>(VectorMask512<T> left, VectorMask512<T> right);
public static VectorMask512<T> Xnor<T>(VectorMask512<T> left, VectorMask512<T> right);
public static VectorMask512<T> ShiftLeft<T>(VectorMask512<T> value, int count);
public static VectorMask512<T> ShiftRight<T>(VectorMask512<T> value, int count);
public static bool Equals<T>(VectorMask512<T> left, VectorMask512<T> right);
public static int LeadingZeroCount(VectorMask512<T> mask);
public static int TrailingZeroCount(VectorMask512<T> mask);
public static int PopCount(VectorMask512<T> mask);
public static bool GetElement(this Vector512<T> vector, int index) where T : struct;
public static Vector512Mask<T> WithElement(this Vector512<T> vector, int index, bool value) where T : struct;
}
public readonly struct VectorMask512<T> where T : struct
{
private readonly ulong _value;
public static bool IsSupported { get; }
public static int Count { get; }
public static VectorMask512<T> AllBitsSet { get; }
public static VectorMask512<T> Zero { get; }
public static bool this[int index] { get; }
public static VectorMask512<T> operator &(VectorMask512<T> left, VectorMask512<T> right);
public static VectorMask512<T> operator |(VectorMask512<T> left, VectorMask512<T> right);
public static VectorMask512<T> operator ~(VectorMask512<T> value);
public static VectorMask512<T> operator ^(VectorMask512<T> left, VectorMask512<T> right);
public static VectorMask512<T> operator <<(VectorMask512<T> value, int count);
public static VectorMask512<T> operator >>(VectorMask512<T> value, int count);
public static bool operator ==(VectorMask512<T> left, VectorMask512<T> right);
public static bool operator !=(VectorMask512<T> left, VectorMask512<T> right);
}
}
namespace System.Numerics
{
public static partial class Vector
{
public VectorMask<T> ExtractMask<T>(this Vector<T> vector);
}
public static class VectorMask
{
public bool IsHardwareAccelerated { get; }
public static VectorMask<T> Create(byte[] value);
public static VectorMask<T> Create(byte[] value, int index);
public static VectorMask<T> Create(ReadOnlySpan<byte> value);
public static VectorMask<TTo> As<TFrom, TTo>(this VectorMask<TFrom> vector) where TFrom : struct where TTo : struct;
public static VectorMask<byte> AsByte <T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<double> AsDouble<T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<short> AsInt16 <T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<int> AsInt32 <T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<long> AsInt64 <T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<nint> AsNInt <T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<nuint> AsNUInt <T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<sbyte> AsSByte <T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<float> AsSingle<T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<ushort> AsUInt16<T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<uint> AsUInt32<T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<ulong> AsUInt64<T>(this VectorMask<T> vector) where T : struct;
public static VectorMask<T> BitwiseAnd<T>(VectorMask<T> left, VectorMask<T> right);
public static VectorMask<T> BitwiseOr<T>(VectorMask<T> left, VectorMask<T> right);
public static VectorMask<T> AndNot<T>(VectorMask<T> left, VectorMask<T> right);
public static VectorMask<T> OnesComplement<T>(VectorMask<T> value);
public static VectorMask<T> Xor<T>(VectorMask<T> left, VectorMask<T> right);
public static VectorMask<T> Xnor<T>(VectorMask<T> left, VectorMask<T> right);
public static VectorMask<T> ShiftLeft<T>(VectorMask<T> value, int count);
public static VectorMask<T> ShiftRight<T>(VectorMask<T> value, int count);
public static bool Equals<T>(VectorMask<T> left, VectorMask<T> right);
public static int LeadingZeroCount(VectorMask<T> mask);
public static int TrailingZeroCount(VectorMask<T> mask);
public static int PopCount(VectorMask<T> mask);
public static bool GetElement(this Vector<T> vector, int index) where T : struct;
public static VectorMask<T> WithElement(this Vector<T> vector, int index, bool value) where T : struct;
}
public readonly struct VectorMask<T> where T : struct
{
private readonly ulong _value;
public static bool IsSupported { get; }
public static int Count { get; }
public static VectorMask<T> AllBitsSet { get; }
public static VectorMask<T> Zero { get; }
public static bool this[int index] { get; }
public static VectorMask<T> operator &(VectorMask<T> left, VectorMask<T> right);
public static VectorMask<T> operator |(VectorMask<T> left, VectorMask<T> right);
public static VectorMask<T> operator ~(VectorMask<T> value);
public static VectorMask<T> operator ^(VectorMask<T> left, VectorMask<T> right);
public static VectorMask<T> operator <<(VectorMask<T> value, int count);
public static VectorMask<T> operator >>(VectorMask<T> value, int count);
public static bool operator ==(VectorMask<T> left, VectorMask<T> right);
public static bool operator !=(VectorMask<T> left, VectorMask<T> right);
}
}
API Usage
A few points require further discussion:
VectorMask<T>
does not have a Create
method because like Vector
, it's size is unknown. So while it would be nice to have VectorMask<byte>.Create(0xFF00)
or VectorMask<byte>.Create(0xFFFF0000)
, (the first might cover if VectorMask == VectorMask128
, the second if VectorMask == VectorMask256
) since technically VectorMask
and Vector
are variable length, it breaks the abstraction a bit. My proposed alternative is to have CreateUnsafe
where a boolean array allows to set each bit, and is an error if the length of the boolean array != VectorMask.Count`.
It would be good to also have this done via a byte array to compress the user effort a bit, e.g., instead of VectorMask<int>.CreateUnsafe([true, false, false, true])
we have VectorMask<int>.CreateUnsafe([0x09])
. We might want to relax the constraint a bit then, and instead say "(if the length of the byte array) * 8 < VectorMask.Count, zero extend, if greater, truncate etc.).