Skip to content

Commit cb34ca6

Browse files
authored
[ML] add new multi custom processor for data frame analytics and model inference (#67362)
This adds the multi custom feature processor to data frame analytics and inference. The `multi_encoding` processor allows custom processors to be chained together and use the outputs from one processor as the inputs to another.
1 parent 24ebcc8 commit cb34ca6

File tree

22 files changed

+778
-22
lines changed

22 files changed

+778
-22
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.elasticsearch.client.ml.inference;
2020

2121
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
22+
import org.elasticsearch.client.ml.inference.preprocessing.Multi;
2223
import org.elasticsearch.client.ml.inference.preprocessing.NGram;
2324
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
2425
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
@@ -60,6 +61,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
6061
CustomWordEmbedding::fromXContent));
6162
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(NGram.NAME),
6263
NGram::fromXContent));
64+
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(Multi.NAME),
65+
Multi::fromXContent));
6366

6467
// Model
6568
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.client.ml.inference.preprocessing;
21+
22+
23+
import java.io.IOException;
24+
import java.util.List;
25+
import java.util.Objects;
26+
27+
import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
28+
import org.elasticsearch.common.ParseField;
29+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
30+
import org.elasticsearch.common.xcontent.ToXContent;
31+
import org.elasticsearch.common.xcontent.XContentBuilder;
32+
import org.elasticsearch.common.xcontent.XContentParser;
33+
34+
/**
35+
* Multi-PreProcessor for chaining together multiple processors
36+
*/
37+
public class Multi implements PreProcessor {
38+
39+
public static final String NAME = "multi_encoding";
40+
public static final ParseField PROCESSORS = new ParseField("processors");
41+
public static final ParseField CUSTOM = new ParseField("custom");
42+
43+
@SuppressWarnings("unchecked")
44+
public static final ConstructingObjectParser<Multi, Void> PARSER = new ConstructingObjectParser<>(
45+
NAME,
46+
true,
47+
a -> new Multi((List<PreProcessor>)a[0], (Boolean)a[1]));
48+
static {
49+
PARSER.declareNamedObjects(ConstructingObjectParser.constructorArg(),
50+
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
51+
(_unused) -> {/* Does not matter client side*/ },
52+
PROCESSORS);
53+
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
54+
}
55+
56+
public static Multi fromXContent(XContentParser parser) {
57+
return PARSER.apply(parser, null);
58+
}
59+
60+
private final List<PreProcessor> processors;
61+
private final Boolean custom;
62+
63+
Multi(List<PreProcessor> processors, Boolean custom) {
64+
this.processors = Objects.requireNonNull(processors, PROCESSORS.getPreferredName());
65+
this.custom = custom;
66+
}
67+
68+
@Override
69+
public String getName() {
70+
return NAME;
71+
}
72+
73+
@Override
74+
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
75+
builder.startObject();
76+
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, PROCESSORS.getPreferredName(), processors);
77+
if (custom != null) {
78+
builder.field(CUSTOM.getPreferredName(), custom);
79+
}
80+
builder.endObject();
81+
return builder;
82+
}
83+
84+
@Override
85+
public boolean equals(Object o) {
86+
if (this == o) return true;
87+
if (o == null || getClass() != o.getClass()) return false;
88+
Multi multi = (Multi) o;
89+
return Objects.equals(multi.processors, processors) && Objects.equals(custom, multi.custom);
90+
}
91+
92+
@Override
93+
public int hashCode() {
94+
return Objects.hash(custom, processors);
95+
}
96+
97+
public static Builder builder(List<PreProcessor> processors) {
98+
return new Builder(processors);
99+
}
100+
101+
public static class Builder {
102+
private final List<PreProcessor> processors;
103+
private Boolean custom;
104+
105+
public Builder(List<PreProcessor> processors) {
106+
this.processors = processors;
107+
}
108+
109+
public Builder setCustom(boolean custom) {
110+
this.custom = custom;
111+
return this;
112+
}
113+
114+
public Multi build() {
115+
return new Multi(processors, custom);
116+
}
117+
}
118+
119+
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/NGram.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,12 @@
2424
import org.elasticsearch.common.xcontent.XContentParser;
2525

2626
import java.io.IOException;
27+
import java.util.ArrayList;
28+
import java.util.Collections;
2729
import java.util.List;
2830
import java.util.Objects;
31+
import java.util.function.IntFunction;
32+
import java.util.stream.IntStream;
2933

