Skip to content

Commit 9c80608

Browse files
author
Prashanth Govindarajan
authored
Support for Exploded columns types in Arrow and IO scenarios (dotnet#2885)
* Support for Exploded columns types in Arrow and IO scenarios * Unit tests * Address feedback
1 parent a6c34d0 commit 9c80608

File tree

4 files changed

+133
-30
lines changed

4 files changed

+133
-30
lines changed

src/Microsoft.Data.Analysis/DataFrame.Arrow.cs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,43 +37,43 @@ public static DataFrame FromArrowRecordBatch(RecordBatch recordBatch)
3737
BooleanArray arrowBooleanArray = (BooleanArray)arrowArray;
3838
ReadOnlyMemory<byte> valueBuffer = arrowBooleanArray.ValueBuffer.Memory;
3939
ReadOnlyMemory<byte> nullBitMapBuffer = arrowBooleanArray.NullBitmapBuffer.Memory;
40-
dataFrameColumn = new PrimitiveDataFrameColumn<bool>(field.Name, valueBuffer, nullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
40+
dataFrameColumn = new BooleanDataFrameColumn(field.Name, valueBuffer, nullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
4141
break;
4242
case ArrowTypeId.Double:
4343
PrimitiveArray<double> arrowDoubleArray = (PrimitiveArray<double>)arrowArray;
4444
ReadOnlyMemory<byte> doubleValueBuffer = arrowDoubleArray.ValueBuffer.Memory;
4545
ReadOnlyMemory<byte> doubleNullBitMapBuffer = arrowDoubleArray.NullBitmapBuffer.Memory;
46-
dataFrameColumn = new PrimitiveDataFrameColumn<double>(field.Name, doubleValueBuffer, doubleNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
46+
dataFrameColumn = new DoubleDataFrameColumn(field.Name, doubleValueBuffer, doubleNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
4747
break;
4848
case ArrowTypeId.Float:
4949
PrimitiveArray<float> arrowFloatArray = (PrimitiveArray<float>)arrowArray;
5050
ReadOnlyMemory<byte> floatValueBuffer = arrowFloatArray.ValueBuffer.Memory;
5151
ReadOnlyMemory<byte> floatNullBitMapBuffer = arrowFloatArray.NullBitmapBuffer.Memory;
52-
dataFrameColumn = new PrimitiveDataFrameColumn<float>(field.Name, floatValueBuffer, floatNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
52+
dataFrameColumn = new SingleDataFrameColumn(field.Name, floatValueBuffer, floatNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
5353
break;
5454
case ArrowTypeId.Int8:
5555
PrimitiveArray<sbyte> arrowsbyteArray = (PrimitiveArray<sbyte>)arrowArray;
5656
ReadOnlyMemory<byte> sbyteValueBuffer = arrowsbyteArray.ValueBuffer.Memory;
5757
ReadOnlyMemory<byte> sbyteNullBitMapBuffer = arrowsbyteArray.NullBitmapBuffer.Memory;
58-
dataFrameColumn = new PrimitiveDataFrameColumn<sbyte>(field.Name, sbyteValueBuffer, sbyteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
58+
dataFrameColumn = new SByteDataFrameColumn(field.Name, sbyteValueBuffer, sbyteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
5959
break;
6060
case ArrowTypeId.Int16:
6161
PrimitiveArray<short> arrowshortArray = (PrimitiveArray<short>)arrowArray;
6262
ReadOnlyMemory<byte> shortValueBuffer = arrowshortArray.ValueBuffer.Memory;
6363
ReadOnlyMemory<byte> shortNullBitMapBuffer = arrowshortArray.NullBitmapBuffer.Memory;
64-
dataFrameColumn = new PrimitiveDataFrameColumn<short>(field.Name, shortValueBuffer, shortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
64+
dataFrameColumn = new Int16DataFrameColumn(field.Name, shortValueBuffer, shortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
6565
break;
6666
case ArrowTypeId.Int32:
6767
PrimitiveArray<int> arrowIntArray = (PrimitiveArray<int>)arrowArray;
6868
ReadOnlyMemory<byte> intValueBuffer = arrowIntArray.ValueBuffer.Memory;
6969
ReadOnlyMemory<byte> intNullBitMapBuffer = arrowIntArray.NullBitmapBuffer.Memory;
70-
dataFrameColumn = new PrimitiveDataFrameColumn<int>(field.Name, intValueBuffer, intNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
70+
dataFrameColumn = new Int32DataFrameColumn(field.Name, intValueBuffer, intNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
7171
break;
7272
case ArrowTypeId.Int64:
7373
PrimitiveArray<long> arrowLongArray = (PrimitiveArray<long>)arrowArray;
7474
ReadOnlyMemory<byte> longValueBuffer = arrowLongArray.ValueBuffer.Memory;
7575
ReadOnlyMemory<byte> longNullBitMapBuffer = arrowLongArray.NullBitmapBuffer.Memory;
76-
dataFrameColumn = new PrimitiveDataFrameColumn<long>(field.Name, longValueBuffer, longNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
76+
dataFrameColumn = new Int64DataFrameColumn(field.Name, longValueBuffer, longNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
7777
break;
7878
case ArrowTypeId.String:
7979
StringArray stringArray = (StringArray)arrowArray;
@@ -86,25 +86,25 @@ public static DataFrame FromArrowRecordBatch(RecordBatch recordBatch)
8686
PrimitiveArray<byte> arrowbyteArray = (PrimitiveArray<byte>)arrowArray;
8787
ReadOnlyMemory<byte> byteValueBuffer = arrowbyteArray.ValueBuffer.Memory;
8888
ReadOnlyMemory<byte> byteNullBitMapBuffer = arrowbyteArray.NullBitmapBuffer.Memory;
89-
dataFrameColumn = new PrimitiveDataFrameColumn<byte>(field.Name, byteValueBuffer, byteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
89+
dataFrameColumn = new ByteDataFrameColumn(field.Name, byteValueBuffer, byteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
9090
break;
9191
case ArrowTypeId.UInt16:
9292
PrimitiveArray<ushort> arrowUshortArray = (PrimitiveArray<ushort>)arrowArray;
9393
ReadOnlyMemory<byte> ushortValueBuffer = arrowUshortArray.ValueBuffer.Memory;
9494
ReadOnlyMemory<byte> ushortNullBitMapBuffer = arrowUshortArray.NullBitmapBuffer.Memory;
95-
dataFrameColumn = new PrimitiveDataFrameColumn<ushort>(field.Name, ushortValueBuffer, ushortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
95+
dataFrameColumn = new UInt16DataFrameColumn(field.Name, ushortValueBuffer, ushortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
9696
break;
9797
case ArrowTypeId.UInt32:
9898
PrimitiveArray<uint> arrowUintArray = (PrimitiveArray<uint>)arrowArray;
9999
ReadOnlyMemory<byte> uintValueBuffer = arrowUintArray.ValueBuffer.Memory;
100100
ReadOnlyMemory<byte> uintNullBitMapBuffer = arrowUintArray.NullBitmapBuffer.Memory;
101-
dataFrameColumn = new PrimitiveDataFrameColumn<uint>(field.Name, uintValueBuffer, uintNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
101+
dataFrameColumn = new UInt32DataFrameColumn(field.Name, uintValueBuffer, uintNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
102102
break;
103103
case ArrowTypeId.UInt64:
104104
PrimitiveArray<ulong> arrowUlongArray = (PrimitiveArray<ulong>)arrowArray;
105105
ReadOnlyMemory<byte> ulongValueBuffer = arrowUlongArray.ValueBuffer.Memory;
106106
ReadOnlyMemory<byte> ulongNullBitMapBuffer = arrowUlongArray.NullBitmapBuffer.Memory;
107-
dataFrameColumn = new PrimitiveDataFrameColumn<ulong>(field.Name, ulongValueBuffer, ulongNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
107+
dataFrameColumn = new UInt64DataFrameColumn(field.Name, ulongValueBuffer, ulongNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
108108
break;
109109
case ArrowTypeId.Decimal:
110110
case ArrowTypeId.Binary:

src/Microsoft.Data.Analysis/DataFrame.IO.cs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -106,69 +106,69 @@ public static DataFrame LoadCsv(string filename,
106106
}
107107
}
108108

109+
private static string GetColumnName(string[] columnNames, int columnIndex)
110+
{
111+
return columnNames == null ? "Column" + columnIndex.ToString() : columnNames[columnIndex];
112+
}
113+
109114
private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex)
110115
{
111-
PrimitiveDataFrameColumn<T> CreatePrimitiveDataFrameColumn<T>()
112-
where T : unmanaged
113-
{
114-
return new PrimitiveDataFrameColumn<T>(columnNames == null ? "Column" + columnIndex.ToString() : columnNames[columnIndex]);
115-
}
116116
DataFrameColumn ret;
117117
if (kind == typeof(bool))
118118
{
119-
ret = CreatePrimitiveDataFrameColumn<bool>();
119+
ret = new BooleanDataFrameColumn(GetColumnName(columnNames, columnIndex));
120120
}
121121
else if (kind == typeof(int))
122122
{
123-
ret = CreatePrimitiveDataFrameColumn<int>();
123+
ret = new Int32DataFrameColumn(GetColumnName(columnNames, columnIndex));
124124
}
125125
else if (kind == typeof(float))
126126
{
127-
ret = CreatePrimitiveDataFrameColumn<float>();
127+
ret = new SingleDataFrameColumn(GetColumnName(columnNames, columnIndex));
128128
}
129129
else if (kind == typeof(string))
130130
{
131-
ret = new StringDataFrameColumn(columnNames == null ? "Column" + columnIndex.ToString() : columnNames[columnIndex], 0);
131+
ret = new StringDataFrameColumn(GetColumnName(columnNames, columnIndex), 0);
132132
}
133133
else if (kind == typeof(long))
134134
{
135-
ret = CreatePrimitiveDataFrameColumn<long>();
135+
ret = new Int64DataFrameColumn(GetColumnName(columnNames, columnIndex));
136136
}
137137
else if (kind == typeof(decimal))
138138
{
139-
ret = CreatePrimitiveDataFrameColumn<decimal>();
139+
ret = new DecimalDataFrameColumn(GetColumnName(columnNames, columnIndex));
140140
}
141141
else if (kind == typeof(byte))
142142
{
143-
ret = CreatePrimitiveDataFrameColumn<byte>();
143+
ret = new ByteDataFrameColumn(GetColumnName(columnNames, columnIndex));
144144
}
145145
else if (kind == typeof(char))
146146
{
147-
ret = CreatePrimitiveDataFrameColumn<char>();
147+
ret = new CharDataFrameColumn(GetColumnName(columnNames, columnIndex));
148148
}
149149
else if (kind == typeof(double))
150150
{
151-
ret = CreatePrimitiveDataFrameColumn<double>();
151+
ret = new DoubleDataFrameColumn(GetColumnName(columnNames, columnIndex));
152152
}
153153
else if (kind == typeof(sbyte))
154154
{
155-
ret = CreatePrimitiveDataFrameColumn<sbyte>();
155+
ret = new SByteDataFrameColumn(GetColumnName(columnNames, columnIndex));
156156
}
157157
else if (kind == typeof(short))
158158
{
159-
ret = CreatePrimitiveDataFrameColumn<short>();
159+
ret = new Int16DataFrameColumn(GetColumnName(columnNames, columnIndex));
160160
}
161161
else if (kind == typeof(uint))
162162
{
163-
ret = CreatePrimitiveDataFrameColumn<uint>();
163+
ret = new UInt32DataFrameColumn(GetColumnName(columnNames, columnIndex));
164164
}
165165
else if (kind == typeof(ulong))
166166
{
167-
ret = CreatePrimitiveDataFrameColumn<ulong>();
167+
ret = new UInt64DataFrameColumn(GetColumnName(columnNames, columnIndex));
168168
}
169169
else if (kind == typeof(ushort))
170170
{
171-
ret = CreatePrimitiveDataFrameColumn<ushort>();
171+
ret = new UInt16DataFrameColumn(GetColumnName(columnNames, columnIndex));
172172
}
173173
else
174174
{

tests/Microsoft.Data.Analysis.Tests/ArrowIntegrationTests.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ public void TestArrowIntegration()
5050
.Build();
5151

5252
DataFrame df = DataFrame.FromArrowRecordBatch(originalBatch);
53+
DataFrameTests.VerifyColumnTypes(df, testArrowStringColumn: true);
5354

5455
IEnumerable<RecordBatch> recordBatches = df.ToArrowRecordBatches();
5556

tests/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,101 @@
55
using System;
66
using System.IO;
77
using System.Text;
8+
using Apache.Arrow;
89
using Xunit;
910

1011
namespace Microsoft.Data.Analysis.Tests
1112
{
1213
public partial class DataFrameTests
1314
{
15+
internal static void VerifyColumnTypes(DataFrame df, bool testArrowStringColumn = false)
16+
{
17+
foreach (DataFrameColumn column in df.Columns)
18+
{
19+
Type dataType = column.DataType;
20+
if (dataType == typeof(bool))
21+
{
22+
Assert.IsType<BooleanDataFrameColumn>(column);
23+
24+
}
25+
else if (dataType == typeof(decimal))
26+
{
27+
Assert.IsType<DecimalDataFrameColumn>(column);
28+
29+
}
30+
else if (dataType == typeof(byte))
31+
{
32+
Assert.IsType<ByteDataFrameColumn>(column);
33+
34+
}
35+
else if (dataType == typeof(char))
36+
{
37+
Assert.IsType<CharDataFrameColumn>(column);
38+
39+
}
40+
else if (dataType == typeof(double))
41+
{
42+
Assert.IsType<DoubleDataFrameColumn>(column);
43+
44+
}
45+
else if (dataType == typeof(float))
46+
{
47+
Assert.IsType<SingleDataFrameColumn>(column);
48+
49+
}
50+
else if (dataType == typeof(int))
51+
{
52+
Assert.IsType<Int32DataFrameColumn>(column);
53+
54+
}
55+
else if (dataType == typeof(long))
56+
{
57+
58+
Assert.IsType<Int64DataFrameColumn>(column);
59+
}
60+
else if (dataType == typeof(sbyte))
61+
{
62+
Assert.IsType<SByteDataFrameColumn>(column);
63+
64+
}
65+
else if (dataType == typeof(short))
66+
{
67+
Assert.IsType<Int16DataFrameColumn>(column);
68+
69+
}
70+
else if (dataType == typeof(uint))
71+
{
72+
Assert.IsType<UInt32DataFrameColumn>(column);
73+
74+
}
75+
else if (dataType == typeof(ulong))
76+
{
77+
78+
Assert.IsType<UInt64DataFrameColumn>(column);
79+
}
80+
else if (dataType == typeof(ushort))
81+
{
82+
Assert.IsType<UInt16DataFrameColumn>(column);
83+
84+
}
85+
else if (dataType == typeof(string))
86+
{
87+
if (!testArrowStringColumn)
88+
{
89+
Assert.IsType<StringDataFrameColumn>(column);
90+
}
91+
else
92+
{
93+
Assert.IsType<ArrowStringDataFrameColumn>(column);
94+
}
95+
}
96+
else
97+
{
98+
throw new NotImplementedException("Unit test has to be updated");
99+
}
100+
}
101+
}
102+
14103
[Fact]
15104
public void TestReadCsvWithHeader()
16105
{
@@ -28,11 +117,13 @@ Stream GetStream(string streamData)
28117
Assert.Equal(4, df.Rows.Count);
29118
Assert.Equal(7, df.Columns.Count);
30119
Assert.Equal("CMT", df.Columns["vendor_id"][3]);
120+
VerifyColumnTypes(df);
31121

32122
DataFrame reducedRows = DataFrame.LoadCsv(GetStream(data), numberOfRowsToRead: 3);
33123
Assert.Equal(3, reducedRows.Rows.Count);
34124
Assert.Equal(7, reducedRows.Columns.Count);
35125
Assert.Equal("CMT", reducedRows.Columns["vendor_id"][2]);
126+
VerifyColumnTypes(df);
36127
}
37128

38129
[Fact]
@@ -51,11 +142,13 @@ Stream GetStream(string streamData)
51142
Assert.Equal(4, df.Rows.Count);
52143
Assert.Equal(7, df.Columns.Count);
53144
Assert.Equal("CMT", df.Columns["Column0"][3]);
145+
VerifyColumnTypes(df);
54146

55147
DataFrame reducedRows = DataFrame.LoadCsv(GetStream(data), header: false, numberOfRowsToRead: 3);
56148
Assert.Equal(3, reducedRows.Rows.Count);
57149
Assert.Equal(7, reducedRows.Columns.Count);
58150
Assert.Equal("CMT", reducedRows.Columns["Column0"][2]);
151+
VerifyColumnTypes(df);
59152
}
60153

61154
[Fact]
@@ -83,6 +176,7 @@ Stream GetStream(string streamData)
83176
Assert.True(typeof(float) == df.Columns[4].DataType);
84177
Assert.True(typeof(string) == df.Columns[5].DataType);
85178
Assert.True(typeof(double) == df.Columns[6].DataType);
179+
VerifyColumnTypes(df);
86180

87181
foreach (var column in df.Columns)
88182
{
@@ -124,11 +218,13 @@ Stream GetStream(string streamData)
124218
Assert.Equal(5, df.Rows.Count);
125219
Assert.Equal(7, df.Columns.Count);
126220
Assert.Equal("CMT", df.Columns["vendor_id"][4]);
221+
VerifyColumnTypes(df);
127222

128223
DataFrame reducedRows = DataFrame.LoadCsv(GetStream(data), separator: '|', numberOfRowsToRead: 3);
129224
Assert.Equal(3, reducedRows.Rows.Count);
130225
Assert.Equal(7, reducedRows.Columns.Count);
131226
Assert.Equal("CMT", reducedRows.Columns["vendor_id"][2]);
227+
VerifyColumnTypes(df);
132228

133229
var nullRow = df.Rows[3];
134230
Assert.Equal("", nullRow[0]);
@@ -159,11 +255,13 @@ Stream GetStream(string streamData)
159255
Assert.Equal(5, df.Rows.Count);
160256
Assert.Equal(7, df.Columns.Count);
161257
Assert.Equal("CMT", df.Columns["vendor_id"][4]);
258+
VerifyColumnTypes(df);
162259

163260
DataFrame reducedRows = DataFrame.LoadCsv(GetStream(data), separator: ';', numberOfRowsToRead: 3);
164261
Assert.Equal(3, reducedRows.Rows.Count);
165262
Assert.Equal(7, reducedRows.Columns.Count);
166263
Assert.Equal("CMT", reducedRows.Columns["vendor_id"][2]);
264+
VerifyColumnTypes(df);
167265

168266
var nullRow = df.Rows[3];
169267
Assert.Equal("", nullRow[0]);
@@ -193,11 +291,13 @@ Stream GetStream(string streamData)
193291
Assert.Equal(4, df.Rows.Count);
194292
Assert.Equal(7, df.Columns.Count);
195293
Assert.Equal("CMT", df.Columns["vendor_id"][3]);
294+
VerifyColumnTypes(df);
196295

197296
DataFrame reducedRows = DataFrame.LoadCsv(GetStream(data), numberOfRowsToRead: 3);
198297
Assert.Equal(3, reducedRows.Rows.Count);
199298
Assert.Equal(7, reducedRows.Columns.Count);
200299
Assert.Equal("CMT", reducedRows.Columns["vendor_id"][2]);
300+
VerifyColumnTypes(df);
201301
}
202302

203303
[Fact]
@@ -235,11 +335,13 @@ Stream GetStream(string streamData)
235335
Assert.Equal(4, df.Rows.Count);
236336
Assert.Equal(6, df.Columns.Count);
237337
Assert.Equal("CMT", df.Columns["vendor_id"][3]);
338+
VerifyColumnTypes(df);
238339

239340
DataFrame reducedRows = DataFrame.LoadCsv(GetStream(data), numberOfRowsToRead: 3);
240341
Assert.Equal(3, reducedRows.Rows.Count);
241342
Assert.Equal(6, reducedRows.Columns.Count);
242343
Assert.Equal("CMT", reducedRows.Columns["vendor_id"][2]);
344+
VerifyColumnTypes(df);
243345

244346
}
245347
}

0 commit comments

Comments
 (0)