Skip to content

Commit 907cc34

Browse files
committed
update
1 parent 99e61d9 commit 907cc34

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

go/codegen/experimental/codegen.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ func parseToIR(sqlProgram string, session *pb.Session) ([]ir.SQLFlowStmt, error)
6565
for _, sql := range sqls {
6666
if sql.IsExtendedSyntax() {
6767
if sql.Train {
68-
r, err = ir.GenerateTrainStmt(sql.SQLFlowSelectStmt)
68+
// TODO(typhoonzero): use feature derivation at runtime, call GenerateTrainStmt only.
69+
r, err = ir.GenerateTrainStmtWithInferredColumns(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false, false)
6970
} else if sql.ShowTrain {
7071
r, err = ir.GenerateShowTrainStmt(sql.SQLFlowSelectStmt)
7172
} else if sql.Explain {

go/codegen/experimental/codegen_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ import (
2121
)
2222

2323
func TestXGBCodegen(t *testing.T) {
24-
t.Skip("enable this test when finished generate code using runtime.feature.*")
25-
26-
sql := "SELECT * FROM iris TO TRAIN xgboost.gbtree WITH objective=\"binary:logistic\",num_class=3 LABEL class INTO sqlflow_models.xgb_classification;"
24+
sql := "SELECT * FROM iris.train TO TRAIN xgboost.gbtree WITH objective=\"binary:logistic\",num_class=3 LABEL class INTO sqlflow_models.xgb_classification;"
2725
s := &pb.Session{DbConnStr: database.GetTestingMySQLURL()}
2826
_, err := GenerateCodeCouler(sql, s)
2927
if err != nil {

go/codegen/experimental/xgboost.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, session *pb.Session) (string,
6868
delete(params["train."], "num_workers")
6969
}
7070

71+
// TODO(typhoonzero): use feature derivation at runtime.
7172
if len(trainStmt.Features) != 1 {
72-
return "", fmt.Errorf("xgboost only support 1 feature column set, received %d", len(trainStmt.Features))
73+
return "", fmt.Errorf("xgboost only support 0 or 1 feature column set, received %d", len(trainStmt.Features))
7374
}
7475

7576
featureColumnCode, featureFieldDesc, labelFieldDesc, err := deriveFeatureColumnCodeAndFieldDescs(trainStmt.Features["feature_columns"], trainStmt.Label)
@@ -128,16 +129,19 @@ def step_entry():
128129
feature_metas = json.loads('''{{.FieldDescJSON}}''')
129130
label_meta = json.loads('''{{.LabelJSON}}''')
130131
131-
feature_column_names = [{{range .FeatureColumnNames}}
132-
"{{.}}",
133-
{{end}}]
134-
135132
ds = "{{.DataSource}}"
136133
is_pai = False
137134
pai_train_table = ""
138135
select = "{{.Select}}"
139136
val_select = "{{.ValidationSelect}}"
140137
138+
# Derive feature columns at runtime like:
139+
# fcmap, fc_label = infer_feature_columns(conn, select, features, label, n=1000)
140+
141+
feature_column_names = [{{range .FeatureColumnNames}}
142+
"{{.}}",
143+
{{end}}]
144+
141145
# NOTE: in the current implementation, we are generating a transform_fn from COLUMN clause.
142146
# The transform_fn is executed during the process of dumping the original data into DMatrix SVM file.
143147
feature_column_list = [{{.FeatureColumnCode}}]

0 commit comments

Comments
 (0)