Skip to content

Expose VectorMask<T> to support generic masking for Vector<T> #74613

Closed
@anthonycanino

Description

@anthonycanino

Summary

For each Vector API, we introduce a corresponding VectorMask, which abstracts away low-level bit-masking and instead allows to express conditional SIMD processing as boolean logic over Vector APIs. In particular, VectorMask<T> allows to perform masking operations and conditional SIMD processing on the variable length Vector<T> API, which allows for Vector<T> to be used more performantly and closer to SIMD processing done with Vector64/Vector128/Vector256.

Please see dotnet/designs#268 and https://github.com/anthonycanino/designs/blob/main/accepted/2022/enable-512-vectors.md#vectormask-usage for detailed discussion behind the rationale for VectorMask, though the APIs that are posted here reflect the most recent discussion on the proposal at dotnet/designs#268.

The API Proposal focuses on Vector128 and Vector with associated VectorMask128 and VectorMask APIs, but we propose a correponding VectorMaskX for each VectorX API, e.g., Vector64, Vector256 etc.

API Proposal

namespace System.Runtime.Intrinsics
{
    public static partial class Vector64
    {
        public VectorMask64<T> ExtractMask<T>(this Vector64<T> vector);
    }

    public static partial class Vector128
    {
        public VectorMask128<T> ExtractMask<T>(this Vector128<T> vector);
    }

    public static partial class Vector256
    {
        public VectorMask128<T> ExtractMask<T>(this Vector256<T> vector);
    }

    public static partial class Vector512
    {
        public VectorMask512<T> ExtractMask<T>(this Vector512<T> vector);
    }

    public static class VectorMask64
    {
        public bool IsHardwareAccelerated { get; }

        public static VectorMask64<T> Create(ushort mask);

        public static VectorMask64<TTo> As<TFrom, TTo>(this VectorMask64<TFrom> vector) where TFrom : struct where TTo : struct;

        public static VectorMask64<byte>   AsByte  <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<double> AsDouble<T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<short>  AsInt16 <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<int>    AsInt32 <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<long>   AsInt64 <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<nint>   AsNInt  <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<nuint>  AsNUInt <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<sbyte>  AsSByte <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<float>  AsSingle<T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<ushort> AsUInt16<T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<uint>   AsUInt32<T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<ulong>  AsUInt64<T>(this VectorMask64<T> vector) where T : struct;

