Skip to content

Commit 33cac0c

Browse files
committed
Fix up VectorWhitening after rebase.
1 parent 5073467 commit 33cac0c

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

src/Microsoft.ML.Transforms/VectorWhitening.cs

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -613,10 +613,19 @@ public enum SvdJob : byte
613613
MinOvr = (byte)'O',
614614
}
615615

616+
public static unsafe void Gemv(Layout layout, Transpose trans, int m, int n, float alpha,
617+
float[] a, int lda, ReadOnlySpan<float> x, int incx, float beta, Span<float> y, int incy)
618+
{
619+
fixed (float* pA = a)
620+
fixed (float* pX = x)
621+
fixed (float* pY = y)
622+
Gemv(layout, trans, m, n, alpha, pA, lda, pX, incx, beta, pY, incy);
623+
}
624+
616625
// See: https://software.intel.com/en-us/node/520750
617626
[DllImport(DllName, EntryPoint = "cblas_sgemv")]
618-
public static extern void Gemv(Layout layout, Transpose trans, int m, int n, float alpha,
619-
float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy);
627+
private static unsafe extern void Gemv(Layout layout, Transpose trans, int m, int n, float alpha,
628+
float* a, int lda, float* x, int incx, float beta, float* y, int incy);
620629

621630
// See: https://software.intel.com/en-us/node/520775
622631
[DllImport(DllName, EntryPoint = "cblas_sgemm")]
@@ -715,36 +724,34 @@ private ValueGetter<T> GetSrcGetter<T>(IRow input, int iinfo)
715724

716725
private static void FillValues(float[] model, ref VBuffer<float> src, ref VBuffer<float> dst, int cdst)
717726
{
718-
int count = src.Count;
727+
var values = src.GetValues();
728+
int count = values.Length;
719729
int length = src.Length;
720-
var values = src.Values;
721-
var indices = src.Indices;
722-
Contracts.Assert(Utils.Size(values) >= count);
723730

724731
// Since the whitening process produces dense vector, always use dense representation of dst.
725-
var a = Utils.Size(dst.Values) >= cdst ? dst.Values : new float[cdst];
732+
var mutation = VBufferMutationContext.Create(ref dst, cdst);
726733
if (src.IsDense)
727734
{
728735
Mkl.Gemv(Mkl.Layout.RowMajor, Mkl.Transpose.NoTrans, cdst, length,
729-
1, model, length, values, 1, 0, a, 1);
736+
1, model, length, values, 1, 0, mutation.Values, 1);
730737
}
731738
else
732739
{
733-
Contracts.Assert(Utils.Size(indices) >= count);
740+
var indices = src.GetIndices();
734741

735742
int offs = 0;
736743
for (int i = 0; i < cdst; i++)
737744
{
738745
// Returns a dot product of dense vector 'model' starting from offset 'offs' and sparse vector 'values'
739746
// with first 'count' valid elements and their corresponding 'indices'.
740-
a[i] = CpuMathUtils.DotProductSparse(model.AsSpan(offs), values, indices, count);
747+
mutation.Values[i] = CpuMathUtils.DotProductSparse(model.AsSpan(offs), values, indices, count);
741748
offs += length;
742749
}
743750
}
744-
dst = new VBuffer<float>(cdst, a, dst.Indices);
751+
dst = mutation.CreateBuffer();
745752
}
746753

747-
private static float DotProduct(float[] a, int aOffset, float[] b, int[] indices, int count)
754+
private static float DotProduct(float[] a, int aOffset, ReadOnlySpan<float> b, ReadOnlySpan<int> indices, int count)
748755
{
749756
Contracts.Assert(count <= indices.Length);
750757
return CpuMathUtils.DotProductSparse(a.AsSpan(aOffset), b, indices, count);

0 commit comments

Comments
 (0)