Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sql/codegen_ant_xgboost.go
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ func xgCreatePredictionTable(pr *extendedSelect, r *antXGBoostFiller, db *DB) er
return nil
}

func genXG(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error {
func genAntXGBoost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error {
r, e := newAntXGBoostFiller(pr, ds, db)
if e != nil {
return e
Expand Down
98 changes: 98 additions & 0 deletions sql/codegen_xgboost.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2019 The SQLFlow Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

import (
"fmt"
"io"
"text/template"
)

type xgbTrainConfig struct {
NumBoostRound int `json:"num_boost_round,omitempty"`
Maximize bool `json:"maximize,omitempty"`
}

type xgbFiller struct {
IsTrain bool
TrainingDatasetSQL string
ValidationDatasetSQL string
TrainCfg *xgbTrainConfig
Features []*featureMeta
Label *featureMeta
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found our filler for tensorflow in codegen.go,

sqlflow/sql/codegen.go

Lines 67 to 69 in 9a0dc86

X []*featureMeta
FeatureColumnsCode map[string][]string
Y *featureMeta

Features named to X and Label to Y.
I would suggest such consistent naming.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a better way is reuse the filler in codegen, would optimize it in the next PR.

Save string
ParamsCfgJSON string
TrainCfgJSON string
*connectionConfig
}

func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (*xgbFiller, error) {
var err error
training, validation := trainingAndValidationDataset(pr, ds)
r := &xgbFiller{
IsTrain: pr.train,
TrainingDatasetSQL: training,
ValidationDatasetSQL: validation,
Save: pr.save,
}
// TODO(Yancey1989): fill the train_args and parameters by WITH statment
r.TrainCfgJSON = ""
r.ParamsCfgJSON = ""

if r.connectionConfig, err = newConnectionConfig(db); err != nil {
return nil, err
}

for _, columns := range pr.columns {
feaCols, colSpecs, err := resolveTrainColumns(&columns)
if err != nil {
return nil, err
}
if len(colSpecs) != 0 {
return nil, fmt.Errorf("newXGBoostFiller doesn't support DENSE/SPARSE")
}
for _, col := range feaCols {
fm := &featureMeta{
FeatureName: col.GetKey(),
Dtype: col.GetDtype(),
Delimiter: col.GetDelimiter(),
InputShape: col.GetInputShape(),
IsSparse: false,
}
r.Features = append(r.Features, fm)
}
}
r.Label = &featureMeta{
FeatureName: pr.label,
Dtype: "int32",
Delimiter: ",",
InputShape: "[1]",
IsSparse: false,
}

return r, nil
}

func genXGBoost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error {
r, e := newXGBFiller(pr, ds, fts, db)
if e != nil {
return e
}
if pr.train {
return xgbTrainTemplate.Execute(w, r)
}
return fmt.Errorf("xgboost prediction codegen has not been implemented")
}

var xgbTrainTemplate = template.Must(template.New("codegenXGBTrain").Parse(xgbTrainTemplateText))
25 changes: 25 additions & 0 deletions sql/codegen_xgboost_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2019 The SQLFlow Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

const testXGBoostTrainSelectIris = `
SELECT *
FROM iris.train
TRAIN xgb.multi.softprob
WITH
train.num_boost_round = 30
COLUMN sepal_length, sepal_width, petal_length, petal_width
LABEL class
INTO sqlflow_models.my_xgboost_model;
`
15 changes: 11 additions & 4 deletions sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,15 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
var program bytes.Buffer
if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) {
// TODO(sperlingxx): write a separate train pipeline for ant-xgboost to support remote mode
if e := genXG(&program, tr, ds, fts, db); e != nil {
return fmt.Errorf("genXG %v", e)
if e := genAntXGBoost(&program, tr, ds, fts, db); e != nil {
return fmt.Errorf("genAntXGBoost %v", e)
}
} else if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGB.`) {
// FIXME(Yancey1989): it's a temporary solution, just for the unit test, we perfer to distinguish
// xgboost and ant-xgboost with env SQLFLOW_WITH_ANTXGBOOST,
// issue: https://github.com/sql-machine-learning/sqlflow/issues/758
if e := genXGBoost(&program, tr, ds, fts, db); e != nil {
return fmt.Errorf("GenXGBoost %v", e)
}
} else {
if e := genTF(&program, tr, ds, fts, db); e != nil {
Expand Down Expand Up @@ -453,8 +460,8 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
var buf bytes.Buffer
if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) {
// TODO(sperlingxx): write a separate pred pipeline for ant-xgboost to support remote mode
if e := genXG(&buf, pr, nil, fts, db); e != nil {
return fmt.Errorf("genXG %v", e)
if e := genAntXGBoost(&buf, pr, nil, fts, db); e != nil {
return fmt.Errorf("genAntXGBoost %v", e)
}
} else {
if e := genTF(&buf, pr, nil, fts, db); e != nil {
Expand Down
10 changes: 10 additions & 0 deletions sql/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ func TestExecutorTrainAnalyzePredictAntXGBoost(t *testing.T) {
})
}

func TestExecutorTrainXGBoost(t *testing.T) {
a := assert.New(t)
modelDir := ""
a.NotPanics(func() {
stream := runExtendedSQL(testXGBoostTrainSelectIris, testDB, modelDir, nil)
a.True(goodStream(stream.ReadAll()))

})
}

func TestExecutorTrainAndPredictDNN(t *testing.T) {
a := assert.New(t)
modelDir := ""
Expand Down
82 changes: 82 additions & 0 deletions sql/template_xgboost.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2019 The SQLFlow Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

const xgbTrainTemplateText = `
import xgboost as xgb
from sqlflow_submitter.db import connect, db_generator

driver="{{.Driver}}"

{{if ne .Database ""}}
database="{{.Database}}"
{{else}}
database=""
{{end}}

session_cfg = {}
{{ range $k, $v := .Session }}
session_cfg["{{$k}}"] = "{{$v}}"
{{end}}

{{if ne .TrainCfgJSON ""}}
train_args = {{.TrainCfgJSON}}
{{else}}
train_args = {}
{{end}}

{{if ne .ParamsCfgJSON ""}}
params = {{.ParamsCfgJSON}}
{{else}}
params = {}
{{end}}

feature_column_names = [{{range .Features}}
"{{.FeatureName}}",
{{end}}]

{{/* Convert go side featureSpec to python dict for input_fn */}}
feature_specs = dict()
{{ range $value := .Features }}
feature_specs["{{$value.FeatureName}}"] = {
"feature_name": "{{$value.FeatureName}}",
"dtype": "{{$value.Dtype}}",
"delimiter": "{{$value.Delimiter}}",
"shape": {{$value.InputShape}},
"is_sparse": "{{$value.IsSparse}}" == "true"
}
{{end}}



conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}")

def xgb_dataset(fn, dataset_sql):
gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Label.FeatureName}}", feature_specs)
with open(fn, 'w') as f:
for item in gen():
features, label = item
row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
f.write("\t".join(row_data) + "\n")
# TODO(yancey1989): genearte group and weight text file if necessary
return xgb.DMatrix(fn)

dtrain = xgb_dataset('train.txt', "{{.TrainingDatasetSQL}}")
dtest = xgb_dataset('test.txt', "{{.ValidationDatasetSQL}}")

#TODO(Yancey1989): specify the eval metrics by WITH statement in SQL
train_args["evals"] = [(dtest, "auc")]
bst = xgb.train(params, dtrain, **train_args)
bst.save_model("{{.Save}}")
`