Skip to content

Commit d84bf38

Browse files
Added in DateTime type support for TimeSeriesImputer (#4812)
* added in DateTime type support for TSI * updates based on PR feedback * Fixes based on PR comments
1 parent 0d21742 commit d84bf38

File tree

3 files changed

+280
-76
lines changed

3 files changed

+280
-76
lines changed

src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs

Lines changed: 97 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.IO;
34
using System.Linq;
45
using System.Runtime.InteropServices;
56
using System.Security;
@@ -32,7 +33,7 @@ public static class TimeSeriesImputerExtensionClass
3233
/// purpose of this estimator. Other column types will have the default value placed if a row is imputed.
3334
/// </summary>
3435
/// <param name="catalog">The transform catalog.</param>
35-
/// <param name="timeSeriesColumn">Column representing the time series. Should be of type <see cref="long"/></param>
36+
/// <param name="timeSeriesColumn">Column representing the time series. Should be of type <see cref="long"/> or <see cref="System.DateTime"/></param>
3637
/// <param name="grainColumns">List of columns to use as grains</param>
3738
/// <param name="imputeMode">Mode of imputation for missing values in column. If not passed defaults to forward fill</param>
3839
public static TimeSeriesImputerEstimator ReplaceMissingTimeSeriesValues(this TransformsCatalog catalog, string timeSeriesColumn, string[] grainColumns,
@@ -46,7 +47,7 @@ public static TimeSeriesImputerEstimator ReplaceMissingTimeSeriesValues(this Tra
4647
/// purpose of this estimator.
4748
/// </summary>
4849
/// <param name="catalog">The transform catalog.</param>
49-
/// <param name="timeSeriesColumn">Column representing the time series. Should be of type <see cref="long"/></param>
50+
/// <param name="timeSeriesColumn">Column representing the time series. Should be of type <see cref="long"/> or <see cref="System.DateTime"/></param>
5051
/// <param name="grainColumns">List of columns to use as grains</param>
5152
/// <param name="filterColumns">List of columns to filter. If <paramref name="filterMode"/> is <see cref="TimeSeriesImputerEstimator.FilterMode.Exclude"/> than columns in the list will be ignored.
5253
/// If <paramref name="filterMode"/> is <see cref="TimeSeriesImputerEstimator.FilterMode.Include"/> than values in the list are the only columns imputed.</param>
@@ -61,15 +62,7 @@ public static TimeSeriesImputerEstimator ReplaceMissingTimeSeriesValues(this Tra
6162
}
6263

6364
/// <summary>
64-
/// Imputes missing rows and column data per grain, based on the dates in the date column. This operation needs to happen to every column in the IDataView,
65-
/// If you "filter" a column using the filterColumns and filterMode parameters, if a row is imputed the default value for that type will be used.
66-
/// Currently only float/double/string columns are supported for imputation strategies, and an empty string is considered "missing" for the
67-
/// purpose of this estimator. A new column is added to the schema after this operation is run. The column is called "IsRowImputed" and is a
68-
/// boolean value representing if the row was created as a result of this operation or not.
69-
///
70-
/// NOTE: It is not recommended to chain this multiple times. If a column is filtered, the default value is placed when a row is imputed, and the
71-
/// default value is not null. Thus any other TimeSeriesImputers will not be able to replace those values anymore causing essentially a very
72-
/// computationally expensive NO-OP.
65+
/// Imputes missing rows and column data per grain, based on the dates in the date column.
7366
///
7467
/// </summary>
7568
/// <remarks>
@@ -83,8 +76,18 @@ public static TimeSeriesImputerEstimator ReplaceMissingTimeSeriesValues(this Tra
8376
/// | Output column data type | All Types |
8477
/// | Exportable to ONNX | No |
8578
///
86-
/// The <xref:Microsoft.ML.Transforms.TimeSeriesImputerEstimator> is not a trivial estimator and needs training.
79+
/// The TimeSeriesImputer imputes missing rows and column data per grain (category), based on the dates in the date column. This operation needs to happen to every column in the IDataView,
80+
/// If you "filter" a column using the filterColumns and filterMode parameters, if a row is imputed the default value for that type will be used.
81+
/// Currently only float/double/string columns are supported for imputation strategies, and an empty string is considered "missing" for the
82+
/// purpose of this estimator. A new column is added to the schema after this operation is run. The column is called "IsRowImputed" and is a
83+
/// boolean value representing if the row was created as a result of this operation or not.
8784
///
85+
/// The imputation strategies that are currently supported are ForwardFill, where the last good value is propagated forward, Backfill, where the next good value is propagated backwards,
86+
/// and Median, where the mathmatical median is used to fill in missing values.
87+
///
88+
/// NOTE: It is not recommended to chain this multiple times. If a column is filtered, the default value is placed when a row is imputed, and the
89+
/// default value is not null. Thus any other TimeSeriesImputers will not be able to replace those values anymore causing essentially a very
90+
/// computationally expensive NO-OP.
8891
///
8992
/// ]]>
9093
/// </format>
@@ -98,7 +101,7 @@ public sealed class TimeSeriesImputerEstimator : IEstimator<TimeSeriesImputerTra
98101

99102
private readonly IHost _host;
100103
private static readonly List<Type> _currentSupportedTypes = new List<Type> { typeof(sbyte), typeof(byte), typeof(short), typeof(ushort), typeof(int), typeof(uint),
101-
typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(string), typeof(ReadOnlyMemory<char>)};
104+
typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(string), typeof(ReadOnlyMemory<char>), typeof(DateTime)};
102105

103106
#region Options
104107
internal sealed class Options : TransformInputBase
@@ -127,18 +130,52 @@ internal sealed class Options : TransformInputBase
127130

128131
#region Class Enums
129132

133+
/// <summary>
134+
/// This is the representation of which Imputation Strategy to use.
135+
/// ForwardFill takes the value from the last good row and propagates it forward anytime a row is imputed or a missing value is found.
136+
/// BackFill is the same as ForwardFill, except it takes from the next good row and propagates backwards.
137+
/// Median only supports float/double, takes the median value found during training and uses that to replace missing values
138+
/// </summary>
130139
public enum ImputationStrategy : byte
131140
{
141+
/// <summary>
142+
/// Takes the value from the last good row and propagates it forward anytime a row is imputed or a missing value is found.
143+
/// </summary>
132144
ForwardFill = 1,
145+
146+
/// <summary>
147+
/// Takes the value from the next good row and propagates it backwards anytime a row is imputed or a missing value is found.
148+
/// </summary>
133149
BackFill = 2,
150+
151+
/// <summary>
152+
/// Takes the median found during training and propagates that anytime a row is imputed or a missing value is found.
153+
/// </summary>
134154
Median = 3,
135155
// Interpolate = 4, interpolate not currently supported in the native code.
136156
};
137157

158+
/// <summary>
159+
/// Method by which columns are selected for imputing values.
160+
/// NoFilter takes all of the columns so you dont have to specify anything.
161+
/// Include only does the specified ImputationStrategy on the columns you specify. The other columns will get a default value.
162+
/// Exclude is the exact opposite of Include, and does the ImputationStrategy on all columns but the ones you specify, which will get the default value.
163+
/// </summary>
138164
public enum FilterMode : byte
139165
{
166+
/// <summary>
167+
/// Takes all of the columns so you dont have to specify anything.
168+
/// </summary>
140169
NoFilter = 1,
170+
171+
/// <summary>
172+
/// Only does the specified ImputationStrategy on the columns you specify. The other columns will get a default value.
173+
/// </summary>
141174
Include = 2,
175+
176+
/// <summary>
177+
/// Does the ImputationStrategy on all columns but the ones you specify, which will get the default value.
178+
/// </summary>
142179
Exclude = 3
143180
};
144181

@@ -331,7 +368,8 @@ private unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimator(IDa
331368
var allColumns = input.Schema.Where(x => _allColumnNames.Contains(x.Name)).Select(x => TypedColumn.CreateTypedColumn(x, _dataColumns)).ToDictionary(x => x.Column.Name);
332369

333370
// Create buffer to hold binary data
334-
var columnBuffer = new byte[4096];
371+
var memoryStream = new MemoryStream(4096);
372+
var binaryWriter = new BinaryWriter(memoryStream, Encoding.UTF8);
335373

336374
// Create TypeId[] for types of grain and data columns;
337375
var dataColumnTypes = new TypeId[_dataColumns.Length];
@@ -365,15 +403,17 @@ private unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimator(IDa
365403

366404
while ((fitResult == FitResult.Continue || fitResult == FitResult.ResetAndContinue) && cursor.MoveNext())
367405
{
368-
BuildColumnByteArray(allColumns, ref columnBuffer, out int serializedDataLength);
406+
BuildColumnByteArray(allColumns, ref binaryWriter);
369407

370-
fixed (byte* bufferPointer = columnBuffer)
408+
fixed (byte* bufferPointer = memoryStream.GetBuffer())
371409
{
372-
var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(serializedDataLength) };
410+
var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(memoryStream.Position) };
373411
success = FitNative(estimatorHandler, binaryArchiveData, out fitResult, out errorHandle);
374412
}
375413
if (!success)
376414
throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
415+
416+
memoryStream.Position = 0;
377417
}
378418

379419
success = CompleteTrainingNative(estimatorHandler, out fitResult, out errorHandle);
@@ -390,18 +430,11 @@ private unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimator(IDa
390430
}
391431
}
392432

393-
private void BuildColumnByteArray(Dictionary<string, TypedColumn> allColumns, ref byte[] columnByteBuffer, out int serializedDataLength)
433+
private void BuildColumnByteArray(Dictionary<string, TypedColumn> allColumns, ref BinaryWriter binaryWriter)
394434
{
395-
serializedDataLength = 0;
396435
foreach (var column in _allColumnNames)
397436
{
398-
var bytes = allColumns[column].GetSerializedValue();
399-
var byteLength = bytes.Length;
400-
if (byteLength + serializedDataLength >= columnByteBuffer.Length)
401-
Array.Resize(ref columnByteBuffer, columnByteBuffer.Length * 2);
402-
403-
Array.Copy(bytes, 0, columnByteBuffer, serializedDataLength, byteLength);
404-
serializedDataLength += byteLength;
437+
allColumns[column].SerializeValue(ref binaryWriter);
405438
}
406439
}
407440

@@ -531,7 +564,7 @@ internal TypedColumn(DataViewSchema.Column column)
531564
}
532565

533566
internal abstract void InitializeGetter(DataViewRowCursor cursor);
534-
internal abstract byte[] GetSerializedValue();
567+
internal abstract void SerializeValue(ref BinaryWriter binaryWriter);
535568
internal abstract TypeId GetTypeId();
536569

537570
internal static TypedColumn CreateTypedColumn(DataViewSchema.Column column, string[] optionalColumns)
@@ -559,6 +592,8 @@ internal static TypedColumn CreateTypedColumn(DataViewSchema.Column column, stri
559592
return new NumericTypedColumn<double>(column, optionalColumns.Contains(column.Name));
560593
else if (type == typeof(ReadOnlyMemory<char>).ToString())
561594
return new StringTypedColumn(column, optionalColumns.Contains(column.Name));
595+
else if (type == typeof(DateTime).ToString())
596+
return new DateTimeTypedColumn(column, optionalColumns.Contains(column.Name));
562597

563598
throw new InvalidOperationException($"Unsupported type {type}");
564599
}
@@ -602,18 +637,14 @@ internal NumericTypedColumn(DataViewSchema.Column column, bool isNullable = fals
602637
_isNullable = isNullable;
603638
}
604639

605-
internal override byte[] GetSerializedValue()
640+
internal override void SerializeValue(ref BinaryWriter binaryWriter)
606641
{
607642
dynamic value = GetValue();
608-
byte[] bytes;
609-
if (value.GetType() == typeof(byte))
610-
bytes = new byte[1] { value };
611-
bytes = BitConverter.GetBytes(value);
612643

613644
if (_isNullable && value.GetType() != typeof(float) && value.GetType() != typeof(double))
614-
return new byte[1] { Convert.ToByte(true) }.Concat(bytes).ToArray();
615-
else
616-
return bytes;
645+
binaryWriter.Write(true);
646+
647+
binaryWriter.Write(value);
617648
}
618649
}
619650

