Skip to content

[API Proposal]: Arm64: FEAT_SVE: mask #93964

Closed
@a74nh

Description

@a74nh
namespace System.Runtime.Intrinsics.Arm;

/// VectorT Summary
public abstract partial class Sve : AdvSimd /// Feature: FEAT_SVE  Category: mask
{

  /// T: float, double
  public static unsafe Vector<T> AbsoluteCompareGreaterThan(Vector<T> left, Vector<T> right); // FACGT // predicated

  /// T: float, double
  public static unsafe Vector<T> AbsoluteCompareGreaterThanOrEqual(Vector<T> left, Vector<T> right); // FACGE // predicated

  /// T: float, double
  public static unsafe Vector<T> AbsoluteCompareLessThan(Vector<T> left, Vector<T> right); // FACGT // predicated

  /// T: float, double
  public static unsafe Vector<T> AbsoluteCompareLessThanOrEqual(Vector<T> left, Vector<T> right); // FACGE // predicated

  /// T: float, double, int, long, uint, ulong
  public static unsafe Vector<T> Compact(Vector<T> mask, Vector<T> value); // COMPACT

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareEqual(Vector<T> left, Vector<T> right); // FCMEQ or CMPEQ // predicated

  /// T: [sbyte, long], [short, long], [int, long]
  public static unsafe Vector<T> CompareEqual(Vector<T> left, Vector<T2> right); // CMPEQ // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareGreaterThan(Vector<T> left, Vector<T> right); // FCMGT or CMPGT or CMPHI // predicated

  /// T: [sbyte, long], [short, long], [int, long], [byte, ulong], [ushort, ulong], [uint, ulong]
  public static unsafe Vector<T> CompareGreaterThan(Vector<T> left, Vector<T2> right); // CMPGT or CMPHI // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareGreaterThanOrEqual(Vector<T> left, Vector<T> right); // FCMGE or CMPGE or CMPHS // predicated

  /// T: [sbyte, long], [short, long], [int, long], [byte, ulong], [ushort, ulong], [uint, ulong]
  public static unsafe Vector<T> CompareGreaterThanOrEqual(Vector<T> left, Vector<T2> right); // CMPGE or CMPHS // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareLessThan(Vector<T> left, Vector<T> right); // FCMGT or CMPGT or CMPHI // predicated

  /// T: [sbyte, long], [short, long], [int, long], [byte, ulong], [ushort, ulong], [uint, ulong]
  public static unsafe Vector<T> CompareLessThan(Vector<T> left, Vector<T2> right); // CMPLT or CMPLO // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareLessThanOrEqual(Vector<T> left, Vector<T> right); // FCMGE or CMPGE or CMPHS // predicated

  /// T: [sbyte, long], [short, long], [int, long], [byte, ulong], [ushort, ulong], [uint, ulong]
  public static unsafe Vector<T> CompareLessThanOrEqual(Vector<T> left, Vector<T2> right); // CMPLE or CMPLS // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareNotEqualTo(Vector<T> left, Vector<T> right); // FCMNE or CMPNE // predicated

  /// T: [sbyte, long], [short, long], [int, long]
  public static unsafe Vector<T> CompareNotEqualTo(Vector<T> left, Vector<T2> right); // CMPNE // predicated

  /// T: float, double
  public static unsafe Vector<T> CompareUnordered(Vector<T> left, Vector<T> right); // FCMUO // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> ConditionalExtractAfterLastActiveElement(Vector<T> mask, Vector<T> fallback, Vector<T> data); // CLASTA // MOVPRFX

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> ConditionalExtractLastActiveElement(Vector<T> mask, Vector<T> fallback, Vector<T> data); // CLASTB // MOVPRFX

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> ConditionalSelect(Vector<T> mask, Vector<T> left, Vector<T> right); // SEL

  /// T: sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateBreakAfterMask(Vector<T> mask, Vector<T> from); // BRKA // predicated

  /// T: sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateBreakAfterPropagateMask(Vector<T> mask, Vector<T> left, Vector<T> right); // BRKPA

  /// T: sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateBreakBeforeMask(Vector<T> mask, Vector<T> from); // BRKB // predicated

  /// T: sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateBreakBeforePropagateMask(Vector<T> mask, Vector<T> left, Vector<T> right); // BRKPB

  /// T: byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateWhileLessThanMask(int left, int right); // WHILELT

  /// T: byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateWhileLessThanMask(long left, long right); // WHILELT

  /// T: byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateWhileLessThanMask(uint left, uint right); // WHILELO

  /// T: byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateWhileLessThanMask(ulong left, ulong right); // WHILELO

  /// T: byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateWhileLessThanOrEqualMask(int left, int right); // WHILELE

  /// T: byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateWhileLessThanOrEqualMask(long left, long right); // WHILELE

  /// T: byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateWhileLessThanOrEqualMask(uint left, uint right); // WHILELS

  /// T: byte, ushort, uint, ulong
  public static unsafe Vector<T> CreateWhileLessThanOrEqualMask(ulong left, ulong right); // WHILELS

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe T ExtractAfterLast(Vector<T> value); // LASTA // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe T ExtractLast(Vector<T> value); // LASTB // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> ExtractVector(Vector<T> upper, Vector<T> lower, ulong index); // EXT // MOVPRFX

  public static unsafe Vector<byte> FalseMask(); // PFALSE

  /// T: byte, ushort, uint, ulong
  public static unsafe Vector<T> MaskGetFirstSet(Vector<T> mask, Vector<T> from); // PNEXT

  /// T: sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> MaskSetFirst(Vector<T> mask, Vector<T> from); // PFIRST

  /// T: sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe bool MaskTestAnyTrue(Vector<T> mask, Vector<T> from); // PTEST

  /// T: sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe bool MaskTestFirstTrue(Vector<T> mask, Vector<T> from); // PTEST

  /// T: sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe bool MaskTestLastTrue(Vector<T> mask, Vector<T> from); // PTEST

  /// T: sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> PropagateBreak(Vector<T> left, Vector<T> right); // BRKN // predicated

  public static unsafe Vector<byte> TrueMask(); // PTRUE

