Skip to content

Commit 805096d

Browse files
authored
Add inference processor (#6031)
1 parent eba15fc commit 805096d

File tree

4 files changed

+287
-1
lines changed

4 files changed

+287
-1
lines changed

src/Nest/Ingest/ProcessorFormatter.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ internal class ProcessorFormatter : IJsonFormatter<IProcessor>
5050
{ "fingerprint", 34 },
5151
{ "community_id", 35 },
5252
{ "network_direction", 36 },
53-
{ "registered_domain", 37 }
53+
{ "registered_domain", 37 },
54+
{ "inference", 38 }
5455
};
5556

5657
public IProcessor Deserialize(ref JsonReader reader, IJsonFormatterResolver formatterResolver)
@@ -185,6 +186,9 @@ public IProcessor Deserialize(ref JsonReader reader, IJsonFormatterResolver form
185186
case 37:
186187
processor = Deserialize<RegisteredDomainProcessor>(ref reader, formatterResolver);
187188
break;
189+
case 38:
190+
processor = Deserialize<InferenceProcessor>(ref reader, formatterResolver);
191+
break;
188192
}
189193
}
190194
else
@@ -320,6 +324,9 @@ public void Serialize(ref JsonWriter writer, IProcessor value, IJsonFormatterRes
320324
case "registered_domain":
321325
Serialize<IRegisteredDomainProcessor>(ref writer, value, formatterResolver);
322326
break;
327+
case "inference":
328+
Serialize<IInferenceProcessor>(ref writer, value, formatterResolver);
329+
break;
323330
default:
324331
var formatter = DynamicObjectResolver.ExcludeNullCamelCase.GetFormatter<IProcessor>();
325332
formatter.Serialize(ref writer, value, formatterResolver);
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
// Licensed to Elasticsearch B.V under one or more agreements.
2+
// Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
// See the LICENSE file in the project root for more information
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq.Expressions;
8+
using System.Runtime.Serialization;
9+
using Elasticsearch.Net;
10+
using Elasticsearch.Net.Utf8Json;
11+
using Nest;
12+
13+
namespace Nest
14+
{
15+
/// <summary>
16+
/// Uses a pre-trained data frame analytics model to infer against the data that is being ingested in the pipeline.
17+
/// <para />
18+
/// Available in Elasticsearch 7.6.0+ with at least basic license.
19+
/// </summary>
20+
[InterfaceDataContract]
21+
public interface IInferenceProcessor : IProcessor
22+
{
23+
/// <summary>
24+
/// The ID of the model to load and infer against.
25+
/// </summary>
26+
[DataMember(Name = "model_id")]
27+
string ModelId { get; set; }
28+
29+
/// <summary>
30+
/// Field added to incoming documents to contain results objects.
31+
/// </summary>
32+
[DataMember(Name ="target_field")]
33+
Field TargetField { get; set; }
34+
35+
/// <summary>
36+
/// Maps the document field names to the known field names of the model.
37+
/// </summary>
38+
[DataMember(Name = "field_mappings")]
39+
IDictionary<Field, Field> FieldMappings { get; set; }
40+
41+
/// <summary>
42+
/// Contains the inference type and its options.
43+
/// </summary>
44+
[DataMember(Name = "inference_config")]
45+
IInferenceConfig InferenceConfig { get; set; }
46+
}
47+
48+
/// <inheritdoc cref="IInferenceProcessor" />
49+
public class InferenceProcessor : ProcessorBase, IInferenceProcessor
50+
{
51+
/// <inheritdoc />
52+
public string ModelId { get; set; }
53+
54+
/// <inheritdoc />
55+
public Field TargetField { get; set; }
56+
57+
/// <inheritdoc />
58+
public IDictionary<Field, Field> FieldMappings { get; set; }
59+
60+
/// <inheritdoc />
61+
public IInferenceConfig InferenceConfig { get; set; }
62+
63+
protected override string Name => "inference";
64+
}
65+
66+
/// <inheritdoc cref="IInferenceProcessor" />
67+
public class InferenceProcessorDescriptor<T>
68+
: ProcessorDescriptorBase<InferenceProcessorDescriptor<T>, IInferenceProcessor>, IInferenceProcessor
69+
where T : class
70+
{
71+
protected override string Name => "inference";
72+
73+
Field IInferenceProcessor.TargetField { get; set; }
74+
string IInferenceProcessor.ModelId { get; set; }
75+
IInferenceConfig IInferenceProcessor.InferenceConfig { get; set; }
76+
IDictionary<Field, Field> IInferenceProcessor.FieldMappings { get; set; }
77+
78+
/// <inheritdoc cref="IInferenceProcessor.TargetField" />
79+
public InferenceProcessorDescriptor<T> TargetField(Field field) => Assign(field, (a, v) => a.TargetField = v);
80+
81+
/// <inheritdoc cref="IInferenceProcessor.TargetField" />
82+
public InferenceProcessorDescriptor<T> TargetField<TValue>(Expression<Func<T, TValue>> objectPath) =>
83+
Assign(objectPath, (a, v) => a.TargetField = v);
84+
85+
/// <inheritdoc cref="IInferenceProcessor.ModelId" />
86+
public InferenceProcessorDescriptor<T> ModelId(string modelId) =>
87+
Assign(modelId, (a, v) => a.ModelId = v);
88+
89+
/// <inheritdoc cref="IInferenceProcessor.ModelId" />
90+
public InferenceProcessorDescriptor<T> InferenceConfig(Func<InferenceConfigDescriptor<T>, IInferenceConfig> selector) =>
91+
Assign(selector, (a, v) => a.InferenceConfig = v.InvokeOrDefault(new InferenceConfigDescriptor<T>()));
92+
93+
/// <inheritdoc cref="IInferenceProcessor.FieldMappings" />
94+
public InferenceProcessorDescriptor<T> FieldMappings(Func<FluentDictionary<Field, Field>, FluentDictionary<Field, Field>> selector = null) =>
95+
Assign(selector, (a, v) => a.FieldMappings = v.InvokeOrDefault(new FluentDictionary<Field, Field>()));
96+
}
97+
98+
[ReadAs(typeof(InferenceConfig))]
99+
public interface IInferenceConfig
100+
{
101+
102+
[DataMember(Name = "regression")]
103+
IRegressionInferenceConfig Regression { get; set; }
104+
105+
[DataMember(Name = "classification")]
106+
IClassificationInferenceConfig Classification { get; set; }
107+
}
108+
109+
public class InferenceConfig
110+
: IInferenceConfig
111+
{
112+
public IRegressionInferenceConfig Regression { get; set; }
113+
114+
public IClassificationInferenceConfig Classification { get; set; }
115+
}
116+
117+
public class InferenceConfigDescriptor<T> : DescriptorBase<InferenceConfigDescriptor<T>, IInferenceConfig>, IInferenceConfig
118+
{
119+
IRegressionInferenceConfig IInferenceConfig.Regression { get; set; }
120+
IClassificationInferenceConfig IInferenceConfig.Classification { get; set; }
121+
122+
public InferenceConfigDescriptor<T> Regression(Func<RegressionInferenceConfigDescriptor<T>, IRegressionInferenceConfig> selector) =>
123+
Assign(selector, (a, v) => a.Regression = v.InvokeOrDefault(new RegressionInferenceConfigDescriptor<T>()));
124+
125+
public InferenceConfigDescriptor<T> Classification(Func<ClassificationInferenceConfigDescriptor<T>, IClassificationInferenceConfig> selector) =>
126+
Assign(selector, (a, v) => a.Classification = v.InvokeOrDefault(new ClassificationInferenceConfigDescriptor<T>()));
127+
}
128+
129+
[ReadAs(typeof(RegressionInferenceConfig))]
130+
public interface IRegressionInferenceConfig
131+
{
132+
/// <summary>
133+
/// Specifies the field to which the inference prediction is written. Defaults to <c>predicted_value</c>.
134+
/// </summary>
135+
[DataMember(Name = "results_field")]
136+
Field ResultsField { get; set; }
137+
}
138+
139+
public class RegressionInferenceConfig : IRegressionInferenceConfig
140+
{
141+
/// <summary>
142+
/// Specifies the field to which the inference prediction is written. Defaults to <c>predicted_value</c>.
143+
/// </summary>
144+
public Field ResultsField { get; set; }
145+
}
146+
147+
public class RegressionInferenceConfigDescriptor<T>
148+
: DescriptorBase<RegressionInferenceConfigDescriptor<T>, IRegressionInferenceConfig>, IRegressionInferenceConfig
149+
{
150+
Field IRegressionInferenceConfig.ResultsField { get; set; }
151+
152+
/// <inheritdoc cref="IRegressionInferenceConfig.ResultsField" />
153+
public RegressionInferenceConfigDescriptor<T> ResultsField(Field field) => Assign(field, (a, v) => a.ResultsField = v);
154+
155+
/// <inheritdoc cref="IRegressionInferenceConfig.ResultsField" />
156+
public RegressionInferenceConfigDescriptor<T> ResultsField<TValue>(Expression<Func<T, TValue>> objectPath) =>
157+
Assign(objectPath, (a, v) => a.ResultsField = v);
158+
}
159+
160+
[ReadAs(typeof(ClassificationInferenceConfig))]
161+
public interface IClassificationInferenceConfig
162+
{
163+
/// <summary>
164+
/// Specifies the field to which the inference prediction is written. Defaults to <c>predicted_value</c>.
165+
/// </summary>
166+
[DataMember(Name = "results_field")]
167+
Field ResultsField { get; set; }
168+
169+
/// <summary>
170+
/// Specifies the number of top class predictions to return. Defaults to <c>0</c>.
171+
/// </summary>
172+
[DataMember(Name = "num_top_classes")]
173+
int? NumTopClasses { get; set; }
174+
175+
/// <summary>
176+
/// Specifies the field to which the top classes are written. Defaults to <c>top_classes</c>.
177+
/// </summary>
178+
[DataMember(Name = "top_classes_results_field")]
179+
Field TopClassesResultsField { get; set; }
180+
}
181+
182+
public class ClassificationInferenceConfig : IClassificationInferenceConfig
183+
{
184+
/// <summary>
185+
/// Specifies the field to which the inference prediction is written. Defaults to <c>predicted_value</c>.
186+
/// </summary>
187+
public Field ResultsField { get; set; }
188+
189+
/// <summary>
190+
/// Specifies the number of top class predictions to return. Defaults to <c>0</c>.
191+
/// </summary>
192+
public int? NumTopClasses { get; set; }
193+
194+
/// <summary>
195+
/// Specifies the field to which the top classes are written. Defaults to <c>top_classes</c>.
196+
/// </summary>
197+
public Field TopClassesResultsField { get; set; }
198+
}
199+
200+
public class ClassificationInferenceConfigDescriptor<T> : DescriptorBase<ClassificationInferenceConfigDescriptor<T>, IClassificationInferenceConfig>, IClassificationInferenceConfig
201+
{
202+
Field IClassificationInferenceConfig.ResultsField { get; set; }
203+
int? IClassificationInferenceConfig.NumTopClasses { get; set; }
204+
Field IClassificationInferenceConfig.TopClassesResultsField { get; set; }
205+
206+
/// <inheritdoc cref="IClassificationInferenceConfig.ResultsField" />
207+
public ClassificationInferenceConfigDescriptor<T> ResultsField(Field field) => Assign(field, (a, v) => a.ResultsField = v);
208+
209+
/// <inheritdoc cref="IClassificationInferenceConfig.ResultsField" />
210+
public ClassificationInferenceConfigDescriptor<T> ResultsField<TValue>(Expression<Func<T, TValue>> objectPath) =>
211+
Assign(objectPath, (a, v) => a.ResultsField = v);
212+
213+
/// <inheritdoc cref="IClassificationInferenceConfig.NumTopClasses" />
214+
public ClassificationInferenceConfigDescriptor<T> NumTopClasses(int? numTopClasses) => Assign(numTopClasses, (a, v) => a.NumTopClasses = v);
215+
216+
/// <inheritdoc cref="IClassificationInferenceConfig.TopClassesResultsField" />
217+
public ClassificationInferenceConfigDescriptor<T> TopClassesResultsField(Field field) => Assign(field, (a, v) => a.TopClassesResultsField = v);
218+
219+
/// <inheritdoc cref="IClassificationInferenceConfig.TopClassesResultsField" />
220+
public ClassificationInferenceConfigDescriptor<T> TopClassesResultsField<TValue>(Expression<Func<T, TValue>> objectPath) =>
221+
Assign(objectPath, (a, v) => a.TopClassesResultsField = v);
222+
}
223+
}

src/Nest/Ingest/ProcessorsDescriptor.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,5 +197,9 @@ public ProcessorsDescriptor NetworkDirection<T>(Func<NetworkDirectionProcessorDe
197197
/// <inheritdoc cref="IRegisteredDomainProcessor"/>
198198
public ProcessorsDescriptor RegisteredDomain<T>(Func<RegisteredDomainProcessorDescriptor<T>, IRegisteredDomainProcessor> selector) where T : class =>
199199
Assign(selector, (a, v) => a.AddIfNotNull(v?.Invoke(new RegisteredDomainProcessorDescriptor<T>())));
200+
201+
/// <inheritdoc cref="IInferenceProcessor"/>
202+
public ProcessorsDescriptor Inference<T>(Func<InferenceProcessorDescriptor<T>, IInferenceProcessor> selector) where T : class =>
203+
Assign(selector, (a, v) => a.AddIfNotNull(v?.Invoke(new InferenceProcessorDescriptor<T>())));
200204
}
201205
}

