Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the initial prototype of a DatabaseLoader #4035

Merged
merged 12 commits into from
Jul 31, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Data" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Api" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Ensemble" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Experimental" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.FastTree" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Mkl.Components" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.KMeansClustering" + PublicKey.Value)]
Expand Down
373 changes: 373 additions & 0 deletions src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,373 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;

[assembly: LoadableClass(DatabaseLoader.Summary, typeof(DatabaseLoader), null, typeof(SignatureLoadModel),
"Database Loader", DatabaseLoader.LoaderSignature)]

namespace Microsoft.ML.Data
{
public sealed partial class DatabaseLoader : IDataLoader<Func<DbDataReader>>
{
internal const string Summary = "Loads data from an DbDataReader.";
tannergooding marked this conversation as resolved.
Show resolved Hide resolved
internal const string LoaderSignature = "DatabaseLoader";

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "DBLOADER",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(DatabaseLoader).Assembly.FullName);
}

private readonly Bindings _bindings;

private readonly IHost _host;
private const string RegistrationName = "DatabaseLoader";

internal DatabaseLoader(IHostEnvironment env, Options options)
{
options = options ?? new Options();

Contracts.CheckValue(env, nameof(env));
_host = env.Register(RegistrationName);
_host.CheckValue(options, nameof(options));

var cols = options.Columns;
if (Utils.Size(cols) == 0)
{
throw _host.Except("DatabaseLoader requires at least one Column");
}

_bindings = new Bindings(this, cols);
}

private DatabaseLoader(IHost host, ModelLoadContext ctx)
{
Contracts.AssertValue(host, "host");
host.AssertValue(ctx);

_host = host;

_bindings = new Bindings(ctx, this);
}

internal static DatabaseLoader Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
IHost h = env.Register(RegistrationName);

h.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());

return h.Apply("Loading Model", ch => new DatabaseLoader(h, ctx));
}

void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

// *** Binary format ***
// bindings
_bindings.Save(ctx);
}

/// <summary>
/// The output <see cref="DataViewSchema"/> that will be produced by the loader.
/// </summary>
public DataViewSchema GetOutputSchema() => _bindings.OutputSchema;

/// <summary>
/// Loads data from <paramref name="input"/> into an <see cref="IDataView"/>.
/// </summary>
/// <param name="input">A function that returns an DbDataReader from which to load data.</param>
public IDataView Load(Func<DbDataReader> input) => new BoundLoader(this, input);

/// <summary>
/// Describes how an input column should be mapped to an <see cref="IDataView"/> column.
/// </summary>
public sealed class Column
{
/// <summary>
/// Name of the column.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the column")]
public string Name;

/// <summary>
/// <see cref="DbType"/> of the items in the column.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Type of the items in the column")]
public DbType Type = DbType.Single;

/// <summary>
/// Source index of the column.
/// </summary>
[Argument(ArgumentType.Multiple, HelpText = "Source index of the column", ShortName = "src")]
public int? Source;

/// <summary>
/// For a key column, this defines the range of values.
/// </summary>
[Argument(ArgumentType.Multiple, HelpText = "For a key column, this defines the range of values", ShortName = "key")]
public KeyCount KeyCount;
}