        public static VectorMask64<T> BitwiseAnd<T>(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> BitwiseOr<T>(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> AndNot<T>(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> OnesComplement<T>(VectorMask64<T> value);
        public static VectorMask64<T> Xor<T>(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> Xnor<T>(VectorMask64<T> left, VectorMask64<T> right);

        public static VectorMask64<T> ShiftLeft<T>(VectorMask64<T> value, int count);
        public static VectorMask64<T> ShiftRight<T>(VectorMask64<T> value, int count);

        public static bool Equals<T>(VectorMask64<T> left, VectorMask64<T> right);

        public static int LeadingZeroCount(VectorMask64<T> mask);
        public static int TrailingZeroCount(VectorMask64<T> mask);
        public static int PopCount(VectorMask64<T> mask);
        
        public static bool GetElement(this Vector64<T> vector, int index) where T : struct;
        public static Vector64Mask<T> WithElement(this Vector64<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct VectorMask64<T> where T : struct 
    {
        private readonly byte _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static VectorMask64<T> AllBitsSet { get; }
        public static VectorMask64<T> Zero { get; }

        public static bool this[int index] { get; }

        public static VectorMask64<T> operator &(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> operator |(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> operator ~(VectorMask64<T> value);
        public static VectorMask64<T> operator ^(VectorMask64<T> left, VectorMask64<T> right);

        public static VectorMask64<T> operator <<(VectorMask64<T> value, int count);
        public static VectorMask64<T> operator >>(VectorMask64<T> value, int count);

        public static bool operator ==(VectorMask64<T> left, VectorMask64<T> right);
        public static bool operator !=(VectorMask64<T> left, VectorMask64<T> right);
    }

    public static class VectorMask128
    {
        public bool IsHardwareAccelerated { get; }

        public static VectorMask128<T> Create(ushort mask);

        public static VectorMask128<TTo> As<TFrom, TTo>(this VectorMask128<TFrom> vector) where TFrom : struct where TTo : struct;

        public static VectorMask128<byte>   AsByte  <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<double> AsDouble<T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<short>  AsInt16 <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<int>    AsInt32 <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<long>   AsInt64 <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<nint>   AsNInt  <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<nuint>  AsNUInt <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<sbyte>  AsSByte <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<float>  AsSingle<T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<ushort> AsUInt16<T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<uint>   AsUInt32<T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<ulong>  AsUInt64<T>(this VectorMask128<T> vector) where T : struct;

        public static VectorMask128<T> BitwiseAnd<T>(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> BitwiseOr<T>(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> AndNot<T>(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> OnesComplement<T>(VectorMask128<T> value);
        public static VectorMask128<T> Xor<T>(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> Xnor<T>(VectorMask128<T> left, VectorMask128<T> right);

        public static VectorMask128<T> ShiftLeft<T>(VectorMask128<T> value, int count);
        public static VectorMask128<T> ShiftRight<T>(VectorMask128<T> value, int count);

        public static bool Equals<T>(VectorMask128<T> left, VectorMask128<T> right);

        public static int LeadingZeroCount(VectorMask128<T> mask);
        public static int TrailingZeroCount(VectorMask128<T> mask);
        public static int PopCount(VectorMask128<T> mask);
        
        public static bool GetElement(this Vector128<T> vector, int index) where T : struct;
        public static Vector128Mask<T> WithElement(this Vector128<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct VectorMask128<T> where T : struct 
    {
        private readonly ushort _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static VectorMask128<T> AllBitsSet { get; }
        public static VectorMask128<T> Zero { get; }

        public static bool this[int index] { get; }

        public static VectorMask128<T> operator &(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> operator |(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> operator ~(VectorMask128<T> value);
        public static VectorMask128<T> operator ^(VectorMask128<T> left, VectorMask128<T> right);

        public static VectorMask128<T> operator <<(VectorMask128<T> value, int count);
        public static VectorMask128<T> operator >>(VectorMask128<T> value, int count);

        public static bool operator ==(VectorMask128<T> left, VectorMask128<T> right);
        public static bool operator !=(VectorMask128<T> left, VectorMask128<T> right);
    }

    public static class VectorMask256
    {
        public bool IsHardwareAccelerated { get; }

        public static VectorMask256<T> Create(ushort mask);

        public static VectorMask256<TTo> As<TFrom, TTo>(this VectorMask256<TFrom> vector) where TFrom : struct where TTo : struct;

        public static VectorMask256<byte>   AsByte  <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<double> AsDouble<T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<short>  AsInt16 <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<int>    AsInt32 <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<long>   AsInt64 <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<nint>   AsNInt  <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<nuint>  AsNUInt <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<sbyte>  AsSByte <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<float>  AsSingle<T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<ushort> AsUInt16<T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<uint>   AsUInt32<T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<ulong>  AsUInt64<T>(this VectorMask256<T> vector) where T : struct;

        public static VectorMask256<T> BitwiseAnd<T>(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> BitwiseOr<T>(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> AndNot<T>(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> OnesComplement<T>(VectorMask256<T> value);
        public static VectorMask256<T> Xor<T>(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> Xnor<T>(VectorMask256<T> left, VectorMask256<T> right);

        public static VectorMask256<T> ShiftLeft<T>(VectorMask256<T> value, int count);
        public static VectorMask256<T> ShiftRight<T>(VectorMask256<T> value, int count);

        public static bool Equals<T>(VectorMask256<T> left, VectorMask256<T> right);

        public static int LeadingZeroCount(VectorMask256<T> mask);
        public static int TrailingZeroCount(VectorMask256<T> mask);
        public static int PopCount(VectorMask256<T> mask);
        
        public static bool GetElement(this Vector256<T> vector, int index) where T : struct;
        public static Vector256Mask<T> WithElement(this Vector256<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct VectorMask256<T> where T : struct 
    {
        private readonly uint _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static VectorMask256<T> AllBitsSet { get; }
        public static VectorMask256<T> Zero { get; }

        public static bool this[int index] { get; }

        public static VectorMask256<T> operator &(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> operator |(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> operator ~(VectorMask256<T> value);
        public static VectorMask256<T> operator ^(VectorMask256<T> left, VectorMask256<T> right);

        public static VectorMask256<T> operator <<(VectorMask256<T> value, int count);
        public static VectorMask256<T> operator >>(VectorMask256<T> value, int count);

        public static bool operator ==(VectorMask256<T> left, VectorMask256<T> right);
        public static bool operator !=(VectorMask256<T> left, VectorMask256<T> right);
    }

    public static class VectorMask512
    {
        public bool IsHardwareAccelerated { get; }

        public static VectorMask512<T> Create(ushort mask);

        public static VectorMask512<TTo> As<TFrom, TTo>(this VectorMask512<TFrom> vector) where TFrom : struct where TTo : struct;

        public static VectorMask512<byte>   AsByte  <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<double> AsDouble<T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<short>  AsInt16 <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<int>    AsInt32 <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<long>   AsInt64 <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<nint>   AsNInt  <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<nuint>  AsNUInt <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<sbyte>  AsSByte <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<float>  AsSingle<T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<ushort> AsUInt16<T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<uint>   AsUInt32<T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<ulong>  AsUInt64<T>(this VectorMask512<T> vector) where T : struct;

        public static VectorMask512<T> BitwiseAnd<T>(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> BitwiseOr<T>(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> AndNot<T>(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> OnesComplement<T>(VectorMask512<T> value);
        public static VectorMask512<T> Xor<T>(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> Xnor<T>(VectorMask512<T> left, VectorMask512<T> right);

        public static VectorMask512<T> ShiftLeft<T>(VectorMask512<T> value, int count);
        public static VectorMask512<T> ShiftRight<T>(VectorMask512<T> value, int count);

        public static bool Equals<T>(VectorMask512<T> left, VectorMask512<T> right);

        public static int LeadingZeroCount(VectorMask512<T> mask);
        public static int TrailingZeroCount(VectorMask512<T> mask);
        public static int PopCount(VectorMask512<T> mask);
        
        public static bool GetElement(this Vector512<T> vector, int index) where T : struct;
        public static Vector512Mask<T> WithElement(this Vector512<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct VectorMask512<T> where T : struct 
    {
        private readonly ulong _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static VectorMask512<T> AllBitsSet { get; }
        public static VectorMask512<T> Zero { get; }

        public static bool this[int index] { get; }

        public static VectorMask512<T> operator &(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> operator |(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> operator ~(VectorMask512<T> value);
        public static VectorMask512<T> operator ^(VectorMask512<T> left, VectorMask512<T> right);

        public static VectorMask512<T> operator <<(VectorMask512<T> value, int count);
        public static VectorMask512<T> operator >>(VectorMask512<T> value, int count);

        public static bool operator ==(VectorMask512<T> left, VectorMask512<T> right);
        public static bool operator !=(VectorMask512<T> left, VectorMask512<T> right);
    }
}

namespace System.Numerics
{
    public static partial class Vector
    {
        public VectorMask<T> ExtractMask<T>(this Vector<T> vector);
    }

    public static class VectorMask
    {
        public bool IsHardwareAccelerated { get; }

        public static VectorMask<T> Create(byte[] value);
        public static VectorMask<T> Create(byte[] value, int index);
        public static VectorMask<T> Create(ReadOnlySpan<byte> value);

        public static VectorMask<TTo> As<TFrom, TTo>(this VectorMask<TFrom> vector) where TFrom : struct where TTo : struct;

        public static VectorMask<byte>   AsByte  <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<double> AsDouble<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<short>  AsInt16 <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<int>    AsInt32 <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<long>   AsInt64 <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<nint>   AsNInt  <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<nuint>  AsNUInt <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<sbyte>  AsSByte <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<float>  AsSingle<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<ushort> AsUInt16<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<uint>   AsUInt32<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<ulong>  AsUInt64<T>(this VectorMask<T> vector) where T : struct;

        public static VectorMask<T> BitwiseAnd<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> BitwiseOr<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> AndNot<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> OnesComplement<T>(VectorMask<T> value);
        public static VectorMask<T> Xor<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> Xnor<T>(VectorMask<T> left, VectorMask<T> right);

        public static VectorMask<T> ShiftLeft<T>(VectorMask<T> value, int count);
        public static VectorMask<T> ShiftRight<T>(VectorMask<T> value, int count);

        public static bool Equals<T>(VectorMask<T> left, VectorMask<T> right);

        public static int LeadingZeroCount(VectorMask<T> mask);
        public static int TrailingZeroCount(VectorMask<T> mask);
        public static int PopCount(VectorMask<T> mask);
        
        public static bool GetElement(this Vector<T> vector, int index) where T : struct;
        public static VectorMask<T> WithElement(this Vector<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct VectorMask<T> where T : struct 
    {
        private readonly ulong _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static VectorMask<T> AllBitsSet { get; }
        public static VectorMask<T> Zero { get; }

        public static bool this[int index] { get; }

        public static VectorMask<T> operator &(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> operator |(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> operator ~(VectorMask<T> value);
        public static VectorMask<T> operator ^(VectorMask<T> left, VectorMask<T> right);

        public static VectorMask<T> operator <<(VectorMask<T> value, int count);
        public static VectorMask<T> operator >>(VectorMask<T> value, int count);

        public static bool operator ==(VectorMask<T> left, VectorMask<T> right);
        public static bool operator !=(VectorMask<T> left, VectorMask<T> right);
    }
}

API Usage

A few points require further discussion:

VectorMask<T> does not have a Create method because like Vector, it's size is unknown. So while it would be nice to have VectorMask<byte>.Create(0xFF00) or VectorMask<byte>.Create(0xFFFF0000), (the first might cover if VectorMask == VectorMask128, the second if VectorMask == VectorMask256) since technically VectorMask and Vector are variable length, it breaks the abstraction a bit. My proposed alternative is to have CreateUnsafe where a boolean array allows to set each bit, and is an error if the length of the boolean array != VectorMask.Count`.

It would be good to also have this done via a byte array to compress the user effort a bit, e.g., instead of VectorMask<int>.CreateUnsafe([true, false, false, true]) we have VectorMask<int>.CreateUnsafe([0x09]). We might want to relax the constraint a bit then, and instead say "(if the length of the byte array) * 8 < VectorMask.Count, zero extend, if greater, truncate etc.).

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions