Skip to content

[API Proposal]: Example usages of a VectorSVE API #88140

Closed
@a74nh

Description

@a74nh

Background and motivation

Adding a vector API for Arm SVE/SVE2 would be useful. SVE is a mandatory feature in Arm 9.0 onwards and is an alternative to NEON. Code written in SVE is vector length agnostic and will automatically scale to the vector length of the machine it is running on, and therefore will only require a single implementation per routine. Use of predication in SVE enables loop heads and tails to be skipped, making code shorter, simpler and easier to write.

This issue provides examples of how such an API might be used.

API Proposal

None provided.

API Usage

  /*
    Sum all the values in an int array.
  */
  public static unsafe int sum_sve(ref int* srcBytes, int length)
  {
    VectorSVE<int> total = Sve.Create((int)0);
    int* src = srcBytes;
    VectorSVEPred pred = Sve.WhileLessThan(i, length);

    /*
      WhileLessThan comes in two variants:
        VectorSVEPred WhileLessThan(int val, int limit)
        VectorSVEComparison WhileLessThan(VectorSVEPred out predicate, int val, int limit)

      A VectorSVEComparison can be tested using the SVE condition codes (none, any, last, nlast etc).
      `if (cmp.nlast) ....`
      `if (Sve.WhileLessThan(out pred, i, length).first) ....`

      `if (cmp)` is the same as doing `if (cmp.any)`

      Ideally the following will not be allowed:
        auto f = Sve.WhileLessThan(out pred, i, length).first
    */

    /*
      Always using a function call for the vector length instead of assigning to a variable will allow
      easier optimisation to INCW (which is faster than incrementing by a variable).
    */

    for (int i = 0; Sve.WhileLessThan(out pred, i, length); i += Sve.VectorLength<int>())
    {
      VectorSVE<int> vec = Sve.LoadUnsafe(pred, ref *src, i);

      /*
        This is the standard sve `add` instruction which uses a merge predicate.
        For each lane in the predicate, add the two vectors. For all other lanes use the first vector.
       */
      total = Sve.MergeAdd(pred, total, vec);
    }

    // No tail call required.
    return Sve.AddAcross(total).ToScalar();
  }


  /*
    Sum all the values in an int array, without predication.
    For performance reasons, it may be better to use an unpredicated loop, followed by a tail.
    Ideally, the user would write the predicated version and the Jit would optimise to this if required.
  */
  public static unsafe int sum_sve_unpredicated_loop(ref int* srcBytes, int length)
  {
    VectorSVE<int> total = Sve.Create((int)0);
    int* src = srcBytes;

    int i = 0;
    for (i = 0; i+Sve.VectorLength<int>() <= length; i+= Sve.VectorLength<int>() )
    {
      VectorSVE<int> vec = Sve.LoadUnsafe(ref *src, i);
      total = Sve.Add(total, vec);
    }

    // Predicated tail.
    VectorSVEPred pred = Sve.WhileLessThan(i, length);
    VectorSVE<int> vec = Sve.LoadUnsafe(pred, ref *src, i);
    total = Sve.MergeAdd(pred, vec, total);

    return Sve.AddAcross(total).ToScalar();
  }


  /*
    Count all the non zero elements in an int array.
  */
  public static unsafe int CountNonZero_sve(ref int* srcBytes, int length)
  {
    VectorSVE<int> total = Sve.Create((int)0);
    int* src = srcBytes;
    VectorSVEPred pred = Sve.WhileLessThan(i, length);
    VectorSVEPred true_pred = Sve.CreatePred(true);

    for (int i = 0; Sve.WhileLessThan(out pred, i, length); i += Sve.VectorLength<int>())
    {
      VectorSVE<int> vec = Sve.LoadUnsafe(pred, ref *src, i);
      VectorSVEPred cmp_res = Sve.CompareGreaterThan(pred, vec, 0);

      total = Sve.MergeAdd(cmp_res, total, vec);
    }

    // No tail call required.
    return Sve.AddAcross(total).ToScalar();
  }


  /*
    Count all the non zero elements in an int array, without predication.
  */
  public static unsafe int CountNonZero_sve_unpredicated_loop(ref int* srcBytes, int length)
  {
    VectorSVE<int> total = Sve.Create((int)0);
    int* src = srcBytes;
    VectorSVEPred pred = Sve.WhileLessThan(i, length);
    VectorSVEPred true_pred = Sve.CreatePred(true);

    // Comparisons require predicates. Therefore for a truely non predicated version, use Neon.
    int vector_length = 16/sizeof(int);
    for (int i = 0; i+vector_length <= length; i+=vector_length)
    {
      Vector128<int> vec = AdvSimd.LoadVector128(src);
      Vector128<int> gt = AdvSimd.CompareGreaterThan(vec, zero);
      Vector128<int> bits = AdvSimd.And(gt, one);

      total = AdvSimd.Add(bits, total);
      src += vector_length;
    }

    // Predicated tail.
    VectorSVEPred pred = Sve.WhileLessThan(i, length);
    VectorSVE<int> vec = Sve.LoadUnsafe(pred, ref *src);
    VectorSVEPred cmp_res = Sve.CompareGreaterThan(pred, vec, 0);
    total = Sve.MergeAdd(cmp_res, total, vec);

    return Sve.AddAcross(total).ToScalar();
  }


  /*
    Count all the elements in a null terminated array of unknown size.
  */
  public static unsafe int CountLength_sve(ref int* srcBytes)
  {
    int* src = srcBytes;
    VectorSVEPred pred = Sve.CreatePred(true);
    int ret = 0;

    while (true)
    {
      VectorSVE<int> vec = Sve.LoadUnsafeUntilFault(pred, ref *src); // LD1FF

      /*
        Reading the fault predicate via RDFFRS will also set the condition flags:
          VectorSVEComparison GetFaultPredicate(VectorSVEPred out fault, VectorSVEPred pred)
       */
      VectorSVEPred fault_pred;

      if (Sve.GetFaultPredicate(out fault_pred, pred).last)
      {
        // Last element is set in fault_pred, therefore the load did not fault.

        /*
          Like WhileLessThan, comparisons come in two variants:
            VectorSVEPred CompareEquals(VectorSVEPred pred, VectorSVE a, VectorSVE b)
            VectorSVEComparison CompareEquals(VectorSVEPred out cmp_result, VectorSVEPred pred, VectorSVE a, VectorSVE b)
         */

        // Look for any zeros across the entire vector.
        VectorSVEPred cmp_zero;
        if (Sve.CompareEquals(out cmp_zero, pred, vec, 0).none)
        {
          // No zeroes found. Continue loop.
          ret += Sve.VectorLength<int>();
        }
        else
        {
          // Zero found. Count up to it and return.
          VectorSVEPred matches = Sve.PredFillUpToFirstMatch(pred, cmp_zero); // BRKB
          ret += Sve.PredCountTrue(matches); // INCP
          return ret;
        }
      }
      else
      {
        // Load caused a fault.

        // Look for any zeros across the vector up until the fault.
        VectorSVEPred cmp_zero;
        if (Sve.CompareEquals(out cmp_zero, fault_pred, vec, 0).none)
        {
          // No zeroes found. Clear faulting predicate and continue loop.
          ret += Sve.PredCountTrue(fault_pred); // INCP
          Sve.ClearFaultPredicate(); // SETFFR
        }
        else
        {
          // Zero found. Count up to it and return.
          VectorSVEPred matches = Sve.PredFillUpToFirstMatch(pred, cmp_zero); // BRKB
          ret += Sve.PredCountTrue(matches); // INCP
          return ret;
        }
      }
    }
  }

Alternative Designs

No response

Risks

References

SVE Programming Examples
A64 -- SVE Instructions (alphabetic order)

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions