Skip to content

Commit 497c7a8

Browse files
committed
[ML] adds new feature_processors field for data frame analytics
feature_processors allow users to create custom features from individual document fields.
1 parent 0022907 commit 497c7a8

File tree

39 files changed

+1063
-172
lines changed

39 files changed

+1063
-172
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
import org.elasticsearch.common.xcontent.XContentBuilder;
1616
import org.elasticsearch.common.xcontent.XContentParser;
1717
import org.elasticsearch.index.mapper.FieldAliasMapper;
18+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
19+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
20+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
1821
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
1922
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
2023
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
2124
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
25+
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
2226

2327
import java.io.IOException;
2428
import java.util.Arrays;
@@ -46,6 +50,7 @@ public class Classification implements DataFrameAnalysis {
4650
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
4751
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
4852
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
53+
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
4954

5055
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
5156

@@ -59,6 +64,7 @@ public class Classification implements DataFrameAnalysis {
5964
*/
6065
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
6166

67+
@SuppressWarnings("unchecked")
6268
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
6369
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
6470
NAME.getPreferredName(),
@@ -70,14 +76,21 @@ private static ConstructingObjectParser<Classification, Void> createParser(boole
7076
(ClassAssignmentObjective) a[8],
7177
(Integer) a[9],
7278
(Double) a[10],
73-
(Long) a[11]));
79+
(Long) a[11],
80+
(List<PreProcessor>) a[12]));
7481
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
7582
BoostedTreeParams.declareFields(parser);
7683
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
7784
parser.declareString(optionalConstructorArg(), ClassAssignmentObjective::fromString, CLASS_ASSIGNMENT_OBJECTIVE);
7885
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
7986
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
8087
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
88+
parser.declareNamedObjects(optionalConstructorArg(),
89+
(p, c, n) -> lenient ?
90+
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
91+
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
92+
(classification) -> {/*TODO should we throw if this is not set?*/},
93+
FEATURE_PROCESSORS);
8194
return parser;
8295
}
8396

@@ -117,14 +130,16 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
117130
private final int numTopClasses;
118131
private final double trainingPercent;
119132
private final long randomizeSeed;
133+
private final List<PreProcessor> featureProcessors;
120134

121135
public Classification(String dependentVariable,
122136
BoostedTreeParams boostedTreeParams,
123137
@Nullable String predictionFieldName,
124138
@Nullable ClassAssignmentObjective classAssignmentObjective,
125139
@Nullable Integer numTopClasses,
126140
@Nullable Double trainingPercent,
127-
@Nullable Long randomizeSeed) {
141+
@Nullable Long randomizeSeed,
142+
@Nullable List<PreProcessor> featureProcessors) {
128143
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
129144
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
130145
}
@@ -139,10 +154,11 @@ public Classification(String dependentVariable,
139154
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
140155
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
141156
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
157+
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
142158
}
143159

144160
public Classification(String dependentVariable) {
145-
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
161+
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
146162
}
147163

148164
public Classification(StreamInput in) throws IOException {
@@ -161,6 +177,11 @@ public Classification(StreamInput in) throws IOException {
161177
} else {
162178
randomizeSeed = Randomness.get().nextLong();
163179
}
180+
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
181+
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
182+
} else {
183+
featureProcessors = Collections.emptyList();
184+
}
164185
}
165186

166187
public String getDependentVariable() {
@@ -191,6 +212,10 @@ public long getRandomizeSeed() {
191212
return randomizeSeed;
192213
}
193214

215+
public List<PreProcessor> getFeatureProcessors() {
216+
return featureProcessors;
217+
}
218+
194219
@Override
195220
public String getWriteableName() {
196221
return NAME.getPreferredName();
@@ -209,6 +234,9 @@ public void writeTo(StreamOutput out) throws IOException {
209234
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
210235
out.writeOptionalLong(randomizeSeed);
211236
}
237+
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
238+
out.writeNamedWriteableList(featureProcessors);
239+
}
212240
}
213241

