Skip to content

Commit ced5708

Browse files
typhoonzeroweiguoz
authored andcommitted
move xgboost objective to attr (#793)
1 parent afda622 commit ced5708

File tree

3 files changed

+20
-13
lines changed

3 files changed

+20
-13
lines changed

doc/xgboost_on_sqlflow_design.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ To explain the benefit of integrating XGBoost with SQLFlow, let us start with an
1010

1111
``` sql
1212
SELECT * FROM train_table
13-
TRAIN xgboost.multi.softmax
13+
TRAIN xgboost.gbtree
1414
WITH
15+
objective=multi:softmax,
1516
train.num_round=2,
1617
max_depth=2,
1718
eta=1
@@ -29,10 +30,9 @@ USING my_xgb_model;
2930

3031
The the above examples,
3132
- `my_xgb_model` names the trained model.
32-
- `xgboost.multi.softmax` is the model spec, where
33-
- the prefix `xgboost.` tells the model is a XGBoost one, but not a Tensorflow model, and
34-
- `multi.softmax` names an [XGBoost learning task](https://xgboost.readthedocs.io/en/latest/parameter.html#learning-task-parameters).
35-
- In the `WITH` clause,
33+
- `xgboost.gbtree` is the model name, to use a different model provided by XGBoost, use `xgboost.gblinear` or `xgboost.dart`, see: [here](https://xgboost.readthedocs.io/en/latest/parameter.html#general-parameters) for details.
34+
- In the `WITH` clause,
35+
- objective names an [XGBoost learning task](https://xgboost.readthedocs.io/en/latest/parameter.html#learning-task-parameters)
3636
- keys with the prefix `train.` identifies parameters of XGBoost API [`xgboost.train`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.train), and
3737
- keys without any prefix identifies [XGBoost Parameters](https://xgboost.readthedocs.io/en/latest/parameter.html) except the `objective` parameter, which was specified by the identifier after the keyword `TRAIN`, as explained above.
3838

sql/codegen_xgboost.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,15 @@ func resolveParamsCfg(attrs map[string]*attribute) (map[string]interface{}, erro
7676
return params, nil
7777
}
7878

79-
func resolveObjective(pr *extendedSelect) (string, error) {
79+
func resolveModelName(pr *extendedSelect) (string, error) {
8080
estimatorParts := strings.Split(pr.estimator, ".")
81-
if len(estimatorParts) != 3 {
82-
return "", fmt.Errorf("XGBoost Estimator should be xgboost.first_part.second_part, current: %s", pr.estimator)
81+
if len(estimatorParts) != 2 {
82+
return "", fmt.Errorf("XGBoost Estimator should be xgboost.modelname, current: %s", pr.estimator)
8383
}
84-
return strings.Join(estimatorParts[1:], ":"), nil
84+
if strings.ToUpper(estimatorParts[1]) != "GBTREE" {
85+
return "", fmt.Errorf("model name %s is not supported yet", estimatorParts[1])
86+
}
87+
return estimatorParts[1], nil
8588
}
8689

8790
func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFiller, error) {
@@ -109,18 +112,21 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFille
109112
}
110113

111114
if isTrain {
115+
objective := getStringAttr(attrs, "objective", "gbtree")
112116
// resolve the attribute keys without any prefix as the XGBoost Paremeters
113117
params, err := resolveParamsCfg(attrs)
114118
if err != nil {
115119
return nil, err
116120
}
121+
params["objective"] = objective
117122

118-
// fill learning target
119-
objective, err := resolveObjective(pr)
123+
// get model name, could be gbtree, gblinear or dart.
124+
// TODO(typhoonzero): only gbtree is supported here, use model name to generate
125+
// differnet training code.
126+
_, err = resolveModelName(pr)
120127
if err != nil {
121128
return nil, err
122129
}
123-
params["objective"] = objective
124130

125131
paramsJSON, err := json.Marshal(params)
126132
if err != nil {

sql/codegen_xgboost_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ import (
2323
const testXGBoostTrainSelectIris = `
2424
SELECT *
2525
FROM iris.train
26-
TRAIN xgboost.multi.softprob
26+
TRAIN xgboost.gbtree
2727
WITH
28+
objective="multi:softprob",
2829
train.num_boost_round = 30,
2930
eta = 3.1,
3031
num_class = 3

0 commit comments

Comments
 (0)