tests/Tests/Ingest/ProcessorAssertions.cs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,58 @@ public class Pipeline : ProcessorAssertion
702702
public override string Key => "pipeline";
703703
}
704704

705+
[SkipVersion("<7.6.0", "Introduced in Elasticsearch 7.6.0+")]
706+
public class Inference : ProcessorAssertion
707+
{
708+
public override Func<ProcessorsDescriptor, IPromise<IList<IProcessor>>> Fluent => d => d
709+
.Inference<Project>(c => c
710+
.TargetField(p => p.Name)
711+
.ModelId("model_id")
712+
.FieldMappings()
713+
.InferenceConfig(i => i
714+
.Classification(cc => cc
715+
.ResultsField("results")
716+
.NumTopClasses(10)
717+
.TopClassesResultsField("topClasses")
718+
)
719+
)
720+
);
721+
722+
public override IProcessor Initializer => new InferenceProcessor
723+
{
724+
TargetField = "name",
725+
ModelId = "model_id",
726+
FieldMappings = new Dictionary<Field, Field>(),
727+
InferenceConfig = new InferenceConfig
728+
{
729+
Classification = new ClassificationInferenceConfig
730+
{
731+
ResultsField = "results",
732+
NumTopClasses = 10,
733+
TopClassesResultsField = "topClasses"
734+
}
735+
}
736+
};
737+
738+
public override object Json => new
739+
{
740+
target_field = "name",
741+
model_id = "model_id",
742+
field_mappings = new {},
743+
inference_config = new
744+
{
745+
classification = new
746+
{
747+
results_field = "results",
748+
num_top_classes = 10,
749+
top_classes_results_field = "topClasses"
750+
}
751+
}
752+
};
753+
754+
public override string Key => "inference";
755+
}
756+
705757
[SkipVersion("<7.11.0", "Uses URI parts which was introduced in 7.11.0")]
706758
public class UriParts : ProcessorAssertion
707759
{

0 commit comments

Comments
 (0)