  public static unsafe Vector<byte> TrueMask([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All); // PTRUE


  // All patterns used by PTRUE.
  public enum SveMaskPattern : byte
  {
    LargestPowerOf2 = 0,   // The largest power of 2.
    VectorCount1 = 1,    // 1 element.
    VectorCount2 = 2,    // 2 elements.
    VectorCount3 = 3,    // 3 elements.
    VectorCount4 = 4,    // 4 elements.
    VectorCount5 = 5,    // 5 elements.
    VectorCount6 = 6,    // 6 elements.
    VectorCount7 = 7,    // 7 elements.
    VectorCount8 = 8,    // 8 elements.
    VectorCount16 = 9,   // 16 elements.
    VectorCount32 = 10,  // 32 elements.
    VectorCount64 = 11,  // 64 elements.
    VectorCount128 = 12, // 128 elements.
    VectorCount256 = 13, // 256 elements.
    LargestMultipleOf4 = 29,  // The largest multiple of 4.
    LargestMultipleOf3 = 30,  // The largest multiple of 3.
    All  = 31    // All available (implicitly a multiple of two).
  };

  /// total method signatures: 45


  /// Optional Entries:

  /// T: float, double
  public static unsafe Vector<T> AbsoluteCompareGreaterThan(Vector<T> left, T right); // FACGT // predicated

  /// T: float, double
  public static unsafe Vector<T> AbsoluteCompareGreaterThanOrEqual(Vector<T> left, T right); // FACGE // predicated

  /// T: float, double
  public static unsafe Vector<T> AbsoluteCompareLessThan(Vector<T> left, T right); // FACGT // predicated

  /// T: float, double
  public static unsafe Vector<T> AbsoluteCompareLessThanOrEqual(Vector<T> left, T right); // FACGE // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareEqual(Vector<T> left, T right); // FCMEQ or CMPEQ // predicated

  /// T: sbyte, short, int
  public static unsafe Vector<T> CompareEqual(Vector<T> left, long right); // CMPEQ // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareGreaterThan(Vector<T> left, T right); // FCMGT or CMPGT or CMPHI // predicated

  /// T: sbyte, short, int
  public static unsafe Vector<T> CompareGreaterThan(Vector<T> left, long right); // CMPGT // predicated

  /// T: byte, ushort, uint
  public static unsafe Vector<T> CompareGreaterThan(Vector<T> left, ulong right); // CMPHI // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareGreaterThanOrEqual(Vector<T> left, T right); // FCMGE or CMPGE or CMPHS // predicated

  /// T: sbyte, short, int
  public static unsafe Vector<T> CompareGreaterThanOrEqual(Vector<T> left, long right); // CMPGE // predicated

  /// T: byte, ushort, uint
  public static unsafe Vector<T> CompareGreaterThanOrEqual(Vector<T> left, ulong right); // CMPHS // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareLessThan(Vector<T> left, T right); // FCMLT or FCMGT or CMPLT or CMPGT or CMPLO or CMPHI // predicated

  /// T: sbyte, short, int
  public static unsafe Vector<T> CompareLessThan(Vector<T> left, long right); // CMPLT // predicated

  /// T: byte, ushort, uint
  public static unsafe Vector<T> CompareLessThan(Vector<T> left, ulong right); // CMPLO // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareLessThanOrEqual(Vector<T> left, T right); // FCMLE or FCMGE or CMPLE or CMPGE or CMPLS or CMPHS // predicated

  /// T: sbyte, short, int
  public static unsafe Vector<T> CompareLessThanOrEqual(Vector<T> left, long right); // CMPLE // predicated

  /// T: byte, ushort, uint
  public static unsafe Vector<T> CompareLessThanOrEqual(Vector<T> left, ulong right); // CMPLS // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe Vector<T> CompareNotEqualTo(Vector<T> left, T right); // FCMNE or CMPNE // predicated

  /// T: sbyte, short, int
  public static unsafe Vector<T> CompareNotEqualTo(Vector<T> left, long right); // CMPNE // predicated

  /// T: float, double
  public static unsafe Vector<T> CompareUnordered(Vector<T> left, T right); // FCMUO // predicated

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe T ConditionalExtractAfterLastActiveElement(Vector<T> mask, T fallback, Vector<T> data); // CLASTA

  /// T: float, double, sbyte, short, int, long, byte, ushort, uint, ulong
  public static unsafe T ConditionalExtractLastActiveElement(Vector<T> mask, T fallback, Vector<T> data); // CLASTB

  /// total optional method signatures: 23

}

For all methods where the first argument is a vector mask, this is used used to determine the active elements in the input.

TrueMask

Set elements of the destination predicate to true if the element number satisfies the constraint pattern, or to false otherwise. If the constraint specifies more elements than are available at the current vector length then all elements of the destination predicate are set to false.

Returns a byte mask. This can safely be cast to a mask of any Vector type.

Compact

Shuffle active elements of vector to the right and fill with zero
Read the active elements from the source vector and pack them into the lowest-numbered elements of the destination vector. Then set any remaining elements of the destination vector to zero.

CreateWhileLessThanMask

While incrementing signed scalar less than scalar

Generate a mask that starting from the lowest numbered element is true while the incrementing value of the scalar operand left is less than the scalar right operand and false thereafter up to the highest numbered element.

This is used to create a mask for the current iteration of a loop, where the final iteration may not fill a full vector width. For example:

// Add all values in an array
for (i=0; i<size; i++)
{
  vector<int> mask = Sve.CreateWhileLessThanMask(i, size);
  vector<int> values = Sve.LoadVectorInt32(mask, base);
  total = Sve.ConditionalSelect(mask, total + values, total);
  base += Sve.GetActiveElementCount(mask);
}
return Sve.AddReduce(total);

CreateWhileLessThanOrEqualMask

While incrementing signed scalar less than scalar. Similar to CreateWhileLessThanMask.

MaskSetFirst

Set the first active mask element to true

Returns a copy of the mask from, with the first active element set to true.

MaskGetFirstSet

Find next active predicate

Used to construct a loop which iterates over all true elements in mask.

If all elements in the mask from are false it finds the first true element in mask. Otherwise it finds the next true element in mask that follows the last true element in from. Returns a mask with the found element true and all other elements false.

TODO: add example.

CreateBreakAfterMask

Break after first true condition

Returns a mask containing true up to and including the first active element in from that is also true. All subsequent elements are set to false. Inactive elements in the destination predicate register remain unmodified or are set to zero, depending on whether merging or zeroing predication is selected

//  for (int i = 0; i < n; i++)
//  {
//    res += a[i] * b[i];
//    if (a[i] == 512) { break; }
//  }
for (int i = 0;
       Sve.TestFirstTrue(Sve.TrueMask(), p0 = Sve.CreateWhileLessThanMask(i, n)); 
       i += Count32BitElements())
  {
    vector<unit> a_vec = Sve.LoadVector(p0, a);
    vector<unit> b_vec = Sve.LoadVector(p0, b);
    vector<unit> p1 = Sve.CompareEqual(p0, a_vec, 512);
    if (Sve.TestAnyTrue(p0, p1))
    {  
      // One of the entries is 512. Get a mask that is true to the first match.
      // Do the final entries with the mask.
      p1 = Sve.CreateBreakAfterMask(p0, p1);
      res_vec = Sve.ConditionalSelect(p1, Sve.MultiplyAdd(, res_vec, a_vec, b_vec), res_vec);
      break;
    }
    else
    {
      // Use the loop mask p0. For the final iteration this may be a parital vector.
      res_vec = Sve.ConditionalSelect(p0, Sve.MultiplyAdd(, res_vec, a_vec, b_vec), res_vec);
      a += Sve.Count32BitElements();
      b += Sve.Count32BitElements();
    }
  }
  return Sve.AddAcross(res_vec);

CreateBreakAfterPropagateMask, CreateBreakBeforePropagateMask

ExtractAfterLast, ConditionalExtractAfterLastActiveElement

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions