Skip to content

Commit ccba945

Browse files
authored
[AINode] Simplify the CREATE MODEL SQL for model training (#15840)
1 parent 719c504 commit ccba945

File tree

14 files changed

+16
-65
lines changed

14 files changed

+16
-65
lines changed

iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ dropSubscription
699699
// ---- Create Model
700700
createModel
701701
: CREATE MODEL modelName=identifier uriClause
702-
| CREATE MODEL modelType=identifier modelId=identifier (WITH HYPERPARAMETERS LR_BRACKET hparamPair (COMMA hparamPair)* RR_BRACKET)? (FROM MODEL existingModelId=identifier)? ON DATASET LR_BRACKET trainingData RR_BRACKET
702+
| CREATE MODEL modelId=identifier (WITH HYPERPARAMETERS LR_BRACKET hparamPair (COMMA hparamPair)* RR_BRACKET)? FROM MODEL existingModelId=identifier ON DATASET LR_BRACKET trainingData RR_BRACKET
703703
;
704704

705705
trainingData

iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2660,7 +2660,6 @@ public TSStatus createTraining(TCreateTrainingReq req) {
26602660

26612661
TTrainingReq trainingReq = new TTrainingReq();
26622662
trainingReq.setModelId(req.getModelId());
2663-
trainingReq.setModelType(req.getModelType());
26642663
if (req.isSetExistingModelId()) {
26652664
trainingReq.setExistingModelId(req.getExistingModelId());
26662665
}

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,11 +1359,7 @@ protected IConfigTask visitCreateTraining(CreateTraining node, MPPQueryContext c
13591359
context.setQueryType(QueryType.WRITE);
13601360

13611361
return new CreateTrainingTask(
1362-
node.getModelId(),
1363-
node.getModelType(),
1364-
node.getParameters(),
1365-
node.getExistingModelId(),
1366-
node.getTargetSql());
1362+
node.getModelId(), node.getParameters(), node.getExistingModelId(), node.getTargetSql());
13671363
}
13681364

13691365
@Override

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,6 @@ public IConfigTask visitCreateTraining(
808808
}
809809
return new CreateTrainingTask(
810810
createTrainingStatement.getModelId(),
811-
createTrainingStatement.getModelType(),
812811
createTrainingStatement.getParameters(),
813812
createTrainingStatement.getTargetTimeRanges(),
814813
createTrainingStatement.getExistingModelId(),

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3331,7 +3331,6 @@ public SettableFuture<ConfigTaskResult> showModels(final String modelName) {
33313331
@Override
33323332
public SettableFuture<ConfigTaskResult> createTraining(
33333333
String modelId,
3334-
String modelType,
33353334
boolean isTableModel,
33363335
Map<String, String> parameters,
33373336
List<List<Long>> timeRanges,
@@ -3341,7 +3340,7 @@ public SettableFuture<ConfigTaskResult> createTraining(
33413340
final SettableFuture<ConfigTaskResult> future = SettableFuture.create();
33423341
try (final ConfigNodeClient client =
33433342
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
3344-
final TCreateTrainingReq req = new TCreateTrainingReq(modelId, modelType, isTableModel);
3343+
final TCreateTrainingReq req = new TCreateTrainingReq(modelId, isTableModel, existingModelId);
33453344

33463345
if (isTableModel) {
33473346
TDataSchemaForTable dataSchemaForTable = new TDataSchemaForTable();
@@ -3354,7 +3353,6 @@ public SettableFuture<ConfigTaskResult> createTraining(
33543353
}
33553354
req.setParameters(parameters);
33563355
req.setTimeRanges(timeRanges);
3357-
req.setExistingModelId(existingModelId);
33583356
final TSStatus executionStatus = client.createTraining(req);
33593357
if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != executionStatus.getCode()) {
33603358
future.setException(new IoTDBException(executionStatus.message, executionStatus.code));

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,6 @@ SettableFuture<ConfigTaskResult> createModel(
424424

425425
SettableFuture<ConfigTaskResult> createTraining(
426426
String modelId,
427-
String modelType,
428427
boolean isTableModel,
429428
Map<String, String> parameters,
430429
List<List<Long>> timeRanges,

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
public class CreateTrainingTask implements IConfigTask {
3232

3333
private final String modelId;
34-
private final String modelType;
3534
private final boolean isTableModel;
3635
private final Map<String, String> parameters;
3736

@@ -45,13 +44,8 @@ public class CreateTrainingTask implements IConfigTask {
4544

4645
// For table model
4746
public CreateTrainingTask(
48-
String modelId,
49-
String modelType,
50-
Map<String, String> parameters,
51-
String existingModelId,
52-
String targetSql) {
47+
String modelId, Map<String, String> parameters, String existingModelId, String targetSql) {
5348
this.modelId = modelId;
54-
this.modelType = modelType;
5549
this.parameters = parameters;
5650
this.existingModelId = existingModelId;
5751
this.targetSql = targetSql;
@@ -61,13 +55,11 @@ public CreateTrainingTask(
6155
// For tree model
6256
public CreateTrainingTask(
6357
String modelId,
64-
String modelType,
6558
Map<String, String> parameters,
6659
List<List<Long>> timeRanges,
6760
String existingModelId,
6861
List<String> targetPaths) {
6962
this.modelId = modelId;
70-
this.modelType = modelType;
7163
this.parameters = parameters;
7264
this.timeRanges = timeRanges;
7365
this.existingModelId = existingModelId;
@@ -80,13 +72,6 @@ public CreateTrainingTask(
8072
public ListenableFuture<ConfigTaskResult> execute(IConfigTaskExecutor configTaskExecutor)
8173
throws InterruptedException {
8274
return configTaskExecutor.createTraining(
83-
modelId,
84-
modelType,
85-
isTableModel,
86-
parameters,
87-
timeRanges,
88-
existingModelId,
89-
targetSql,
90-
targetPaths);
75+
modelId, isTableModel, parameters, timeRanges, existingModelId, targetSql, targetPaths);
9176
}
9277
}

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,9 +1368,7 @@ public static void validateModelName(String modelName) {
13681368
public Statement visitCreateModel(IoTDBSqlParser.CreateModelContext ctx) {
13691369
if (ctx.modelName == null) {
13701370
String modelId = ctx.modelId.getText();
1371-
String modelType = ctx.modelType.getText();
1372-
CreateTrainingStatement createTrainingStatement =
1373-
new CreateTrainingStatement(modelId, modelType);
1371+
CreateTrainingStatement createTrainingStatement = new CreateTrainingStatement(modelId);
13741372
if (ctx.hparamPair() != null) {
13751373
Map<String, String> parameterList = new HashMap<>();
13761374
for (IoTDBSqlParser.HparamPairContext hparamPairContext : ctx.hparamPair()) {

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,14 @@
2626
public class CreateTraining extends Statement {
2727

2828
private final String modelId;
29-
private final String modelType;
3029
private final String targetSql;
3130

3231
private Map<String, String> parameters;
3332
private String existingModelId = null;
3433

35-
public CreateTraining(String modelId, String modelType, String targetSql) {
34+
public CreateTraining(String modelId, String targetSql) {
3635
super(null);
3736
this.modelId = modelId;
38-
this.modelType = modelType;
3937
this.targetSql = targetSql;
4038
}
4139

@@ -56,10 +54,6 @@ public String getModelId() {
5654
return modelId;
5755
}
5856

59-
public String getModelType() {
60-
return modelType;
61-
}
62-
6357
public Map<String, String> getParameters() {
6458
return parameters;
6559
}
@@ -79,7 +73,7 @@ public List<? extends Node> getChildren() {
7973

8074
@Override
8175
public int hashCode() {
82-
return Objects.hash(modelId, modelType, targetSql, existingModelId, parameters);
76+
return Objects.hash(modelId, targetSql, existingModelId, parameters);
8377
}
8478

8579
@Override
@@ -89,7 +83,6 @@ public boolean equals(Object obj) {
8983
}
9084
CreateTraining createTraining = (CreateTraining) obj;
9185
return modelId.equals(createTraining.modelId)
92-
&& modelType.equals(createTraining.modelType)
9386
&& Objects.equals(existingModelId, createTraining.existingModelId)
9487
&& Objects.equals(parameters, createTraining.parameters)
9588
&& Objects.equals(targetSql, createTraining.targetSql);
@@ -101,9 +94,6 @@ public String toString() {
10194
+ "modelId='"
10295
+ modelId
10396
+ '\''
104-
+ ", modelType='"
105-
+ modelType
106-
+ '\''
10797
+ ", parameters="
10898
+ parameters
10999
+ ", existingModelId='"

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3558,13 +3558,12 @@ public static void validateModelName(String modelName) {
35583558
public Node visitCreateModelStatement(RelationalSqlParser.CreateModelStatementContext ctx) {
35593559
String modelId = ctx.modelId.getText();
35603560
validateModelName(modelId);
3561-
String modelType = ctx.modelType.getText();
35623561

35633562
if (ctx.targetData == null) {
35643563
throw new SemanticException("Target data in sql should be set in CREATE MODEL");
35653564
}
35663565
String targetData = ((StringLiteral) visit(ctx.targetData)).getValue();
3567-
CreateTraining createTraining = new CreateTraining(modelId, modelType, targetData);
3566+
CreateTraining createTraining = new CreateTraining(modelId, targetData);
35683567
if (ctx.HYPERPARAMETERS() != null) {
35693568
Map<String, String> parameters = new HashMap<>();
35703569
for (RelationalSqlParser.HparamPairContext hparamPairContext : ctx.hparamPair()) {

0 commit comments

Comments
 (0)