@@ -627,13 +658,41 @@ internal StringTypedColumn(DataViewSchema.Column column, bool isNullable = false
627658
_isNullable = isNullable;
628659
}
629660

630-
internal override byte[] GetSerializedValue()
661+
internal override void SerializeValue(ref BinaryWriter binaryWriter)
631662
{
632663
var value = GetValue().ToString();
633664
var stringBytes = Encoding.UTF8.GetBytes(value);
665+
634666
if (_isNullable)
635-
return new byte[] { Convert.ToByte(true) }.Concat(BitConverter.GetBytes(stringBytes.Length)).Concat(stringBytes).ToArray();
636-
return BitConverter.GetBytes(stringBytes.Length).Concat(stringBytes).ToArray();
667+
binaryWriter.Write(true);
668+
669+
binaryWriter.Write(stringBytes.Length);
670+
671+
binaryWriter.Write(stringBytes);
672+
}
673+
}
674+
675+
private class DateTimeTypedColumn : TypedColumn<DateTime>
676+
{
677+
private static readonly DateTime _unixEpoch = new DateTime(1970, 1, 1);
678+
private readonly bool _isNullable;
679+
680+
internal DateTimeTypedColumn(DataViewSchema.Column column, bool isNullable = false) :
681+
base(column)
682+
{
683+
_isNullable = isNullable;
684+
}
685+
686+
internal override void SerializeValue(ref BinaryWriter binaryWriter)
687+
{
688+
var dateTime = GetValue();
689+
690+
var value = dateTime.Subtract(_unixEpoch).Ticks / TimeSpan.TicksPerSecond;
691+
692+
if (_isNullable)
693+
binaryWriter.Write(true);
694+
695+
binaryWriter.Write(value);
637696
}
638697
}
639698

0 commit comments

Comments
 (0)