Skip to content

Commit

Permalink
Added ADO.NET importing/exporting functionality to DataFrame (#5975)
Browse files Browse the repository at this point in the history
* refactoring - removed copy/paste in DataFrame.CreateColumn()

* added a universal loading method and export to DataTable

* added tests for new loading/saving methods in DataFrame

* improved error handling - DataFrame.LoadFrom()

* DataFrame - importing and exporting data using ADO.NET providers

* DataFrame.LoadFrom() - use async

* DataFrame.LoadFrom() - minor refactorings

* Update Microsoft.Data.Analysis.Tests.csproj

Changed version of System.Data.SQLite

* Update Microsoft.Data.Analysis.Tests.csproj

* fixed chown command

* sql db test path change

* sql db test path change

* sql db test fix

* sql db test fix

---------

Co-authored-by: Michael Sharp <51342856+michaelgsharp@users.noreply.github.com>
Co-authored-by: Michael Sharp <misharp@microsoft.com>
  • Loading branch information
3 people authored May 9, 2023
1 parent ff3b1b9 commit 3d705bf
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 18 deletions.
2 changes: 1 addition & 1 deletion eng/Versions.props
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
<MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion>
<MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion>
<SystemDataSqlClientVersion>4.6.1</SystemDataSqlClientVersion>
<SystemDataSQLiteCoreVersion>1.0.112.2</SystemDataSQLiteCoreVersion>
<SystemDataSQLiteCoreVersion>1.0.113</SystemDataSQLiteCoreVersion>
<XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion>
<XUnitVersion>2.4.2</XUnitVersion>
<!-- Opt-out repo features -->
Expand Down
2 changes: 1 addition & 1 deletion eng/helix.proj
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
</MSBuild>

<PropertyGroup>
<HelixPreCommands Condition="$(IsPosixShell)">$(HelixPreCommands);export ML_TEST_DATADIR=$HELIX_CORRELATION_PAYLOAD;export MICROSOFTML_RESOURCE_PATH=$HELIX_WORKITEM_ROOT;sudo chmod -R 777 $HELIX_WORKITEM_ROOT;sudo chown -R $(whoami) $HELIX_WORKITEM_ROOT</HelixPreCommands>
<HelixPreCommands Condition="$(IsPosixShell)">$(HelixPreCommands);export ML_TEST_DATADIR=$HELIX_CORRELATION_PAYLOAD;export MICROSOFTML_RESOURCE_PATH=$HELIX_WORKITEM_ROOT;sudo chmod -R 777 $HELIX_WORKITEM_ROOT;sudo chown -R $USER $HELIX_WORKITEM_ROOT</HelixPreCommands>
<HelixPreCommands Condition="!$(IsPosixShell)">$(HelixPreCommands);set ML_TEST_DATADIR=%HELIX_CORRELATION_PAYLOAD%;set MICROSOFTML_RESOURCE_PATH=%HELIX_WORKITEM_ROOT%</HelixPreCommands>

<HelixPreCommands Condition="$(HelixTargetQueues.ToLowerInvariant().Contains('osx'))">$(HelixPreCommands);install_name_tool -change "/usr/local/opt/libomp/lib/libomp.dylib" "@loader_path/libomp.dylib" libSymSgdNative.dylib</HelixPreCommands>
Expand Down
186 changes: 170 additions & 16 deletions src/Microsoft.Data.Analysis/DataFrame.IO.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Globalization;
using System.IO;
using System.Text;
using System.Threading.Tasks;

namespace Microsoft.Data.Analysis
{
Expand Down Expand Up @@ -109,12 +112,158 @@ public static DataFrame LoadCsv(string filename,
}
}

public static DataFrame LoadFrom(IEnumerable<IList<object>> vals, IList<(string, Type)> columnInfos)
{
var columnsCount = columnInfos.Count;
var columns = new List<DataFrameColumn>(columnsCount);

foreach (var (name, type) in columnInfos)
{
var column = CreateColumn(type, name);
columns.Add(column);
}

var res = new DataFrame(columns);

foreach (var items in vals)
{
for (var c = 0; c < items.Count; c++)
{
items[c] = items[c];
}
res.Append(items, inPlace: true);
}

return res;
}

public void SaveTo(DataTable table)
{
var columnsCount = Columns.Count;

if (table.Columns.Count == 0)
{
foreach (var column in Columns)
{
table.Columns.Add(column.Name, column.DataType);
}
}
else
{
if (table.Columns.Count != columnsCount)
throw new ArgumentException();
for (var c = 0; c < columnsCount; c++)
{
if (table.Columns[c].DataType != Columns[c].DataType)
throw new ArgumentException();
}
}

var items = new object[columnsCount];
foreach (var row in Rows)
{
for (var c = 0; c < columnsCount; c++)
{
items[c] = row[c] ?? DBNull.Value;
}
table.Rows.Add(items);
}
}

public DataTable ToTable()
{
var res = new DataTable();
SaveTo(res);
return res;
}

public static DataFrame FromSchema(DbDataReader reader)
{
var columnsCount = reader.FieldCount;
var columns = new DataFrameColumn[columnsCount];

for (var c = 0; c < columnsCount; c++)
{
var type = reader.GetFieldType(c);
var name = reader.GetName(c);
var column = CreateColumn(type, name);
columns[c] = column;
}

var res = new DataFrame(columns);
return res;
}

public static async Task<DataFrame> LoadFrom(DbDataReader reader)
{
var res = FromSchema(reader);
var columnsCount = reader.FieldCount;

var items = new object[columnsCount];
while (await reader.ReadAsync())
{
for (var c = 0; c < columnsCount; c++)
{
items[c] = reader.IsDBNull(c)
? null
: reader[c];
}
res.Append(items, inPlace: true);
}

reader.Close();

return res;
}

public static async Task<DataFrame> LoadFrom(DbDataAdapter adapter)
{
using var reader = await adapter.SelectCommand.ExecuteReaderAsync();
return await LoadFrom(reader);
}

public void SaveTo(DbDataAdapter dataAdapter, DbProviderFactory factory)
{
using var commandBuilder = factory.CreateCommandBuilder();
commandBuilder.DataAdapter = dataAdapter;
dataAdapter.InsertCommand = commandBuilder.GetInsertCommand();
dataAdapter.UpdateCommand = commandBuilder.GetUpdateCommand();
dataAdapter.DeleteCommand = commandBuilder.GetDeleteCommand();

using var table = ToTable();

var connection = dataAdapter.SelectCommand.Connection;
var needClose = connection.TryOpen();

try
{
using var transaction = connection.BeginTransaction();
try
{
dataAdapter.Update(table);
}
catch
{
transaction.Rollback();
transaction.Dispose();
throw;
}
transaction.Commit();
}
finally
{
if (needClose)
connection.Close();
}
}

/// <summary>
/// return <paramref name="columnIndex"/> of <paramref name="columnNames"/> if not null or empty, otherwise return "Column{i}" where i is <paramref name="columnIndex"/>.
/// </summary>
/// <param name="columnNames">column names.</param>
/// <param name="columnIndex">column index.</param>
/// <returns></returns>

private static string GetColumnName(string[] columnNames, int columnIndex)
{
var defaultColumnName = "Column" + columnIndex.ToString();
Expand All @@ -126,68 +275,68 @@ private static string GetColumnName(string[] columnNames, int columnIndex)
return defaultColumnName;
}

private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex)
private static DataFrameColumn CreateColumn(Type kind, string columnName)
{
DataFrameColumn ret;
if (kind == typeof(bool))
{
ret = new BooleanDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new BooleanDataFrameColumn(columnName);
}
else if (kind == typeof(int))
{
ret = new Int32DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new Int32DataFrameColumn(columnName);
}
else if (kind == typeof(float))
{
ret = new SingleDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new SingleDataFrameColumn(columnName);
}
else if (kind == typeof(string))
{
ret = new StringDataFrameColumn(GetColumnName(columnNames, columnIndex), 0);
ret = new StringDataFrameColumn(columnName, 0);
}
else if (kind == typeof(long))
{
ret = new Int64DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new Int64DataFrameColumn(columnName);
}
else if (kind == typeof(decimal))
{
ret = new DecimalDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new DecimalDataFrameColumn(columnName);
}
else if (kind == typeof(byte))
{
ret = new ByteDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new ByteDataFrameColumn(columnName);
}
else if (kind == typeof(char))
{
ret = new CharDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new CharDataFrameColumn(columnName);
}
else if (kind == typeof(double))
{
ret = new DoubleDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new DoubleDataFrameColumn(columnName);
}
else if (kind == typeof(sbyte))
{
ret = new SByteDataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new SByteDataFrameColumn(columnName);
}
else if (kind == typeof(short))
{
ret = new Int16DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new Int16DataFrameColumn(columnName);
}
else if (kind == typeof(uint))
{
ret = new UInt32DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new UInt32DataFrameColumn(columnName);
}
else if (kind == typeof(ulong))
{
ret = new UInt64DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new UInt64DataFrameColumn(columnName);
}
else if (kind == typeof(ushort))
{
ret = new UInt16DataFrameColumn(GetColumnName(columnNames, columnIndex));
ret = new UInt16DataFrameColumn(columnName);
}
else if (kind == typeof(DateTime))
{
ret = new PrimitiveDataFrameColumn<DateTime>(GetColumnName(columnNames, columnIndex));
ret = new PrimitiveDataFrameColumn<DateTime>(columnName);
}
else
{
Expand All @@ -196,6 +345,11 @@ private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int
return ret;
}

private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex)
{
return CreateColumn(kind, GetColumnName(columnNames, columnIndex));
}

private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringReader wrappedReader,
char separator = ',', bool header = true,
string[] columnNames = null, Type[] dataTypes = null,
Expand Down
37 changes: 37 additions & 0 deletions src/Microsoft.Data.Analysis/Extensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Text;

namespace Microsoft.Data.Analysis
{
public static class Extensions
{
public static DbDataAdapter CreateDataAdapter(this DbProviderFactory factory, DbConnection connection, string tableName)
{
var query = connection.CreateCommand();
query.CommandText = $"SELECT * FROM {tableName}";
var res = factory.CreateDataAdapter();
res.SelectCommand = query;
return res;
}

public static bool TryOpen(this DbConnection connection)
{
if (connection.State == ConnectionState.Closed)
{
connection.Open();
return true;
}
else
{
return false;
}
}
}
}
Loading

0 comments on commit 3d705bf

Please sign in to comment.