/// <summary>
/// The settings for <see cref="DatabaseLoader"/>
/// </summary>
public sealed class Options
{
/// <summary>
/// Specifies the input columns that should be mapped to <see cref="IDataView"/> columns.
/// </summary>
[Argument(ArgumentType.Multiple, HelpText = "Column groups. Each group is specified as name:type:numeric-ranges, eg, col=Features:R4:1-17,26,35-40",
Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
}

/// <summary>
/// Information for an output column.
/// </summary>
private sealed class ColInfo
{
public readonly string Name;
public readonly int? SourceIndex;
public readonly DataViewType ColType;

public ColInfo(string name, int? sourceIndex, DataViewType colType)
{
Contracts.AssertNonEmpty(name);
Contracts.Assert(!sourceIndex.HasValue || sourceIndex >= 0);
Contracts.AssertValue(colType);

Name = name;
SourceIndex = sourceIndex;
ColType = colType;
}
}

private sealed class Bindings
{
/// <summary>
/// <see cref="Infos"/>[i] stores the i-th column's name and type. Columns are loaded from the input text file.
/// </summary>
public readonly ColInfo[] Infos;

public DataViewSchema OutputSchema { get; }

public Bindings(DatabaseLoader parent, Column[] cols)
{
Contracts.AssertNonEmpty(cols);

using (var ch = parent._host.Start("Binding"))
{
// Make sure all columns have at least one source range.
foreach (var col in cols)
{
if (col.Source < 0)
throw ch.ExceptUserArg(nameof(Column.Source), "Source column index must be non-negative");
}

Infos = new ColInfo[cols.Length];

// This dictionary is used only for detecting duplicated column names specified by user.
var nameToInfoIndex = new Dictionary<string, int>(Infos.Length);

for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
{
var col = cols[iinfo];

ch.CheckNonWhiteSpace(col.Name, nameof(col.Name));
string name = col.Name.Trim();
if (iinfo == nameToInfoIndex.Count && nameToInfoIndex.ContainsKey(name))
ch.Info("Duplicate name(s) specified - later columns will hide earlier ones");

PrimitiveDataViewType itemType;
if (col.KeyCount != null)
{
itemType = ConstructKeyType(col.Type, col.KeyCount);
}
else
{
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(col.Type.ToType());
}

Infos[iinfo] = new ColInfo(name, col.Source, itemType);

nameToInfoIndex[name] = iinfo;
}
}
OutputSchema = ComputeOutputSchema();
}

public Bindings(ModelLoadContext ctx, DatabaseLoader parent)
{
Contracts.AssertValue(ctx);

// *** Binary format ***
// int: number of columns
// foreach column:
// int: id of column name
// byte: DataKind
// byte: bool of whether this is a key type
// for a key type:
// ulong: count for key range
// byte: bool of whether the source index is valid
// for a valid source index:
// int: source index
int cinfo = ctx.Reader.ReadInt32();
Contracts.CheckDecode(cinfo > 0);
Infos = new ColInfo[cinfo];

for (int iinfo = 0; iinfo < cinfo; iinfo++)
{
string name = ctx.LoadNonEmptyString();

PrimitiveDataViewType itemType;
var kind = (InternalDataKind)ctx.Reader.ReadByte();
Contracts.CheckDecode(Enum.IsDefined(typeof(InternalDataKind), kind));
bool isKey = ctx.Reader.ReadBoolByte();
if (isKey)
{
ulong count;
Contracts.CheckDecode(KeyDataViewType.IsValidDataType(kind.ToType()));

count = ctx.Reader.ReadUInt64();
Contracts.CheckDecode(0 < count);

itemType = new KeyDataViewType(kind.ToType(), count);
}
else
itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);

int? sourceIndex = null;
bool hasSourceIndex = ctx.Reader.ReadBoolByte();
if (hasSourceIndex)
{
sourceIndex = ctx.Reader.ReadInt32();
}

Infos[iinfo] = new ColInfo(name, sourceIndex, itemType);
}

OutputSchema = ComputeOutputSchema();
}

internal void Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);

// *** Binary format ***
// int: number of columns
// foreach column:
// int: id of column name
// byte: DataKind
// byte: bool of whether this is a key type
// for a key type:
// ulong: count for key range
// byte: bool of whether the source index is valid
// for a valid source index:
// int: source index
ctx.Writer.Write(Infos.Length);
for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
{
var info = Infos[iinfo];
ctx.SaveNonEmptyString(info.Name);
var type = info.ColType.GetItemType();
InternalDataKind rawKind = type.GetRawKind();
Contracts.Assert((InternalDataKind)(byte)rawKind == rawKind);
ctx.Writer.Write((byte)rawKind);
ctx.Writer.WriteBoolByte(type is KeyDataViewType);
if (type is KeyDataViewType key)
ctx.Writer.Write(key.Count);
ctx.Writer.WriteBoolByte(info.SourceIndex.HasValue);
if (info.SourceIndex.HasValue)
ctx.Writer.Write(info.SourceIndex.GetValueOrDefault());
}
}

private DataViewSchema ComputeOutputSchema()
{
var schemaBuilder = new DataViewSchema.Builder();

// Iterate through all loaded columns. The index i indicates the i-th column loaded.
for (int i = 0; i < Infos.Length; ++i)
{
var info = Infos[i];
schemaBuilder.AddColumn(info.Name, info.ColType);
}

return schemaBuilder.ToSchema();
}

/// <summary>
/// Construct a <see cref="KeyDataViewType"/> out of the DbType and the keyCount.
/// </summary>
private static KeyDataViewType ConstructKeyType(DbType dbType, KeyCount keyCount)
{
Contracts.CheckValue(keyCount, nameof(keyCount));

KeyDataViewType keyType;
Type rawType = dbType.ToType();
Contracts.CheckUserArg(KeyDataViewType.IsValidDataType(rawType), nameof(DatabaseLoader.Column.Type), "Bad item type for Key");

if (keyCount.Count == null)
keyType = new KeyDataViewType(rawType, rawType.ToMaxInt());
else
keyType = new KeyDataViewType(rawType, keyCount.Count.GetValueOrDefault());

return keyType;
}
}

private sealed class BoundLoader : IDataView
{
private readonly DatabaseLoader _loader;
private readonly IHost _host;
private readonly Func<DbDataReader> _input;

public BoundLoader(DatabaseLoader loader, Func<DbDataReader> input)
{
_loader = loader;
_host = loader._host.Register(nameof(BoundLoader));

_host.CheckValue(input, nameof(input));
_input = input;
}

public long? GetRowCount() => null;
public bool CanShuffle => false;

public DataViewSchema Schema => _loader._bindings.OutputSchema;

public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
{
_host.CheckValueOrNull(rand);
var active = Utils.BuildArray(_loader._bindings.OutputSchema.Count, columnsNeeded);
return Cursor.Create(_loader, _input, active);
}

public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
{
return new DataViewRowCursor[] { GetRowCursor(columnsNeeded, rand) };
}
}
}
}
Loading