3034

3135
/**
@@ -134,6 +138,10 @@ public Boolean getCustom() {
134138
return custom;
135139
}
136140

141+
public List<String> outputFields() {
142+
return allPossibleNGramOutputFeatureNames();
143+
}
144+
137145
@Override
138146
public boolean equals(Object o) {
139147
if (this == o) return true;
@@ -152,6 +160,30 @@ public int hashCode() {
152160
return Objects.hash(field, featurePrefix, start, length, custom, nGrams);
153161
}
154162

163+
private String nGramFeature(int nGram, int pos) {
164+
return featurePrefix
165+
+ "."
166+
+ nGram
167+
+ pos;
168+
}
169+
170+
private List<String> allPossibleNGramOutputFeatureNames() {
171+
int totalNgrams = 0;
172+
for (int nGram : nGrams) {
173+
totalNgrams += (length - (nGram - 1));
174+
}
175+
if (totalNgrams <= 0) {
176+
return Collections.emptyList();
177+
}
178+
List<String> ngramOutputs = new ArrayList<>(totalNgrams);
179+
180+
for (int nGram : nGrams) {
181+
IntFunction<String> func = i -> nGramFeature(nGram, i);
182+
IntStream.range(0, (length - (nGram - 1))).mapToObj(func).forEach(ngramOutputs::add);
183+
}
184+
return ngramOutputs;
185+
}
186+
155187
public static Builder builder(String field) {
156188
return new Builder(field);
157189
}

client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats;
7777
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
7878
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
79+
import org.elasticsearch.client.ml.inference.preprocessing.Multi;
7980
import org.elasticsearch.client.ml.inference.preprocessing.NGram;
8081
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
8182
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
@@ -707,7 +708,7 @@ public void testDefaultNamedXContents() {
707708

708709
public void testProvidedNamedXContents() {
709710
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
710-
assertEquals(75, namedXContents.size());
711+
assertEquals(76, namedXContents.size());
711712
Map<Class<?>, Integer> categories = new HashMap<>();
712713
List<String> names = new ArrayList<>();
713714
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -792,9 +793,16 @@ public void testProvidedNamedXContents() {
792793
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
793794
registeredMetricName(Regression.NAME, HuberMetric.NAME),
794795
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
795-
assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
796+
assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
796797
assertThat(names,
797-
hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME, NGram.NAME));
798+
hasItems(
799+
FrequencyEncoding.NAME,
800+
OneHotEncoding.NAME,
801+
TargetMeanEncoding.NAME,
802+
CustomWordEmbedding.NAME,
803+
NGram.NAME,
804+
Multi.NAME
805+
));
798806
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
799807
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
800808
assertEquals(Integer.valueOf(4),

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
package org.elasticsearch.client.ml.inference;
2020

2121
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
22+
import org.elasticsearch.client.ml.inference.preprocessing.MultiTests;
23+
import org.elasticsearch.client.ml.inference.preprocessing.NGramTests;
2224
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
2325
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
2426
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
@@ -66,9 +68,12 @@ public static TrainedModelDefinition.Builder createRandomBuilder(TargetType targ
6668
return new TrainedModelDefinition.Builder()
6769
.setPreProcessors(
6870
randomBoolean() ? null :
69-
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
71+
Stream.generate(() -> randomFrom(
72+
FrequencyEncodingTests.createRandom(),
7073
OneHotEncodingTests.createRandom(),
71-
TargetMeanEncodingTests.createRandom()))
74+
TargetMeanEncodingTests.createRandom(),
75+
NGramTests.createRandom(),
76+
MultiTests.createRandom()))
7277
.limit(numberOfProcessors)
7378
.collect(Collectors.toList()))
7479
.setTrainedModel(randomFrom(TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 6, targetType),

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/FrequencyEncodingTests.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,19 @@ protected FrequencyEncoding createTestInstance() {
5050
}
5151

5252
public static FrequencyEncoding createRandom() {
53+
return createRandom(randomAlphaOfLength(10));
54+
}
55+
56+
public static FrequencyEncoding createRandom(String inputField) {
5357
int valuesSize = randomIntBetween(1, 10);
5458
Map<String, Double> valueMap = new HashMap<>();
5559
for (int i = 0; i < valuesSize; i++) {
5660
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
5761
}
58-
return new FrequencyEncoding(randomAlphaOfLength(10),
62+
return new FrequencyEncoding(inputField,
5963
randomAlphaOfLength(10),
6064
valueMap,
6165
randomBoolean() ? null : randomBoolean());
6266
}
67+
6368
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.inference.preprocessing;
20+
21+
import java.io.IOException;
22+
import java.util.ArrayList;
23+
import java.util.Arrays;
24+
import java.util.List;
25+
import java.util.function.Predicate;
26+
import java.util.stream.Collectors;
27+
import java.util.stream.Stream;
28+
29+
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
30+
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
31+
import org.elasticsearch.common.xcontent.XContentParser;
32+
import org.elasticsearch.test.AbstractXContentTestCase;
33+
34+
35+
public class MultiTests extends AbstractXContentTestCase<Multi> {
36+
37+
@Override
38+
protected Multi doParseInstance(XContentParser parser) throws IOException {
39+
return Multi.fromXContent(parser);
40+
}
41+
42+
@Override
43+
protected Predicate<String> getRandomFieldsExcludeFilter() {
44+
return field -> !field.isEmpty();
45+
}
46+
47+
@Override
48+
protected NamedXContentRegistry xContentRegistry() {
49+
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
50+
}
51+
52+
@Override
53+
protected boolean supportsUnknownFields() {
54+
return true;
55+
}
56+
57+
@Override
58+
protected Multi createTestInstance() {
59+
return createRandom();
60+
}
61+
62+
public static Multi createRandom() {
63+
final List<PreProcessor> processors;
64+
Boolean isCustom = randomBoolean() ? null : randomBoolean();
65+
if (isCustom == null || isCustom == false) {
66+
NGram nGram = new NGram(randomAlphaOfLength(10), Arrays.asList(1, 2), 0, 10, isCustom, "f");
67+
List<PreProcessor> preProcessorList = new ArrayList<>();
68+
preProcessorList.add(nGram);
69+
Stream.generate(() -> randomFrom(
70+
FrequencyEncodingTests.createRandom(randomFrom(nGram.outputFields())),
71+
TargetMeanEncodingTests.createRandom(randomFrom(nGram.outputFields())),
72+
OneHotEncodingTests.createRandom(randomFrom(nGram.outputFields()))
73+
)).limit(randomIntBetween(1, 10)).forEach(preProcessorList::add);
74+
processors = preProcessorList;
75+
} else {
76+
processors = Stream.generate(
77+
() -> randomFrom(
78+
FrequencyEncodingTests.createRandom(),
79+
TargetMeanEncodingTests.createRandom(),
80+
OneHotEncodingTests.createRandom(),
81+
NGramTests.createRandom()
82+
)
83+
).limit(randomIntBetween(1, 10)).collect(Collectors.toList());
84+
}
85+
return new Multi(processors, isCustom);
86+
}
87+
88+
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/NGramTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ protected NGram createTestInstance() {
4444
}
4545

4646
public static NGram createRandom() {
47+
int length = randomIntBetween(1, 10);
4748
return new NGram(randomAlphaOfLength(10),
48-
IntStream.range(1, 5).limit(5).boxed().collect(Collectors.toList()),
49+
IntStream.range(1, Math.min(5, length + 1)).limit(5).boxed().collect(Collectors.toList()),
4950
randomBoolean() ? null : randomIntBetween(0, 10),
50-
randomBoolean() ? null : randomIntBetween(1, 10),
51+
randomBoolean() ? null : length,
5152
randomBoolean() ? null : randomBoolean(),
5253
randomBoolean() ? null : randomAlphaOfLength(10));
5354
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/OneHotEncodingTests.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,18 @@ protected OneHotEncoding createTestInstance() {
5050
}
5151

5252
public static OneHotEncoding createRandom() {
53+
return createRandom(randomAlphaOfLength(10));
54+
}
55+
56+
public static OneHotEncoding createRandom(String inputField) {
5357
int valuesSize = randomIntBetween(1, 10);
5458
Map<String, String> valueMap = new HashMap<>();
5559
for (int i = 0; i < valuesSize; i++) {
5660
valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
5761
}
58-
return new OneHotEncoding(randomAlphaOfLength(10), valueMap, randomBoolean() ? null : randomBoolean());
62+
return new OneHotEncoding(inputField,
63+
valueMap,
64+
randomBoolean() ? null : randomBoolean());
5965
}
6066

6167
}

0 commit comments

Comments
 (0)