Skip to content

[7.x] [ML] calculate feature importance for multi-class results (#1071) #1075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ model training. (See {ml-pull}1034[#1034].)
* Add instrumentation information for supervised learning data frame analytics jobs.
(See {ml-pull}1031[#1031].)
* Add instrumentation information for outlier detection data frame analytics jobs.
(See {ml-pull}1068[#1068].)
* Write out feature importance for multi-class models. (See {ml-pull}1071[#1071])

=== Bug Fixes

Expand Down
6 changes: 6 additions & 0 deletions include/api/CDataFrameTrainBoostedTreeRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class API_EXPORT CDataFrameTrainBoostedTreeRunner : public CDataFrameAnalysisRun
static const std::string BAYESIAN_OPTIMISATION_RESTARTS;
static const std::string NUM_TOP_FEATURE_IMPORTANCE_VALUES;

//Output
static const std::string IS_TRAINING_FIELD_NAME;
static const std::string FEATURE_NAME_FIELD_NAME;
static const std::string IMPORTANCE_FIELD_NAME;
static const std::string FEATURE_IMPORTANCE_FIELD_NAME;

public:
~CDataFrameTrainBoostedTreeRunner() override;

Expand Down
3 changes: 0 additions & 3 deletions include/maths/CTreeShapFeatureImportance.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
using TShapWriter =
std::function<void(const TSizeVec&, const TStrVec&, const TVectorVec&)>;

public:
static const std::string SHAP_PREFIX;

public:
CTreeShapFeatureImportance(const core::CDataFrame& frame,
const CDataFrameCategoryEncoder& encoder,
Expand Down
2 changes: 2 additions & 0 deletions include/test/CDataFrameAnalysisSpecificationFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TEST_EXPORT CDataFrameAnalysisSpecificationFactory {

// Classification
CDataFrameAnalysisSpecificationFactory& numberClasses(std::size_t number);
CDataFrameAnalysisSpecificationFactory& numberTopClasses(std::size_t number);
CDataFrameAnalysisSpecificationFactory& predictionFieldType(const std::string& type);

std::string outlierParams() const;
Expand Down Expand Up @@ -117,6 +118,7 @@ class TEST_EXPORT CDataFrameAnalysisSpecificationFactory {
TRestoreSearcherSupplier* m_RestoreSearcherSupplier = nullptr;
// Classification
std::size_t m_NumberClasses = 2;
std::size_t m_NumberTopClasses = 0;
std::string m_PredictionFieldType;
};
}
Expand Down
28 changes: 22 additions & 6 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,32 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(

if (featureImportance != nullptr) {
featureImportance->shap(
row, [&writer](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& names,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
row, [&writer, &classValues](
const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& names,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
writer.Key(CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
for (auto i : indices) {
if (shap[i].norm() != 0.0) {
writer.Key(names[i]);
// TODO fixme
writer.Double(shap[i](0));
writer.StartObject();
writer.Key(CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME);
writer.String(names[i]);
if (shap[i].size() == 1) {
writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME);
writer.Double(shap[i](0));
} else {
for (int j = 0; j < shap[i].size(); ++j) {
writer.Key(classValues[j]);
writer.Double(shap[i](j));
}
writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME);
writer.Double(shap[i].lpNorm<1>());
}
writer.EndObject();
}
}
writer.EndArray();
});
}
writer.EndObject();
Expand Down
9 changes: 8 additions & 1 deletion lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,19 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow(
row, [&writer](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& names,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
writer.Key(CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
for (auto i : indices) {
if (shap[i].norm() != 0.0) {
writer.Key(names[i]);
writer.StartObject();
writer.Key(CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME);
writer.String(names[i]);
writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME);
writer.Double(shap[i](0));
writer.EndObject();
}
}
writer.EndArray();
});
}
writer.EndObject();
Expand Down
4 changes: 4 additions & 0 deletions lib/api/CDataFrameTrainBoostedTreeRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ const std::string CDataFrameTrainBoostedTreeRunner::STOP_CROSS_VALIDATION_EARLY{
const std::string CDataFrameTrainBoostedTreeRunner::NUMBER_ROUNDS_PER_HYPERPARAMETER{"number_rounds_per_hyperparameter"};
const std::string CDataFrameTrainBoostedTreeRunner::BAYESIAN_OPTIMISATION_RESTARTS{"bayesian_optimisation_restarts"};
const std::string CDataFrameTrainBoostedTreeRunner::NUM_TOP_FEATURE_IMPORTANCE_VALUES{"num_top_feature_importance_values"};
const std::string CDataFrameTrainBoostedTreeRunner::IS_TRAINING_FIELD_NAME{"is_training"};
const std::string CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME{"feature_name"};
const std::string CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME{"importance"};
const std::string CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME{"feature_importance"};
// clang-format on
}
}
203 changes: 185 additions & 18 deletions lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@

#include <maths/CBasicStatistics.h>
#include <maths/CDataFramePredictiveModel.h>
#include <maths/CSampling.h>
#include <maths/CTools.h>
#include <maths/CToolsDetail.h>
#include <maths/CTreeShapFeatureImportance.h>

#include <api/CDataFrameAnalyzer.h>
#include <api/CDataFrameTrainBoostedTreeRunner.h>

#include <test/CDataFrameAnalysisSpecificationFactory.h>
#include <test/CRandomNumbers.h>
Expand All @@ -27,12 +30,14 @@ using namespace ml;

namespace {
using TDoubleVec = std::vector<double>;
using TVector = maths::CDenseVector<double>;
using TStrVec = std::vector<std::string>;
using TRowItr = core::CDataFrame::TRowItr;
using TRowRef = core::CDataFrame::TRowRef;
using TMeanAccumulator = maths::CBasicStatistics::SSampleMean<double>::TAccumulator;
using TMeanAccumulatorVec = std::vector<TMeanAccumulator>;
using TMeanVarAccumulator = maths::CBasicStatistics::SSampleMeanVar<double>::TAccumulator;
using TMemoryMappedMatrix = maths::CMemoryMappedDenseMatrix<double>;

void setupLinearRegressionData(const TStrVec& fieldNames,
TStrVec& fieldValues,
Expand Down Expand Up @@ -128,6 +133,47 @@ void setupBinaryClassificationData(const TStrVec& fieldNames,
}
}

void setupMultiClassClassificationData(const TStrVec& fieldNames,
TStrVec& fieldValues,
api::CDataFrameAnalyzer& analyzer,
const TDoubleVec& weights,
const TDoubleVec& values) {
TStrVec classes{"foo", "bar", "baz"};
maths::CPRNG::CXorOShiro128Plus rng;
std::uniform_real_distribution<double> u01;
int numberFeatures{static_cast<int>(weights.size())};
TDoubleVec w{weights};
int numberClasses{static_cast<int>(classes.size())};
auto probability = [&](const TDoubleVec& row) {
TMemoryMappedMatrix W(&w[0], numberClasses, numberFeatures);
TVector x(numberFeatures);
for (int i = 0; i < numberFeatures; ++i) {
x(i) = row[i];
}
TVector logit{W * x};
return maths::CTools::softmax(std::move(logit));
};
auto target = [&](const TDoubleVec& row) {
TDoubleVec probabilities{probability(row).to<TDoubleVec>()};
return classes[maths::CSampling::categoricalSample(rng, probabilities)];
};

for (std::size_t i = 0; i < values.size(); i += weights.size()) {
TDoubleVec row(weights.size());
for (std::size_t j = 0; j < weights.size(); ++j) {
row[j] = values[i + j];
}

fieldValues[0] = target(row);
for (std::size_t j = 0; j < row.size(); ++j) {
fieldValues[j + 1] = core::CStringUtils::typeToStringPrecise(
row[j], core::CIEEE754::E_DoublePrecision);
}

analyzer.handleRecord(fieldNames, fieldValues);
}
}

struct SFixture {
rapidjson::Document
runRegression(std::size_t shapValues, TDoubleVec weights, double noiseVar = 0.0) {
Expand Down Expand Up @@ -231,6 +277,57 @@ struct SFixture {
return results;
}

rapidjson::Document runMultiClassClassification(std::size_t shapValues,
TDoubleVec&& weights) {
auto outputWriterFactory = [&]() {
return std::make_unique<core::CJsonOutputStreamWrapper>(s_Output);
};
test::CDataFrameAnalysisSpecificationFactory specFactory;
api::CDataFrameAnalyzer analyzer{
specFactory.rows(s_Rows)
.memoryLimit(26000000)
.predictionCategoricalFieldNames({"target"})
.predictionAlpha(s_Alpha)
.predictionLambda(s_Lambda)
.predictionGamma(s_Gamma)
.predictionSoftTreeDepthLimit(s_SoftTreeDepthLimit)
.predictionSoftTreeDepthTolerance(s_SoftTreeDepthTolerance)
.predictionEta(s_Eta)
.predictionMaximumNumberTrees(s_MaximumNumberTrees)
.predictionFeatureBagFraction(s_FeatureBagFraction)
.predictionNumberTopShapValues(shapValues)
.numberClasses(3)
.numberTopClasses(3)
.predictionSpec(test::CDataFrameAnalysisSpecificationFactory::classification(), "target"),
outputWriterFactory};
TStrVec fieldNames{"target", "c1", "c2", "c3", "c4", ".", "."};
TStrVec fieldValues{"", "", "", "", "", "0", ""};
test::CRandomNumbers rng;

TDoubleVec values;
rng.generateUniformSamples(-10.0, 10.0, weights.size() * s_Rows, values);

setupMultiClassClassificationData(fieldNames, fieldValues, analyzer, weights, values);

analyzer.handleRecord(fieldNames, {"", "", "", "", "", "", "$"});

LOG_DEBUG(<< "estimated memory usage = "
<< core::CProgramCounters::counter(counter_t::E_DFTPMEstimatedPeakMemoryUsage));
LOG_DEBUG(<< "peak memory = "
<< core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage));
LOG_DEBUG(<< "time to train = " << core::CProgramCounters::counter(counter_t::E_DFTPMTimeToTrain)
<< "ms");

BOOST_TEST_REQUIRE(
core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage) <
core::CProgramCounters::counter(counter_t::E_DFTPMEstimatedPeakMemoryUsage));

rapidjson::Document results;
rapidjson::ParseResult ok(results.Parse(s_Output.str()));
BOOST_TEST_REQUIRE(static_cast<bool>(ok) == true);
return results;
}

rapidjson::Document runRegressionWithMissingFeatures(std::size_t shapValues) {
auto outputWriterFactory = [&]() {
return std::make_unique<core::CJsonOutputStreamWrapper>(s_Output);
Expand Down Expand Up @@ -289,9 +386,48 @@ struct SFixture {

template<typename RESULTS>
double readShapValue(const RESULTS& results, std::string shapField) {
shapField = maths::CTreeShapFeatureImportance::SHAP_PREFIX + shapField;
if (results["row_results"]["results"]["ml"].HasMember(shapField)) {
return results["row_results"]["results"]["ml"][shapField].GetDouble();
if (results["row_results"]["results"]["ml"].HasMember(
api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME)) {
for (const auto& shapResult :
results["row_results"]["results"]["ml"][api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME]
.GetArray()) {
if (shapResult[api::CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME]
.GetString() == shapField) {
return shapResult[api::CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME]
.GetDouble();
}
}
}
return 0.0;
}

template<typename RESULTS>
double readShapValue(const RESULTS& results, std::string shapField, std::string className) {
if (results["row_results"]["results"]["ml"].HasMember(
api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME)) {
for (const auto& shapResult :
results["row_results"]["results"]["ml"][api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME]
.GetArray()) {
if (shapResult[api::CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME]
.GetString() == shapField) {
if (shapResult.HasMember(className)) {
return shapResult[className].GetDouble();
}
}
}
}
return 0.0;
}

template<typename RESULTS>
double readClassProbability(const RESULTS& results, std::string className) {
if (results["row_results"]["results"]["ml"].HasMember("top_classes")) {
for (const auto& topClasses :
results["row_results"]["results"]["ml"]["top_classes"].GetArray()) {
if (topClasses["class_name"].GetString() == className) {
return topClasses["class_probability"].GetDouble();
}
}
}
return 0.0;
}
Expand Down Expand Up @@ -324,9 +460,7 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) {
c3Sum += std::fabs(c3);
c4Sum += std::fabs(c4);
// assert that no SHAP value for the dependent variable is returned
BOOST_TEST_REQUIRE(result["row_results"]["results"]["ml"].HasMember(
maths::CTreeShapFeatureImportance::SHAP_PREFIX +
"target") == false);
BOOST_REQUIRE_EQUAL(readShapValue(result, "target"), 0.0);
}
}

Expand Down Expand Up @@ -421,25 +555,58 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) {
BOOST_REQUIRE_SMALL(maths::CBasicStatistics::variance(bias), 1e-6);
}

BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SFixture) {

std::size_t topShapValues{4};
auto results{runMultiClassClassification(topShapValues, {0.5, -0.7, 0.2, -0.2})};

for (const auto& result : results.GetArray()) {
if (result.HasMember("row_results")) {
double c1Sum{readShapValue(result, "c1")};
double c2Sum{readShapValue(result, "c2")};
double c3Sum{readShapValue(result, "c3")};
double c4Sum{readShapValue(result, "c4")};
// We should have at least one feature that is important
BOOST_TEST_REQUIRE((c1Sum > 0.0 || c2Sum > 0.0 || c3Sum > 0.0 || c4Sum > 0.0));

// class shap values should sum(abs()) to the overall feature importance
double c1f{readShapValue(result, "c1", "foo")};
double c1bar{readShapValue(result, "c1", "bar")};
double c1baz{readShapValue(result, "c1", "baz")};
BOOST_REQUIRE_CLOSE(
c1Sum, std::abs(c1f) + std::abs(c1bar) + std::abs(c1baz), 1e-6);

double c2f{readShapValue(result, "c2", "foo")};
double c2bar{readShapValue(result, "c2", "bar")};
double c2baz{readShapValue(result, "c2", "baz")};
BOOST_REQUIRE_CLOSE(
c2Sum, std::abs(c2f) + std::abs(c2bar) + std::abs(c2baz), 1e-6);

double c3f{readShapValue(result, "c3", "foo")};
double c3bar{readShapValue(result, "c3", "bar")};
double c3baz{readShapValue(result, "c3", "baz")};
BOOST_REQUIRE_CLOSE(
c3Sum, std::abs(c3f) + std::abs(c3bar) + std::abs(c3baz), 1e-6);

double c4f{readShapValue(result, "c4", "foo")};
double c4bar{readShapValue(result, "c4", "bar")};
double c4baz{readShapValue(result, "c4", "baz")};
BOOST_REQUIRE_CLOSE(
c4Sum, std::abs(c4f) + std::abs(c4bar) + std::abs(c4baz), 1e-6);
}
}
}

BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoShap, SFixture) {
// Test that if topShapValue is set to 0, no feature importance values are returned.
std::size_t topShapValues{0};
auto results{runRegression(topShapValues, {50.0, 150.0, 50.0, -50.0})};

for (const auto& result : results.GetArray()) {
if (result.HasMember("row_results")) {
BOOST_TEST_REQUIRE(
result["row_results"]["results"]["ml"].HasMember(
maths::CTreeShapFeatureImportance::SHAP_PREFIX + "c1") == false);
BOOST_TEST_REQUIRE(
result["row_results"]["results"]["ml"].HasMember(
maths::CTreeShapFeatureImportance::SHAP_PREFIX + "c2") == false);
BOOST_TEST_REQUIRE(
result["row_results"]["results"]["ml"].HasMember(
maths::CTreeShapFeatureImportance::SHAP_PREFIX + "c3") == false);
BOOST_TEST_REQUIRE(
result["row_results"]["results"]["ml"].HasMember(
maths::CTreeShapFeatureImportance::SHAP_PREFIX + "c4") == false);
BOOST_TEST_REQUIRE(result["row_results"]["results"]["ml"].HasMember(
api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME) ==
false);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions lib/api/unittest/CDataFrameAnalyzerTrainingTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ BOOST_AUTO_TEST_CASE(testRunBoostedTreeClassifierTraining) {
api::CDataFrameAnalyzer analyzer{
specFactory.memoryLimit(6000000)
.predictionCategoricalFieldNames({"target"})
.numberTopClasses(1)
.predictionSpec(test::CDataFrameAnalysisSpecificationFactory::classification(), "target"),
outputWriterFactory};
test::CDataFrameAnalyzerTrainingFactory::addPredictionTestData(
Expand Down
Loading