|
12 | 12 |
|
13 | 13 | namespace Microsoft.ML.TimeSeriesProcessing
|
14 | 14 | {
|
15 |
| - /// <summary> |
16 |
| - /// This transform computes the p-values and martingale scores for a supposedly i.i.d input sequence of floats. In other words, it assumes |
17 |
| - /// the input sequence represents the raw anomaly score which might have been computed via another process. |
18 |
| - /// </summary> |
19 |
| - public abstract class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase<Single, IidAnomalyDetectionBase.State> |
| 15 | + public class IidAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveModel |
20 | 16 | {
|
21 |
| - public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env) |
22 |
| - : base(args, name, env) |
23 |
| - { |
24 |
| - InitialWindowSize = 0; |
25 |
| - StateRef = new State(); |
26 |
| - StateRef.InitState(WindowSize, InitialWindowSize, this, Host); |
27 |
| - } |
| 17 | + public bool IsRowToRowMapper => Base.IsRowToRowMapper; |
28 | 18 |
|
29 |
| - public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name) |
30 |
| - : base(env, ctx, name) |
31 |
| - { |
32 |
| - Host.CheckDecode(InitialWindowSize == 0); |
33 |
| - StateRef = new State(ctx.Reader); |
34 |
| - StateRef.InitState(this, Host); |
35 |
| - } |
| 19 | + IStatefulTransformer IStatefulTransformer.Clone() => Base.Clone(); |
| 20 | + |
| 21 | + public Schema GetOutputSchema(Schema inputSchema) => Base.GetOutputSchema(inputSchema); |
| 22 | + |
| 23 | + public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) => Base.GetRowToRowMapper(inputSchema); |
| 24 | + |
| 25 | + public IRowToRowMapper GetStatefulRowToRowMapper(Schema inputSchema) => ((IStatefulTransformer)Base).GetStatefulRowToRowMapper(inputSchema); |
| 26 | + |
| 27 | + public IDataView Transform(IDataView input) => Base.Transform(input); |
36 | 28 |
|
37 |
| - public override Schema GetOutputSchema(Schema inputSchema) |
| 29 | + public virtual void Save(ModelSaveContext ctx) |
38 | 30 | {
|
39 |
| - Host.CheckValue(inputSchema, nameof(inputSchema)); |
| 31 | + Base.SaveThis(ctx); |
| 32 | + } |
40 | 33 |
|
41 |
| - if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col)) |
42 |
| - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName); |
| 34 | + internal IStatefulRowMapper MakeRowMapper(Schema schema) => Base.MakeRowMapper(schema); |
43 | 35 |
|
44 |
| - var colType = inputSchema[col].Type; |
45 |
| - if (colType != NumberType.R4) |
46 |
| - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, "float", colType.ToString()); |
| 36 | + internal IDataTransform MakeDataTransform(IDataView input) => Base.MakeDataTransform(input); |
47 | 37 |
|
48 |
| - return Transform(new EmptyDataView(Host, inputSchema)).Schema; |
| 38 | + internal IidAnomalyDetectionBase Base; |
| 39 | + public IidAnomalyDetectionBaseWrapper(ArgumentsBase args, string name, IHostEnvironment env) |
| 40 | + { |
| 41 | + Base = new IidAnomalyDetectionBase(args, name, env, this); |
49 | 42 | }
|
50 | 43 |
|
51 |
| - public override void Save(ModelSaveContext ctx) |
| 44 | + public IidAnomalyDetectionBaseWrapper(IHostEnvironment env, ModelLoadContext ctx, string name) |
52 | 45 | {
|
53 |
| - ctx.CheckAtModel(); |
54 |
| - Host.Assert(InitialWindowSize == 0); |
55 |
| - base.Save(ctx); |
56 |
| - |
57 |
| - // *** Binary format *** |
58 |
| - // <base> |
59 |
| - // State: StateRef |
60 |
| - StateRef.Save(ctx.Writer); |
| 46 | + Base = new IidAnomalyDetectionBase(env, ctx, name, this); |
61 | 47 | }
|
62 | 48 |
|
63 |
| - public sealed class State : AnomalyDetectionStateBase |
| 49 | + /// <summary> |
| 50 | + /// This transform computes the p-values and martingale scores for a supposedly i.i.d input sequence of floats. In other words, it assumes |
| 51 | + /// the input sequence represents the raw anomaly score which might have been computed via another process. |
| 52 | + /// </summary> |
| 53 | + internal class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase<Single, IidAnomalyDetectionBase.State> |
64 | 54 | {
|
65 |
| - public State() |
66 |
| - { |
67 |
| - } |
| 55 | + internal IidAnomalyDetectionBaseWrapper Parent; |
68 | 56 |
|
69 |
| - internal State(BinaryReader reader) : base(reader) |
| 57 | + public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env, IidAnomalyDetectionBaseWrapper parent) |
| 58 | + : base(args, name, env) |
70 | 59 | {
|
71 |
| - WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); |
72 |
| - InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); |
| 60 | + InitialWindowSize = 0; |
| 61 | + StateRef = new State(); |
| 62 | + StateRef.InitState(WindowSize, InitialWindowSize, this, Host); |
| 63 | + Parent = parent; |
73 | 64 | }
|
74 | 65 |
|
75 |
| - internal override void Save(BinaryWriter writer) |
| 66 | + public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name, IidAnomalyDetectionBaseWrapper parent) |
| 67 | + : base(env, ctx, name) |
76 | 68 | {
|
77 |
| - base.Save(writer); |
78 |
| - TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer); |
79 |
| - TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer); |
| 69 | + Host.CheckDecode(InitialWindowSize == 0); |
| 70 | + StateRef = new State(ctx.Reader); |
| 71 | + StateRef.InitState(this, Host); |
| 72 | + Parent = parent; |
80 | 73 | }
|
81 | 74 |
|
82 |
| - private protected override void CloneCore(StateBase state) |
| 75 | + public override Schema GetOutputSchema(Schema inputSchema) |
83 | 76 | {
|
84 |
| - base.CloneCore(state); |
85 |
| - Contracts.Assert(state is State); |
86 |
| - var stateLocal = state as State; |
87 |
| - stateLocal.WindowedBuffer = WindowedBuffer.Clone(); |
88 |
| - stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone(); |
89 |
| - } |
| 77 | + Host.CheckValue(inputSchema, nameof(inputSchema)); |
90 | 78 |
|
91 |
| - private protected override void LearnStateFromDataCore(FixedSizeQueue<Single> data) |
92 |
| - { |
93 |
| - // This method is empty because there is no need for initial tuning for this transform. |
| 79 | + if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col)) |
| 80 | + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName); |
| 81 | + |
| 82 | + var colType = inputSchema[col].Type; |
| 83 | + if (colType != NumberType.R4) |
| 84 | + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, NumberType.R4.ToString(), colType.ToString()); |
| 85 | + |
| 86 | + return Transform(new EmptyDataView(Host, inputSchema)).Schema; |
94 | 87 | }
|
95 | 88 |
|
96 |
| - private protected override void InitializeAnomalyDetector() |
| 89 | + public override void Save(ModelSaveContext ctx) |
97 | 90 | {
|
98 |
| - // This method is empty because there is no need for any extra initialization for this transform. |
| 91 | + Parent.Save(ctx); |
99 | 92 | }
|
100 | 93 |
|
101 |
| - private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue<Single> windowedBuffer, long iteration) |
| 94 | + internal void SaveThis(ModelSaveContext ctx) |
102 | 95 | {
|
103 |
| - // This transform treats the input sequenence as the raw anomaly score. |
104 |
| - return (double)input; |
| 96 | + ctx.CheckAtModel(); |
| 97 | + Host.Assert(InitialWindowSize == 0); |
| 98 | + base.Save(ctx); |
| 99 | + |
| 100 | + // *** Binary format *** |
| 101 | + // <base> |
| 102 | + // State: StateRef |
| 103 | + StateRef.Save(ctx.Writer); |
105 | 104 | }
|
106 | 105 |
|
107 |
| - public override void Consume(float value) |
| 106 | + internal sealed class State : AnomalyDetectionStateBase |
108 | 107 | {
|
| 108 | + public State() |
| 109 | + { |
| 110 | + } |
| 111 | + |
| 112 | + internal State(BinaryReader reader) : base(reader) |
| 113 | + { |
| 114 | + WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); |
| 115 | + InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); |
| 116 | + } |
| 117 | + |
| 118 | + internal override void Save(BinaryWriter writer) |
| 119 | + { |
| 120 | + base.Save(writer); |
| 121 | + TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer); |
| 122 | + TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer); |
| 123 | + } |
| 124 | + |
| 125 | + private protected override void CloneCore(State state) |
| 126 | + { |
| 127 | + base.CloneCore(state); |
| 128 | + Contracts.Assert(state is State); |
| 129 | + var stateLocal = state as State; |
| 130 | + stateLocal.WindowedBuffer = WindowedBuffer.Clone(); |
| 131 | + stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone(); |
| 132 | + } |
| 133 | + |
| 134 | + private protected override void LearnStateFromDataCore(FixedSizeQueue<Single> data) |
| 135 | + { |
| 136 | + // This method is empty because there is no need for initial tuning for this transform. |
| 137 | + } |
| 138 | + |
| 139 | + private protected override void InitializeAnomalyDetector() |
| 140 | + { |
| 141 | + // This method is empty because there is no need for any extra initialization for this transform. |
| 142 | + } |
| 143 | + |
| 144 | + private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue<Single> windowedBuffer, long iteration) |
| 145 | + { |
| 146 | + // This transform treats the input sequenence as the raw anomaly score. |
| 147 | + return (double)input; |
| 148 | + } |
| 149 | + |
| 150 | + public override void Consume(float value) |
| 151 | + { |
| 152 | + } |
109 | 153 | }
|
110 | 154 | }
|
111 | 155 | }
|
|
0 commit comments