Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ADO.NET importing/exporting functionality to DataFrame #5975

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eng/Versions.props
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
<MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion>
<MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion>
<SystemDataSqlClientVersion>4.6.1</SystemDataSqlClientVersion>
<SystemDataSQLiteCoreVersion>1.0.112.2</SystemDataSQLiteCoreVersion>
<SystemDataSQLiteCoreVersion>1.0.113</SystemDataSQLiteCoreVersion>
<XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion>
<XUnitVersion>2.4.2</XUnitVersion>
<!-- Opt-out repo features -->
Expand Down
2 changes: 1 addition & 1 deletion eng/helix.proj
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
</MSBuild>

<PropertyGroup>
<HelixPreCommands Condition="$(IsPosixShell)">$(HelixPreCommands);export ML_TEST_DATADIR=$HELIX_CORRELATION_PAYLOAD;export MICROSOFTML_RESOURCE_PATH=$HELIX_WORKITEM_ROOT;sudo chmod -R 777 $HELIX_WORKITEM_ROOT;sudo chown -R $(whoami) $HELIX_WORKITEM_ROOT</HelixPreCommands>
<HelixPreCommands Condition="$(IsPosixShell)">$(HelixPreCommands);export ML_TEST_DATADIR=$HELIX_CORRELATION_PAYLOAD;export MICROSOFTML_RESOURCE_PATH=$HELIX_WORKITEM_ROOT;sudo chmod -R 777 $HELIX_WORKITEM_ROOT;sudo chown -R $USER $HELIX_WORKITEM_ROOT</HelixPreCommands>
<HelixPreCommands Condition="!$(IsPosixShell)">$(HelixPreCommands);set ML_TEST_DATADIR=%HELIX_CORRELATION_PAYLOAD%;set MICROSOFTML_RESOURCE_PATH=%HELIX_WORKITEM_ROOT%</HelixPreCommands>

<HelixPreCommands Condition="$(HelixTargetQueues.ToLowerInvariant().Contains('osx'))">$(HelixPreCommands);install_name_tool -change "/usr/local/opt/libomp/lib/libomp.dylib" "@loader_path/libomp.dylib" libSymSgdNative.dylib</HelixPreCommands>
Expand Down
186 changes: 170 additions & 16 deletions src/Microsoft.Data.Analysis/DataFrame.IO.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Globalization;
using System.IO;
using System.Text;
using System.Threading.Tasks;

namespace Microsoft.Data.Analysis
{
Expand Down Expand Up @@ -109,12 +112,158 @@ public static DataFrame LoadCsv(string filename,
}
}

public static DataFrame LoadFrom(IEnumerable<IList<object>> vals, IList<(string, Type)> columnInfos)
{
var columnsCount = columnInfos.Count;
var columns = new List<DataFrameColumn>(columnsCount);

foreach (var (name, type) in columnInfos)
{
var column = CreateColumn(type, name);
columns.Add(column);
}

var res = new DataFrame(columns);

foreach (var items in vals)
{
for (var c = 0; c < items.Count; c++)
{
items[c] = items[c];
}
res.Append(items, inPlace: true);
}

return res;
}

public void SaveTo(DataTable table)
{
var columnsCount = Columns.Count;

if (table.Columns.Count == 0)
{
foreach (var column in Columns)
{
table.Columns.Add(column.Name, column.DataType);
}
}
else
{
if (table.Columns.Count != columnsCount)
throw new ArgumentException();
for (var c = 0; c < columnsCount; c++)
{
if (table.Columns[c].DataType != Columns[c].DataType)
throw new ArgumentException();
}
}

var items = new object[columnsCount];
foreach (var row in Rows)
{
for (var c = 0; c < columnsCount; c++)
{
items[c] = row[c] ?? DBNull.Value;
}
table.Rows.Add(items);
}
}

public DataTable ToTable()
{
var res = new DataTable();
SaveTo(res);
return res;
}

public static DataFrame FromSchema(DbDataReader reader)
{
var columnsCount = reader.FieldCount;
var columns = new DataFrameColumn[columnsCount];

for (var c = 0; c < columnsCount; c++)
{
var type = reader.GetFieldType(c);
var name = reader.GetName(c);
var column = CreateColumn(type, name);
columns[c] = column;
}

var res = new DataFrame(columns);
return res;
}

public static async Task<DataFrame> LoadFrom(DbDataReader reader)
{
var res = FromSchema(reader);
var columnsCount = reader.FieldCount;

var items = new object[columnsCount];
while (await reader.ReadAsync())
{
for (var c = 0; c < columnsCount; c++)
{
items[c] = reader.IsDBNull(c)
? null
: reader[c];
}
res.Append(items, inPlace: true);
}

reader.Close();

return res;
}

