Skip to content

Commit 388172a

Browse files
authored
Generate workflow using runtime (#2784)
* WIP generate workflow using runtime * wip update * update * update * fix hive ci
1 parent a28f009 commit 388172a

File tree

4 files changed

+570
-0
lines changed

4 files changed

+570
-0
lines changed

go/codegen/experimental/codegen.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package experimental
15+
16+
import (
17+
"fmt"
18+
"strings"
19+
20+
"sqlflow.org/sqlflow/go/ir"
21+
"sqlflow.org/sqlflow/go/parser"
22+
pb "sqlflow.org/sqlflow/go/proto"
23+
)
24+
25+
// GenerateCodeCouler generate a Couler program to submit a workflow to run the sql program.
26+
// 1. generate IR of each statement.
27+
// 2. generate runtime code of each statement
28+
// 3. generate couler program to form a workflow
29+
func GenerateCodeCouler(sqlProgram string, session *pb.Session) (string, error) {
30+
stmts, err := parseToIR(sqlProgram, session)
31+
if err != nil {
32+
return "", err
33+
}
34+
for _, stmt := range stmts {
35+
stepCode, err := generateStepCode(stmt, session)
36+
if err != nil {
37+
return "", err
38+
}
39+
fmt.Println(stepCode)
40+
}
41+
return "", nil
42+
}
43+
44+
func parseToIR(sqlProgram string, session *pb.Session) ([]ir.SQLFlowStmt, error) {
45+
var dbDriver string
46+
var r ir.SQLFlowStmt
47+
var result []ir.SQLFlowStmt
48+
49+
sqlProgram, err := parser.RemoveCommentInSQLStatement(sqlProgram)
50+
if err != nil {
51+
return nil, err
52+
}
53+
54+
dbDriverParts := strings.Split(session.DbConnStr, "://")
55+
if len(dbDriverParts) != 2 {
56+
return nil, fmt.Errorf("invalid database connection string %s", session.DbConnStr)
57+
}
58+
dbDriver = dbDriverParts[0]
59+
60+
stmts, err := parser.Parse(dbDriver, sqlProgram)
61+
if err != nil {
62+
return nil, err
63+
}
64+
sqls := rewriteStatementsWithHints(stmts, dbDriver)
65+
for _, sql := range sqls {
66+
if sql.IsExtendedSyntax() {
67+
if sql.Train {
68+
// TODO(typhoonzero): use feature derivation at runtime, call GenerateTrainStmt only.
69+
r, err = ir.GenerateTrainStmtWithInferredColumns(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false, false)
70+
} else if sql.ShowTrain {
71+
r, err = ir.GenerateShowTrainStmt(sql.SQLFlowSelectStmt)
72+
} else if sql.Explain {
73+
r, err = ir.GenerateExplainStmt(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false)
74+
} else if sql.Predict {
75+
r, err = ir.GeneratePredictStmt(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false)
76+
} else if sql.Evaluate {
77+
r, err = ir.GenerateEvaluateStmt(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false)
78+
} else if sql.Optimize {
79+
r, err = ir.GenerateOptimizeStmt(sql.SQLFlowSelectStmt)
80+
} else if sql.Run {
81+
r, err = ir.GenerateRunStmt(sql.SQLFlowSelectStmt)
82+
}
83+
} else {
84+
standardSQL := ir.NormalStmt(sql.Original)
85+
r = &standardSQL
86+
}
87+
if err != nil {
88+
return nil, err
89+
}
90+
if err = initializeAndCheckAttributes(r); err != nil {
91+
return nil, err
92+
}
93+
r.SetOriginalSQL(sql.Original)
94+
result = append(result, r)
95+
}
96+
return result, nil
97+
}
98+
99+
func generateStepCode(stmt ir.SQLFlowStmt, session *pb.Session) (string, error) {
100+
switch stmt.(type) {
101+
case *ir.TrainStmt:
102+
trainStmt := stmt.(*ir.TrainStmt)
103+
if strings.HasPrefix(strings.ToUpper(trainStmt.Estimator), "XGBOOST.") {
104+
return XGBoostGenerateTrain(trainStmt, session)
105+
}
106+
return "", fmt.Errorf("not implemented estimator type %s", trainStmt.Estimator)
107+
default:
108+
return "", fmt.Errorf("not implemented stmt execution type %v", stmt)
109+
}
110+
}
111+
112+
func initializeAndCheckAttributes(stmt ir.SQLFlowStmt) error {
113+
switch s := stmt.(type) {
114+
case *ir.TrainStmt:
115+
if s.GetModelKind() == ir.XGBoost {
116+
return InitializeAttributes(s)
117+
}
118+
// TODO(typhoonzero): add below lines
119+
// else if s.GetModelKind() == ir.KMeans {
120+
// return pai.InitializeKMeansAttributes(s)
121+
// }
122+
// return tensorflow.InitializeAttributes(s)
123+
// case *ir.OptimizeStmt:
124+
// return optimize.InitializeAttributes(s)
125+
}
126+
return nil
127+
}
128+
129+
// InitializeAttributes initializes the attributes of XGBoost and does type checking for them
130+
func InitializeAttributes(trainStmt *ir.TrainStmt) error {
131+
attributeDictionary.ExportDefaults(trainStmt.Attributes)
132+
return fullAttrValidator.Validate(trainStmt.Attributes)
133+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package experimental
15+
16+
import (
17+
"os"
18+
"testing"
19+
20+
"sqlflow.org/sqlflow/go/database"
21+
pb "sqlflow.org/sqlflow/go/proto"
22+
)
23+
24+
func TestExperimentalXGBCodegen(t *testing.T) {
25+
if os.Getenv("SQLFLOW_TEST_DB") != "mysql" {
26+
t.Skipf("skip TestExperimentalXGBCodegen of DB type %s", os.Getenv("SQLFLOW_TEST_DB"))
27+
}
28+
sql := "SELECT * FROM iris.train TO TRAIN xgboost.gbtree WITH objective=\"binary:logistic\",num_class=3 LABEL class INTO sqlflow_models.xgb_classification;"
29+
s := &pb.Session{DbConnStr: database.GetTestingMySQLURL()}
30+
_, err := GenerateCodeCouler(sql, s)
31+
if err != nil {
32+
t.Errorf("error %s", err)
33+
}
34+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package experimental
15+
16+
import (
17+
"strings"
18+
19+
"sqlflow.org/sqlflow/go/parser"
20+
)
21+
22+
// RewriteStatementsWithHints combines the hints into the standard SQL(s)
23+
//
24+
// FIXME(weiguoz): I'm not happy with such an implementation.
25+
// I mean it is not clean that sqlflow handles such database relative details.
26+
func rewriteStatementsWithHints(stmts []*parser.SQLFlowStmt, dialect string) []*parser.SQLFlowStmt {
27+
hints, sqls := splitHints(stmts, dialect)
28+
if len(hints) > 0 {
29+
for _, sql := range sqls {
30+
if !sql.IsExtendedSyntax() {
31+
sql.Original = hints + sql.Original
32+
}
33+
}
34+
}
35+
return sqls
36+
}
37+
38+
func splitHints(stmts []*parser.SQLFlowStmt, dialect string) (string, []*parser.SQLFlowStmt) {
39+
hints, sqls := "", []*parser.SQLFlowStmt{}
40+
for _, stmt := range stmts {
41+
if isHint(stmt, dialect) {
42+
hints += stmt.Original + "\n" // alisa's requirements
43+
} else {
44+
sqls = append(sqls, stmt)
45+
}
46+
}
47+
return hints, sqls
48+
}
49+
50+
func isHint(stmt *parser.SQLFlowStmt, dialect string) bool {
51+
if !stmt.IsExtendedSyntax() {
52+
if dialect == "alisa" {
53+
return isAlisaHint(stmt.Original)
54+
}
55+
// TODO(weiguoz) handle if submitter is "maxcompute" or "hive"
56+
}
57+
return false
58+
}
59+
60+
func isAlisaHint(sql string) bool {
61+
for {
62+
sql = strings.TrimSpace(sql)
63+
// TODO(weiguoz): Let's remove the following code if we clean the comments before
64+
if strings.HasPrefix(sql, "--") {
65+
eol := strings.IndexAny(sql, "\n\r")
66+
if eol != -1 {
67+
sql = sql[eol+1:]
68+
} else {
69+
break
70+
}
71+
} else {
72+
break
73+
}
74+
}
75+
return strings.HasPrefix(strings.ToLower(sql), "set ")
76+
}

0 commit comments

Comments
 (0)