Skip to content

Commit

Permalink
Improvements to the sort routine (#5776)
Browse files Browse the repository at this point in the history
* Improvements to the sort routine

* Fix unit test

* Fold into existing API
  • Loading branch information
Prashanth Govindarajan authored May 10, 2021
1 parent 750956d commit 43c49f6
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 27 deletions.
12 changes: 8 additions & 4 deletions src/Microsoft.Data.Analysis/DataFrame.cs
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ public DataFrame Sample(int numberOfRows)

int shuffleLowerLimit = 0;
int shuffleUpperLimit = (int)Math.Min(Int32.MaxValue, Rows.Count);

int[] shuffleArray = Enumerable.Range(0, shuffleUpperLimit).ToArray();
Random rand = new Random();
while (shuffleLowerLimit < numberOfRows)
Expand All @@ -349,7 +349,7 @@ public DataFrame Sample(int numberOfRows)
ArraySegment<int> segment = new ArraySegment<int>(shuffleArray, 0, shuffleLowerLimit);

PrimitiveDataFrameColumn<int> indices = new PrimitiveDataFrameColumn<int>("indices", segment);

return Clone(indices);
}

Expand Down Expand Up @@ -623,12 +623,16 @@ private void OnColumnsChanged()
private DataFrame Sort(string columnName, bool isAscending)
{
DataFrameColumn column = Columns[columnName];
DataFrameColumn sortIndices = column.GetAscendingSortIndices();
PrimitiveDataFrameColumn<long> sortIndices = column.GetAscendingSortIndices(out Int64DataFrameColumn nullIndices);
for (long i = 0; i < nullIndices.Length; i++)
{
sortIndices.Append(nullIndices[i]);
}
List<DataFrameColumn> newColumns = new List<DataFrameColumn>(Columns.Count);
for (int i = 0; i < Columns.Count; i++)
{
DataFrameColumn oldColumn = Columns[i];
DataFrameColumn newColumn = oldColumn.Clone(sortIndices, !isAscending, oldColumn.NullCount);
DataFrameColumn newColumn = oldColumn.Clone(sortIndices, !isAscending);
Debug.Assert(newColumn.NullCount == oldColumn.NullCount);
newColumns.Add(newColumn);
}
Expand Down
8 changes: 6 additions & 2 deletions src/Microsoft.Data.Analysis/DataFrameColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ public object this[long rowIndex]
/// <param name="ascending"></param>
public virtual DataFrameColumn Sort(bool ascending = true)
{
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices();
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
return Clone(sortIndices, !ascending, NullCount);
}

Expand Down Expand Up @@ -336,7 +336,11 @@ public virtual StringDataFrameColumn Info()
/// </summary>
public virtual DataFrameColumn Description() => throw new NotImplementedException();

internal virtual PrimitiveDataFrameColumn<long> GetAscendingSortIndices() => throw new NotImplementedException();
/// <summary>
/// Returns the indices of non-null values that, when applied, result in this column being sorted in ascending order. Also returns the indices of null values in <paramref name="nullIndices"/>.
/// </summary>
/// <param name="nullIndices">Indices of values that are <see langword="null"/>.</param>
internal virtual PrimitiveDataFrameColumn<long> GetAscendingSortIndices(out Int64DataFrameColumn nullIndices) => throw new NotImplementedException();

internal delegate long GetBufferSortIndex(int bufferIndex, int sortIndex);
internal delegate ValueTuple<T, int> GetValueAndBufferSortIndexAtBuffer<T>(int bufferIndex, int valueIndex);
Expand Down
28 changes: 20 additions & 8 deletions src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Sort.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,46 @@ public partial class PrimitiveDataFrameColumn<T> : DataFrameColumn
{
public new PrimitiveDataFrameColumn<T> Sort(bool ascending = true)
{
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices();
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
return Clone(sortIndices, !ascending, NullCount);
}

internal override PrimitiveDataFrameColumn<long> GetAscendingSortIndices()
internal override PrimitiveDataFrameColumn<long> GetAscendingSortIndices(out Int64DataFrameColumn nullIndices)
{
// The return sortIndices contains only the non null indices.
GetSortIndices(Comparer<T>.Default, out PrimitiveDataFrameColumn<long> sortIndices);
Int64DataFrameColumn sortIndices = GetSortIndices(Comparer<T>.Default, out nullIndices);
return sortIndices;
}

private void GetSortIndices(IComparer<T> comparer, out PrimitiveDataFrameColumn<long> columnSortIndices)
private Int64DataFrameColumn GetSortIndices(IComparer<T> comparer, out Int64DataFrameColumn columnNullIndices)
{
List<List<int>> bufferSortIndices = new List<List<int>>(_columnContainer.Buffers.Count);
columnNullIndices = new Int64DataFrameColumn("NullIndices", NullCount);
long nullIndicesSlot = 0;
// Sort each buffer first
for (int b = 0; b < _columnContainer.Buffers.Count; b++)
{
ReadOnlyDataFrameBuffer<T> buffer = _columnContainer.Buffers[b];
ReadOnlySpan<byte> nullBitMapSpan = _columnContainer.NullBitMapBuffers[b].ReadOnlySpan;
int[] sortIndices = new int[buffer.Length];
for (int i = 0; i < buffer.Length; i++)
{
sortIndices[i] = i;
}
IntrospectiveSort(buffer.ReadOnlySpan, buffer.Length, sortIndices, comparer);
// Bug fix: QuickSort is not stable. When PrimitiveDataFrameColumn has null values and default values, they move around
List<int> nonNullSortIndices = new List<int>();
for (int i = 0; i < sortIndices.Length; i++)
{
if (_columnContainer.IsValid(nullBitMapSpan, sortIndices[i]))
int localSortIndex = sortIndices[i];
if (_columnContainer.IsValid(nullBitMapSpan, localSortIndex))
{
nonNullSortIndices.Add(sortIndices[i]);

}
else
{
columnNullIndices[nullIndicesSlot] = localSortIndex + b * _columnContainer.Buffers[0].Length;
nullIndicesSlot++;
}
}
bufferSortIndices.Add(nonNullSortIndices);
}
Expand Down Expand Up @@ -90,11 +100,13 @@ ValueTuple<T, int> GetFirstNonNullValueAndBufferIndexStartingAtIndex(int bufferI
heapOfValueAndListOfTupleOfSortAndBufferIndex.Add(valueAndBufferIndex.Item1, new List<ValueTuple<int, int>>() { (valueAndBufferIndex.Item2, i) });
}
}
columnSortIndices = new PrimitiveDataFrameColumn<long>("SortIndices");
Int64DataFrameColumn columnSortIndices = new Int64DataFrameColumn("SortIndices");
GetBufferSortIndex getBufferSortIndex = new GetBufferSortIndex((int bufferIndex, int sortIndex) => (bufferSortIndices[bufferIndex][sortIndex]) + bufferIndex * bufferSortIndices[0].Count);
GetValueAndBufferSortIndexAtBuffer<T> getValueAndBufferSortIndexAtBuffer = new GetValueAndBufferSortIndexAtBuffer<T>((int bufferIndex, int sortIndex) => GetFirstNonNullValueAndBufferIndexStartingAtIndex(bufferIndex, sortIndex));
GetBufferLengthAtIndex getBufferLengthAtIndex = new GetBufferLengthAtIndex((int bufferIndex) => bufferSortIndices[bufferIndex].Count);
PopulateColumnSortIndicesWithHeap(heapOfValueAndListOfTupleOfSortAndBufferIndex, columnSortIndices, getBufferSortIndex, getValueAndBufferSortIndexAtBuffer, getBufferLengthAtIndex);

return columnSortIndices;
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ public override double Median()
// Not the most efficient implementation. Using a selection algorithm here would be O(n) instead of O(nLogn)
if (Length == 0)
return 0;
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices();
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
long middle = sortIndices.Length / 2;
double middleValue = (double)Convert.ChangeType(this[sortIndices[middle].Value].Value, typeof(double));
if (Length % 2 == 0)
Expand Down
18 changes: 13 additions & 5 deletions src/Microsoft.Data.Analysis/StringDataFrameColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,25 +171,32 @@ public IEnumerator<string> GetEnumerator()

public new StringDataFrameColumn Sort(bool ascending = true)
{
PrimitiveDataFrameColumn<long> columnSortIndices = GetAscendingSortIndices();
PrimitiveDataFrameColumn<long> columnSortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
return Clone(columnSortIndices, !ascending, NullCount);
}

internal override PrimitiveDataFrameColumn<long> GetAscendingSortIndices()
internal override PrimitiveDataFrameColumn<long> GetAscendingSortIndices(out Int64DataFrameColumn nullIndices)
{
GetSortIndices(Comparer<string>.Default, out PrimitiveDataFrameColumn<long> columnSortIndices);
PrimitiveDataFrameColumn<long> columnSortIndices = GetSortIndices(Comparer<string>.Default, out nullIndices);
return columnSortIndices;
}

private void GetSortIndices(Comparer<string> comparer, out PrimitiveDataFrameColumn<long> columnSortIndices)
private PrimitiveDataFrameColumn<long> GetSortIndices(Comparer<string> comparer, out Int64DataFrameColumn columnNullIndices)
{
List<int[]> bufferSortIndices = new List<int[]>(_stringBuffers.Count);
columnNullIndices = new Int64DataFrameColumn("NullIndices", NullCount);
long nullIndicesSlot = 0;
foreach (List<string> buffer in _stringBuffers)
{
var sortIndices = new int[buffer.Count];
for (int i = 0; i < buffer.Count; i++)
{
sortIndices[i] = i;
if (buffer[i] == null)
{
columnNullIndices[nullIndicesSlot] = i + bufferSortIndices.Count * int.MaxValue;
nullIndicesSlot++;
}
}
// TODO: Refactor the sort routine to also work with IList?
string[] array = buffer.ToArray();
Expand Down Expand Up @@ -227,11 +234,12 @@ ValueTuple<string, int> GetFirstNonNullValueStartingAtIndex(int stringBufferInde
heapOfValueAndListOfTupleOfSortAndBufferIndex.Add(valueAndBufferSortIndex.Item1, new List<ValueTuple<int, int>>() { (valueAndBufferSortIndex.Item2, i) });
}
}
columnSortIndices = new PrimitiveDataFrameColumn<long>("SortIndices");
PrimitiveDataFrameColumn<long> columnSortIndices = new PrimitiveDataFrameColumn<long>("SortIndices");
GetBufferSortIndex getBufferSortIndex = new GetBufferSortIndex((int bufferIndex, int sortIndex) => (bufferSortIndices[bufferIndex][sortIndex]) + bufferIndex * bufferSortIndices[0].Length);
GetValueAndBufferSortIndexAtBuffer<string> getValueAtBuffer = new GetValueAndBufferSortIndexAtBuffer<string>((int bufferIndex, int sortIndex) => GetFirstNonNullValueStartingAtIndex(bufferIndex, sortIndex));
GetBufferLengthAtIndex getBufferLengthAtIndex = new GetBufferLengthAtIndex((int bufferIndex) => bufferSortIndices[bufferIndex].Length);
PopulateColumnSortIndicesWithHeap(heapOfValueAndListOfTupleOfSortAndBufferIndex, columnSortIndices, getBufferSortIndex, getValueAtBuffer, getBufferLengthAtIndex);
return columnSortIndices;
}

public new StringDataFrameColumn Clone(DataFrameColumn mapIndices, bool invertMapIndices, long numberOfNullsToAppend)
Expand Down
51 changes: 44 additions & 7 deletions test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -815,10 +815,10 @@ public void TestOrderBy()

// Sort by "Int" in descending order
sortedDf = df.OrderByDescending("Int");
Assert.Null(sortedDf.Columns["Int"][19]);
Assert.Equal(-1, sortedDf.Columns["Int"][18]);
Assert.Equal(100, sortedDf.Columns["Int"][1]);
Assert.Equal(2000, sortedDf.Columns["Int"][0]);
Assert.Null(sortedDf.Columns["Int"][0]);
Assert.Equal(-1, sortedDf.Columns["Int"][19]);
Assert.Equal(100, sortedDf.Columns["Int"][2]);
Assert.Equal(2000, sortedDf.Columns["Int"][1]);

// Sort by "String" in ascending order
sortedDf = df.OrderBy("String");
Expand All @@ -829,9 +829,9 @@ public void TestOrderBy()

// Sort by "String" in descending order
sortedDf = df.OrderByDescending("String");
Assert.Null(sortedDf.Columns["Int"][19]);
Assert.Equal(8, sortedDf.Columns["Int"][1]);
Assert.Equal(9, sortedDf.Columns["Int"][0]);
Assert.Null(sortedDf.Columns["Int"][0]);
Assert.Equal(8, sortedDf.Columns["Int"][2]);
Assert.Equal(9, sortedDf.Columns["Int"][1]);
}

[Fact]
Expand Down Expand Up @@ -920,6 +920,43 @@ public void TestPrimitiveColumnSort(int numberOfNulls)
Assert.Null(sortedIntColumn[9]);
}

