Skip to content

Commit f670c04

Browse files
authored
Blacklist a number of prediction field names. (#861) (#868)
1 parent 7275f29 commit f670c04

6 files changed

+82
-0
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[
4646

4747
=== Bug Fixes
4848
* Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].)
49+
* Prevent prediction_field_name clashing with other fields in ml results.
50+
(See {ml-pull}861[#861].)
4951

5052

5153
== {es} version 7.5.0

lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier
6868
this->dependentVariableFieldName()) == categoricalFieldNames.end()) {
6969
HANDLE_FATAL(<< "Input error: trying to perform classification with numeric target.");
7070
}
71+
const std::set<std::string> predictionFieldNameBlacklist{
72+
IS_TRAINING_FIELD_NAME, PREDICTION_PROBABILITY_FIELD_NAME, TOP_CLASSES_FIELD_NAME};
73+
if (predictionFieldNameBlacklist.count(this->predictionFieldName()) > 0) {
74+
HANDLE_FATAL(<< "Input error: prediction_field_name must not be equal to any of "
75+
<< core::CContainerPrinter::print(predictionFieldNameBlacklist)
76+
<< ".");
77+
}
7178
}
7279

7380
CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifierRunner(

lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ CDataFrameTrainBoostedTreeRegressionRunner::CDataFrameTrainBoostedTreeRegression
5252
this->dependentVariableFieldName()) != categoricalFieldNames.end()) {
5353
HANDLE_FATAL(<< "Input error: trying to perform regression with categorical target.");
5454
}
55+
const std::set<std::string> predictionFieldNameBlacklist{IS_TRAINING_FIELD_NAME};
56+
if (predictionFieldNameBlacklist.count(this->predictionFieldName()) > 0) {
57+
HANDLE_FATAL(<< "Input error: prediction_field_name must not be equal to any of "
58+
<< core::CContainerPrinter::print(predictionFieldNameBlacklist)
59+
<< ".");
60+
}
5561
}
5662

5763
CDataFrameTrainBoostedTreeRegressionRunner::CDataFrameTrainBoostedTreeRegressionRunner(

lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,26 @@ using TStrVec = std::vector<std::string>;
2525
using TStrVecVec = std::vector<TStrVec>;
2626
}
2727

28+
BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
29+
TStrVec errors;
30+
auto errorHandler = [&errors](std::string error) { errors.push_back(error); };
31+
core::CLogger::CScopeSetFatalErrorHandler scope{errorHandler};
32+
33+
const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
34+
"classification", "dep_var", 5, 6, 13000000, 0, 0, {"dep_var"})};
35+
rapidjson::Document jsonParameters;
36+
jsonParameters.Parse("{"
37+
" \"dependent_variable\": \"dep_var\","
38+
" \"prediction_field_name\": \"is_training\""
39+
"}");
40+
const auto parameters{
41+
api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)};
42+
api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters);
43+
44+
BOOST_TEST_REQUIRE(errors.size() == 1);
45+
BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes].");
46+
}
47+
2848
BOOST_AUTO_TEST_CASE(testWriteOneRow) {
2949
// Prepare input data frame
3050
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"};
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
7+
#include <core/CDataFrame.h>
8+
9+
#include <api/CDataFrameAnalysisConfigReader.h>
10+
#include <api/CDataFrameTrainBoostedTreeRegressionRunner.h>
11+
12+
#include <test/CDataFrameAnalysisSpecificationFactory.h>
13+
14+
#include <boost/test/unit_test.hpp>
15+
16+
#include <string>
17+
#include <vector>
18+
19+
BOOST_AUTO_TEST_SUITE(CDataFrameTrainBoostedTreeRegressionRunnerTest)
20+
21+
using namespace ml;
22+
namespace {
23+
using TStrVec = std::vector<std::string>;
24+
}
25+
26+
BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
27+
TStrVec errors;
28+
auto errorHandler = [&errors](std::string error) { errors.push_back(error); };
29+
core::CLogger::CScopeSetFatalErrorHandler scope{errorHandler};
30+
31+
const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
32+
"regression", "dep_var", 5, 6, 13000000, 0, 0)};
33+
rapidjson::Document jsonParameters;
34+
jsonParameters.Parse("{"
35+
" \"dependent_variable\": \"dep_var\","
36+
" \"prediction_field_name\": \"is_training\""
37+
"}");
38+
const auto parameters{
39+
api::CDataFrameTrainBoostedTreeRegressionRunner::parameterReader().read(jsonParameters)};
40+
api::CDataFrameTrainBoostedTreeRegressionRunner runner(*spec, parameters);
41+
42+
BOOST_TEST_REQUIRE(errors.size() == 1);
43+
BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training].");
44+
}
45+
46+
BOOST_AUTO_TEST_SUITE_END()

lib/api/unittest/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ SRCS=\
3131
CDataFrameAnalyzerOutlierTest.cc \
3232
CDataFrameAnalyzerTrainingTest.cc \
3333
CDataFrameTrainBoostedTreeClassifierRunnerTest.cc \
34+
CDataFrameTrainBoostedTreeRegressionRunnerTest.cc \
3435
CDataFrameMockAnalysisRunner.cc \
3536
CDetectionRulesJsonParserTest.cc \
3637
CFieldConfigTest.cc \

0 commit comments

Comments
 (0)