|
| 1 | +// Licensed to the .NET Foundation under one or more agreements. |
| 2 | +// The .NET Foundation licenses this file to you under the MIT license. |
| 3 | +// See the LICENSE file in the project root for more information. |
| 4 | + |
| 5 | +using System; |
| 6 | +using System.Collections.Generic; |
| 7 | +using System.Data; |
| 8 | +using System.Data.Common; |
| 9 | +using System.Linq; |
| 10 | +using Microsoft.ML; |
| 11 | +using Microsoft.ML.CommandLine; |
| 12 | +using Microsoft.ML.Data; |
| 13 | +using Microsoft.ML.Internal.Utilities; |
| 14 | +using Microsoft.ML.Runtime; |
| 15 | + |
| 16 | +[assembly: LoadableClass(DatabaseLoader.Summary, typeof(DatabaseLoader), null, typeof(SignatureLoadModel), |
| 17 | + "Database Loader", DatabaseLoader.LoaderSignature)] |
| 18 | + |
| 19 | +namespace Microsoft.ML.Data |
| 20 | +{ |
| 21 | + public sealed partial class DatabaseLoader : IDataLoader<Func<DbDataReader>> |
| 22 | + { |
| 23 | + internal const string Summary = "Loads data from a DbDataReader."; |
| 24 | + internal const string LoaderSignature = "DatabaseLoader"; |
| 25 | + |
| 26 | + private static VersionInfo GetVersionInfo() |
| 27 | + { |
| 28 | + return new VersionInfo( |
| 29 | + modelSignature: "DBLOADER", |
| 30 | + verWrittenCur: 0x00010001, // Initial |
| 31 | + verReadableCur: 0x00010001, |
| 32 | + verWeCanReadBack: 0x00010001, |
| 33 | + loaderSignature: LoaderSignature, |
| 34 | + loaderAssemblyName: typeof(DatabaseLoader).Assembly.FullName); |
| 35 | + } |
| 36 | + |
| 37 | + private readonly Bindings _bindings; |
| 38 | + |
| 39 | + private readonly IHost _host; |
| 40 | + private const string RegistrationName = "DatabaseLoader"; |
| 41 | + |
| 42 | + internal DatabaseLoader(IHostEnvironment env, Options options) |
| 43 | + { |
| 44 | + options = options ?? new Options(); |
| 45 | + |
| 46 | + Contracts.CheckValue(env, nameof(env)); |
| 47 | + _host = env.Register(RegistrationName); |
| 48 | + _host.CheckValue(options, nameof(options)); |
| 49 | + |
| 50 | + var cols = options.Columns; |
| 51 | + if (Utils.Size(cols) == 0) |
| 52 | + { |
| 53 | + throw _host.Except("DatabaseLoader requires at least one Column"); |
| 54 | + } |
| 55 | + |
| 56 | + _bindings = new Bindings(this, cols); |
| 57 | + } |
| 58 | + |
| 59 | + private DatabaseLoader(IHost host, ModelLoadContext ctx) |
| 60 | + { |
| 61 | + Contracts.AssertValue(host, "host"); |
| 62 | + host.AssertValue(ctx); |
| 63 | + |
| 64 | + _host = host; |
| 65 | + |
| 66 | + _bindings = new Bindings(ctx, this); |
| 67 | + } |
| 68 | + |
| 69 | + internal static DatabaseLoader Create(IHostEnvironment env, ModelLoadContext ctx) |
| 70 | + { |
| 71 | + Contracts.CheckValue(env, nameof(env)); |
| 72 | + IHost h = env.Register(RegistrationName); |
| 73 | + |
| 74 | + h.CheckValue(ctx, nameof(ctx)); |
| 75 | + ctx.CheckAtModel(GetVersionInfo()); |
| 76 | + |
| 77 | + return h.Apply("Loading Model", ch => new DatabaseLoader(h, ctx)); |
| 78 | + } |
| 79 | + |
| 80 | + void ICanSaveModel.Save(ModelSaveContext ctx) |
| 81 | + { |
| 82 | + _host.CheckValue(ctx, nameof(ctx)); |
| 83 | + ctx.CheckAtModel(); |
| 84 | + ctx.SetVersionInfo(GetVersionInfo()); |
| 85 | + |
| 86 | + // *** Binary format *** |
| 87 | + // bindings |
| 88 | + _bindings.Save(ctx); |
| 89 | + } |
| 90 | + |
| 91 | + /// <summary> |
| 92 | + /// The output <see cref="DataViewSchema"/> that will be produced by the loader. |
| 93 | + /// </summary> |
| 94 | + public DataViewSchema GetOutputSchema() => _bindings.OutputSchema; |
| 95 | + |
| 96 | + /// <summary> |
| 97 | + /// Loads data from <paramref name="input"/> into an <see cref="IDataView"/>. |
| 98 | + /// </summary> |
| 99 | + /// <param name="input">A function that returns a DbDataReader from which to load data.</param> |
| 100 | + public IDataView Load(Func<DbDataReader> input) => new BoundLoader(this, input); |
| 101 | + |
| 102 | + /// <summary> |
| 103 | + /// Describes how an input column should be mapped to an <see cref="IDataView"/> column. |
| 104 | + /// </summary> |
| 105 | + public sealed class Column |
| 106 | + { |
| 107 | + /// <summary> |
| 108 | + /// Name of the column. |
| 109 | + /// </summary> |
| 110 | + [Argument(ArgumentType.AtMostOnce, HelpText = "Name of the column")] |
| 111 | + public string Name; |
| 112 | + |
| 113 | + /// <summary> |
| 114 | + /// <see cref="DbType"/> of the items in the column. |
| 115 | + /// </summary> |
| 116 | + [Argument(ArgumentType.AtMostOnce, HelpText = "Type of the items in the column")] |
| 117 | + public DbType Type = DbType.Single; |
| 118 | + |
| 119 | + /// <summary> |
| 120 | + /// Source index of the column. |
| 121 | + /// </summary> |
| 122 | + [Argument(ArgumentType.Multiple, HelpText = "Source index of the column", ShortName = "src")] |
| 123 | + public int? Source; |
| 124 | + |
| 125 | + /// <summary> |
| 126 | + /// For a key column, this defines the range of values. |
| 127 | + /// </summary> |
| 128 | + [Argument(ArgumentType.Multiple, HelpText = "For a key column, this defines the range of values", ShortName = "key")] |
| 129 | + public KeyCount KeyCount; |
| 130 | + } |
| 131 | + |
| 132 | + /// <summary> |
| 133 | + /// The settings for <see cref="DatabaseLoader"/> |
| 134 | + /// </summary> |
| 135 | + public sealed class Options |
| 136 | + { |
| 137 | + /// <summary> |
| 138 | + /// Specifies the input columns that should be mapped to <see cref="IDataView"/> columns. |
| 139 | + /// </summary> |
| 140 | + [Argument(ArgumentType.Multiple, HelpText = "Column groups. Each group is specified as name:type:numeric-ranges, eg, col=Features:R4:1-17,26,35-40", |
| 141 | + Name = "Column", ShortName = "col", SortOrder = 1)] |
| 142 | + public Column[] Columns; |
| 143 | + } |
| 144 | + |
| 145 | + /// <summary> |
| 146 | + /// Information for an output column. |
| 147 | + /// </summary> |
| 148 | + private sealed class ColInfo |
| 149 | + { |
| 150 | + public readonly string Name; |
| 151 | + public readonly int? SourceIndex; |
| 152 | + public readonly DataViewType ColType; |
| 153 | + |
| 154 | + public ColInfo(string name, int? sourceIndex, DataViewType colType) |
| 155 | + { |
| 156 | + Contracts.AssertNonEmpty(name); |
| 157 | + Contracts.Assert(!sourceIndex.HasValue || sourceIndex >= 0); |
| 158 | + Contracts.AssertValue(colType); |
| 159 | + |
| 160 | + Name = name; |
| 161 | + SourceIndex = sourceIndex; |
| 162 | + ColType = colType; |
| 163 | + } |
| 164 | + } |
| 165 | + |
| 166 | + private sealed class Bindings |
| 167 | + { |
| 168 | + /// <summary> |
| 169 | + /// <see cref="Infos"/>[i] stores the i-th column's name and type. Columns are loaded from the input text file. |
| 170 | + /// </summary> |
| 171 | + public readonly ColInfo[] Infos; |
| 172 | + |
| 173 | + public DataViewSchema OutputSchema { get; } |
| 174 | + |
| 175 | + public Bindings(DatabaseLoader parent, Column[] cols) |
| 176 | + { |
| 177 | + Contracts.AssertNonEmpty(cols); |
| 178 | + |
| 179 | + using (var ch = parent._host.Start("Binding")) |
| 180 | + { |
| 181 | + // Make sure all columns have at least one source range. |
| 182 | + foreach (var col in cols) |
| 183 | + { |
| 184 | + if (col.Source < 0) |
| 185 | + throw ch.ExceptUserArg(nameof(Column.Source), "Source column index must be non-negative"); |
| 186 | + } |
| 187 | + |
| 188 | + Infos = new ColInfo[cols.Length]; |
| 189 | + |
| 190 | + // This dictionary is used only for detecting duplicated column names specified by user. |
| 191 | + var nameToInfoIndex = new Dictionary<string, int>(Infos.Length); |
| 192 | + |
| 193 | + for (int iinfo = 0; iinfo < Infos.Length; iinfo++) |
| 194 | + { |
| 195 | + var col = cols[iinfo]; |
| 196 | + |
| 197 | + ch.CheckNonWhiteSpace(col.Name, nameof(col.Name)); |
| 198 | + string name = col.Name.Trim(); |
| 199 | + if (iinfo == nameToInfoIndex.Count && nameToInfoIndex.ContainsKey(name)) |
| 200 | + ch.Info("Duplicate name(s) specified - later columns will hide earlier ones"); |
| 201 | + |
| 202 | + PrimitiveDataViewType itemType; |
| 203 | + if (col.KeyCount != null) |
| 204 | + { |
| 205 | + itemType = ConstructKeyType(col.Type, col.KeyCount); |
| 206 | + } |
| 207 | + else |
| 208 | + { |
| 209 | + itemType = ColumnTypeExtensions.PrimitiveTypeFromType(col.Type.ToType()); |
| 210 | + } |
| 211 | + |
| 212 | + Infos[iinfo] = new ColInfo(name, col.Source, itemType); |
| 213 | + |
| 214 | + nameToInfoIndex[name] = iinfo; |
| 215 | + } |
| 216 | + } |
| 217 | + OutputSchema = ComputeOutputSchema(); |
| 218 | + } |
| 219 | + |
| 220 | + public Bindings(ModelLoadContext ctx, DatabaseLoader parent) |
| 221 | + { |
| 222 | + Contracts.AssertValue(ctx); |
| 223 | + |
| 224 | + // *** Binary format *** |
| 225 | + // int: number of columns |
| 226 | + // foreach column: |
| 227 | + // int: id of column name |
| 228 | + // byte: DataKind |
| 229 | + // byte: bool of whether this is a key type |
| 230 | + // for a key type: |
| 231 | + // ulong: count for key range |
| 232 | + // byte: bool of whether the source index is valid |
| 233 | + // for a valid source index: |
| 234 | + // int: source index |
| 235 | + int cinfo = ctx.Reader.ReadInt32(); |
| 236 | + Contracts.CheckDecode(cinfo > 0); |
| 237 | + Infos = new ColInfo[cinfo]; |
| 238 | + |
| 239 | + for (int iinfo = 0; iinfo < cinfo; iinfo++) |
| 240 | + { |
| 241 | + string name = ctx.LoadNonEmptyString(); |
| 242 | + |
| 243 | + PrimitiveDataViewType itemType; |
| 244 | + var kind = (InternalDataKind)ctx.Reader.ReadByte(); |
| 245 | + Contracts.CheckDecode(Enum.IsDefined(typeof(InternalDataKind), kind)); |
| 246 | + bool isKey = ctx.Reader.ReadBoolByte(); |
| 247 | + if (isKey) |
| 248 | + { |
| 249 | + ulong count; |
| 250 | + Contracts.CheckDecode(KeyDataViewType.IsValidDataType(kind.ToType())); |
| 251 | + |
| 252 | + count = ctx.Reader.ReadUInt64(); |
| 253 | + Contracts.CheckDecode(0 < count); |
| 254 | + |
| 255 | + itemType = new KeyDataViewType(kind.ToType(), count); |
| 256 | + } |
| 257 | + else |
| 258 | + itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind); |
| 259 | + |
| 260 | + int? sourceIndex = null; |
| 261 | + bool hasSourceIndex = ctx.Reader.ReadBoolByte(); |
| 262 | + if (hasSourceIndex) |
| 263 | + { |
| 264 | + sourceIndex = ctx.Reader.ReadInt32(); |
| 265 | + } |
| 266 | + |
| 267 | + Infos[iinfo] = new ColInfo(name, sourceIndex, itemType); |
| 268 | + } |
| 269 | + |
| 270 | + OutputSchema = ComputeOutputSchema(); |
| 271 | + } |
| 272 | + |
| 273 | + internal void Save(ModelSaveContext ctx) |
| 274 | + { |
| 275 | + Contracts.AssertValue(ctx); |
| 276 | + |
| 277 | + // *** Binary format *** |
| 278 | + // int: number of columns |
| 279 | + // foreach column: |
| 280 | + // int: id of column name |
| 281 | + // byte: DataKind |
| 282 | + // byte: bool of whether this is a key type |
| 283 | + // for a key type: |
| 284 | + // ulong: count for key range |
| 285 | + // byte: bool of whether the source index is valid |
| 286 | + // for a valid source index: |
| 287 | + // int: source index |
| 288 | + ctx.Writer.Write(Infos.Length); |
| 289 | + for (int iinfo = 0; iinfo < Infos.Length; iinfo++) |
| 290 | + { |
| 291 | + var info = Infos[iinfo]; |
| 292 | + ctx.SaveNonEmptyString(info.Name); |
| 293 | + var type = info.ColType.GetItemType(); |
| 294 | + InternalDataKind rawKind = type.GetRawKind(); |
| 295 | + Contracts.Assert((InternalDataKind)(byte)rawKind == rawKind); |
| 296 | + ctx.Writer.Write((byte)rawKind); |
| 297 | + ctx.Writer.WriteBoolByte(type is KeyDataViewType); |
| 298 | + if (type is KeyDataViewType key) |
| 299 | + ctx.Writer.Write(key.Count); |
| 300 | + ctx.Writer.WriteBoolByte(info.SourceIndex.HasValue); |
| 301 | + if (info.SourceIndex.HasValue) |
| 302 | + ctx.Writer.Write(info.SourceIndex.GetValueOrDefault()); |
| 303 | + } |
| 304 | + } |
| 305 | + |
| 306 | + private DataViewSchema ComputeOutputSchema() |
| 307 | + { |
| 308 | + var schemaBuilder = new DataViewSchema.Builder(); |
| 309 | + |
| 310 | + // Iterate through all loaded columns. The index i indicates the i-th column loaded. |
| 311 | + for (int i = 0; i < Infos.Length; ++i) |
| 312 | + { |
| 313 | + var info = Infos[i]; |
| 314 | + schemaBuilder.AddColumn(info.Name, info.ColType); |
| 315 | + } |
| 316 | + |
| 317 | + return schemaBuilder.ToSchema(); |
| 318 | + } |
| 319 | + |
| 320 | + /// <summary> |
| 321 | + /// Construct a <see cref="KeyDataViewType"/> out of the DbType and the keyCount. |
| 322 | + /// </summary> |
| 323 | + private static KeyDataViewType ConstructKeyType(DbType dbType, KeyCount keyCount) |
| 324 | + { |
| 325 | + Contracts.CheckValue(keyCount, nameof(keyCount)); |
| 326 | + |
| 327 | + KeyDataViewType keyType; |
| 328 | + Type rawType = dbType.ToType(); |
| 329 | + Contracts.CheckUserArg(KeyDataViewType.IsValidDataType(rawType), nameof(DatabaseLoader.Column.Type), "Bad item type for Key"); |
| 330 | + |
| 331 | + if (keyCount.Count == null) |
| 332 | + keyType = new KeyDataViewType(rawType, rawType.ToMaxInt()); |
| 333 | + else |
| 334 | + keyType = new KeyDataViewType(rawType, keyCount.Count.GetValueOrDefault()); |
| 335 | + |
| 336 | + return keyType; |
| 337 | + } |
| 338 | + } |
| 339 | + |
| 340 | + private sealed class BoundLoader : IDataView |
| 341 | + { |
| 342 | + private readonly DatabaseLoader _loader; |
| 343 | + private readonly IHost _host; |
| 344 | + private readonly Func<DbDataReader> _input; |
| 345 | + |
| 346 | + public BoundLoader(DatabaseLoader loader, Func<DbDataReader> input) |
| 347 | + { |
| 348 | + _loader = loader; |
| 349 | + _host = loader._host.Register(nameof(BoundLoader)); |
| 350 | + |
| 351 | + _host.CheckValue(input, nameof(input)); |
| 352 | + _input = input; |
| 353 | + } |
| 354 | + |
| 355 | + public long? GetRowCount() => null; |
| 356 | + public bool CanShuffle => false; |
| 357 | + |
| 358 | + public DataViewSchema Schema => _loader._bindings.OutputSchema; |
| 359 | + |
| 360 | + public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null) |
| 361 | + { |
| 362 | + _host.CheckValueOrNull(rand); |
| 363 | + var active = Utils.BuildArray(_loader._bindings.OutputSchema.Count, columnsNeeded); |
| 364 | + return Cursor.Create(_loader, _input, active); |
| 365 | + } |
| 366 | + |
| 367 | + public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null) |
| 368 | + { |
| 369 | + return new DataViewRowCursor[] { GetRowCursor(columnsNeeded, rand) }; |
| 370 | + } |
| 371 | + } |
| 372 | + } |
| 373 | +} |
0 commit comments