Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/Microsoft.Data.Analysis/DataFrame.cs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,25 @@ public GroupBy GroupBy(string columnName)
DataFrameColumn column = _columnCollection[columnIndex];
return column.GroupBy(columnIndex, this);
}

/// <summary>
/// Groups the rows of the <see cref="DataFrame"/> by unique values in the <paramref name="columnName"/> column.
/// </summary>
/// <typeparam name="TKey">Type of column used for grouping</typeparam>
/// <param name="columnName">The column used to group unique values</param>
/// <returns>A GroupBy object that stores the group information.</returns>
public GroupBy<TKey> GroupBy<TKey>(string columnName)
{
GroupBy<TKey> group = GroupBy(columnName) as GroupBy<TKey>;

if (group == null)
{
DataFrameColumn column = this[columnName];
throw new InvalidCastException(String.Format(Strings.BadColumnCastDuringGrouping, columnName, column.DataType, typeof(TKey)));
}

return group;
}

// In GroupBy and ReadCsv calls, columns get resized. We need to set the RowCount to reflect the true Length of the DataFrame. This does internal validation
internal void SetTableRowCount(long rowCount)
Expand Down
39 changes: 39 additions & 0 deletions src/Microsoft.Data.Analysis/GroupBy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.Data.Analysis
{
Expand Down Expand Up @@ -72,6 +74,33 @@ public abstract class GroupBy

public class GroupBy<TKey> : GroupBy
{
#region Internal class that implements IGrouping LINQ interface
private class Grouping : IGrouping<TKey, DataFrameRow>
{
private readonly TKey _key;
private readonly IEnumerable<DataFrameRow> _rows;

public Grouping(TKey key, IEnumerable<DataFrameRow> rows)
{
_key = key;
_rows = rows;
}

public TKey Key => _key;

public IEnumerator<DataFrameRow> GetEnumerator()
{
return _rows.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return _rows.GetEnumerator();
}
}

#endregion

private int _groupByColumnIndex;
private IDictionary<TKey, ICollection<long>> _keyToRowIndicesMap;
private DataFrame _dataFrame;
Expand Down Expand Up @@ -464,5 +493,15 @@ public override DataFrame Mean(params string[] columnNames)
return ret;
}

/// <summary>
/// Returns a collection of Grouping objects, where each object represent a set of DataFrameRows having the same Key
/// </summary>
public IEnumerable<IGrouping<TKey, DataFrameRow>> Groupings
{
get
{
return _keyToRowIndicesMap.Select(kvp => new Grouping(kvp.Key, kvp.Value.Select(index => _dataFrame.Rows[index])));
}
}
}
}
11 changes: 10 additions & 1 deletion src/Microsoft.Data.Analysis/Strings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/Microsoft.Data.Analysis/Strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@
<data name="BadColumnCast" xml:space="preserve">
<value>Cannot cast column holding {0} values to type {1}</value>
</data>
<data name="BadColumnCastDuringGrouping" xml:space="preserve">
<value>Cannot cast elements of column '{0}' type of {1} to type {2} used as TKey in grouping </value>
</data>
<data name="CannotParseWithDelimiters" xml:space="preserve">
<value>Line {0} cannot be parsed with the current Delimiters.</value>
</data>
Expand Down
116 changes: 116 additions & 0 deletions test/Microsoft.Data.Analysis.Tests/DataFrameGroupByTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Xunit;

namespace Microsoft.Data.Analysis.Tests
{
public class DataFrameGroupByTests
{
[Fact]
public void TestGroupingWithTKeyTypeofString()
{
const int length = 11;

//Create test dataframe (numbers starting from 0 up to lenght)
DataFrame df = MakeTestDataFrameWithParityAndTensColumns(length);

var grouping = df.GroupBy<string>("Parity").Groupings;

//Check groups count
Assert.Equal(2, grouping.Count());

//Check number of elements in each group
var oddGroup = grouping.Where(gr => gr.Key == "odd").FirstOrDefault();
Assert.NotNull(oddGroup);
Assert.Equal(length/2, oddGroup.Count());

var evenGroup = grouping.Where(gr => gr.Key == "even").FirstOrDefault();
Assert.NotNull(evenGroup);
Assert.Equal(length / 2 + length % 2, evenGroup.Count());


}

[Fact]
public void TestGroupingWithTKey_CornerCases()
{
//Check corner cases
var df = MakeTestDataFrameWithParityAndTensColumns(0);
var grouping = df.GroupBy<string>("Parity").Groupings;
Assert.Empty(grouping);


df = MakeTestDataFrameWithParityAndTensColumns(1);
grouping = df.GroupBy<string>("Parity").Groupings;
Assert.Single(grouping);
Assert.Equal("even", grouping.First().Key);
}


[Fact]
public void TestGroupingWithTKeyPrimitiveType()
{
const int length = 55;

//Create test dataframe (numbers starting from 0 up to lenght)
DataFrame df = MakeTestDataFrameWithParityAndTensColumns(length);

//Group elements by int column, that contain the amount of full tens in each int
var groupings = df.GroupBy<int>("Tens").Groupings.ToDictionary(g => g.Key, g => g.ToList());

//Get the amount of all number based columns
int numberColumnsCount = df.Columns.Count - 2; //except "Parity" and "Tens" columns

//Check each group
for (int i = 0; i < length / 10; i++)
{
Assert.Equal(10, groupings[i].Count());

var rows = groupings[i];
for (int colIndex = 0; colIndex < numberColumnsCount; colIndex++)
{
var values = rows.Select(row => Convert.ToInt32(row[colIndex]));

for (int j = 0; j < 10; j++)
{
Assert.Contains(i * 10 + j, values);
}
}
}

//Last group should contain smaller amount of items
Assert.Equal(length % 10, groupings[length / 10].Count());
}

[Fact]
public void TestGroupingWithTKeyOfWrongType()
{

var message = string.Empty;

//Create test dataframe (numbers starting from 0 up to lenght)
DataFrame df = MakeTestDataFrameWithParityAndTensColumns(1);

//Use wrong type for grouping
Assert.Throws<InvalidCastException>(() => df.GroupBy<double>("Tens"));
}


private DataFrame MakeTestDataFrameWithParityAndTensColumns(int length)
{
DataFrame df = DataFrameTests.MakeDataFrameWithNumericColumns(length, false);
DataFrameColumn parityColumn = new StringDataFrameColumn("Parity", Enumerable.Range(0, length).Select(x => x % 2 == 0 ? "even" : "odd"));
DataFrameColumn tensColumn = new Int32DataFrameColumn("Tens", Enumerable.Range(0, length).Select(x => x / 10));
df.Columns.Insert(df.Columns.Count, parityColumn);
df.Columns.Insert(df.Columns.Count, tensColumn);

return df;
}
}
}