214242
@Override
@@ -227,6 +255,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
227255
if (version.onOrAfter(Version.V_7_6_0)) {
228256
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
229257
}
258+
if (featureProcessors.isEmpty() == false) {
259+
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
260+
}
230261
builder.endObject();
231262
return builder;
232263
}
@@ -247,6 +278,10 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
247278
}
248279
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
249280
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
281+
if (featureProcessors.isEmpty() == false) {
282+
params.put(FEATURE_PROCESSORS.getPreferredName(),
283+
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
284+
}
250285
return params;
251286
}
252287

@@ -388,14 +423,15 @@ public boolean equals(Object o) {
388423
&& Objects.equals(predictionFieldName, that.predictionFieldName)
389424
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
390425
&& Objects.equals(numTopClasses, that.numTopClasses)
426+
&& Objects.equals(featureProcessors, that.featureProcessors)
391427
&& trainingPercent == that.trainingPercent
392428
&& randomizeSeed == that.randomizeSeed;
393429
}
394430

395431
@Override
396432
public int hashCode() {
397433
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
398-
numTopClasses, trainingPercent, randomizeSeed);
434+
numTopClasses, trainingPercent, randomizeSeed, featureProcessors);
399435
}
400436

401437
public enum ClassAssignmentObjective {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
import org.elasticsearch.common.xcontent.XContentBuilder;
1616
import org.elasticsearch.common.xcontent.XContentParser;
1717
import org.elasticsearch.index.mapper.NumberFieldMapper;
18+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
19+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
20+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
1821
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
1922
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
2023
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
24+
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
2125

2226
import java.io.IOException;
2327
import java.util.Arrays;
@@ -28,6 +32,7 @@
2832
import java.util.Map;
2933
import java.util.Objects;
3034
import java.util.Set;
35+
import java.util.stream.Collectors;
3136

3237
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
3338
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@@ -42,12 +47,14 @@ public class Regression implements DataFrameAnalysis {
4247
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
4348
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
4449
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
50+
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
4551

4652
private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";
4753

4854
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
4955
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
5056

57+
@SuppressWarnings("unchecked")
5158
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
5259
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(
5360
NAME.getPreferredName(),
@@ -59,14 +66,21 @@ private static ConstructingObjectParser<Regression, Void> createParser(boolean l
5966
(Double) a[8],
6067
(Long) a[9],
6168
(LossFunction) a[10],
62-
(Double) a[11]));
69+
(Double) a[11],
70+
(List<PreProcessor>) a[12]));
6371
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
6472
BoostedTreeParams.declareFields(parser);
6573
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
6674
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
6775
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
6876
parser.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION);
6977
parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
78+
parser.declareNamedObjects(optionalConstructorArg(),
79+
(p, c, n) -> lenient ?
80+
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
81+
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
82+
(regression) -> {/*TODO should we throw if this is not set?*/},
83+
FEATURE_PROCESSORS);
7084
return parser;
7185
}
7286

@@ -90,14 +104,16 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno
90104
private final long randomizeSeed;
91105
private final LossFunction lossFunction;
92106
private final Double lossFunctionParameter;
107+
private final List<PreProcessor> featureProcessors;
93108

94109
public Regression(String dependentVariable,
95110
BoostedTreeParams boostedTreeParams,
96111
@Nullable String predictionFieldName,
97112
@Nullable Double trainingPercent,
98113
@Nullable Long randomizeSeed,
99114
@Nullable LossFunction lossFunction,
100-
@Nullable Double lossFunctionParameter) {
115+
@Nullable Double lossFunctionParameter,
116+
@Nullable List<PreProcessor> featureProcessors) {
101117
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
102118
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
103119
}
@@ -112,10 +128,11 @@ public Regression(String dependentVariable,
112128
throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName());
113129
}
114130
this.lossFunctionParameter = lossFunctionParameter;
131+
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
115132
}
116133

