Skip to content

Commit

Permalink
EVALUATE MODEL
Browse files Browse the repository at this point in the history
Signed-off-by: Misiu Godfrey <misiu.godfrey@kraken.mapd.com>
  • Loading branch information
tmostak authored and misiugodfrey committed Aug 28, 2023
1 parent 6525f2d commit 047af0e
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 3 deletions.
61 changes: 61 additions & 0 deletions Catalog/DdlCommandExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ std::unique_ptr<RexLiteral> genLiteralBigInt(int64_t val) {
new RexLiteral(val, SQLTypes::kBIGINT, SQLTypes::kBIGINT, 0, 8, 0, 8));
}

std::unique_ptr<RexLiteral> genLiteralDouble(double val) {
return std::unique_ptr<RexLiteral>(
new RexLiteral(val, SQLTypes::kDOUBLE, SQLTypes::kDOUBLE, 0, 8, 0, 8));
}

std::unique_ptr<RexLiteral> genLiteralBoolean(bool val) {
return std::unique_ptr<RexLiteral>(
// new RexLiteral(val, SQLTypes::kBOOLEAN, SQLTypes::kBOOLEAN, 0, 0, 0, 0));
Expand Down Expand Up @@ -613,6 +618,8 @@ ExecutionResult DdlCommandExecutor::execute(bool read_only_mode) {
auto drop_model_stmt = Parser::DropModelStmt(extractPayload(*ddl_data_));
drop_model_stmt.execute(*session_ptr_, read_only_mode);
return result;
} else if (ddl_command_ == "EVALUATE_MODEL") {
result = EvaluateModelCommand{*ddl_data_, session_ptr_}.execute(read_only_mode);
} else if (ddl_command_ == "VALIDATE_SYSTEM") {
// VALIDATE should have been excuted in outer context before it reaches here
UNREACHABLE();
Expand Down Expand Up @@ -2053,6 +2060,60 @@ ExecutionResult ShowModelsCommand::execute(bool read_only_mode) {
return ExecutionResult(rSet, label_infos);
}

EvaluateModelCommand::EvaluateModelCommand(
const DdlCommandData& ddl_data,
std::shared_ptr<Catalog_Namespace::SessionInfo const> session_ptr)
: DdlCommand(ddl_data, session_ptr) {}

ExecutionResult EvaluateModelCommand::execute(bool read_only_mode) {
auto execute_read_lock = legacylockmgr::getExecuteReadLock();
auto& ddl_payload = extractPayload(ddl_data_);
std::string model_name;
std::string select_query;
if (ddl_payload.HasMember("modelName")) {
model_name = ddl_payload["modelName"].GetString();
}
if (ddl_payload.HasMember("query")) {
select_query = ddl_payload["query"].GetString();
}
std::regex newline_re("\\n");
std::regex backtick_re("`");
select_query = std::regex_replace(select_query, newline_re, " ");
select_query = std::regex_replace(select_query, backtick_re, "");
std::ostringstream r2_query_oss;
r2_query_oss << "SELECT * FROM TABLE(r2_score(model_name => '" << model_name << "', "
<< "data => CURSOR(" << select_query << ")))";
std::string r2_query = r2_query_oss.str();

Parser::LocalQueryConnector local_connector;
auto query_state = query_state::QueryState::create(session_ptr_, select_query);
auto result =
local_connector.query(query_state->createQueryStateProxy(), r2_query, {}, false);
std::vector<std::string> labels{"r2"};
std::vector<TargetMetaInfo> label_infos;
for (const auto& label : labels) {
label_infos.emplace_back(label, SQLTypeInfo(kDOUBLE, true));
}
std::vector<RelLogicalValues::RowValues> logical_values;
logical_values.emplace_back(RelLogicalValues::RowValues{});

CHECK_EQ(result.size(), size_t(1));
CHECK_EQ(result[0].rs->rowCount(), size_t(1));
CHECK_EQ(result[0].rs->colCount(), size_t(1));

auto result_row = result[0].rs->getNextRow(true, true);

auto scalar_r = boost::get<ScalarTargetValue>(&result_row[0]);
auto p = boost::get<double>(scalar_r);

logical_values.back().emplace_back(genLiteralDouble(*p));

std::shared_ptr<ResultSet> rSet = std::shared_ptr<ResultSet>(
ResultSetLogicalValuesBuilder::create(label_infos, logical_values));

return ExecutionResult(rSet, label_infos);
}

ShowForeignServersCommand::ShowForeignServersCommand(
const DdlCommandData& ddl_data,
std::shared_ptr<Catalog_Namespace::SessionInfo const> session_ptr)
Expand Down
8 changes: 8 additions & 0 deletions Catalog/DdlCommandExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,14 @@ class ShowModelsCommand : public DdlCommand {
ExecutionResult execute(bool read_only_mode) override;
};

class EvaluateModelCommand : public DdlCommand {
public:
EvaluateModelCommand(const DdlCommandData& ddl_data,
std::shared_ptr<Catalog_Namespace::SessionInfo const> session_ptr);

ExecutionResult execute(bool read_only_mode) override;
};

class ShowDiskCacheUsageCommand : public DdlCommand {
public:
ShowDiskCacheUsageCommand(
Expand Down
6 changes: 3 additions & 3 deletions Parser/ParserWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ ExplainInfo::ExplainInfo(std::string query_string) {
}

const std::vector<std::string> ParserWrapper::ddl_cmd = {
"ARCHIVE", "ALTER", "COPY", "CREATE", "DROP", "DUMP", "GRANT",
"KILL", "OPTIMIZE", "REFRESH", "RENAME", "RESTORE", "REVOKE", "SHOW",
"TRUNCATE", "REASSIGN", "VALIDATE", "CLEAR", "PAUSE", "RESUME"};
"ARCHIVE", "ALTER", "COPY", "CREATE", "DROP", "DUMP", "EVALUATE",
"GRANT", "KILL", "OPTIMIZE", "REFRESH", "RENAME", "RESTORE", "REVOKE",
"SHOW", "TRUNCATE", "REASSIGN", "VALIDATE", "CLEAR", "PAUSE", "RESUME"};

const std::vector<std::string> ParserWrapper::update_dml_cmd = {"INSERT",
"DELETE",
Expand Down
5 changes: 5 additions & 0 deletions java/calcite/src/main/codegen/config.fmpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ data: {
"com.mapd.parser.extension.ddl.SqlCreateView"
"com.mapd.parser.extension.ddl.SqlCreateModel"
"com.mapd.parser.extension.ddl.SqlDropModel"
"com.mapd.parser.extension.ddl.SqlEvaluateModel"
"com.mapd.parser.extension.ddl.SqlCreateUserMapping"
"com.mapd.parser.extension.ddl.SqlCreateUser"
"com.mapd.parser.extension.ddl.SqlDropUserMapping"
Expand All @@ -62,6 +63,7 @@ data: {
"com.mapd.parser.extension.ddl.SqlShowFunctions"
"com.mapd.parser.extension.ddl.SqlShowRuntimeFunctions"
"com.mapd.parser.extension.ddl.SqlShowModels"
"com.mapd.parser.extension.ddl.SqlEvaluateModel"
"com.mapd.parser.extension.ddl.SqlAlterTable"
"com.mapd.parser.extension.ddl.SqlAlterServer"
"com.mapd.parser.extension.ddl.SqlAlterDatabase"
Expand Down Expand Up @@ -133,6 +135,7 @@ data: {
"EDIT"
"EDITOR"
"EFFECTIVE"
"EVALUATE"
"FUNCTIONS"
"MAPPING"
"MODEL"
Expand Down Expand Up @@ -515,6 +518,7 @@ data: {
"EDIT"
"EDITOR"
"EFFECTIVE"
"EVALUATE"
"FUNCTIONS"
"MAPPING"
"OPTIMIZE"
Expand Down Expand Up @@ -584,6 +588,7 @@ data: {
"SqlRestoreTable(span())"
"SqlTruncateTable(span())"
"SqlOptimizeTable(span())"
"SqlEvaluateModel(span())"
"SqlCopyTable(span())"
"SqlValidateSystem(span())"
"SqlAlterSystem(span())"
Expand Down
20 changes: 20 additions & 0 deletions java/calcite/src/main/codegen/includes/ddlParser.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,26 @@ SqlDdl SqlOptimizeTable(Span s) :
}
}

/*
* Evaluate a model using the following syntax:
*
* EVALUATE MODEL <modelName> ON <query>
*
*/
SqlDdl SqlEvaluateModel(Span s) :
{
final SqlIdentifier modelName;
final SqlNode query;
}
{
<EVALUATE>
<MODEL>
modelName = CompoundIdentifier()
<ON> query = OrderedQueryOrExpr(ExprContext.ACCEPT_QUERY)
{
return new SqlEvaluateModel(s.end(this), modelName.toString(), query);
}
}

/*
* Create a view using the following syntax:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.mapd.parser.extension.ddl;

import com.google.gson.annotations.Expose;
import com.mapd.parser.extension.ddl.heavydb.HeavyDBOptionsMap;

import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSpecialOperator;
import org.apache.calcite.sql.parser.SqlParserPos;

/**
* Class that encapsulates all information associated with a EVALUATE MODEL DDL command.
*/
public class SqlEvaluateModel extends SqlCustomDdl {
private static final SqlOperator OPERATOR =
new SqlSpecialOperator("EVALUATE_MODEL", SqlKind.OTHER_DDL);

@Expose
private String modelName;
@Expose
private String query;
@Expose
private HeavyDBOptionsMap options;

public SqlEvaluateModel(final SqlParserPos pos, final String modelName, SqlNode query) {
super(OPERATOR, pos);
this.modelName = modelName;
this.query = query.toString();
}
}

0 comments on commit 047af0e

Please sign in to comment.