Skip to content

Commit 8e4f596

Browse files
committed
Make time series state internal.
1 parent ad1d222 commit 8e4f596

File tree

9 files changed

+970
-887
lines changed

9 files changed

+970
-887
lines changed

src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
namespace Microsoft.ML.StaticPipe
1111
{
12-
using IidBase = Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase<float, Microsoft.ML.TimeSeriesProcessing.IidAnomalyDetectionBase.State>;
13-
using SsaBase = Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase<float, Microsoft.ML.TimeSeriesProcessing.SsaAnomalyDetectionBase.State>;
1412

1513
/// <summary>
1614
/// Static API extension methods for <see cref="IidChangePointEstimator"/>.
@@ -25,7 +23,7 @@ public OutColumn(
2523
Scalar<float> input,
2624
int confidence,
2725
int changeHistoryLength,
28-
IidBase.MartingaleType martingale,
26+
MartingaleType martingale,
2927
double eps)
3028
: base(new Reconciler(confidence, changeHistoryLength, martingale, eps), input)
3129
{
@@ -37,13 +35,13 @@ private sealed class Reconciler : EstimatorReconciler
3735
{
3836
private readonly int _confidence;
3937
private readonly int _changeHistoryLength;
40-
private readonly IidBase.MartingaleType _martingale;
38+
private readonly MartingaleType _martingale;
4139
private readonly double _eps;
4240

4341
public Reconciler(
4442
int confidence,
4543
int changeHistoryLength,
46-
IidBase.MartingaleType martingale,
44+
MartingaleType martingale,
4745
double eps)
4846
{
4947
_confidence = confidence;
@@ -77,7 +75,7 @@ public static Vector<double> IidChangePointDetect(
7775
this Scalar<float> input,
7876
int confidence,
7977
int changeHistoryLength,
80-
IidBase.MartingaleType martingale = IidBase.MartingaleType.Power,
78+
MartingaleType martingale = MartingaleType.Power,
8179
double eps = 0.1) => new OutColumn(input, confidence, changeHistoryLength, martingale, eps);
8280
}
8381

@@ -93,7 +91,7 @@ private sealed class OutColumn : Vector<double>
9391
public OutColumn(Scalar<float> input,
9492
int confidence,
9593
int pvalueHistoryLength,
96-
IidBase.AnomalySide side)
94+
AnomalySide side)
9795
: base(new Reconciler(confidence, pvalueHistoryLength, side), input)
9896
{
9997
Input = input;
@@ -104,12 +102,12 @@ private sealed class Reconciler : EstimatorReconciler
104102
{
105103
private readonly int _confidence;
106104
private readonly int _pvalueHistoryLength;
107-
private readonly IidBase.AnomalySide _side;
105+
private readonly AnomalySide _side;
108106

109107
public Reconciler(
110108
int confidence,
111109
int pvalueHistoryLength,
112-
IidBase.AnomalySide side)
110+
AnomalySide side)
113111
{
114112
_confidence = confidence;
115113
_pvalueHistoryLength = pvalueHistoryLength;
@@ -140,7 +138,7 @@ public static Vector<double> IidSpikeDetect(
140138
this Scalar<float> input,
141139
int confidence,
142140
int pvalueHistoryLength,
143-
IidBase.AnomalySide side = IidBase.AnomalySide.TwoSided
141+
AnomalySide side = AnomalySide.TwoSided
144142
) => new OutColumn(input, confidence, pvalueHistoryLength, side);
145143
}
146144

@@ -159,7 +157,7 @@ public OutColumn(Scalar<float> input,
159157
int trainingWindowSize,
160158
int seasonalityWindowSize,
161159
ErrorFunction errorFunction,
162-
SsaBase.MartingaleType martingale,
160+
MartingaleType martingale,
163161
double eps)
164162
: base(new Reconciler(confidence, changeHistoryLength, trainingWindowSize, seasonalityWindowSize, errorFunction, martingale, eps), input)
165163
{
@@ -174,7 +172,7 @@ private sealed class Reconciler : EstimatorReconciler
174172
private readonly int _trainingWindowSize;
175173
private readonly int _seasonalityWindowSize;
176174
private readonly ErrorFunction _errorFunction;
177-
private readonly SsaBase.MartingaleType _martingale;
175+
private readonly MartingaleType _martingale;
178176
private readonly double _eps;
179177

180178
public Reconciler(
@@ -183,7 +181,7 @@ public Reconciler(
183181
int trainingWindowSize,
184182
int seasonalityWindowSize,
185183
ErrorFunction errorFunction,
186-
SsaBase.MartingaleType martingale,
184+
MartingaleType martingale,
187185
double eps)
188186
{
189187
_confidence = confidence;
@@ -226,7 +224,7 @@ public static Vector<double> SsaChangePointDetect(
226224
int trainingWindowSize,
227225
int seasonalityWindowSize,
228226
ErrorFunction errorFunction = ErrorFunction.SignedDifference,
229-
SsaBase.MartingaleType martingale = SsaBase.MartingaleType.Power,
227+
MartingaleType martingale = MartingaleType.Power,
230228
double eps = 0.1) => new OutColumn(input, confidence, changeHistoryLength, trainingWindowSize, seasonalityWindowSize, errorFunction, martingale, eps);
231229
}
232230

@@ -244,7 +242,7 @@ public OutColumn(Scalar<float> input,
244242
int pvalueHistoryLength,
245243
int trainingWindowSize,
246244
int seasonalityWindowSize,
247-
SsaBase.AnomalySide side,
245+
AnomalySide side,
248246
ErrorFunction errorFunction)
249247
: base(new Reconciler(confidence, pvalueHistoryLength, trainingWindowSize, seasonalityWindowSize, side, errorFunction), input)
250248
{
@@ -258,15 +256,15 @@ private sealed class Reconciler : EstimatorReconciler
258256
private readonly int _pvalueHistoryLength;
259257
private readonly int _trainingWindowSize;
260258
private readonly int _seasonalityWindowSize;
261-
private readonly SsaBase.AnomalySide _side;
259+
private readonly AnomalySide _side;
262260
private readonly ErrorFunction _errorFunction;
263261

264262
public Reconciler(
265263
int confidence,
266264
int pvalueHistoryLength,
267265
int trainingWindowSize,
268266
int seasonalityWindowSize,
269-
SsaBase.AnomalySide side,
267+
AnomalySide side,
270268
ErrorFunction errorFunction)
271269
{
272270
_confidence = confidence;
@@ -306,7 +304,7 @@ public static Vector<double> SsaSpikeDetect(
306304
int changeHistoryLength,
307305
int trainingWindowSize,
308306
int seasonalityWindowSize,
309-
SsaBase.AnomalySide side = SsaBase.AnomalySide.TwoSided,
307+
AnomalySide side = AnomalySide.TwoSided,
310308
ErrorFunction errorFunction = ErrorFunction.SignedDifference
311309
) => new OutColumn(input, confidence, changeHistoryLength, trainingWindowSize, seasonalityWindowSize, side, errorFunction);
312310

src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs

Lines changed: 107 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,100 +12,144 @@
1212

1313
namespace Microsoft.ML.TimeSeriesProcessing
1414
{
15-
/// <summary>
16-
/// This transform computes the p-values and martingale scores for a supposedly i.i.d input sequence of floats. In other words, it assumes
17-
/// the input sequence represents the raw anomaly score which might have been computed via another process.
18-
/// </summary>
19-
public abstract class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase<Single, IidAnomalyDetectionBase.State>
15+
public class IidAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveModel
2016
{
21-
public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env)
22-
: base(args, name, env)
23-
{
24-
InitialWindowSize = 0;
25-
StateRef = new State();
26-
StateRef.InitState(WindowSize, InitialWindowSize, this, Host);
27-
}
17+
public bool IsRowToRowMapper => Base.IsRowToRowMapper;
2818

29-
public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name)
30-
: base(env, ctx, name)
31-
{
32-
Host.CheckDecode(InitialWindowSize == 0);
33-
StateRef = new State(ctx.Reader);
34-
StateRef.InitState(this, Host);
35-
}
19+
IStatefulTransformer IStatefulTransformer.Clone() => Base.Clone();
20+
21+
public Schema GetOutputSchema(Schema inputSchema) => Base.GetOutputSchema(inputSchema);
22+
23+
public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) => Base.GetRowToRowMapper(inputSchema);
24+
25+
public IRowToRowMapper GetStatefulRowToRowMapper(Schema inputSchema) => ((IStatefulTransformer)Base).GetStatefulRowToRowMapper(inputSchema);
26+
27+
public IDataView Transform(IDataView input) => Base.Transform(input);
3628

37-
public override Schema GetOutputSchema(Schema inputSchema)
29+
public virtual void Save(ModelSaveContext ctx)
3830
{
39-
Host.CheckValue(inputSchema, nameof(inputSchema));
31+
Base.SaveThis(ctx);
32+
}
4033

41-
if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col))
42-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName);
34+
internal IStatefulRowMapper MakeRowMapper(Schema schema) => Base.MakeRowMapper(schema);
4335

44-
var colType = inputSchema[col].Type;
45-
if (colType != NumberType.R4)
46-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, "float", colType.ToString());
36+
internal IDataTransform MakeDataTransform(IDataView input) => Base.MakeDataTransform(input);
4737

48-
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
38+
internal IidAnomalyDetectionBase Base;
39+
public IidAnomalyDetectionBaseWrapper(ArgumentsBase args, string name, IHostEnvironment env)
40+
{
41+
Base = new IidAnomalyDetectionBase(args, name, env, this);
4942
}
5043

51-
public override void Save(ModelSaveContext ctx)
44+
public IidAnomalyDetectionBaseWrapper(IHostEnvironment env, ModelLoadContext ctx, string name)
5245
{
53-
ctx.CheckAtModel();
54-
Host.Assert(InitialWindowSize == 0);
55-
base.Save(ctx);
56-
57-
// *** Binary format ***
58-
// <base>
59-
// State: StateRef
60-
StateRef.Save(ctx.Writer);
46+
Base = new IidAnomalyDetectionBase(env, ctx, name, this);
6147
}
6248

63-
public sealed class State : AnomalyDetectionStateBase
49+
/// <summary>
50+
/// This transform computes the p-values and martingale scores for a supposedly i.i.d input sequence of floats. In other words, it assumes
51+
/// the input sequence represents the raw anomaly score which might have been computed via another process.
52+
/// </summary>
53+
internal class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase<Single, IidAnomalyDetectionBase.State>
6454
{
65-
public State()
66-
{
67-
}
55+
internal IidAnomalyDetectionBaseWrapper Parent;
6856

69-
internal State(BinaryReader reader) : base(reader)
57+
public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env, IidAnomalyDetectionBaseWrapper parent)
58+
: base(args, name, env)
7059
{
71-
WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
72-
InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
60+
InitialWindowSize = 0;
61+
StateRef = new State();
62+
StateRef.InitState(WindowSize, InitialWindowSize, this, Host);
63+
Parent = parent;
7364
}
7465

75-
internal override void Save(BinaryWriter writer)
66+
public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name, IidAnomalyDetectionBaseWrapper parent)
67+
: base(env, ctx, name)
7668
{
77-
base.Save(writer);
78-
TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer);
79-
TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer);
69+
Host.CheckDecode(InitialWindowSize == 0);
70+
StateRef = new State(ctx.Reader);
71+
StateRef.InitState(this, Host);
72+
Parent = parent;
8073
}
8174