117134
public Regression(String dependentVariable) {
118-
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
135+
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
119136
}
120137

121138
public Regression(StreamInput in) throws IOException {
@@ -126,6 +143,11 @@ public Regression(StreamInput in) throws IOException {
126143
randomizeSeed = in.readOptionalLong();
127144
lossFunction = in.readEnum(LossFunction.class);
128145
lossFunctionParameter = in.readOptionalDouble();
146+
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
147+
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
148+
} else {
149+
featureProcessors = Collections.emptyList();
150+
}
129151
}
130152

131153
public String getDependentVariable() {
@@ -156,6 +178,10 @@ public Double getLossFunctionParameter() {
156178
return lossFunctionParameter;
157179
}
158180

181+
public List<PreProcessor> getFeatureProcessors() {
182+
return featureProcessors;
183+
}
184+
159185
@Override
160186
public String getWriteableName() {
161187
return NAME.getPreferredName();
@@ -170,6 +196,9 @@ public void writeTo(StreamOutput out) throws IOException {
170196
out.writeOptionalLong(randomizeSeed);
171197
out.writeEnum(lossFunction);
172198
out.writeOptionalDouble(lossFunctionParameter);
199+
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
200+
out.writeNamedWriteableList(featureProcessors);
201+
}
173202
}
174203

175204
@Override
@@ -190,6 +219,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
190219
if (lossFunctionParameter != null) {
191220
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
192221
}
222+
if (featureProcessors.isEmpty() == false) {
223+
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
224+
}
193225
builder.endObject();
194226
return builder;
195227
}
@@ -207,6 +239,10 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
207239
if (lossFunctionParameter != null) {
208240
params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
209241
}
242+
if (featureProcessors.isEmpty() == false) {
243+
params.put(FEATURE_PROCESSORS.getPreferredName(),
244+
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
245+
}
210246
return params;
211247
}
212248

@@ -290,13 +326,14 @@ public boolean equals(Object o) {
290326
&& trainingPercent == that.trainingPercent
291327
&& randomizeSeed == that.randomizeSeed
292328
&& lossFunction == that.lossFunction
329+
&& Objects.equals(featureProcessors, that.featureProcessors)
293330
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
294331
}
295332

296333
@Override
297334
public int hashCode() {
298335
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
299-
lossFunctionParameter);
336+
lossFunctionParameter, featureProcessors);
300337
}
301338

302339
public enum LossFunction {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,23 +57,23 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
5757

5858
// PreProcessing Lenient
5959
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME,
60-
OneHotEncoding::fromXContentLenient));
60+
(p, c) -> OneHotEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
6161
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
62-
TargetMeanEncoding::fromXContentLenient));
62+
(p, c) -> TargetMeanEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
6363
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME,
64-
FrequencyEncoding::fromXContentLenient));
64+
(p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
6565
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
66-
CustomWordEmbedding::fromXContentLenient));
66+
(p, c) -> CustomWordEmbedding.fromXContentLenient(p)));
6767

6868
// PreProcessing Strict
6969
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
70-
OneHotEncoding::fromXContentStrict));
70+
(p, c) -> OneHotEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
7171
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
72-
TargetMeanEncoding::fromXContentStrict));
72+
(p, c) -> TargetMeanEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
7373
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME,
74-
FrequencyEncoding::fromXContentStrict));
74+
(p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
7575
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
76-
CustomWordEmbedding::fromXContentStrict));
76+
(p, c) -> CustomWordEmbedding.fromXContentStrict(p)));
7777

7878
// Model Lenient
7979
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(b
5656
TRAINED_MODEL);
5757
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
5858
(p, c, n) -> ignoreUnknownFields ?
59-
p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
60-
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
59+
p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT) :
60+
p.namedObject(StrictlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
6161
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
6262
PREPROCESSORS);
6363
return parser;

0 commit comments

Comments
 (0)