15
15
import org .elasticsearch .common .xcontent .XContentBuilder ;
16
16
import org .elasticsearch .common .xcontent .XContentParser ;
17
17
import org .elasticsearch .index .mapper .NumberFieldMapper ;
18
+ import org .elasticsearch .xpack .core .ml .inference .preprocessing .LenientlyParsedPreProcessor ;
19
+ import org .elasticsearch .xpack .core .ml .inference .preprocessing .PreProcessor ;
20
+ import org .elasticsearch .xpack .core .ml .inference .preprocessing .StrictlyParsedPreProcessor ;
18
21
import org .elasticsearch .xpack .core .ml .inference .trainedmodel .InferenceConfig ;
19
22
import org .elasticsearch .xpack .core .ml .inference .trainedmodel .RegressionConfig ;
20
23
import org .elasticsearch .xpack .core .ml .utils .ExceptionsHelper ;
24
+ import org .elasticsearch .xpack .core .ml .utils .NamedXContentObjectHelper ;
21
25
22
26
import java .io .IOException ;
23
27
import java .util .Arrays ;
28
32
import java .util .Map ;
29
33
import java .util .Objects ;
30
34
import java .util .Set ;
35
+ import java .util .stream .Collectors ;
31
36
32
37
import static org .elasticsearch .common .xcontent .ConstructingObjectParser .constructorArg ;
33
38
import static org .elasticsearch .common .xcontent .ConstructingObjectParser .optionalConstructorArg ;
@@ -42,12 +47,14 @@ public class Regression implements DataFrameAnalysis {
42
47
public static final ParseField RANDOMIZE_SEED = new ParseField ("randomize_seed" );
43
48
public static final ParseField LOSS_FUNCTION = new ParseField ("loss_function" );
44
49
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField ("loss_function_parameter" );
50
+ public static final ParseField FEATURE_PROCESSORS = new ParseField ("feature_processors" );
45
51
46
52
private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1" ;
47
53
48
54
private static final ConstructingObjectParser <Regression , Void > LENIENT_PARSER = createParser (true );
49
55
private static final ConstructingObjectParser <Regression , Void > STRICT_PARSER = createParser (false );
50
56
57
+ @ SuppressWarnings ("unchecked" )
51
58
private static ConstructingObjectParser <Regression , Void > createParser (boolean lenient ) {
52
59
ConstructingObjectParser <Regression , Void > parser = new ConstructingObjectParser <>(
53
60
NAME .getPreferredName (),
@@ -59,14 +66,21 @@ private static ConstructingObjectParser<Regression, Void> createParser(boolean l
59
66
(Double ) a [8 ],
60
67
(Long ) a [9 ],
61
68
(LossFunction ) a [10 ],
62
- (Double ) a [11 ]));
69
+ (Double ) a [11 ],
70
+ (List <PreProcessor >) a [12 ]));
63
71
parser .declareString (constructorArg (), DEPENDENT_VARIABLE );
64
72
BoostedTreeParams .declareFields (parser );
65
73
parser .declareString (optionalConstructorArg (), PREDICTION_FIELD_NAME );
66
74
parser .declareDouble (optionalConstructorArg (), TRAINING_PERCENT );
67
75
parser .declareLong (optionalConstructorArg (), RANDOMIZE_SEED );
68
76
parser .declareString (optionalConstructorArg (), LossFunction ::fromString , LOSS_FUNCTION );
69
77
parser .declareDouble (optionalConstructorArg (), LOSS_FUNCTION_PARAMETER );
78
+ parser .declareNamedObjects (optionalConstructorArg (),
79
+ (p , c , n ) -> lenient ?
80
+ p .namedObject (LenientlyParsedPreProcessor .class , n , new PreProcessor .PreProcessorParseContext (true )) :
81
+ p .namedObject (StrictlyParsedPreProcessor .class , n , new PreProcessor .PreProcessorParseContext (true )),
82
+ (regression ) -> {/*TODO should we throw if this is not set?*/ },
83
+ FEATURE_PROCESSORS );
70
84
return parser ;
71
85
}
72
86
@@ -90,14 +104,16 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno
90
104
private final long randomizeSeed ;
91
105
private final LossFunction lossFunction ;
92
106
private final Double lossFunctionParameter ;
107
+ private final List <PreProcessor > featureProcessors ;
93
108
94
109
public Regression (String dependentVariable ,
95
110
BoostedTreeParams boostedTreeParams ,
96
111
@ Nullable String predictionFieldName ,
97
112
@ Nullable Double trainingPercent ,
98
113
@ Nullable Long randomizeSeed ,
99
114
@ Nullable LossFunction lossFunction ,
100
- @ Nullable Double lossFunctionParameter ) {
115
+ @ Nullable Double lossFunctionParameter ,
116
+ @ Nullable List <PreProcessor > featureProcessors ) {
101
117
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0 )) {
102
118
throw ExceptionsHelper .badRequestException ("[{}] must be a double in [1, 100]" , TRAINING_PERCENT .getPreferredName ());
103
119
}
@@ -112,10 +128,11 @@ public Regression(String dependentVariable,
112
128
throw ExceptionsHelper .badRequestException ("[{}] must be a positive double" , LOSS_FUNCTION_PARAMETER .getPreferredName ());
113
129
}
114
130
this .lossFunctionParameter = lossFunctionParameter ;
131
+ this .featureProcessors = featureProcessors == null ? Collections .emptyList () : Collections .unmodifiableList (featureProcessors );
115
132
}
116
133
117
134
public Regression (String dependentVariable ) {
118
- this (dependentVariable , BoostedTreeParams .builder ().build (), null , null , null , null , null );
135
+ this (dependentVariable , BoostedTreeParams .builder ().build (), null , null , null , null , null , null );
119
136
}
120
137
121
138
public Regression (StreamInput in ) throws IOException {
@@ -126,6 +143,11 @@ public Regression(StreamInput in) throws IOException {
126
143
randomizeSeed = in .readOptionalLong ();
127
144
lossFunction = in .readEnum (LossFunction .class );
128
145
lossFunctionParameter = in .readOptionalDouble ();
146
+ if (in .getVersion ().onOrAfter (Version .V_8_0_0 )) {
147
+ featureProcessors = Collections .unmodifiableList (in .readNamedWriteableList (PreProcessor .class ));
148
+ } else {
149
+ featureProcessors = Collections .emptyList ();
150
+ }
129
151
}
130
152
131
153
public String getDependentVariable () {
@@ -156,6 +178,10 @@ public Double getLossFunctionParameter() {
156
178
return lossFunctionParameter ;
157
179
}
158
180
181
+ public List <PreProcessor > getFeatureProcessors () {
182
+ return featureProcessors ;
183
+ }
184
+
159
185
@ Override
160
186
public String getWriteableName () {
161
187
return NAME .getPreferredName ();
@@ -170,6 +196,9 @@ public void writeTo(StreamOutput out) throws IOException {
170
196
out .writeOptionalLong (randomizeSeed );
171
197
out .writeEnum (lossFunction );
172
198
out .writeOptionalDouble (lossFunctionParameter );
199
+ if (out .getVersion ().onOrAfter (Version .V_8_0_0 )) {
200
+ out .writeNamedWriteableList (featureProcessors );
201
+ }
173
202
}
174
203
175
204
@ Override
@@ -190,6 +219,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
190
219
if (lossFunctionParameter != null ) {
191
220
builder .field (LOSS_FUNCTION_PARAMETER .getPreferredName (), lossFunctionParameter );
192
221
}
222
+ if (featureProcessors .isEmpty () == false ) {
223
+ NamedXContentObjectHelper .writeNamedObjects (builder , params , true , FEATURE_PROCESSORS .getPreferredName (), featureProcessors );
224
+ }
193
225
builder .endObject ();
194
226
return builder ;
195
227
}
@@ -207,6 +239,10 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
207
239
if (lossFunctionParameter != null ) {
208
240
params .put (LOSS_FUNCTION_PARAMETER .getPreferredName (), lossFunctionParameter );
209
241
}
242
+ if (featureProcessors .isEmpty () == false ) {
243
+ params .put (FEATURE_PROCESSORS .getPreferredName (),
244
+ featureProcessors .stream ().map (p -> Collections .singletonMap (p .getName (), p )).collect (Collectors .toList ()));
245
+ }
210
246
return params ;
211
247
}
212
248
@@ -290,13 +326,14 @@ public boolean equals(Object o) {
290
326
&& trainingPercent == that .trainingPercent
291
327
&& randomizeSeed == that .randomizeSeed
292
328
&& lossFunction == that .lossFunction
329
+ && Objects .equals (featureProcessors , that .featureProcessors )
293
330
&& Objects .equals (lossFunctionParameter , that .lossFunctionParameter );
294
331
}
295
332
296
333
@ Override
297
334
public int hashCode () {
298
335
return Objects .hash (dependentVariable , boostedTreeParams , predictionFieldName , trainingPercent , randomizeSeed , lossFunction ,
299
- lossFunctionParameter );
336
+ lossFunctionParameter , featureProcessors );
300
337
}
301
338
302
339
public enum LossFunction {
0 commit comments