[Fact]
public void TestSortWithDifferentNullCountsInColumns()
{
DataFrame dataFrame = MakeDataFrameWithAllMutableColumnTypes(10);
dataFrame["Int"][3] = null;
dataFrame["String"][3] = null;
DataFrame sorted = dataFrame.OrderBy("Int");
void Verify(DataFrame sortedDataFrame)
{
Assert.Equal(10, sortedDataFrame.Rows.Count);
DataFrameRow lastRow = sortedDataFrame.Rows[sortedDataFrame.Rows.Count - 1];
DataFrameRow penultimateRow = sortedDataFrame.Rows[sortedDataFrame.Rows.Count - 2];
foreach (object value in lastRow)
{
Assert.Null(value);
}

for (int i = 0; i < sortedDataFrame.Columns.Count; i++)
{
string columnName = sortedDataFrame.Columns[i].Name;
if (columnName != "String" && columnName != "Int")
{
Assert.Equal(dataFrame[columnName][3], penultimateRow[i]);
}
else if (columnName == "String" || columnName == "Int")
{
Assert.Null(penultimateRow[i]);
}
}
}

Verify(sorted);

sorted = dataFrame.OrderBy("String");
Verify(sorted);
}

private void VerifyJoin(DataFrame join, DataFrame left, DataFrame right, JoinAlgorithm joinAlgorithm)
{
Int64DataFrameColumn mapIndices = new Int64DataFrameColumn("map", join.Rows.Count);
Expand Down

0 comments on commit 43c49f6

Please sign in to comment.