82-
private protected override void CloneCore(StateBase state)
75+
public override Schema GetOutputSchema(Schema inputSchema)
8376
{
84-
base.CloneCore(state);
85-
Contracts.Assert(state is State);
86-
var stateLocal = state as State;
87-
stateLocal.WindowedBuffer = WindowedBuffer.Clone();
88-
stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone();
89-
}
77+
Host.CheckValue(inputSchema, nameof(inputSchema));
9078

91-
private protected override void LearnStateFromDataCore(FixedSizeQueue<Single> data)
92-
{
93-
// This method is empty because there is no need for initial tuning for this transform.
79+
if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col))
80+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName);
81+
82+
var colType = inputSchema[col].Type;
83+
if (colType != NumberType.R4)
84+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, NumberType.R4.ToString(), colType.ToString());
85+
86+
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
9487
}
9588

96-
private protected override void InitializeAnomalyDetector()
89+
public override void Save(ModelSaveContext ctx)
9790
{
98-
// This method is empty because there is no need for any extra initialization for this transform.
91+
Parent.Save(ctx);
9992
}
10093

101-
private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue<Single> windowedBuffer, long iteration)
94+
internal void SaveThis(ModelSaveContext ctx)
10295
{
103-
// This transform treats the input sequenence as the raw anomaly score.
104-
return (double)input;
96+
ctx.CheckAtModel();
97+
Host.Assert(InitialWindowSize == 0);
98+
base.Save(ctx);
99+
100+
// *** Binary format ***
101+
// <base>
102+
// State: StateRef
103+
StateRef.Save(ctx.Writer);
105104
}
106105

