Skip to content

Commit

Permalink
Support SHOW MODEL FEATURE DETAILS
Browse files Browse the repository at this point in the history
Signed-off-by: Misiu Godfrey <misiu.godfrey@kraken.mapd.com>
  • Loading branch information
tmostak authored and misiugodfrey committed Sep 1, 2023
1 parent 226faca commit 5d05a5f
Show file tree
Hide file tree
Showing 8 changed files with 444 additions and 53 deletions.
132 changes: 131 additions & 1 deletion Catalog/DdlCommandExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,9 @@ ExecutionResult DdlCommandExecutor::execute(bool read_only_mode) {
result = ShowModelsCommand{*ddl_data_, session_ptr_}.execute(read_only_mode);
} else if (ddl_command_ == "SHOW_MODEL_DETAILS") {
result = ShowModelDetailsCommand{*ddl_data_, session_ptr_}.execute(read_only_mode);
} else if (ddl_command_ == "SHOW_MODEL_FEATURE_DETAILS") {
result =
ShowModelFeatureDetailsCommand{*ddl_data_, session_ptr_}.execute(read_only_mode);
} else if (ddl_command_ == "ALTER_SERVER") {
result = AlterForeignServerCommand{*ddl_data_, session_ptr_}.execute(read_only_mode);
} else if (ddl_command_ == "ALTER_DATABASE") {
Expand Down Expand Up @@ -2091,7 +2094,6 @@ ShowModelDetailsCommand::ShowModelDetailsCommand(
ExecutionResult ShowModelDetailsCommand::execute(bool read_only_mode) {
auto execute_read_lock = legacylockmgr::getExecuteReadLock();

std::vector<std::string> labels{"model_name", "model_type", "training_query"};
std::vector<TargetMetaInfo> label_infos;
label_infos.emplace_back("model_name", SQLTypeInfo(kTEXT, true));
label_infos.emplace_back("model_type", SQLTypeInfo(kTEXT, true));
Expand Down Expand Up @@ -2171,6 +2173,134 @@ std::vector<std::string> ShowModelDetailsCommand::getFilteredModelNames() {
}
}

ShowModelFeatureDetailsCommand::ShowModelFeatureDetailsCommand(
const DdlCommandData& ddl_data,
std::shared_ptr<Catalog_Namespace::SessionInfo const> session_ptr)
: DdlCommand(ddl_data, session_ptr) {
if (!g_enable_ml_functions) {
throw std::runtime_error(
"Cannot show model feature details. ML functions are disabled.");
}
if (g_restrict_ml_model_metadata_to_superusers) {
// Check if user is super user
const auto& current_user = session_ptr->get_currentUser();
if (!current_user.isSuper) {
throw std::runtime_error(
"Cannot show model feature details. Showing model information to "
"non-superusers is "
"disabled.");
}
}
}

ExecutionResult ShowModelFeatureDetailsCommand::execute(bool read_only_mode) {
auto execute_read_lock = legacylockmgr::getExecuteReadLock();
auto& ddl_payload = extractPayload(ddl_data_);
CHECK(ddl_payload.HasMember("modelName")) << "Model name missing.";
const auto model_name = ddl_payload["modelName"].GetString();
const auto model_metadata = g_ml_models.getModelMetadata(model_name);
const auto model_type = model_metadata.getModelType();
std::vector<TargetMetaInfo> label_infos;
label_infos.emplace_back("feature_id", SQLTypeInfo(kBIGINT, true));
label_infos.emplace_back("feature", SQLTypeInfo(kTEXT, true));
label_infos.emplace_back("sub_feature_id", SQLTypeInfo(kBIGINT, true));
label_infos.emplace_back("sub_feature", SQLTypeInfo(kTEXT, true));
std::vector<double> extra_metadata;
std::vector<std::vector<double>> eigenvectors;
auto features = model_metadata.getFeatures();
const auto model = g_ml_models.getModel(model_name);
auto cat_sub_features = model->getCatFeatureKeys();
switch (model_type) {
case MLModelType::LINEAR_REG: {
label_infos.emplace_back("coefficient", SQLTypeInfo(kDOUBLE, true));
const auto linear_reg_model =
std::dynamic_pointer_cast<LinearRegressionModel>(model);
extra_metadata = linear_reg_model->getCoefs();
break;
}
#ifdef HAVE_ONEDAL
case MLModelType::RANDOM_FOREST_REG: {
const auto random_forest_reg_model =
std::dynamic_pointer_cast<RandomForestRegressionModel>(model);
extra_metadata = random_forest_reg_model->getVariableImportanceScores();
if (!extra_metadata.empty()) {
label_infos.emplace_back("feature_importance", SQLTypeInfo(kDOUBLE, true));
}
break;
}
case MLModelType::PCA: {
label_infos.emplace_back("eigenvalue", SQLTypeInfo(kDOUBLE, true));
label_infos.emplace_back("eigenvector", SQLTypeInfo(kTEXT, true));
const auto pca_model = std::dynamic_pointer_cast<PcaModel>(model);
extra_metadata = pca_model->getEigenvalues();
eigenvectors = pca_model->getEigenvectors();
CHECK_EQ(eigenvectors.size(), extra_metadata.size());

break;
}
#endif // HAVE_ONEDAL
default: {
break;
}
}
const int64_t num_features = static_cast<int64_t>(features.size());
std::vector<RelLogicalValues::RowValues> logical_values;
if (model_type == MLModelType::LINEAR_REG) {
logical_values.emplace_back(RelLogicalValues::RowValues{});
logical_values.back().emplace_back(genLiteralBigInt(0));
logical_values.back().emplace_back(genLiteralStr("intercept"));
logical_values.back().emplace_back(genLiteralBigInt(1));
logical_values.back().emplace_back(genLiteralStr(""));
logical_values.back().emplace_back(genLiteralDouble(extra_metadata[0]));
extra_metadata.erase(extra_metadata.begin());
}
int64_t physical_feature_idx = 0;
for (int64_t feature_idx = 0; feature_idx < num_features; ++feature_idx) {
int64_t num_sub_features =
feature_idx >= static_cast<int64_t>(cat_sub_features.size())
? 0
: static_cast<int64_t>(cat_sub_features[feature_idx].size());
const bool has_sub_features = num_sub_features > 0;
num_sub_features = num_sub_features == 0 ? 1 : num_sub_features;
for (int64_t sub_feature_idx = 0; sub_feature_idx < num_sub_features;
++sub_feature_idx) {
logical_values.emplace_back(RelLogicalValues::RowValues{});
// Make feature id one-based
logical_values.back().emplace_back(genLiteralBigInt(feature_idx + 1));
logical_values.back().emplace_back(genLiteralStr(features[feature_idx]));
logical_values.back().emplace_back(genLiteralBigInt(sub_feature_idx + 1));
if (has_sub_features) {
logical_values.back().emplace_back(
genLiteralStr(cat_sub_features[feature_idx][sub_feature_idx]));
} else {
logical_values.back().emplace_back(genLiteralStr(""));
}
if (!extra_metadata.empty()) {
logical_values.back().emplace_back(
genLiteralDouble(extra_metadata[physical_feature_idx]));
}
if (!eigenvectors.empty()) {
std::ostringstream eigenvector_oss;
eigenvector_oss << "[";
for (size_t i = 0; i < eigenvectors[physical_feature_idx].size(); ++i) {
if (i > 0) {
eigenvector_oss << ", ";
}
eigenvector_oss << eigenvectors[physical_feature_idx][i];
}
eigenvector_oss << "]";
logical_values.back().emplace_back(genLiteralStr(eigenvector_oss.str()));
}
physical_feature_idx++;
}
}
// Create ResultSet
std::shared_ptr<ResultSet> rSet = std::shared_ptr<ResultSet>(
ResultSetLogicalValuesBuilder::create(label_infos, logical_values));

return ExecutionResult(rSet, label_infos);
}

EvaluateModelCommand::EvaluateModelCommand(
const DdlCommandData& ddl_data,
std::shared_ptr<Catalog_Namespace::SessionInfo const> session_ptr)
Expand Down
9 changes: 9 additions & 0 deletions Catalog/DdlCommandExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,15 @@ class ShowModelDetailsCommand : public DdlCommand {
std::vector<std::string> getFilteredModelNames();
};

class ShowModelFeatureDetailsCommand : public DdlCommand {
public:
ShowModelFeatureDetailsCommand(
const DdlCommandData& ddl_data,
std::shared_ptr<Catalog_Namespace::SessionInfo const> session_ptr);

ExecutionResult execute(bool read_only_mode) override;
};

class EvaluateModelCommand : public DdlCommand {
public:
EvaluateModelCommand(const DdlCommandData& ddl_data,
Expand Down
4 changes: 2 additions & 2 deletions QueryEngine/TableFunctions/SystemFunctions/os/ML/MLModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ class RandomForestRegressionModel : public virtual AbstractTreeModel {
double out_of_bag_error_;
};

#endif // #ifdef HAVE_ONEDAL

class PcaModel : public AbstractMLModel {
public:
PcaModel(const std::vector<double>& col_means,
Expand Down Expand Up @@ -374,6 +376,4 @@ class PcaModel : public AbstractMLModel {
std::vector<double> eigenvalues_;
};

#endif // #ifdef HAVE_ONEDAL

#endif // #ifndef __CUDACC__
5 changes: 5 additions & 0 deletions Tests/DBHandlerTestHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,11 @@ class DBHandlerTestFixture : public TestHelpers::TbbPrivateServerKiller {
size_t getRowCount(const TQueryResult& result) {
return getRowCount(result.row_set);
}

size_t getColumnCount(const TQueryResult& result) {
return getColumnCount(result.row_set);
}

std::vector<TDatum> getRow(const TQueryResult& result, const size_t index) {
return getRow(result.row_set, index);
}
Expand Down
Loading

0 comments on commit 5d05a5f

Please sign in to comment.