Skip to content

Commit 6974de2

Browse files
authored
Generate workflow step for normal statement run (#2824)
* generate workflow step for normal statement run * clean up * build step image before run workflow test * fix is_query
1 parent a96cb79 commit 6974de2

File tree

5 files changed

+81
-1
lines changed

5 files changed

+81
-1
lines changed

.github/workflows/main.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ jobs:
8181
set -e
8282
bash scripts/test/prepare.sh
8383
source build/env/bin/activate
84+
# build sqlflow binaries under build/
85+
bash docker/dev/build.sh
8486
docker pull sqlflow/sqlflow:step
87+
docker build --cache-from sqlflow/sqlflow:step -t sqlflow/sqlflow:step --build-arg FIND_FASTED_MIRROR="false" -f docker/step/Dockerfile .
8588
bash scripts/test/workflow.sh
8689
# bash scripts/travis/upload_codecov.sh
8790
push-images:

go/cmd/sqlflowserver/e2e_workflow_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,9 @@ func TestEnd2EndFluidWorkflow(t *testing.T) {
362362
func CaseWorkflowTrainXgboost(t *testing.T) {
363363
a := assert.New(t)
364364

365-
sqlProgram := `SELECT * FROM iris.train
365+
sqlProgram := `SELECT * FROM iris.train LIMIT 100;
366+
367+
SELECT * FROM iris.train
366368
TO TRAIN xgboost.gbtree
367369
WITH objective="multi:softmax",num_class=3
368370
LABEL class

go/codegen/experimental/codegen.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ func generateStepCode(stmt ir.SQLFlowStmt, stepIndex int, session *pb.Session) (
2929
return XGBoostGenerateTrain(trainStmt, stepIndex, session)
3030
}
3131
return "", fmt.Errorf("not implemented estimator type %s", trainStmt.Estimator)
32+
case *ir.NormalStmt:
33+
stmt := stmt.(*ir.NormalStmt)
34+
return GenerateNormalStmtStep(string(*stmt), session, stepIndex)
3235
default:
3336
return "", fmt.Errorf("not implemented stmt execution type %v", stmt)
3437
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package experimental
15+
16+
import (
17+
"bytes"
18+
"text/template"
19+
20+
pb "sqlflow.org/sqlflow/go/proto"
21+
)
22+
23+
var normalStmtStepTmpl = `
24+
def step_entry_{{.StepIndex}}():
25+
import runtime
26+
import runtime.dbapi
27+
conn = runtime.dbapi.connect("{{.DataSource}}")
28+
stmt = """{{.Stmt}}"""
29+
if conn.is_query(stmt):
30+
rs = conn.query(stmt)
31+
# write rs to stdout using protobuf table writer
32+
else:
33+
success = conn.execute(stmt)
34+
if not success:
35+
raise Exception("execute statment error: " % stmt)
36+
`
37+
38+
var normalStmtStepTemplate = template.Must(template.New("NormalStmtStep").Parse(normalStmtStepTmpl))
39+
40+
type normalStmtFiller struct {
41+
StepIndex int
42+
DataSource string
43+
Stmt string
44+
}
45+
46+
// GenerateNormalStmtStep generate step Python code to run a normal SQL statement.
47+
func GenerateNormalStmtStep(stmt string, session *pb.Session, stepIndex int) (string, error) {
48+
filler := &normalStmtFiller{
49+
StepIndex: stepIndex,
50+
DataSource: session.DbConnStr,
51+
Stmt: stmt,
52+
}
53+
var program bytes.Buffer
54+
if err := normalStmtStepTemplate.Execute(&program, filler); err != nil {
55+
return "", err
56+
}
57+
return program.String(), nil
58+
}

python/runtime/dbapi/connection.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,20 @@ def query(self, statement):
143143
"""
144144
return self._get_result_set(statement)
145145

146+
def is_query(self, statement):
147+
"""Return true if the statement is a query SQL statement."""
148+
s = statement.strip()
149+
s = s.upper()
150+
151+
if s.startswith("SELECT") and s.find("INTO") == -1:
152+
return True
153+
if s.startswith("SHOW") and s.find("CREATE") >= 0 or s.find(
154+
"DATABASES") >= 0 or s.find("TABLES") >= 0:
155+
return True
156+
if s.startswith("DESC") or s.startswith("EXPLAIN"):
157+
return True
158+
return False
159+
146160
def execute(self, statement):
147161
"""Execute given statement and return True on success
148162

0 commit comments

Comments
 (0)