107-
public override void Consume(float value)
106+
internal sealed class State : AnomalyDetectionStateBase
108107
{
108+
public State()
109+
{
110+
}
111+
112+
internal State(BinaryReader reader) : base(reader)
113+
{
114+
WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
115+
InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
116+
}
117+
118+
internal override void Save(BinaryWriter writer)
119+
{
120+
base.Save(writer);
121+
TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer);
122+
TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer);
123+
}
124+
125+
private protected override void CloneCore(State state)
126+
{
127+
base.CloneCore(state);
128+
Contracts.Assert(state is State);
129+
var stateLocal = state as State;
130+
stateLocal.WindowedBuffer = WindowedBuffer.Clone();
131+
stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone();
132+
}
133+
134+
private protected override void LearnStateFromDataCore(FixedSizeQueue<Single> data)
135+
{
136+
// This method is empty because there is no need for initial tuning for this transform.
137+
}
138+
139+
private protected override void InitializeAnomalyDetector()
140+
{
141+
// This method is empty because there is no need for any extra initialization for this transform.
142+
}
143+
144+
private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue<Single> windowedBuffer, long iteration)
145+
{
146+
// This transform treats the input sequenence as the raw anomaly score.
147+
return (double)input;
148+
}
149+
150+
public override void Consume(float value)
151+
{
152+
}
109153
}
110154
}
111155
}

0 commit comments

Comments
 (0)