Skip to content

Commit 9c2de1e

Browse files
tannergoodingEric Erhardt
authored andcommitted
Adding the initial prototype of a DatabaseLoader (#4035)
* Initial implementation of DatabaseLoader. * More work in progress. Infer source index if not specified. Add initial public API. Fix bug in GetIdGetter * Moving the DatabaseLoader to Microsoft.ML.Experimental * Adding getter delegate support for the remaining internal data kinds. * Use DbDataReader instead of IDataReader * Fixing the DatabaseLoader tests to use public surface area. * Return a ReadOnlyMemory<char> for the string value getter * An DbDataReader to a DbDataReader * Creating a barebones MockDbConnection which wraps a TextLoader, for testing. * Removing the System.Data.SqlClient PackageReference * Ensuring the MockCommand filters to the listed columns. * Updating the name of the DatabaseLoader test and using the correct fact attribute.
1 parent 59699a5 commit 9c2de1e

File tree

7 files changed

+1311
-0
lines changed

7 files changed

+1311
-0
lines changed

src/Microsoft.ML.Core/Properties/AssemblyInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Data" + PublicKey.Value)]
2222
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Api" + PublicKey.Value)]
2323
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Ensemble" + PublicKey.Value)]
24+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Experimental" + PublicKey.Value)]
2425
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.FastTree" + PublicKey.Value)]
2526
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Mkl.Components" + PublicKey.Value)]
2627
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.KMeansClustering" + PublicKey.Value)]
Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Data;
8+
using System.Data.Common;
9+
using System.Linq;
10+
using Microsoft.ML;
11+
using Microsoft.ML.CommandLine;
12+
using Microsoft.ML.Data;
13+
using Microsoft.ML.Internal.Utilities;
14+
using Microsoft.ML.Runtime;
15+
16+
[assembly: LoadableClass(DatabaseLoader.Summary, typeof(DatabaseLoader), null, typeof(SignatureLoadModel),
17+
"Database Loader", DatabaseLoader.LoaderSignature)]
18+
19+
namespace Microsoft.ML.Data
20+
{
21+
public sealed partial class DatabaseLoader : IDataLoader<Func<DbDataReader>>
22+
{
23+
internal const string Summary = "Loads data from a DbDataReader.";
24+
internal const string LoaderSignature = "DatabaseLoader";
25+
26+
private static VersionInfo GetVersionInfo()
27+
{
28+
return new VersionInfo(
29+
modelSignature: "DBLOADER",
30+
verWrittenCur: 0x00010001, // Initial
31+
verReadableCur: 0x00010001,
32+
verWeCanReadBack: 0x00010001,
33+
loaderSignature: LoaderSignature,
34+
loaderAssemblyName: typeof(DatabaseLoader).Assembly.FullName);
35+
}
36+
37+
private readonly Bindings _bindings;
38+
39+
private readonly IHost _host;
40+
private const string RegistrationName = "DatabaseLoader";
41+
42+
internal DatabaseLoader(IHostEnvironment env, Options options)
43+
{
44+
options = options ?? new Options();
45+
46+
Contracts.CheckValue(env, nameof(env));
47+
_host = env.Register(RegistrationName);
48+
_host.CheckValue(options, nameof(options));
49+
50+
var cols = options.Columns;
51+
if (Utils.Size(cols) == 0)
52+
{
53+
throw _host.Except("DatabaseLoader requires at least one Column");
54+
}
55+
56+
_bindings = new Bindings(this, cols);
57+
}
58+
59+
private DatabaseLoader(IHost host, ModelLoadContext ctx)
60+
{
61+
Contracts.AssertValue(host, "host");
62+
host.AssertValue(ctx);
63+
64+
_host = host;
65+
66+
_bindings = new Bindings(ctx, this);
67+
}
68+
69+
internal static DatabaseLoader Create(IHostEnvironment env, ModelLoadContext ctx)
70+
{
71+
Contracts.CheckValue(env, nameof(env));
72+
IHost h = env.Register(RegistrationName);
73+
74+
h.CheckValue(ctx, nameof(ctx));
75+
ctx.CheckAtModel(GetVersionInfo());
76+
77+
return h.Apply("Loading Model", ch => new DatabaseLoader(h, ctx));
78+
}
79+
80+
void ICanSaveModel.Save(ModelSaveContext ctx)
81+
{
82+
_host.CheckValue(ctx, nameof(ctx));
83+
ctx.CheckAtModel();
84+
ctx.SetVersionInfo(GetVersionInfo());
85+
86+
// *** Binary format ***
87+
// bindings
88+
_bindings.Save(ctx);
89+
}
90+
91+
/// <summary>
92+
/// The output <see cref="DataViewSchema"/> that will be produced by the loader.
93+
/// </summary>
94+
public DataViewSchema GetOutputSchema() => _bindings.OutputSchema;
95+
96+
/// <summary>
97+
/// Loads data from <paramref name="input"/> into an <see cref="IDataView"/>.
98+
/// </summary>
99+
/// <param name="input">A function that returns a DbDataReader from which to load data.</param>
100+
public IDataView Load(Func<DbDataReader> input) => new BoundLoader(this, input);
101+
102+
/// <summary>
103+
/// Describes how an input column should be mapped to an <see cref="IDataView"/> column.
104+
/// </summary>
105+
public sealed class Column
106+
{
107+
/// <summary>
108+
/// Name of the column.
109+
/// </summary>
110+
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the column")]
111+
public string Name;
112+
113+
/// <summary>
114+
/// <see cref="DbType"/> of the items in the column.
115+
/// </summary>
116+
[Argument(ArgumentType.AtMostOnce, HelpText = "Type of the items in the column")]
117+
public DbType Type = DbType.Single;
118+
119+
/// <summary>
120+
/// Source index of the column.
121+
/// </summary>
122+
[Argument(ArgumentType.Multiple, HelpText = "Source index of the column", ShortName = "src")]
123+
public int? Source;
124+
125+
/// <summary>
126+
/// For a key column, this defines the range of values.
127+
/// </summary>
128+
[Argument(ArgumentType.Multiple, HelpText = "For a key column, this defines the range of values", ShortName = "key")]
129+
public KeyCount KeyCount;
130+
}
131+
132+
/// <summary>
133+
/// The settings for <see cref="DatabaseLoader"/>
134+
/// </summary>
135+
public sealed class Options
136+
{
137+
/// <summary>
138+
/// Specifies the input columns that should be mapped to <see cref="IDataView"/> columns.
139+
/// </summary>
140+
[Argument(ArgumentType.Multiple, HelpText = "Column groups. Each group is specified as name:type:numeric-ranges, eg, col=Features:R4:1-17,26,35-40",
141+
Name = "Column", ShortName = "col", SortOrder = 1)]
142+
public Column[] Columns;
143+
}
144+
145+
/// <summary>
146+
/// Information for an output column.
147+
/// </summary>
148+
private sealed class ColInfo
149+
{
150+
public readonly string Name;
151+
public readonly int? SourceIndex;
152+
public readonly DataViewType ColType;
153+
154+
public ColInfo(string name, int? sourceIndex, DataViewType colType)
155+
{
156+
Contracts.AssertNonEmpty(name);
157+
Contracts.Assert(!sourceIndex.HasValue || sourceIndex >= 0);
158+
Contracts.AssertValue(colType);
159+
160+
Name = name;
161+
SourceIndex = sourceIndex;
162+
ColType = colType;
163+
}
164+
}
165+
166+
private sealed class Bindings
167+
{
168+
/// <summary>
169+
/// <see cref="Infos"/>[i] stores the i-th column's name and type. Columns are loaded from the input text file.
170+
/// </summary>
171+
public readonly ColInfo[] Infos;
172+
173+
public DataViewSchema OutputSchema { get; }
174+
175+
public Bindings(DatabaseLoader parent, Column[] cols)
176+
{
177+
Contracts.AssertNonEmpty(cols);
178+
179+
using (var ch = parent._host.Start("Binding"))
180+
{
181+
// Make sure all columns have at least one source range.
182+
foreach (var col in cols)
183+
{
184+
if (col.Source < 0)
185+
throw ch.ExceptUserArg(nameof(Column.Source), "Source column index must be non-negative");
186+
}
187+
188+
Infos = new ColInfo[cols.Length];
189+
190+
// This dictionary is used only for detecting duplicated column names specified by user.
191+
var nameToInfoIndex = new Dictionary<string, int>(Infos.Length);
192+
193+
for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
194+
{
195+
var col = cols[iinfo];
196+
197+
ch.CheckNonWhiteSpace(col.Name, nameof(col.Name));
198+
string name = col.Name.Trim();
199+
if (iinfo == nameToInfoIndex.Count && nameToInfoIndex.ContainsKey(name))
200+
ch.Info("Duplicate name(s) specified - later columns will hide earlier ones");
201+
202+
PrimitiveDataViewType itemType;
203+
if (col.KeyCount != null)
204+
{
205+
itemType = ConstructKeyType(col.Type, col.KeyCount);
206+
}
207+
else
208+
{
209+
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(col.Type.ToType());
210+
}
211+
212+
Infos[iinfo] = new ColInfo(name, col.Source, itemType);
213+
214+
nameToInfoIndex[name] = iinfo;
215+
}
216+
}
217+
OutputSchema = ComputeOutputSchema();
218+
}
219+
220+
public Bindings(ModelLoadContext ctx, DatabaseLoader parent)
221+
{
222+
Contracts.AssertValue(ctx);
223+
224+
// *** Binary format ***
225+
// int: number of columns
226+
// foreach column:
227+
// int: id of column name
228+
// byte: DataKind
229+
// byte: bool of whether this is a key type
230+
// for a key type:
231+
// ulong: count for key range
232+
// byte: bool of whether the source index is valid
233+
// for a valid source index:
234+
// int: source index
235+
int cinfo = ctx.Reader.ReadInt32();
236+
Contracts.CheckDecode(cinfo > 0);
237+
Infos = new ColInfo[cinfo];
238+
239+
for (int iinfo = 0; iinfo < cinfo; iinfo++)
240+
{
241+
string name = ctx.LoadNonEmptyString();
242+
243+
PrimitiveDataViewType itemType;
244+
var kind = (InternalDataKind)ctx.Reader.ReadByte();
245+
Contracts.CheckDecode(Enum.IsDefined(typeof(InternalDataKind), kind));
246+
bool isKey = ctx.Reader.ReadBoolByte();
247+
if (isKey)
248+
{
249+
ulong count;
250+
Contracts.CheckDecode(KeyDataViewType.IsValidDataType(kind.ToType()));
251+
252+
count = ctx.Reader.ReadUInt64();
253+
Contracts.CheckDecode(0 < count);
254+
255+
itemType = new KeyDataViewType(kind.ToType(), count);
256+
}
257+
else
258+
itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
259+
260+
int? sourceIndex = null;
261+
bool hasSourceIndex = ctx.Reader.ReadBoolByte();
262+
if (hasSourceIndex)
263+
{
264+
sourceIndex = ctx.Reader.ReadInt32();
265+
}
266+
267+
Infos[iinfo] = new ColInfo(name, sourceIndex, itemType);
268+
}
269+
270+
OutputSchema = ComputeOutputSchema();
271+
}
272+
273+
internal void Save(ModelSaveContext ctx)
274+
{
275+
Contracts.AssertValue(ctx);
276+
277+
// *** Binary format ***
278+
// int: number of columns
279+
// foreach column:
280+
// int: id of column name
281+
// byte: DataKind
282+
// byte: bool of whether this is a key type
283+
// for a key type:
284+
// ulong: count for key range
285+
// byte: bool of whether the source index is valid
286+
// for a valid source index:
287+
// int: source index
288+
ctx.Writer.Write(Infos.Length);
289+
for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
290+
{
291+
var info = Infos[iinfo];
292+
ctx.SaveNonEmptyString(info.Name);
293+
var type = info.ColType.GetItemType();
294+
InternalDataKind rawKind = type.GetRawKind();
295+
Contracts.Assert((InternalDataKind)(byte)rawKind == rawKind);
296+
ctx.Writer.Write((byte)rawKind);
297+
ctx.Writer.WriteBoolByte(type is KeyDataViewType);
298+
if (type is KeyDataViewType key)
299+
ctx.Writer.Write(key.Count);
300+
ctx.Writer.WriteBoolByte(info.SourceIndex.HasValue);
301+
if (info.SourceIndex.HasValue)
302+
ctx.Writer.Write(info.SourceIndex.GetValueOrDefault());
303+
}
304+
}
305+
306+
private DataViewSchema ComputeOutputSchema()
307+
{
308+
var schemaBuilder = new DataViewSchema.Builder();
309+
310+
// Iterate through all loaded columns. The index i indicates the i-th column loaded.
311+
for (int i = 0; i < Infos.Length; ++i)
312+
{
313+
var info = Infos[i];
314+
schemaBuilder.AddColumn(info.Name, info.ColType);
315+
}
316+
317+
return schemaBuilder.ToSchema();
318+
}
319+
320+
/// <summary>
321+
/// Construct a <see cref="KeyDataViewType"/> out of the DbType and the keyCount.
322+
/// </summary>
323+
private static KeyDataViewType ConstructKeyType(DbType dbType, KeyCount keyCount)
324+
{
325+
Contracts.CheckValue(keyCount, nameof(keyCount));
326+
327+
KeyDataViewType keyType;
328+
Type rawType = dbType.ToType();
329+
Contracts.CheckUserArg(KeyDataViewType.IsValidDataType(rawType), nameof(DatabaseLoader.Column.Type), "Bad item type for Key");
330+
331+
if (keyCount.Count == null)
332+
keyType = new KeyDataViewType(rawType, rawType.ToMaxInt());
333+
else
334+
keyType = new KeyDataViewType(rawType, keyCount.Count.GetValueOrDefault());
335+
336+
return keyType;
337+
}
338+
}
339+
340+
private sealed class BoundLoader : IDataView
341+
{
342+
private readonly DatabaseLoader _loader;
343+
private readonly IHost _host;
344+
private readonly Func<DbDataReader> _input;
345+
346+
public BoundLoader(DatabaseLoader loader, Func<DbDataReader> input)
347+
{
348+
_loader = loader;
349+
_host = loader._host.Register(nameof(BoundLoader));
350+
351+
_host.CheckValue(input, nameof(input));
352+
_input = input;
353+
}
354+
355+
public long? GetRowCount() => null;
356+
public bool CanShuffle => false;
357+
358+
public DataViewSchema Schema => _loader._bindings.OutputSchema;
359+
360+
public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
361+
{
362+
_host.CheckValueOrNull(rand);
363+
var active = Utils.BuildArray(_loader._bindings.OutputSchema.Count, columnsNeeded);
364+
return Cursor.Create(_loader, _input, active);
365+
}
366+
367+
public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
368+
{
369+
return new DataViewRowCursor[] { GetRowCursor(columnsNeeded, rand) };
370+
}
371+
}
372+
}
373+
}

0 commit comments

Comments
 (0)