public static async Task<DataFrame> LoadFrom(DbDataAdapter adapter)
{
using var reader = await adapter.SelectCommand.ExecuteReaderAsync();
return await LoadFrom(reader);
}

public void SaveTo(DbDataAdapter dataAdapter, DbProviderFactory factory)
{
using var commandBuilder = factory.CreateCommandBuilder();
commandBuilder.DataAdapter = dataAdapter;
dataAdapter.InsertCommand = commandBuilder.GetInsertCommand();
dataAdapter.UpdateCommand = commandBuilder.GetUpdateCommand();
dataAdapter.DeleteCommand = commandBuilder.GetDeleteCommand();

using var table = ToTable();

var connection = dataAdapter.SelectCommand.Connection;
var needClose = connection.TryOpen();

try
{
using var transaction = connection.BeginTransaction();
try
{
dataAdapter.Update(table);
}
catch
{
transaction.Rollback();
transaction.Dispose();
throw;
}
transaction.Commit();
}
finally
{
if (needClose)
connection.Close();
}
}

/// <summary>
/// return <paramref name="columnIndex"/> of <paramref name="columnNames"/> if not null or empty, otherwise return "Column{i}" where i is <paramref name="columnIndex"/>.
/// </summary>
/// <param name="columnNames">column names.</param>
/// <param name="columnIndex">column index.</param>
/// <returns></returns>

private static string GetColumnName(string[] columnNames, int columnIndex)
{
var defaultColumnName = "Column" + columnIndex.ToString();
Expand All @@ -126,68 +275,68 @@ private static string GetColumnName(string[] columnNames, int columnIndex)
return defaultColumnName;
}

private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex)
private static DataFrameColumn CreateColumn(Type kind, string columnName)
{
DataFrameColumn ret;
if (kind == typeof(bool))
{
ret = new BooleanDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new BooleanDataFrameColumn(columnName);
}
else if (kind == typeof(int))
{
ret = new Int32DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new Int32DataFrameColumn(columnName);
}
else if (kind == typeof(float))
{
ret = new SingleDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new SingleDataFrameColumn(columnName);
}
else if (kind == typeof(string))
{
ret = new StringDataFrameColumn(GetColumnName(columnNames, columnIndex), 0);
ret = new StringDataFrameColumn(columnName, 0);
}
else if (kind == typeof(long))
{
ret = new Int64DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new Int64DataFrameColumn(columnName);
}
else if (kind == typeof(decimal))
{
ret = new DecimalDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new DecimalDataFrameColumn(columnName);
}
else if (kind == typeof(byte))
{
ret = new ByteDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new ByteDataFrameColumn(columnName);
}
else if (kind == typeof(char))
{
ret = new CharDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new CharDataFrameColumn(columnName);
}
else if (kind == typeof(double))
{
ret = new DoubleDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new DoubleDataFrameColumn(columnName);
}
else if (kind == typeof(sbyte))
{
ret = new SByteDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new SByteDataFrameColumn(columnName);
}
else if (kind == typeof(short))
{
ret = new Int16DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new Int16DataFrameColumn(columnName);
}
else if (kind == typeof(uint))
{
ret = new UInt32DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new UInt32DataFrameColumn(columnName);
}
else if (kind == typeof(ulong))
{
ret = new UInt64DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new UInt64DataFrameColumn(columnName);
}
else if (kind == typeof(ushort))
{
ret = new UInt16DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new UInt16DataFrameColumn(columnName);
}
else if (kind == typeof(DateTime))
{
ret = new PrimitiveDataFrameColumn<DateTime>(GetColumnName(columnNames, columnIndex));
ret = new PrimitiveDataFrameColumn<DateTime>(columnName);
}
else
{
Expand All @@ -196,6 +345,11 @@ private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int
return ret;
}

private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex)
{
return CreateColumn(kind, GetColumnName(columnNames, columnIndex));
}

private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringReader wrappedReader,
char separator = ',', bool header = true,
string[] columnNames = null, Type[] dataTypes = null,
Expand Down
37 changes: 37 additions & 0 deletions src/Microsoft.Data.Analysis/Extensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Text;

namespace Microsoft.Data.Analysis
{
public static class Extensions
{
public static DbDataAdapter CreateDataAdapter(this DbProviderFactory factory, DbConnection connection, string tableName)
{
var query = connection.CreateCommand();
query.CommandText = $"SELECT * FROM {tableName}";
var res = factory.CreateDataAdapter();
res.SelectCommand = query;
return res;
}

public static bool TryOpen(this DbConnection connection)
{
if (connection.State == ConnectionState.Closed)
{
connection.Open();
return true;
}
else
{
return false;
}
}
}
}
Loading